mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
supports decorator for jit and wrapper
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -10,8 +10,7 @@ import triton_tvm_ffi
|
||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||
|
||||
|
||||
# Support decorators here like
|
||||
# @triton_tvm_ffi.jit
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr,
|
||||
@@ -33,7 +32,7 @@ def add_kernel(
|
||||
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
|
||||
|
||||
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output: torch.Tensor = torch.empty_like(x)
|
||||
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||
n_elements: int = output.numel()
|
||||
@@ -43,19 +42,12 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return output
|
||||
|
||||
|
||||
# TODO: it woule be more user-friendly to define wrapper functions like below
|
||||
# @triton_tvm_ffi.torch_wrap(
|
||||
# "add",
|
||||
# [add_kernel_tvm_ffi],
|
||||
# Path(__file__).parent / "add.cc",
|
||||
# )
|
||||
# def add_tvm_ffi(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
# ...
|
||||
add_tvm_ffi = triton_tvm_ffi.torch_wrap(
|
||||
"add",
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[add_kernel_tvm_ffi],
|
||||
Path(__file__).parent / "add.cc",
|
||||
)
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(0)
|
||||
@@ -63,11 +55,11 @@ if __name__ == "__main__":
|
||||
x = torch.rand(size, device=DEVICE)
|
||||
y = torch.rand(size, device=DEVICE)
|
||||
output_torch = x + y
|
||||
output_triton = add(x, y)
|
||||
output_tvm_ffi = add_tvm_ffi(x, y)
|
||||
output_triton = add_triton(x, y)
|
||||
output_tvm_ffi = add(x, y)
|
||||
assert torch.allclose(output_torch, output_triton)
|
||||
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||
output_tvm_ffi = add_tvm_ffi(x, y)
|
||||
output_tvm_ffi = add(x, y)
|
||||
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||
|
||||
round = 1000
|
||||
@@ -76,10 +68,10 @@ if __name__ == "__main__":
|
||||
x + y
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
add(x, y)
|
||||
add_triton(x, y)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
add_tvm_ffi(x, y)
|
||||
add(x, y)
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from functools import cached_property
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import Final, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch.utils.cpp_extension
|
||||
import tvm_ffi
|
||||
@@ -88,7 +88,7 @@ class TVMFFIWrapperFunction(object):
|
||||
if fn.kernel is not None
|
||||
},
|
||||
)
|
||||
return tvm_ffi.get_global_func(self.uniquename, allow_missing=True)
|
||||
return tvm_ffi.get_global_func(self.uniquename)
|
||||
|
||||
@staticmethod
|
||||
def gendef(fn: TVMFFIJITFunction) -> str:
|
||||
@@ -127,7 +127,6 @@ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __bloc
|
||||
|
||||
|
||||
def wrap(
|
||||
name: str,
|
||||
fns: List[TVMFFIJITFunction],
|
||||
code: Union[str, Path, TextIOWrapper],
|
||||
extra_cflags: Optional[Sequence[str]] = None,
|
||||
@@ -135,19 +134,21 @@ def wrap(
|
||||
extra_ldflags: Optional[Sequence[str]] = None,
|
||||
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
||||
) -> TVMFFIWrapperFunction:
|
||||
return TVMFFIWrapperFunction(
|
||||
name,
|
||||
fns,
|
||||
code,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
)
|
||||
def decorate(fn: Union[str, Callable[..., Any]]) -> TVMFFIWrapperFunction:
|
||||
return TVMFFIWrapperFunction(
|
||||
fn if isinstance(fn, str) else fn.__name__,
|
||||
fns,
|
||||
code,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def torch_wrap(
|
||||
name: str,
|
||||
fns: List[TVMFFIJITFunction],
|
||||
code: Union[str, Path, TextIOWrapper],
|
||||
extra_cflags: Optional[Sequence[str]] = None,
|
||||
@@ -157,7 +158,6 @@ def torch_wrap(
|
||||
) -> TVMFFIWrapperFunction:
|
||||
cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home()
|
||||
return wrap(
|
||||
name,
|
||||
fns,
|
||||
code,
|
||||
extra_ldflags=[
|
||||
|
||||
Reference in New Issue
Block a user