supports decorator for jit and wrapper

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-04 10:46:14 +08:00
parent 6e4c2d4a43
commit 192dc95ac0
2 changed files with 24 additions and 32 deletions

View File

@@ -10,8 +10,7 @@ import triton_tvm_ffi
DEVICE = triton.runtime.driver.active.get_active_torch_device()
# Support decorators here like
# @triton_tvm_ffi.jit
@triton_tvm_ffi.jit
@triton.jit
def add_kernel(
x_ptr,
@@ -33,7 +32,7 @@ def add_kernel(
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output: torch.Tensor = torch.empty_like(x)
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements: int = output.numel()
@@ -43,19 +42,12 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return output
# TODO: it woule be more user-friendly to define wrapper functions like below
# @triton_tvm_ffi.torch_wrap(
# "add",
# [add_kernel_tvm_ffi],
# Path(__file__).parent / "add.cc",
# )
# def add_tvm_ffi(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# ...
add_tvm_ffi = triton_tvm_ffi.torch_wrap(
"add",
@triton_tvm_ffi.torch_wrap(
[add_kernel_tvm_ffi],
Path(__file__).parent / "add.cc",
)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
if __name__ == "__main__":
torch.manual_seed(0)
@@ -63,11 +55,11 @@ if __name__ == "__main__":
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x + y
output_triton = add(x, y)
output_tvm_ffi = add_tvm_ffi(x, y)
output_triton = add_triton(x, y)
output_tvm_ffi = add(x, y)
assert torch.allclose(output_torch, output_triton)
assert torch.allclose(output_torch, output_tvm_ffi)
output_tvm_ffi = add_tvm_ffi(x, y)
output_tvm_ffi = add(x, y)
assert torch.allclose(output_torch, output_tvm_ffi)
round = 1000
@@ -76,10 +68,10 @@ if __name__ == "__main__":
x + y
cp1 = time.perf_counter_ns()
for _ in range(round):
add(x, y)
add_triton(x, y)
cp2 = time.perf_counter_ns()
for _ in range(round):
add_tvm_ffi(x, y)
add(x, y)
cp3 = time.perf_counter_ns()
print(
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"