diff --git a/examples/add/add.cc b/examples/add/add.cc index cbfb364..6810ccd 100644 --- a/examples/add/add.cc +++ b/examples/add/add.cc @@ -1,3 +1,4 @@ +#include "tvm/ffi/function.h" #include #include #include @@ -15,10 +16,14 @@ tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { at::Tensor xtorch = at::fromDLPack(x.ToDLPack()); at::Tensor otorch = at::empty_like(xtorch); - int64_t numel = otorch.numel(); + int32_t numel = otorch.numel(); tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch)); - tvm::ffi::Tuple grid{(numel + 1023) / 1024, 1, 1}; - // TODO: check the performance loss after enabling `Optional` + tvm::ffi::Function grid = tvm::ffi::Function::FromTyped( + [numel](const tvm::ffi::Map &meta) + -> tvm::ffi::Tuple { + const int32_t BLOCK_SIZE = meta["BLOCK_SIZE"].cast(); + return tvm::ffi::Tuple((numel + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1); + }); tvm::ffi::Optional numWarps = std::nullopt, numStages = std::nullopt; DLDevice device = x.device(); void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); diff --git a/examples/add/add.py b/examples/add/add.py index 050ee5d..3e5daa5 100644 --- a/examples/add/add.py +++ b/examples/add/add.py @@ -33,8 +33,9 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE n_elements: int = output.numel() BLOCK_SIZE: int = 1024 - grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1) - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE) + add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1, 1)]( + x, y, output, n_elements, BLOCK_SIZE + ) return output diff --git a/python/triton_tvm_ffi/jit.py b/python/triton_tvm_ffi/jit.py index ac2c0a5..a262726 100644 --- a/python/triton_tvm_ffi/jit.py +++ b/python/triton_tvm_ffi/jit.py @@ -1,7 +1,8 @@ from __future__ import annotations from functools import cached_property -from typing import Any, Dict, Final, List, Optional, Tuple +import inspect +from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union import torch from triton.compiler import CompiledKernel @@ -15,13 +16,16 @@ class TVMFFIJITFunction(object): def __init__(self, fn: JITFunction, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.fn: Final[JITFunction] = fn + self.signature: Optional[List[str]] = None self.ctypes: Optional[List[Optional[str]]] = None self.kernel: Optional[bytes] = None self.num_warps: Optional[int] = None @tvm_ffi.register_global_func(self.fullname) def _( - grid: Tuple[int, int, int], + grid: Union[ + Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int] + ], _: int, num_warps: Optional[int], num_stages: Optional[int], @@ -38,11 +42,17 @@ class TVMFFIJITFunction(object): kwargs["num_stages"] = num_stages kernel: CompiledKernel = self.fn[grid](*args, **kwargs) self.num_warps, _, _ = kernel.packed_metadata + self.signature = [*inspect.signature(self.fn.fn).parameters.keys()] self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()] self.kernel = kernel.kernel return kernel - def __getitem__(self, grid: Tuple[int, int, int]): + def __getitem__( + self, + grid: Union[ + Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int] + ], + ): return self.fn[grid] @property diff --git a/python/triton_tvm_ffi/templates/gendef.cc.j2 b/python/triton_tvm_ffi/templates/gendef.cc.j2 new file mode 100644 index 0000000..3f51f56 --- /dev/null +++ b/python/triton_tvm_ffi/templates/gendef.cc.j2 @@ -0,0 +1,42 @@ +#include +#include +#include + +{% include "grid.h" %} + +#define {{ name | upper }}_NAME "{{ uniquename }}" +{% for fn in fns %} +{% if fn.ctypes is none %} +#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}") +{% else %} +TVM_FFI_EMBED_CUBIN(triton_{{ fn.fnname }}); +#define {{ fn.fnname | upper}}_STUB(__grid, __stream, __numWarps, __numStages{% for ctype in fn.ctypes %}, {{ "__arg" ~ loop.index0 }}{% endfor %}) do { \ +const tvm::ffi::Map __meta = { \ +{% for name in fn.signature %} + { "{{ name }}", __arg{{ loop.index0 }} }, \ +{% endfor %} + }; \ +static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{{ fn.fnname }}, "{{ fn.fnname }}"); \ +tvm::ffi::dim3 __gridDim = MakeGridDim(__grid, __meta); \ +tvm::ffi::dim3 __block({% if fn.num_warps != none %}{{ fn.num_warps }}{% else %}__numWarps{% endif %} * 32, 1, 1); \ +void *dummy = nullptr +{%- for ctype in fn.ctypes -%} + {%- if ctype == "CUdeviceptr" -%} + , *__arg{{ loop.index0 }}_ptr=__arg{{ loop.index0 }}.data_ptr() + {%- endif -%} +{%- endfor -%}; \ +void *__params[] = { +{%- for ctype in fn.ctypes -%} + {%- if ctype != none -%} + &__arg{{ loop.index0 }} + {%- if ctype == "CUdeviceptr" -%} + _ptr + {%- endif -%}, + {%- endif -%} +{%- endfor -%}&dummy, &dummy }; \ +TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __gridDim, __block, static_cast(__stream))); \ +} while (false) +{% endif %} +{% endfor %} + +{{ code }} diff --git a/python/triton_tvm_ffi/templates/grid.h b/python/triton_tvm_ffi/templates/grid.h new file mode 100644 index 0000000..b7e52e9 --- /dev/null +++ b/python/triton_tvm_ffi/templates/grid.h @@ -0,0 +1,29 @@ +#ifndef TRITON_TVM_FFI_GRID_H +#define TRITON_TVM_FFI_GRID_H + +#include +#include +#include + +template +inline tvm::ffi::dim3 +MakeGridDim(const T &grid, + const tvm::ffi::Map &meta); + +template <> +inline tvm::ffi::dim3 MakeGridDim>( + const tvm::ffi::Tuple &grid, + const tvm::ffi::Map &) { + return tvm::ffi::dim3(grid.get<0>(), grid.get<1>(), grid.get<2>()); +} + +template <> +inline tvm::ffi::dim3 MakeGridDim( + const tvm::ffi::Function &grid, + const tvm::ffi::Map &meta) { + tvm::ffi::Tuple tuple = + grid(meta).cast>(); + return MakeGridDim(tuple, meta); +} + +#endif diff --git a/python/triton_tvm_ffi/wrap.py b/python/triton_tvm_ffi/wrap.py index 6b96b3d..2908991 100644 --- a/python/triton_tvm_ffi/wrap.py +++ b/python/triton_tvm_ffi/wrap.py @@ -1,8 +1,9 @@ from functools import cached_property from io import TextIOWrapper from pathlib import Path -from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Final, List, Optional, Sequence, Union +import jinja2 import torch.utils.cpp_extension import tvm_ffi @@ -38,6 +39,11 @@ class TVMFFIWrapperFunction(object): self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = ( extra_include_paths ) + self.env: Final[jinja2.Environment] = jinja2.Environment( + loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"), + trim_blocks=True, + ) + self.tpl: Final[jinja2.Template] = self.env.get_template("gendef.cc.j2") def __call__(self, *args, **kwargs) -> None: func: tvm_ffi.Function = self.compile() @@ -53,19 +59,9 @@ class TVMFFIWrapperFunction(object): @property def emit(self) -> str: - defs: str = "\n".join( - [ - "#include ", - "#include ", - "#include ", - f'#define {self.name.upper()}_NAME "{self.uniquename}"', - *map( - self.gendef, - self.fns, - ), - ] + return self.tpl.render( + code=self.code, fns=self.fns, name=self.name, uniquename=self.uniquename ) - return f"{defs}\n{self.code}" @property def uniquename(self) -> str: @@ -90,42 +86,6 @@ class TVMFFIWrapperFunction(object): ) return tvm_ffi.get_global_func(self.uniquename) - @staticmethod - def gendef(fn: TVMFFIJITFunction) -> str: - if fn.ctypes is None: - return f'#define {fn.fnname.upper()}_STUB tvm::ffi::Function::GetGlobalRequired("{fn.fullname}")' - else: - ctype_arg_list: List[Tuple[str, str]] = [ - (ctype, f"__arg{idx}") for idx, ctype in enumerate(fn.ctypes) - ] - - return """ -TVM_FFI_EMBED_CUBIN(triton_{fnname}); -#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\ -static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\ -tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\ -tvm::ffi::dim3 __block({} * 32, 1, 1); \\ -void *dummy = nullptr, {}; \\ -void *__params[] = {{{}, &dummy, &dummy}}; \\ -TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast(__stream))); \\ -}} while (false) -""".format( - fn.fnname.upper(), - ", ".join(arg for _, arg in ctype_arg_list), - fn.num_warps if fn.num_warps is not None else "__numWarps", - ", ".join( - f"*{arg}_ptr = {arg}.data_ptr()" - for ctype, arg in ctype_arg_list - if ctype == "CUdeviceptr" - ), - ", ".join( - f"&{arg}" if ctype != "CUdeviceptr" else f"&{arg}_ptr" - for ctype, arg in ctype_arg_list - if ctype is not None - ), - fnname=fn.fnname, - ).strip() - def wrap( fns: List[TVMFFIJITFunction],