mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
verify tvm-ffi cpp wrapper on vector-add.py
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
.clangd
|
||||||
|
.python-version
|
||||||
|
uv.lock
|
||||||
31
examples/add/add.cc
Normal file
31
examples/add/add.cc
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#include <ATen/DLConvertor.h>
|
||||||
|
#include <ATen/dlpack.h>
|
||||||
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
|
#include <tvm/ffi/tvm_ffi.h>
|
||||||
|
|
||||||
|
#ifndef ADD_KERNEL_STUB
|
||||||
|
#define ADD_KERNEL_STUB(grid, stream, numWarps, numStages, x, y, output, \
|
||||||
|
numel, BLOCK_SIZE)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef ADD_NAME
|
||||||
|
#define ADD_NAME ""
|
||||||
|
#endif
|
||||||
|
|
||||||
|
tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
|
||||||
|
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
|
||||||
|
at::Tensor otorch = at::empty_like(xtorch);
|
||||||
|
int64_t numel = otorch.numel();
|
||||||
|
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
|
||||||
|
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{(numel + 1023) / 1024, 1, 1};
|
||||||
|
size_t numWarps = 4, numStages = 3;
|
||||||
|
DLDevice device = x.device();
|
||||||
|
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||||
|
ADD_KERNEL_STUB(grid, stream, numWarps, numStages, x, y, output, numel, 1024);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||||
|
namespace refl = tvm::ffi::reflection;
|
||||||
|
refl::GlobalDef().def(ADD_NAME, Add);
|
||||||
|
}
|
||||||
86
examples/add/add.py
Normal file
86
examples/add/add.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
import triton_tvm_ffi
|
||||||
|
|
||||||
|
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||||
|
|
||||||
|
|
||||||
|
# Support decorators here like
|
||||||
|
# @triton_tvm_ffi.jit
|
||||||
|
@triton.jit
|
||||||
|
def add_kernel(
|
||||||
|
x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
output_ptr,
|
||||||
|
n_elements,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
x = tl.load(x_ptr + offsets, mask=mask)
|
||||||
|
y = tl.load(y_ptr + offsets, mask=mask)
|
||||||
|
output = x + y
|
||||||
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
|
||||||
|
|
||||||
|
|
||||||
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
output: torch.Tensor = torch.empty_like(x)
|
||||||
|
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||||
|
n_elements: int = output.numel()
|
||||||
|
BLOCK_SIZE: int = 1024
|
||||||
|
grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1)
|
||||||
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: it woule be more user-friendly to define wrapper functions like below
|
||||||
|
# @triton_tvm_ffi.torch_wrap(
|
||||||
|
# "add",
|
||||||
|
# [add_kernel_tvm_ffi],
|
||||||
|
# Path(__file__).parent / "add.cc",
|
||||||
|
# )
|
||||||
|
# def add_tvm_ffi(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
# ...
|
||||||
|
add_tvm_ffi = triton_tvm_ffi.torch_wrap(
|
||||||
|
"add",
|
||||||
|
[add_kernel_tvm_ffi],
|
||||||
|
Path(__file__).parent / "add.cc",
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(0)
|
||||||
|
size = 98432
|
||||||
|
x = torch.rand(size, device=DEVICE)
|
||||||
|
y = torch.rand(size, device=DEVICE)
|
||||||
|
output_torch = x + y
|
||||||
|
output_triton = add(x, y)
|
||||||
|
output_tvm_ffi = add_tvm_ffi(x, y)
|
||||||
|
assert torch.allclose(output_torch, output_triton)
|
||||||
|
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||||
|
output_tvm_ffi = add_tvm_ffi(x, y)
|
||||||
|
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||||
|
|
||||||
|
round = 1000
|
||||||
|
cp0 = time.perf_counter_ns()
|
||||||
|
for _ in range(round):
|
||||||
|
x + y
|
||||||
|
cp1 = time.perf_counter_ns()
|
||||||
|
for _ in range(round):
|
||||||
|
add(x, y)
|
||||||
|
cp2 = time.perf_counter_ns()
|
||||||
|
for _ in range(round):
|
||||||
|
add_tvm_ffi(x, y)
|
||||||
|
cp3 = time.perf_counter_ns()
|
||||||
|
print(
|
||||||
|
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||||
|
)
|
||||||
6
main.py
Normal file
6
main.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from triton-tvm-ffi!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
[project]
|
||||||
|
name = "triton-tvm-ffi"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
dependencies = [
|
||||||
|
"apache-tvm-ffi",
|
||||||
|
"jinja2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = ["triton_tvm_ffi"]
|
||||||
|
package-dir = {"" = "python"}
|
||||||
4
python/triton_tvm_ffi/__init__.py
Normal file
4
python/triton_tvm_ffi/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .jit import jit
|
||||||
|
from .wrap import torch_wrap, wrap
|
||||||
|
|
||||||
|
__all__ = ["jit", "torch_wrap", "wrap"]
|
||||||
77
python/triton_tvm_ffi/jit.py
Normal file
77
python/triton_tvm_ffi/jit.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any, Dict, Final, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from triton.compiler import CompiledKernel
|
||||||
|
from triton.runtime import JITFunction
|
||||||
|
import tvm_ffi
|
||||||
|
|
||||||
|
from .utils import type_canonicalize
|
||||||
|
|
||||||
|
|
||||||
|
class TVMFFIJITFunction(object):
|
||||||
|
def __init__(self, fn: JITFunction, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.fn: Final[JITFunction] = fn
|
||||||
|
self.ctypes: Optional[List[Optional[str]]] = None
|
||||||
|
self.kernel: Optional[bytes] = None
|
||||||
|
|
||||||
|
@tvm_ffi.register_global_func(self.fullname)
|
||||||
|
def _(
|
||||||
|
grid: Tuple[int, int, int],
|
||||||
|
_: int,
|
||||||
|
num_warps: int,
|
||||||
|
num_stages: int,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
args: List[Any] = map(self.canonicalize, args)
|
||||||
|
kwargs: Dict[str, Any] = {
|
||||||
|
k: self.canonicalize(v) for k, v in kwargs.items()
|
||||||
|
}
|
||||||
|
kernel: CompiledKernel = self.fn[grid](
|
||||||
|
*args, **kwargs, num_warps=num_warps, num_stages=num_stages
|
||||||
|
)
|
||||||
|
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
||||||
|
self.kernel = kernel.kernel
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
def __getitem__(self, grid: Tuple[int, int, int]):
|
||||||
|
return self.fn[grid]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_hash(self) -> int:
|
||||||
|
return self.ctypes_hash ^ self.kernel_hash
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ctypes_hash(self) -> int:
|
||||||
|
return hash(tuple(self.ctypes) if self.ctypes is not None else None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kernel_hash(self) -> int:
|
||||||
|
return hash(self.kernel)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def fnname(self) -> str:
|
||||||
|
return self.fn.fn.__name__
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def fullname(self) -> str:
|
||||||
|
return f"triton.{self.name}"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def name(self) -> str:
|
||||||
|
return f"{self.fnname}_{hash(self.fn.fn)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def canonicalize(val: Any) -> Any:
|
||||||
|
if hasattr(val, "__dlpack__"):
|
||||||
|
return torch.from_dlpack(val)
|
||||||
|
else:
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def jit(fn: JITFunction) -> TVMFFIJITFunction:
|
||||||
|
return TVMFFIJITFunction(fn)
|
||||||
10
python/triton_tvm_ffi/utils.py
Normal file
10
python/triton_tvm_ffi/utils.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from triton.backends.nvidia.driver import ty_to_cpp
|
||||||
|
|
||||||
|
|
||||||
|
def type_canonicalize(ty: str) -> Optional[str]:
|
||||||
|
if ty == "constexpr":
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return ty_to_cpp(ty)
|
||||||
176
python/triton_tvm_ffi/wrap.py
Normal file
176
python/triton_tvm_ffi/wrap.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
from io import TextIOWrapper
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Final, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import torch.utils.cpp_extension
|
||||||
|
import tvm_ffi
|
||||||
|
|
||||||
|
from .jit import TVMFFIJITFunction
|
||||||
|
|
||||||
|
|
||||||
|
class TVMFFIWrapperFunction(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
fns: List[TVMFFIJITFunction],
|
||||||
|
code: Union[str, Path, TextIOWrapper],
|
||||||
|
extra_cflags: Optional[Sequence[str]] = None,
|
||||||
|
extra_cuda_cflags: Optional[Sequence[str]] = None,
|
||||||
|
extra_ldflags: Optional[Sequence[str]] = None,
|
||||||
|
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.name: Final[str] = name
|
||||||
|
self.fns: List[TVMFFIJITFunction] = [*fns]
|
||||||
|
if isinstance(code, Path):
|
||||||
|
with open(code, "r") as f:
|
||||||
|
self.code: Final[str] = f.read()
|
||||||
|
elif isinstance(code, TextIOWrapper):
|
||||||
|
self.code: Final[str] = code.read()
|
||||||
|
else:
|
||||||
|
self.code: Final[str] = f"{code}"
|
||||||
|
self.extra_cflags: Optional[Sequence[str]] = extra_cflags
|
||||||
|
self.extra_cuda_cflags: Optional[Sequence[str]] = extra_cuda_cflags
|
||||||
|
self.extra_ldflags: Optional[Sequence[str]] = extra_ldflags
|
||||||
|
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
|
||||||
|
extra_include_paths
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> None:
|
||||||
|
func: tvm_ffi.Function = self.compile()
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fns_hash(self) -> int:
|
||||||
|
return hash(tuple(fn.cache_hash for fn in self.fns))
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def fullname(self) -> str:
|
||||||
|
return f"triton.{self.name}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def emit(self) -> str:
|
||||||
|
defs: str = "\n".join(
|
||||||
|
[
|
||||||
|
"#include <cuda.h>",
|
||||||
|
"#include <tvm/ffi/extra/cuda/cubin_launcher.h>",
|
||||||
|
"#include <tvm/ffi/function.h>",
|
||||||
|
f'#define {self.name.upper()}_NAME "{self.uniquename}"',
|
||||||
|
*map(
|
||||||
|
self.gendef,
|
||||||
|
self.fns,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return f"{defs}\n{self.code}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uniquename(self) -> str:
|
||||||
|
return f"{self.name}_{self.fns_hash}"
|
||||||
|
|
||||||
|
def compile(self) -> tvm_ffi.Function:
|
||||||
|
if func := tvm_ffi.get_global_func(self.uniquename, allow_missing=True):
|
||||||
|
return func
|
||||||
|
else:
|
||||||
|
tvm_ffi.cpp.load_inline(
|
||||||
|
self.name,
|
||||||
|
cpp_sources=[self.emit],
|
||||||
|
extra_cflags=self.extra_cflags,
|
||||||
|
extra_cuda_cflags=self.extra_cuda_cflags,
|
||||||
|
extra_ldflags=self.extra_ldflags,
|
||||||
|
extra_include_paths=self.extra_include_paths,
|
||||||
|
embed_cubin={
|
||||||
|
f"triton_{fn.fnname}": fn.kernel
|
||||||
|
for fn in self.fns
|
||||||
|
if fn.kernel is not None
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return tvm_ffi.get_global_func(self.uniquename, allow_missing=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gendef(fn: TVMFFIJITFunction) -> str:
|
||||||
|
if fn.ctypes is None:
|
||||||
|
return f'#define {fn.fnname.upper()}_STUB tvm::ffi::Function::GetGlobalRequired("{fn.fullname}")'
|
||||||
|
else:
|
||||||
|
ctype_arg_list: List[Tuple[str, str]] = [
|
||||||
|
(ctype, f"__arg{idx}") for idx, ctype in enumerate(fn.ctypes)
|
||||||
|
]
|
||||||
|
|
||||||
|
return """
|
||||||
|
TVM_FFI_EMBED_CUBIN(triton_{fnname});
|
||||||
|
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
|
||||||
|
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
|
||||||
|
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
|
||||||
|
tvm::ffi::dim3 __block(__numWarps * 32, 1, 1); \\
|
||||||
|
void *dummy = nullptr, {}; \\
|
||||||
|
void *__params[] = {{{}, &dummy, &dummy}}; \\
|
||||||
|
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
|
||||||
|
}} while (false)
|
||||||
|
""".format(
|
||||||
|
fn.fnname.upper(),
|
||||||
|
", ".join(arg for _, arg in ctype_arg_list),
|
||||||
|
", ".join(
|
||||||
|
f"*{arg}_ptr = {arg}.data_ptr()"
|
||||||
|
for ctype, arg in ctype_arg_list
|
||||||
|
if ctype == "CUdeviceptr"
|
||||||
|
),
|
||||||
|
", ".join(
|
||||||
|
f"&{arg}" if ctype != "CUdeviceptr" else f"&{arg}_ptr"
|
||||||
|
for ctype, arg in ctype_arg_list
|
||||||
|
if ctype is not None
|
||||||
|
),
|
||||||
|
fnname=fn.fnname,
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def wrap(
|
||||||
|
name: str,
|
||||||
|
fns: List[TVMFFIJITFunction],
|
||||||
|
code: Union[str, Path, TextIOWrapper],
|
||||||
|
extra_cflags: Optional[Sequence[str]] = None,
|
||||||
|
extra_cuda_cflags: Optional[Sequence[str]] = None,
|
||||||
|
extra_ldflags: Optional[Sequence[str]] = None,
|
||||||
|
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
||||||
|
) -> TVMFFIWrapperFunction:
|
||||||
|
return TVMFFIWrapperFunction(
|
||||||
|
name,
|
||||||
|
fns,
|
||||||
|
code,
|
||||||
|
extra_cflags,
|
||||||
|
extra_cuda_cflags,
|
||||||
|
extra_ldflags,
|
||||||
|
extra_include_paths,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_wrap(
|
||||||
|
name: str,
|
||||||
|
fns: List[TVMFFIJITFunction],
|
||||||
|
code: Union[str, Path, TextIOWrapper],
|
||||||
|
extra_cflags: Optional[Sequence[str]] = None,
|
||||||
|
extra_cuda_cflags: Optional[Sequence[str]] = None,
|
||||||
|
extra_ldflags: Optional[Sequence[str]] = None,
|
||||||
|
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
||||||
|
) -> TVMFFIWrapperFunction:
|
||||||
|
return wrap(
|
||||||
|
name,
|
||||||
|
fns,
|
||||||
|
code,
|
||||||
|
extra_ldflags=[
|
||||||
|
"-Wl,--no-as-needed",
|
||||||
|
*map(
|
||||||
|
lambda path: f"-L{path}",
|
||||||
|
torch.utils.cpp_extension.library_paths(),
|
||||||
|
),
|
||||||
|
"-lc10",
|
||||||
|
"-ltorch",
|
||||||
|
]
|
||||||
|
+ (extra_ldflags or []),
|
||||||
|
extra_cflags=extra_cflags,
|
||||||
|
extra_cuda_cflags=extra_cuda_cflags,
|
||||||
|
extra_include_paths=[*torch.utils.cpp_extension.include_paths()]
|
||||||
|
+ (extra_include_paths or []),
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user