enable lambda function for grid descriptor

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-05 15:59:22 +08:00
parent 8b8aa6cb84
commit f6c7a48c1b
6 changed files with 104 additions and 57 deletions

View File

@@ -1,3 +1,4 @@
#include "tvm/ffi/function.h"
#include <ATen/DLConvertor.h>
#include <ATen/dlpack.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
@@ -15,10 +16,14 @@
tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
at::Tensor otorch = at::empty_like(xtorch);
int64_t numel = otorch.numel();
int32_t numel = otorch.numel();
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{(numel + 1023) / 1024, 1, 1};
// TODO: check the performance loss after enabling `Optional`
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
[numel](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
const int32_t BLOCK_SIZE = meta["BLOCK_SIZE"].cast<int32_t>();
return tvm::ffi::Tuple((numel + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1);
});
tvm::ffi::Optional<int32_t> numWarps = std::nullopt, numStages = std::nullopt;
DLDevice device = x.device();
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);

View File

@@ -33,8 +33,9 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements: int = output.numel()
BLOCK_SIZE: int = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE)
add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1, 1)](
x, y, output, n_elements, BLOCK_SIZE
)
return output

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
from functools import cached_property
from typing import Any, Dict, Final, List, Optional, Tuple
import inspect
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union
import torch
from triton.compiler import CompiledKernel
@@ -15,13 +16,16 @@ class TVMFFIJITFunction(object):
def __init__(self, fn: JITFunction, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fn: Final[JITFunction] = fn
self.signature: Optional[List[str]] = None
self.ctypes: Optional[List[Optional[str]]] = None
self.kernel: Optional[bytes] = None
self.num_warps: Optional[int] = None
@tvm_ffi.register_global_func(self.fullname)
def _(
grid: Tuple[int, int, int],
grid: Union[
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
],
_: int,
num_warps: Optional[int],
num_stages: Optional[int],
@@ -38,11 +42,17 @@ class TVMFFIJITFunction(object):
kwargs["num_stages"] = num_stages
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
self.num_warps, _, _ = kernel.packed_metadata
self.signature = [*inspect.signature(self.fn.fn).parameters.keys()]
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
self.kernel = kernel.kernel
return kernel
def __getitem__(self, grid: Tuple[int, int, int]):
def __getitem__(
self,
grid: Union[
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
],
):
return self.fn[grid]
@property

View File

@@ -0,0 +1,42 @@
#include <cuda.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
{% include "grid.h" %}
#define {{ name | upper }}_NAME "{{ uniquename }}"
{% for fn in fns %}
{% if fn.ctypes is none %}
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
{% else %}
TVM_FFI_EMBED_CUBIN(triton_{{ fn.fnname }});
#define {{ fn.fnname | upper}}_STUB(__grid, __stream, __numWarps, __numStages{% for ctype in fn.ctypes %}, {{ "__arg" ~ loop.index0 }}{% endfor %}) do { \
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
{% for name in fn.signature %}
{ "{{ name }}", __arg{{ loop.index0 }} }, \
{% endfor %}
}; \
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{{ fn.fnname }}, "{{ fn.fnname }}"); \
tvm::ffi::dim3 __gridDim = MakeGridDim(__grid, __meta); \
tvm::ffi::dim3 __block({% if fn.num_warps != none %}{{ fn.num_warps }}{% else %}__numWarps{% endif %} * 32, 1, 1); \
void *dummy = nullptr
{%- for ctype in fn.ctypes -%}
{%- if ctype == "CUdeviceptr" -%}
, *__arg{{ loop.index0 }}_ptr=__arg{{ loop.index0 }}.data_ptr()
{%- endif -%}
{%- endfor -%}; \
void *__params[] = {
{%- for ctype in fn.ctypes -%}
{%- if ctype != none -%}
&__arg{{ loop.index0 }}
{%- if ctype == "CUdeviceptr" -%}
_ptr
{%- endif -%},
{%- endif -%}
{%- endfor -%}&dummy, &dummy }; \
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __gridDim, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \
} while (false)
{% endif %}
{% endfor %}
{{ code }}

View File

@@ -0,0 +1,29 @@
#ifndef TRITON_TVM_FFI_GRID_H
#define TRITON_TVM_FFI_GRID_H
#include <cstdint>
#include <tvm/ffi/extra/cuda/base.h>
#include <tvm/ffi/tvm_ffi.h>
template <typename T>
inline tvm::ffi::dim3
MakeGridDim(const T &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta);
template <>
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
const tvm::ffi::Tuple<int32_t, int32_t, int32_t> &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &) {
return tvm::ffi::dim3(grid.get<0>(), grid.get<1>(), grid.get<2>());
}
template <>
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Function>(
const tvm::ffi::Function &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta) {
tvm::ffi::Tuple<int32_t, int32_t, int32_t> tuple =
grid(meta).cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
return MakeGridDim(tuple, meta);
}
#endif

View File

@@ -1,8 +1,9 @@
from functools import cached_property
from io import TextIOWrapper
from pathlib import Path
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Final, List, Optional, Sequence, Union
import jinja2
import torch.utils.cpp_extension
import tvm_ffi
@@ -38,6 +39,11 @@ class TVMFFIWrapperFunction(object):
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
extra_include_paths
)
self.env: Final[jinja2.Environment] = jinja2.Environment(
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
trim_blocks=True,
)
self.tpl: Final[jinja2.Template] = self.env.get_template("gendef.cc.j2")
def __call__(self, *args, **kwargs) -> None:
func: tvm_ffi.Function = self.compile()
@@ -53,19 +59,9 @@ class TVMFFIWrapperFunction(object):
@property
def emit(self) -> str:
defs: str = "\n".join(
[
"#include <cuda.h>",
"#include <tvm/ffi/extra/cuda/cubin_launcher.h>",
"#include <tvm/ffi/function.h>",
f'#define {self.name.upper()}_NAME "{self.uniquename}"',
*map(
self.gendef,
self.fns,
),
]
return self.tpl.render(
code=self.code, fns=self.fns, name=self.name, uniquename=self.uniquename
)
return f"{defs}\n{self.code}"
@property
def uniquename(self) -> str:
@@ -90,42 +86,6 @@ class TVMFFIWrapperFunction(object):
)
return tvm_ffi.get_global_func(self.uniquename)
@staticmethod
def gendef(fn: TVMFFIJITFunction) -> str:
if fn.ctypes is None:
return f'#define {fn.fnname.upper()}_STUB tvm::ffi::Function::GetGlobalRequired("{fn.fullname}")'
else:
ctype_arg_list: List[Tuple[str, str]] = [
(ctype, f"__arg{idx}") for idx, ctype in enumerate(fn.ctypes)
]
return """
TVM_FFI_EMBED_CUBIN(triton_{fnname});
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
tvm::ffi::dim3 __block({} * 32, 1, 1); \\
void *dummy = nullptr, {}; \\
void *__params[] = {{{}, &dummy, &dummy}}; \\
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
}} while (false)
""".format(
fn.fnname.upper(),
", ".join(arg for _, arg in ctype_arg_list),
fn.num_warps if fn.num_warps is not None else "__numWarps",
", ".join(
f"*{arg}_ptr = {arg}.data_ptr()"
for ctype, arg in ctype_arg_list
if ctype == "CUdeviceptr"
),
", ".join(
f"&{arg}" if ctype != "CUdeviceptr" else f"&{arg}_ptr"
for ctype, arg in ctype_arg_list
if ctype is not None
),
fnname=fn.fnname,
).strip()
def wrap(
fns: List[TVMFFIJITFunction],