enable lambda function for grid descriptor

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-05 15:59:22 +08:00
parent 8b8aa6cb84
commit f6c7a48c1b
6 changed files with 104 additions and 57 deletions

View File

@@ -1,3 +1,4 @@
#include "tvm/ffi/function.h"
#include <ATen/DLConvertor.h> #include <ATen/DLConvertor.h>
#include <ATen/dlpack.h> #include <ATen/dlpack.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h> #include <tvm/ffi/extra/cuda/cubin_launcher.h>
@@ -15,10 +16,14 @@
tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
at::Tensor xtorch = at::fromDLPack(x.ToDLPack()); at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
at::Tensor otorch = at::empty_like(xtorch); at::Tensor otorch = at::empty_like(xtorch);
int64_t numel = otorch.numel(); int32_t numel = otorch.numel();
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch)); tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{(numel + 1023) / 1024, 1, 1}; tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
// TODO: check the performance loss after enabling `Optional` [numel](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
const int32_t BLOCK_SIZE = meta["BLOCK_SIZE"].cast<int32_t>();
return tvm::ffi::Tuple((numel + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1);
});
tvm::ffi::Optional<int32_t> numWarps = std::nullopt, numStages = std::nullopt; tvm::ffi::Optional<int32_t> numWarps = std::nullopt, numStages = std::nullopt;
DLDevice device = x.device(); DLDevice device = x.device();
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);

View File

@@ -33,8 +33,9 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements: int = output.numel() n_elements: int = output.numel()
BLOCK_SIZE: int = 1024 BLOCK_SIZE: int = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1) add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1, 1)](
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE) x, y, output, n_elements, BLOCK_SIZE
)
return output return output

View File

@@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
from functools import cached_property from functools import cached_property
from typing import Any, Dict, Final, List, Optional, Tuple import inspect
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union
import torch import torch
from triton.compiler import CompiledKernel from triton.compiler import CompiledKernel
@@ -15,13 +16,16 @@ class TVMFFIJITFunction(object):
def __init__(self, fn: JITFunction, *args, **kwargs) -> None: def __init__(self, fn: JITFunction, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.fn: Final[JITFunction] = fn self.fn: Final[JITFunction] = fn
self.signature: Optional[List[str]] = None
self.ctypes: Optional[List[Optional[str]]] = None self.ctypes: Optional[List[Optional[str]]] = None
self.kernel: Optional[bytes] = None self.kernel: Optional[bytes] = None
self.num_warps: Optional[int] = None self.num_warps: Optional[int] = None
@tvm_ffi.register_global_func(self.fullname) @tvm_ffi.register_global_func(self.fullname)
def _( def _(
grid: Tuple[int, int, int], grid: Union[
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
],
_: int, _: int,
num_warps: Optional[int], num_warps: Optional[int],
num_stages: Optional[int], num_stages: Optional[int],
@@ -38,11 +42,17 @@ class TVMFFIJITFunction(object):
kwargs["num_stages"] = num_stages kwargs["num_stages"] = num_stages
kernel: CompiledKernel = self.fn[grid](*args, **kwargs) kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
self.num_warps, _, _ = kernel.packed_metadata self.num_warps, _, _ = kernel.packed_metadata
self.signature = [*inspect.signature(self.fn.fn).parameters.keys()]
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()] self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
self.kernel = kernel.kernel self.kernel = kernel.kernel
return kernel return kernel
def __getitem__(self, grid: Tuple[int, int, int]): def __getitem__(
self,
grid: Union[
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
],
):
return self.fn[grid] return self.fn[grid]
@property @property

View File

@@ -0,0 +1,42 @@
#include <cuda.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
{% include "grid.h" %}
#define {{ name | upper }}_NAME "{{ uniquename }}"
{% for fn in fns %}
{% if fn.ctypes is none %}
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
{% else %}
TVM_FFI_EMBED_CUBIN(triton_{{ fn.fnname }});
#define {{ fn.fnname | upper}}_STUB(__grid, __stream, __numWarps, __numStages{% for ctype in fn.ctypes %}, {{ "__arg" ~ loop.index0 }}{% endfor %}) do { \
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
{% for name in fn.signature %}
{ "{{ name }}", __arg{{ loop.index0 }} }, \
{% endfor %}
}; \
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{{ fn.fnname }}, "{{ fn.fnname }}"); \
tvm::ffi::dim3 __gridDim = MakeGridDim(__grid, __meta); \
tvm::ffi::dim3 __block({% if fn.num_warps != none %}{{ fn.num_warps }}{% else %}__numWarps{% endif %} * 32, 1, 1); \
void *dummy = nullptr
{%- for ctype in fn.ctypes -%}
{%- if ctype == "CUdeviceptr" -%}
, *__arg{{ loop.index0 }}_ptr=__arg{{ loop.index0 }}.data_ptr()
{%- endif -%}
{%- endfor -%}; \
void *__params[] = {
{%- for ctype in fn.ctypes -%}
{%- if ctype != none -%}
&__arg{{ loop.index0 }}
{%- if ctype == "CUdeviceptr" -%}
_ptr
{%- endif -%},
{%- endif -%}
{%- endfor -%}&dummy, &dummy }; \
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __gridDim, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \
} while (false)
{% endif %}
{% endfor %}
{{ code }}

View File

@@ -0,0 +1,29 @@
#ifndef TRITON_TVM_FFI_GRID_H
#define TRITON_TVM_FFI_GRID_H
#include <cstdint>
#include <tvm/ffi/extra/cuda/base.h>
#include <tvm/ffi/tvm_ffi.h>
template <typename T>
inline tvm::ffi::dim3
MakeGridDim(const T &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta);
template <>
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
const tvm::ffi::Tuple<int32_t, int32_t, int32_t> &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &) {
return tvm::ffi::dim3(grid.get<0>(), grid.get<1>(), grid.get<2>());
}
template <>
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Function>(
const tvm::ffi::Function &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta) {
tvm::ffi::Tuple<int32_t, int32_t, int32_t> tuple =
grid(meta).cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
return MakeGridDim(tuple, meta);
}
#endif

View File

@@ -1,8 +1,9 @@
from functools import cached_property from functools import cached_property
from io import TextIOWrapper from io import TextIOWrapper
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, Final, List, Optional, Sequence, Union
import jinja2
import torch.utils.cpp_extension import torch.utils.cpp_extension
import tvm_ffi import tvm_ffi
@@ -38,6 +39,11 @@ class TVMFFIWrapperFunction(object):
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = ( self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
extra_include_paths extra_include_paths
) )
self.env: Final[jinja2.Environment] = jinja2.Environment(
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
trim_blocks=True,
)
self.tpl: Final[jinja2.Template] = self.env.get_template("gendef.cc.j2")
def __call__(self, *args, **kwargs) -> None: def __call__(self, *args, **kwargs) -> None:
func: tvm_ffi.Function = self.compile() func: tvm_ffi.Function = self.compile()
@@ -53,19 +59,9 @@ class TVMFFIWrapperFunction(object):
@property @property
def emit(self) -> str: def emit(self) -> str:
defs: str = "\n".join( return self.tpl.render(
[ code=self.code, fns=self.fns, name=self.name, uniquename=self.uniquename
"#include <cuda.h>",
"#include <tvm/ffi/extra/cuda/cubin_launcher.h>",
"#include <tvm/ffi/function.h>",
f'#define {self.name.upper()}_NAME "{self.uniquename}"',
*map(
self.gendef,
self.fns,
),
]
) )
return f"{defs}\n{self.code}"
@property @property
def uniquename(self) -> str: def uniquename(self) -> str:
@@ -90,42 +86,6 @@ class TVMFFIWrapperFunction(object):
) )
return tvm_ffi.get_global_func(self.uniquename) return tvm_ffi.get_global_func(self.uniquename)
@staticmethod
def gendef(fn: TVMFFIJITFunction) -> str:
if fn.ctypes is None:
return f'#define {fn.fnname.upper()}_STUB tvm::ffi::Function::GetGlobalRequired("{fn.fullname}")'
else:
ctype_arg_list: List[Tuple[str, str]] = [
(ctype, f"__arg{idx}") for idx, ctype in enumerate(fn.ctypes)
]
return """
TVM_FFI_EMBED_CUBIN(triton_{fnname});
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
tvm::ffi::dim3 __block({} * 32, 1, 1); \\
void *dummy = nullptr, {}; \\
void *__params[] = {{{}, &dummy, &dummy}}; \\
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
}} while (false)
""".format(
fn.fnname.upper(),
", ".join(arg for _, arg in ctype_arg_list),
fn.num_warps if fn.num_warps is not None else "__numWarps",
", ".join(
f"*{arg}_ptr = {arg}.data_ptr()"
for ctype, arg in ctype_arg_list
if ctype == "CUdeviceptr"
),
", ".join(
f"&{arg}" if ctype != "CUdeviceptr" else f"&{arg}_ptr"
for ctype, arg in ctype_arg_list
if ctype is not None
),
fnname=fn.fnname,
).strip()
def wrap( def wrap(
fns: List[TVMFFIJITFunction], fns: List[TVMFFIJITFunction],