mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
enable lambda function for grid descriptor
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from functools import cached_property
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Final, List, Optional, Sequence, Union
|
||||
|
||||
import jinja2
|
||||
import torch.utils.cpp_extension
|
||||
import tvm_ffi
|
||||
|
||||
@@ -38,6 +39,11 @@ class TVMFFIWrapperFunction(object):
|
||||
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
|
||||
extra_include_paths
|
||||
)
|
||||
self.env: Final[jinja2.Environment] = jinja2.Environment(
|
||||
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
|
||||
trim_blocks=True,
|
||||
)
|
||||
self.tpl: Final[jinja2.Template] = self.env.get_template("gendef.cc.j2")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> None:
|
||||
func: tvm_ffi.Function = self.compile()
|
||||
@@ -53,19 +59,9 @@ class TVMFFIWrapperFunction(object):
|
||||
|
||||
@property
|
||||
def emit(self) -> str:
|
||||
defs: str = "\n".join(
|
||||
[
|
||||
"#include <cuda.h>",
|
||||
"#include <tvm/ffi/extra/cuda/cubin_launcher.h>",
|
||||
"#include <tvm/ffi/function.h>",
|
||||
f'#define {self.name.upper()}_NAME "{self.uniquename}"',
|
||||
*map(
|
||||
self.gendef,
|
||||
self.fns,
|
||||
),
|
||||
]
|
||||
return self.tpl.render(
|
||||
code=self.code, fns=self.fns, name=self.name, uniquename=self.uniquename
|
||||
)
|
||||
return f"{defs}\n{self.code}"
|
||||
|
||||
@property
|
||||
def uniquename(self) -> str:
|
||||
@@ -90,42 +86,6 @@ class TVMFFIWrapperFunction(object):
|
||||
)
|
||||
return tvm_ffi.get_global_func(self.uniquename)
|
||||
|
||||
@staticmethod
|
||||
def gendef(fn: TVMFFIJITFunction) -> str:
|
||||
if fn.ctypes is None:
|
||||
return f'#define {fn.fnname.upper()}_STUB tvm::ffi::Function::GetGlobalRequired("{fn.fullname}")'
|
||||
else:
|
||||
ctype_arg_list: List[Tuple[str, str]] = [
|
||||
(ctype, f"__arg{idx}") for idx, ctype in enumerate(fn.ctypes)
|
||||
]
|
||||
|
||||
return """
|
||||
TVM_FFI_EMBED_CUBIN(triton_{fnname});
|
||||
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
|
||||
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
|
||||
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
|
||||
tvm::ffi::dim3 __block({} * 32, 1, 1); \\
|
||||
void *dummy = nullptr, {}; \\
|
||||
void *__params[] = {{{}, &dummy, &dummy}}; \\
|
||||
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
|
||||
}} while (false)
|
||||
""".format(
|
||||
fn.fnname.upper(),
|
||||
", ".join(arg for _, arg in ctype_arg_list),
|
||||
fn.num_warps if fn.num_warps is not None else "__numWarps",
|
||||
", ".join(
|
||||
f"*{arg}_ptr = {arg}.data_ptr()"
|
||||
for ctype, arg in ctype_arg_list
|
||||
if ctype == "CUdeviceptr"
|
||||
),
|
||||
", ".join(
|
||||
f"&{arg}" if ctype != "CUdeviceptr" else f"&{arg}_ptr"
|
||||
for ctype, arg in ctype_arg_list
|
||||
if ctype is not None
|
||||
),
|
||||
fnname=fn.fnname,
|
||||
).strip()
|
||||
|
||||
|
||||
def wrap(
|
||||
fns: List[TVMFFIJITFunction],
|
||||
|
||||
Reference in New Issue
Block a user