enable lambda function for grid descriptor

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-05 15:59:22 +08:00
parent 8b8aa6cb84
commit f6c7a48c1b
6 changed files with 104 additions and 57 deletions

View File

@@ -1,8 +1,9 @@
from functools import cached_property
from io import TextIOWrapper
from pathlib import Path
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Final, List, Optional, Sequence, Union
import jinja2
import torch.utils.cpp_extension
import tvm_ffi
@@ -38,6 +39,11 @@ class TVMFFIWrapperFunction(object):
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
extra_include_paths
)
self.env: Final[jinja2.Environment] = jinja2.Environment(
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
trim_blocks=True,
)
self.tpl: Final[jinja2.Template] = self.env.get_template("gendef.cc.j2")
def __call__(self, *args, **kwargs) -> None:
func: tvm_ffi.Function = self.compile()
@@ -53,19 +59,9 @@ class TVMFFIWrapperFunction(object):
@property
def emit(self) -> str:
defs: str = "\n".join(
[
"#include <cuda.h>",
"#include <tvm/ffi/extra/cuda/cubin_launcher.h>",
"#include <tvm/ffi/function.h>",
f'#define {self.name.upper()}_NAME "{self.uniquename}"',
*map(
self.gendef,
self.fns,
),
]
return self.tpl.render(
code=self.code, fns=self.fns, name=self.name, uniquename=self.uniquename
)
return f"{defs}\n{self.code}"
@property
def uniquename(self) -> str:
@@ -90,42 +86,6 @@ class TVMFFIWrapperFunction(object):
)
return tvm_ffi.get_global_func(self.uniquename)
@staticmethod
def gendef(fn: TVMFFIJITFunction) -> str:
if fn.ctypes is None:
return f'#define {fn.fnname.upper()}_STUB tvm::ffi::Function::GetGlobalRequired("{fn.fullname}")'
else:
ctype_arg_list: List[Tuple[str, str]] = [
(ctype, f"__arg{idx}") for idx, ctype in enumerate(fn.ctypes)
]
return """
TVM_FFI_EMBED_CUBIN(triton_{fnname});
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
tvm::ffi::dim3 __block({} * 32, 1, 1); \\
void *dummy = nullptr, {}; \\
void *__params[] = {{{}, &dummy, &dummy}}; \\
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
}} while (false)
""".format(
fn.fnname.upper(),
", ".join(arg for _, arg in ctype_arg_list),
fn.num_warps if fn.num_warps is not None else "__numWarps",
", ".join(
f"*{arg}_ptr = {arg}.data_ptr()"
for ctype, arg in ctype_arg_list
if ctype == "CUdeviceptr"
),
", ".join(
f"&{arg}" if ctype != "CUdeviceptr" else f"&{arg}_ptr"
for ctype, arg in ctype_arg_list
if ctype is not None
),
fnname=fn.fnname,
).strip()
def wrap(
fns: List[TVMFFIJITFunction],