commit dc8c2c17e0a399bf29fb87a5e8c0523df275d1a1 Author: jinjieliu Date: Wed Feb 4 02:30:26 2026 +0800 verify tvm-ffi cpp wrapper on vector-add.py Signed-off-by: jinjieliu diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4e2888b --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# Virtual environments +.venv + +.vscode/ + +.clangd +.python-version +uv.lock diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/examples/add/add.cc b/examples/add/add.cc new file mode 100644 index 0000000..4e4aadd --- /dev/null +++ b/examples/add/add.cc @@ -0,0 +1,31 @@ +#include +#include +#include +#include + +#ifndef ADD_KERNEL_STUB +#define ADD_KERNEL_STUB(grid, stream, numWarps, numStages, x, y, output, \ + numel, BLOCK_SIZE) +#endif + +#ifndef ADD_NAME +#define ADD_NAME "" +#endif + +tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { + at::Tensor xtorch = at::fromDLPack(x.ToDLPack()); + at::Tensor otorch = at::empty_like(xtorch); + int64_t numel = otorch.numel(); + tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch)); + tvm::ffi::Tuple grid{(numel + 1023) / 1024, 1, 1}; + size_t numWarps = 4, numStages = 3; + DLDevice device = x.device(); + void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); + ADD_KERNEL_STUB(grid, stream, numWarps, numStages, x, y, output, numel, 1024); + return output; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def(ADD_NAME, Add); +} diff --git a/examples/add/add.py b/examples/add/add.py new file mode 100644 index 0000000..a0c7579 --- /dev/null +++ b/examples/add/add.py @@ -0,0 +1,86 @@ +from pathlib import Path +import time + +import torch +import triton +import triton.language as tl + +import triton_tvm_ffi + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +# Support decorators here like +# @triton_tvm_ffi.jit +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel) + + +def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output: torch.Tensor = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements: int = output.numel() + BLOCK_SIZE: int = 1024 + grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE) + return output + + +# TODO: it woule be more user-friendly to define wrapper functions like below +# @triton_tvm_ffi.torch_wrap( +# "add", +# [add_kernel_tvm_ffi], +# Path(__file__).parent / "add.cc", +# ) +# def add_tvm_ffi(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: +# ... +add_tvm_ffi = triton_tvm_ffi.torch_wrap( + "add", + [add_kernel_tvm_ffi], + Path(__file__).parent / "add.cc", +) + +if __name__ == "__main__": + torch.manual_seed(0) + size = 98432 + x = torch.rand(size, device=DEVICE) + y = torch.rand(size, device=DEVICE) + output_torch = x + y + output_triton = add(x, y) + output_tvm_ffi = add_tvm_ffi(x, y) + assert torch.allclose(output_torch, output_triton) + assert torch.allclose(output_torch, output_tvm_ffi) + output_tvm_ffi = add_tvm_ffi(x, y) + assert torch.allclose(output_torch, output_tvm_ffi) + + round = 1000 + cp0 = time.perf_counter_ns() + for _ in range(round): + x + y + cp1 = time.perf_counter_ns() + for _ in range(round): + add(x, y) + cp2 = time.perf_counter_ns() + for _ in range(round): + add_tvm_ffi(x, y) + cp3 = time.perf_counter_ns() + print( + f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms" + ) diff --git a/main.py b/main.py new file mode 100644 index 0000000..d20d5d4 --- /dev/null +++ b/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from triton-tvm-ffi!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b0fb141 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[project] +name = "triton-tvm-ffi" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +dependencies = [ + "apache-tvm-ffi", + "jinja2", +] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["triton_tvm_ffi"] +package-dir = {"" = "python"} diff --git a/python/triton_tvm_ffi/__init__.py b/python/triton_tvm_ffi/__init__.py new file mode 100644 index 0000000..2200bf7 --- /dev/null +++ b/python/triton_tvm_ffi/__init__.py @@ -0,0 +1,4 @@ +from .jit import jit +from .wrap import torch_wrap, wrap + +__all__ = ["jit", "torch_wrap", "wrap"] diff --git a/python/triton_tvm_ffi/jit.py b/python/triton_tvm_ffi/jit.py new file mode 100644 index 0000000..1611a0f --- /dev/null +++ b/python/triton_tvm_ffi/jit.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from functools import cached_property +from typing import Any, Dict, Final, List, Optional, Tuple + +import torch +from triton.compiler import CompiledKernel +from triton.runtime import JITFunction +import tvm_ffi + +from .utils import type_canonicalize + + +class TVMFFIJITFunction(object): + def __init__(self, fn: JITFunction, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.fn: Final[JITFunction] = fn + self.ctypes: Optional[List[Optional[str]]] = None + self.kernel: Optional[bytes] = None + + @tvm_ffi.register_global_func(self.fullname) + def _( + grid: Tuple[int, int, int], + _: int, + num_warps: int, + num_stages: int, + *args, + **kwargs, + ): + args: List[Any] = map(self.canonicalize, args) + kwargs: Dict[str, Any] = { + k: self.canonicalize(v) for k, v in kwargs.items() + } + kernel: CompiledKernel = self.fn[grid]( + *args, **kwargs, num_warps=num_warps, num_stages=num_stages + ) + self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()] + self.kernel = kernel.kernel + return kernel + + def __getitem__(self, grid: Tuple[int, int, int]): + return self.fn[grid] + + @property + def cache_hash(self) -> int: + return self.ctypes_hash ^ self.kernel_hash + + @property + def ctypes_hash(self) -> int: + return hash(tuple(self.ctypes) if self.ctypes is not None else None) + + @property + def kernel_hash(self) -> int: + return hash(self.kernel) + + @cached_property + def fnname(self) -> str: + return self.fn.fn.__name__ + + @cached_property + def fullname(self) -> str: + return f"triton.{self.name}" + + @cached_property + def name(self) -> str: + return f"{self.fnname}_{hash(self.fn.fn)}" + + @staticmethod + def canonicalize(val: Any) -> Any: + if hasattr(val, "__dlpack__"): + return torch.from_dlpack(val) + else: + return val + + +def jit(fn: JITFunction) -> TVMFFIJITFunction: + return TVMFFIJITFunction(fn) diff --git a/python/triton_tvm_ffi/utils.py b/python/triton_tvm_ffi/utils.py new file mode 100644 index 0000000..3874f2a --- /dev/null +++ b/python/triton_tvm_ffi/utils.py @@ -0,0 +1,10 @@ +from typing import Optional + +from triton.backends.nvidia.driver import ty_to_cpp + + +def type_canonicalize(ty: str) -> Optional[str]: + if ty == "constexpr": + return None + else: + return ty_to_cpp(ty) diff --git a/python/triton_tvm_ffi/wrap.py b/python/triton_tvm_ffi/wrap.py new file mode 100644 index 0000000..c5cda5c --- /dev/null +++ b/python/triton_tvm_ffi/wrap.py @@ -0,0 +1,176 @@ +from functools import cached_property +from io import TextIOWrapper +from pathlib import Path +from typing import Final, List, Optional, Sequence, Tuple, Union + +import torch.utils.cpp_extension +import tvm_ffi + +from .jit import TVMFFIJITFunction + + +class TVMFFIWrapperFunction(object): + def __init__( + self, + name: str, + fns: List[TVMFFIJITFunction], + code: Union[str, Path, TextIOWrapper], + extra_cflags: Optional[Sequence[str]] = None, + extra_cuda_cflags: Optional[Sequence[str]] = None, + extra_ldflags: Optional[Sequence[str]] = None, + extra_include_paths: Optional[Sequence[Union[str, Path]]] = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.name: Final[str] = name + self.fns: List[TVMFFIJITFunction] = [*fns] + if isinstance(code, Path): + with open(code, "r") as f: + self.code: Final[str] = f.read() + elif isinstance(code, TextIOWrapper): + self.code: Final[str] = code.read() + else: + self.code: Final[str] = f"{code}" + self.extra_cflags: Optional[Sequence[str]] = extra_cflags + self.extra_cuda_cflags: Optional[Sequence[str]] = extra_cuda_cflags + self.extra_ldflags: Optional[Sequence[str]] = extra_ldflags + self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = ( + extra_include_paths + ) + + def __call__(self, *args, **kwargs) -> None: + func: tvm_ffi.Function = self.compile() + return func(*args, **kwargs) + + @property + def fns_hash(self) -> int: + return hash(tuple(fn.cache_hash for fn in self.fns)) + + @cached_property + def fullname(self) -> str: + return f"triton.{self.name}" + + @property + def emit(self) -> str: + defs: str = "\n".join( + [ + "#include ", + "#include ", + "#include ", + f'#define {self.name.upper()}_NAME "{self.uniquename}"', + *map( + self.gendef, + self.fns, + ), + ] + ) + return f"{defs}\n{self.code}" + + @property + def uniquename(self) -> str: + return f"{self.name}_{self.fns_hash}" + + def compile(self) -> tvm_ffi.Function: + if func := tvm_ffi.get_global_func(self.uniquename, allow_missing=True): + return func + else: + tvm_ffi.cpp.load_inline( + self.name, + cpp_sources=[self.emit], + extra_cflags=self.extra_cflags, + extra_cuda_cflags=self.extra_cuda_cflags, + extra_ldflags=self.extra_ldflags, + extra_include_paths=self.extra_include_paths, + embed_cubin={ + f"triton_{fn.fnname}": fn.kernel + for fn in self.fns + if fn.kernel is not None + }, + ) + return tvm_ffi.get_global_func(self.uniquename, allow_missing=True) + + @staticmethod + def gendef(fn: TVMFFIJITFunction) -> str: + if fn.ctypes is None: + return f'#define {fn.fnname.upper()}_STUB tvm::ffi::Function::GetGlobalRequired("{fn.fullname}")' + else: + ctype_arg_list: List[Tuple[str, str]] = [ + (ctype, f"__arg{idx}") for idx, ctype in enumerate(fn.ctypes) + ] + + return """ +TVM_FFI_EMBED_CUBIN(triton_{fnname}); +#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\ +static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\ +tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\ +tvm::ffi::dim3 __block(__numWarps * 32, 1, 1); \\ +void *dummy = nullptr, {}; \\ +void *__params[] = {{{}, &dummy, &dummy}}; \\ +TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast(__stream))); \\ +}} while (false) +""".format( + fn.fnname.upper(), + ", ".join(arg for _, arg in ctype_arg_list), + ", ".join( + f"*{arg}_ptr = {arg}.data_ptr()" + for ctype, arg in ctype_arg_list + if ctype == "CUdeviceptr" + ), + ", ".join( + f"&{arg}" if ctype != "CUdeviceptr" else f"&{arg}_ptr" + for ctype, arg in ctype_arg_list + if ctype is not None + ), + fnname=fn.fnname, + ).strip() + + +def wrap( + name: str, + fns: List[TVMFFIJITFunction], + code: Union[str, Path, TextIOWrapper], + extra_cflags: Optional[Sequence[str]] = None, + extra_cuda_cflags: Optional[Sequence[str]] = None, + extra_ldflags: Optional[Sequence[str]] = None, + extra_include_paths: Optional[Sequence[Union[str, Path]]] = None, +) -> TVMFFIWrapperFunction: + return TVMFFIWrapperFunction( + name, + fns, + code, + extra_cflags, + extra_cuda_cflags, + extra_ldflags, + extra_include_paths, + ) + + +def torch_wrap( + name: str, + fns: List[TVMFFIJITFunction], + code: Union[str, Path, TextIOWrapper], + extra_cflags: Optional[Sequence[str]] = None, + extra_cuda_cflags: Optional[Sequence[str]] = None, + extra_ldflags: Optional[Sequence[str]] = None, + extra_include_paths: Optional[Sequence[Union[str, Path]]] = None, +) -> TVMFFIWrapperFunction: + return wrap( + name, + fns, + code, + extra_ldflags=[ + "-Wl,--no-as-needed", + *map( + lambda path: f"-L{path}", + torch.utils.cpp_extension.library_paths(), + ), + "-lc10", + "-ltorch", + ] + + (extra_ldflags or []), + extra_cflags=extra_cflags, + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=[*torch.utils.cpp_extension.include_paths()] + + (extra_include_paths or []), + )