supports decorator for jit and wrapper

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-04 10:46:14 +08:00
parent 6e4c2d4a43
commit 192dc95ac0
2 changed files with 24 additions and 32 deletions

View File

@@ -10,8 +10,7 @@ import triton_tvm_ffi
DEVICE = triton.runtime.driver.active.get_active_torch_device()
# Support decorators here like
# @triton_tvm_ffi.jit
@triton_tvm_ffi.jit
@triton.jit
def add_kernel(
x_ptr,
@@ -33,7 +32,7 @@ def add_kernel(
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output: torch.Tensor = torch.empty_like(x)
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements: int = output.numel()
@@ -43,19 +42,12 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return output
# TODO: it woule be more user-friendly to define wrapper functions like below
# @triton_tvm_ffi.torch_wrap(
# "add",
# [add_kernel_tvm_ffi],
# Path(__file__).parent / "add.cc",
# )
# def add_tvm_ffi(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# ...
add_tvm_ffi = triton_tvm_ffi.torch_wrap(
"add",
@triton_tvm_ffi.torch_wrap(
[add_kernel_tvm_ffi],
Path(__file__).parent / "add.cc",
)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
if __name__ == "__main__":
torch.manual_seed(0)
@@ -63,11 +55,11 @@ if __name__ == "__main__":
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x + y
output_triton = add(x, y)
output_tvm_ffi = add_tvm_ffi(x, y)
output_triton = add_triton(x, y)
output_tvm_ffi = add(x, y)
assert torch.allclose(output_torch, output_triton)
assert torch.allclose(output_torch, output_tvm_ffi)
output_tvm_ffi = add_tvm_ffi(x, y)
output_tvm_ffi = add(x, y)
assert torch.allclose(output_torch, output_tvm_ffi)
round = 1000
@@ -76,10 +68,10 @@ if __name__ == "__main__":
x + y
cp1 = time.perf_counter_ns()
for _ in range(round):
add(x, y)
add_triton(x, y)
cp2 = time.perf_counter_ns()
for _ in range(round):
add_tvm_ffi(x, y)
add(x, y)
cp3 = time.perf_counter_ns()
print(
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"

View File

@@ -1,7 +1,7 @@
from functools import cached_property
from io import TextIOWrapper
from pathlib import Path
from typing import Final, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
import torch.utils.cpp_extension
import tvm_ffi
@@ -88,7 +88,7 @@ class TVMFFIWrapperFunction(object):
if fn.kernel is not None
},
)
return tvm_ffi.get_global_func(self.uniquename, allow_missing=True)
return tvm_ffi.get_global_func(self.uniquename)
@staticmethod
def gendef(fn: TVMFFIJITFunction) -> str:
@@ -127,7 +127,6 @@ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __bloc
def wrap(
name: str,
fns: List[TVMFFIJITFunction],
code: Union[str, Path, TextIOWrapper],
extra_cflags: Optional[Sequence[str]] = None,
@@ -135,19 +134,21 @@ def wrap(
extra_ldflags: Optional[Sequence[str]] = None,
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
) -> TVMFFIWrapperFunction:
return TVMFFIWrapperFunction(
name,
fns,
code,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
)
def decorate(fn: Union[str, Callable[..., Any]]) -> TVMFFIWrapperFunction:
return TVMFFIWrapperFunction(
fn if isinstance(fn, str) else fn.__name__,
fns,
code,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
)
return decorate
def torch_wrap(
name: str,
fns: List[TVMFFIJITFunction],
code: Union[str, Path, TextIOWrapper],
extra_cflags: Optional[Sequence[str]] = None,
@@ -157,7 +158,6 @@ def torch_wrap(
) -> TVMFFIWrapperFunction:
cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home()
return wrap(
name,
fns,
code,
extra_ldflags=[