mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
supports decorator for jit and wrapper
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -10,8 +10,7 @@ import triton_tvm_ffi
|
|||||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||||
|
|
||||||
|
|
||||||
# Support decorators here like
|
@triton_tvm_ffi.jit
|
||||||
# @triton_tvm_ffi.jit
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def add_kernel(
|
def add_kernel(
|
||||||
x_ptr,
|
x_ptr,
|
||||||
@@ -33,7 +32,7 @@ def add_kernel(
|
|||||||
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
|
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
|
||||||
|
|
||||||
|
|
||||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
output: torch.Tensor = torch.empty_like(x)
|
output: torch.Tensor = torch.empty_like(x)
|
||||||
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||||
n_elements: int = output.numel()
|
n_elements: int = output.numel()
|
||||||
@@ -43,19 +42,12 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# TODO: it woule be more user-friendly to define wrapper functions like below
|
@triton_tvm_ffi.torch_wrap(
|
||||||
# @triton_tvm_ffi.torch_wrap(
|
|
||||||
# "add",
|
|
||||||
# [add_kernel_tvm_ffi],
|
|
||||||
# Path(__file__).parent / "add.cc",
|
|
||||||
# )
|
|
||||||
# def add_tvm_ffi(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
||||||
# ...
|
|
||||||
add_tvm_ffi = triton_tvm_ffi.torch_wrap(
|
|
||||||
"add",
|
|
||||||
[add_kernel_tvm_ffi],
|
[add_kernel_tvm_ffi],
|
||||||
Path(__file__).parent / "add.cc",
|
Path(__file__).parent / "add.cc",
|
||||||
)
|
)
|
||||||
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
@@ -63,11 +55,11 @@ if __name__ == "__main__":
|
|||||||
x = torch.rand(size, device=DEVICE)
|
x = torch.rand(size, device=DEVICE)
|
||||||
y = torch.rand(size, device=DEVICE)
|
y = torch.rand(size, device=DEVICE)
|
||||||
output_torch = x + y
|
output_torch = x + y
|
||||||
output_triton = add(x, y)
|
output_triton = add_triton(x, y)
|
||||||
output_tvm_ffi = add_tvm_ffi(x, y)
|
output_tvm_ffi = add(x, y)
|
||||||
assert torch.allclose(output_torch, output_triton)
|
assert torch.allclose(output_torch, output_triton)
|
||||||
assert torch.allclose(output_torch, output_tvm_ffi)
|
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||||
output_tvm_ffi = add_tvm_ffi(x, y)
|
output_tvm_ffi = add(x, y)
|
||||||
assert torch.allclose(output_torch, output_tvm_ffi)
|
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||||
|
|
||||||
round = 1000
|
round = 1000
|
||||||
@@ -76,10 +68,10 @@ if __name__ == "__main__":
|
|||||||
x + y
|
x + y
|
||||||
cp1 = time.perf_counter_ns()
|
cp1 = time.perf_counter_ns()
|
||||||
for _ in range(round):
|
for _ in range(round):
|
||||||
add(x, y)
|
add_triton(x, y)
|
||||||
cp2 = time.perf_counter_ns()
|
cp2 = time.perf_counter_ns()
|
||||||
for _ in range(round):
|
for _ in range(round):
|
||||||
add_tvm_ffi(x, y)
|
add(x, y)
|
||||||
cp3 = time.perf_counter_ns()
|
cp3 = time.perf_counter_ns()
|
||||||
print(
|
print(
|
||||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from io import TextIOWrapper
|
from io import TextIOWrapper
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Final, List, Optional, Sequence, Tuple, Union
|
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
import tvm_ffi
|
import tvm_ffi
|
||||||
@@ -88,7 +88,7 @@ class TVMFFIWrapperFunction(object):
|
|||||||
if fn.kernel is not None
|
if fn.kernel is not None
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return tvm_ffi.get_global_func(self.uniquename, allow_missing=True)
|
return tvm_ffi.get_global_func(self.uniquename)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gendef(fn: TVMFFIJITFunction) -> str:
|
def gendef(fn: TVMFFIJITFunction) -> str:
|
||||||
@@ -127,7 +127,6 @@ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __bloc
|
|||||||
|
|
||||||
|
|
||||||
def wrap(
|
def wrap(
|
||||||
name: str,
|
|
||||||
fns: List[TVMFFIJITFunction],
|
fns: List[TVMFFIJITFunction],
|
||||||
code: Union[str, Path, TextIOWrapper],
|
code: Union[str, Path, TextIOWrapper],
|
||||||
extra_cflags: Optional[Sequence[str]] = None,
|
extra_cflags: Optional[Sequence[str]] = None,
|
||||||
@@ -135,8 +134,9 @@ def wrap(
|
|||||||
extra_ldflags: Optional[Sequence[str]] = None,
|
extra_ldflags: Optional[Sequence[str]] = None,
|
||||||
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
||||||
) -> TVMFFIWrapperFunction:
|
) -> TVMFFIWrapperFunction:
|
||||||
|
def decorate(fn: Union[str, Callable[..., Any]]) -> TVMFFIWrapperFunction:
|
||||||
return TVMFFIWrapperFunction(
|
return TVMFFIWrapperFunction(
|
||||||
name,
|
fn if isinstance(fn, str) else fn.__name__,
|
||||||
fns,
|
fns,
|
||||||
code,
|
code,
|
||||||
extra_cflags,
|
extra_cflags,
|
||||||
@@ -145,9 +145,10 @@ def wrap(
|
|||||||
extra_include_paths,
|
extra_include_paths,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return decorate
|
||||||
|
|
||||||
|
|
||||||
def torch_wrap(
|
def torch_wrap(
|
||||||
name: str,
|
|
||||||
fns: List[TVMFFIJITFunction],
|
fns: List[TVMFFIJITFunction],
|
||||||
code: Union[str, Path, TextIOWrapper],
|
code: Union[str, Path, TextIOWrapper],
|
||||||
extra_cflags: Optional[Sequence[str]] = None,
|
extra_cflags: Optional[Sequence[str]] = None,
|
||||||
@@ -157,7 +158,6 @@ def torch_wrap(
|
|||||||
) -> TVMFFIWrapperFunction:
|
) -> TVMFFIWrapperFunction:
|
||||||
cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home()
|
cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home()
|
||||||
return wrap(
|
return wrap(
|
||||||
name,
|
|
||||||
fns,
|
fns,
|
||||||
code,
|
code,
|
||||||
extra_ldflags=[
|
extra_ldflags=[
|
||||||
|
|||||||
Reference in New Issue
Block a user