From 192dc95ac06b90eacfbd480742f8870547122212 Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Wed, 4 Feb 2026 10:46:14 +0800 Subject: [PATCH] supports decorator for jit and wrapper Signed-off-by: Jinjie Liu --- examples/add/add.py | 28 ++++++++++------------------ python/triton_tvm_ffi/wrap.py | 28 ++++++++++++++-------------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/examples/add/add.py b/examples/add/add.py index a0c7579..8977432 100644 --- a/examples/add/add.py +++ b/examples/add/add.py @@ -10,8 +10,7 @@ import triton_tvm_ffi DEVICE = triton.runtime.driver.active.get_active_torch_device() -# Support decorators here like -# @triton_tvm_ffi.jit +@triton_tvm_ffi.jit @triton.jit def add_kernel( x_ptr, @@ -33,7 +32,7 @@ def add_kernel( add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel) -def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: +def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output: torch.Tensor = torch.empty_like(x) assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE n_elements: int = output.numel() @@ -43,19 +42,12 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return output -# TODO: it woule be more user-friendly to define wrapper functions like below -# @triton_tvm_ffi.torch_wrap( -# "add", -# [add_kernel_tvm_ffi], -# Path(__file__).parent / "add.cc", -# ) -# def add_tvm_ffi(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: -# ... -add_tvm_ffi = triton_tvm_ffi.torch_wrap( - "add", +@triton_tvm_ffi.torch_wrap( [add_kernel_tvm_ffi], Path(__file__).parent / "add.cc", ) +def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... + if __name__ == "__main__": torch.manual_seed(0) @@ -63,11 +55,11 @@ if __name__ == "__main__": x = torch.rand(size, device=DEVICE) y = torch.rand(size, device=DEVICE) output_torch = x + y - output_triton = add(x, y) - output_tvm_ffi = add_tvm_ffi(x, y) + output_triton = add_triton(x, y) + output_tvm_ffi = add(x, y) assert torch.allclose(output_torch, output_triton) assert torch.allclose(output_torch, output_tvm_ffi) - output_tvm_ffi = add_tvm_ffi(x, y) + output_tvm_ffi = add(x, y) assert torch.allclose(output_torch, output_tvm_ffi) round = 1000 @@ -76,10 +68,10 @@ if __name__ == "__main__": x + y cp1 = time.perf_counter_ns() for _ in range(round): - add(x, y) + add_triton(x, y) cp2 = time.perf_counter_ns() for _ in range(round): - add_tvm_ffi(x, y) + add(x, y) cp3 = time.perf_counter_ns() print( f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms" diff --git a/python/triton_tvm_ffi/wrap.py b/python/triton_tvm_ffi/wrap.py index f0e8f6c..7340e2b 100644 --- a/python/triton_tvm_ffi/wrap.py +++ b/python/triton_tvm_ffi/wrap.py @@ -1,7 +1,7 @@ from functools import cached_property from io import TextIOWrapper from pathlib import Path -from typing import Final, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union import torch.utils.cpp_extension import tvm_ffi @@ -88,7 +88,7 @@ class TVMFFIWrapperFunction(object): if fn.kernel is not None }, ) - return tvm_ffi.get_global_func(self.uniquename, allow_missing=True) + return tvm_ffi.get_global_func(self.uniquename) @staticmethod def gendef(fn: TVMFFIJITFunction) -> str: @@ -127,7 +127,6 @@ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __bloc def wrap( - name: str, fns: List[TVMFFIJITFunction], code: Union[str, Path, TextIOWrapper], extra_cflags: Optional[Sequence[str]] = None, @@ -135,19 +134,21 @@ def wrap( extra_ldflags: Optional[Sequence[str]] = None, extra_include_paths: Optional[Sequence[Union[str, Path]]] = None, ) -> TVMFFIWrapperFunction: - return TVMFFIWrapperFunction( - name, - fns, - code, - extra_cflags, - extra_cuda_cflags, - extra_ldflags, - extra_include_paths, - ) + def decorate(fn: Union[str, Callable[..., Any]]) -> TVMFFIWrapperFunction: + return TVMFFIWrapperFunction( + fn if isinstance(fn, str) else fn.__name__, + fns, + code, + extra_cflags, + extra_cuda_cflags, + extra_ldflags, + extra_include_paths, + ) + + return decorate def torch_wrap( - name: str, fns: List[TVMFFIJITFunction], code: Union[str, Path, TextIOWrapper], extra_cflags: Optional[Sequence[str]] = None, @@ -157,7 +158,6 @@ def torch_wrap( ) -> TVMFFIWrapperFunction: cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home() return wrap( - name, fns, code, extra_ldflags=[