diff --git a/examples/add/add.cc b/examples/add/add.cc index 6810ccd..ed67885 100644 --- a/examples/add/add.cc +++ b/examples/add/add.cc @@ -5,8 +5,7 @@ #include #ifndef ADD_KERNEL_STUB -#define ADD_KERNEL_STUB(grid, stream, numWarps, numStages, x, y, output, \ - numel, BLOCK_SIZE) +#define ADD_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs) #endif #ifndef ADD_NAME @@ -27,7 +26,9 @@ tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { tvm::ffi::Optional numWarps = std::nullopt, numStages = std::nullopt; DLDevice device = x.device(); void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); - ADD_KERNEL_STUB(grid, stream, numWarps, numStages, x, y, output, numel, 1024); + tvm::ffi::Array args = {x, y, output, numel, 1024}; + tvm::ffi::Map kwargs = {}; + ADD_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs); return output; } diff --git a/examples/mm/mm.cc b/examples/mm/mm.cc new file mode 100644 index 0000000..dbb8543 --- /dev/null +++ b/examples/mm/mm.cc @@ -0,0 +1,56 @@ +#include +#include +#include +#include + +#ifndef MATMUL_KERNEL_STUB +#define MATMUL_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs) +#endif + +#ifndef MATMUL_NAME +#define MATMUL_NAME "" +#endif + +tvm::ffi::Tensor Matmul(tvm::ffi::Tensor a, tvm::ffi::Tensor b, + tvm::ffi::String activation) { + at::Tensor atorch = at::fromDLPack(a.ToDLPack()), + btorch = at::fromDLPack(b.ToDLPack()); + const int32_t M = atorch.size(0), K = atorch.size(1), N = btorch.size(1); + at::Tensor ctorch = at::empty({M, N}, atorch.options()); + tvm::ffi::Function grid = tvm::ffi::Function::FromTyped( + [M, N](const tvm::ffi::Map &meta) + -> tvm::ffi::Tuple { + const int32_t BLOCK_SIZE_M = meta["BLOCK_SIZE_M"].cast(), + BLOCK_SIZE_N = meta["BLOCK_SIZE_N"].cast(); + return tvm::ffi::Tuple{ + (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M * + ((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N), + 1, 1}; + }); + tvm::ffi::Optional numWarps = std::nullopt, numStages = std::nullopt; + DLDevice device = a.device(); + void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); + tvm::ffi::Tensor c = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ctorch)); + tvm::ffi::Array args = {a, + b, + c, + M, + N, + K, + atorch.stride(0), + atorch.stride(1), + btorch.stride(0), + btorch.stride(1), + ctorch.stride(0), + ctorch.stride(1)}; + tvm::ffi::Map kwargs = { + {"ACTIVATION", activation}, + }; + MATMUL_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs); + return c; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def(MATMUL_NAME, Matmul); +} diff --git a/examples/mm/mm.py b/examples/mm/mm.py new file mode 100644 index 0000000..20caa01 --- /dev/null +++ b/examples/mm/mm.py @@ -0,0 +1,307 @@ +from pathlib import Path +import time +import torch + +import triton +import triton.language as tl +import triton_tvm_ffi + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def get_autotune_config(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + ] + + +@triton_tvm_ffi.jit +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + + +def matmul_triton(a, b, activation=""): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + matmul_kernel[ + lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + ]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ACTIVATION=activation, + ) + return c + + +@triton_tvm_ffi.torch_wrap( + [matmul_kernel], + Path(__file__).parent / "mm.cc", +) +def matmul(a: torch.Tensor, b: torch.Tensor, activation: str = "") -> torch.Tensor: ... + + +if __name__ == "__main__": + torch.manual_seed(0) + a = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5 + b = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5 + torch_output = torch.matmul(a, b) + triton_output = matmul_triton(a, b, "") + tvm_ffi_output = matmul(a, b, "") + assert torch.allclose(torch_output, triton_output, atol=1e-2, rtol=1e-2) + assert torch.allclose(torch_output, tvm_ffi_output, atol=1e-2, rtol=1e-2) + tvm_ffi_output = matmul(a, b, "") + assert torch.allclose(torch_output, tvm_ffi_output, atol=1e-2, rtol=1e-2) + + round = 1000 + cp0 = time.perf_counter_ns() + for _ in range(round): + a @ b + cp1 = time.perf_counter_ns() + for _ in range(round): + matmul_triton(a, b, "") + cp2 = time.perf_counter_ns() + for _ in range(round): + matmul(a, b, "") + cp3 = time.perf_counter_ns() + print( + f"PyTorch matmul: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton matmul: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI matmul: {(cp3 - cp2) / round * 1e-6:.3f} ms" + ) diff --git a/examples/softmax/softmax.cc b/examples/softmax/softmax.cc index a060572..4892252 100644 --- a/examples/softmax/softmax.cc +++ b/examples/softmax/softmax.cc @@ -4,9 +4,7 @@ #include #ifndef SOFTMAX_KERNEL_STUB -#define SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, output, input, \ - inputStride, outputStride, nRows, nCols, \ - BLOCK_SIZE) +#define SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs) #endif #ifndef SOFTMAX_NAME @@ -23,9 +21,12 @@ tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) { tvm::ffi::Tensor y = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ytorch)); tvm::ffi::Tuple grid{nRows / 1024, 1, 1}; DLDevice device = x.device(); - void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); - SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, y, x, xStride, yStride, - nRows, nCols, BLOCK_SIZE); + void* stream = + TVMFFIEnvGetStream(device.device_type, device.device_id); + tvm::ffi::Array args = {y, x, xStride, yStride, + nRows, nCols, BLOCK_SIZE}; + tvm::ffi::Map kwargs = {}; + SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs); return y; } diff --git a/python/triton_tvm_ffi/jit.py b/python/triton_tvm_ffi/jit.py index a262726..ae95af0 100644 --- a/python/triton_tvm_ffi/jit.py +++ b/python/triton_tvm_ffi/jit.py @@ -2,24 +2,38 @@ from __future__ import annotations from functools import cached_property import inspect -from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Final, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import torch from triton.compiler import CompiledKernel -from triton.runtime import JITFunction +from triton.runtime import Autotuner, JITFunction import tvm_ffi from .utils import type_canonicalize class TVMFFIJITFunction(object): - def __init__(self, fn: JITFunction, *args, **kwargs) -> None: + def __init__(self, fn: Union[Autotuner, JITFunction], *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.fn: Final[JITFunction] = fn - self.signature: Optional[List[str]] = None + self.fn: Final[Union[Autotuner, JITFunction]] = fn + self.signature: List[str] = [*inspect.signature(self.basefn).parameters.keys()] + self.best_config: Optional[Dict[str, Any]] = None self.ctypes: Optional[List[Optional[str]]] = None self.kernel: Optional[bytes] = None self.num_warps: Optional[int] = None + self.shmem: int = 0 @tvm_ffi.register_global_func(self.fullname) def _( @@ -29,22 +43,23 @@ class TVMFFIJITFunction(object): _: int, num_warps: Optional[int], num_stages: Optional[int], - *args, - **kwargs, + args: Sequence[Any], + kwargs: Mapping[str, Any], ): - args: List[Any] = map(self.canonicalize, args) + args: Iterator[Any] = map(self.canonicalize, args) kwargs: Dict[str, Any] = { - k: self.canonicalize(v) for k, v in kwargs.items() - } + k: v for k, v in zip(self.signature, args) if v is not None + } | {k: self.canonicalize(v) for k, v in kwargs.items()} if num_warps is not None: kwargs["num_warps"] = num_warps if num_stages is not None: kwargs["num_stages"] = num_stages kernel: CompiledKernel = self.fn[grid](*args, **kwargs) - self.num_warps, _, _ = kernel.packed_metadata - self.signature = [*inspect.signature(self.fn.fn).parameters.keys()] + self.num_warps, _, self.shmem = kernel.packed_metadata self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()] self.kernel = kernel.kernel + if isinstance(self.fn, Autotuner): + self.best_config = self.fn.best_config.all_kwargs() return kernel def __getitem__( @@ -55,6 +70,10 @@ class TVMFFIJITFunction(object): ): return self.fn[grid] + @cached_property + def basefn(self) -> Callable: + return self.jitfn.fn + @property def cache_hash(self) -> int: return self.ctypes_hash ^ self.kernel_hash @@ -63,21 +82,35 @@ class TVMFFIJITFunction(object): def ctypes_hash(self) -> int: return hash(tuple(self.ctypes) if self.ctypes is not None else None) - @property - def kernel_hash(self) -> int: - return hash(self.kernel) - @cached_property def fnname(self) -> str: - return self.fn.fn.__name__ + return self.basefn.__name__ @cached_property def fullname(self) -> str: return f"triton.{self.name}" + @cached_property + def jitfn(self) -> JITFunction: + fn: Union[Autotuner, JITFunction] = self.fn + while not isinstance(fn, JITFunction): + fn = fn.fn + return fn + + @property + def kernel_hash(self) -> int: + return hash(self.kernel) + + @property + def kernel_cstr(self) -> Optional[str]: + if self.kernel is not None: + return "".join(f"\\x{byte:02x}" for byte in self.kernel) + else: + return None + @cached_property def name(self) -> str: - return f"{self.fnname}_{hash(self.fn.fn)}" + return f"{self.fnname}_{hash(self.basefn)}" @staticmethod def canonicalize(val: Any) -> Any: diff --git a/python/triton_tvm_ffi/templates/gendef.cc.j2 b/python/triton_tvm_ffi/templates/gendef.cc.j2 index 3f51f56..ffa09eb 100644 --- a/python/triton_tvm_ffi/templates/gendef.cc.j2 +++ b/python/triton_tvm_ffi/templates/gendef.cc.j2 @@ -1,4 +1,6 @@ +#include #include +#include #include #include @@ -9,32 +11,60 @@ {% if fn.ctypes is none %} #define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}") {% else %} -TVM_FFI_EMBED_CUBIN(triton_{{ fn.fnname }}); -#define {{ fn.fnname | upper}}_STUB(__grid, __stream, __numWarps, __numStages{% for ctype in fn.ctypes %}, {{ "__arg" ~ loop.index0 }}{% endfor %}) do { \ -const tvm::ffi::Map __meta = { \ -{% for name in fn.signature %} - { "{{ name }}", __arg{{ loop.index0 }} }, \ +static const char __cubin[] = "{{ fn.kernel_cstr }}"; + +#define __CUDA_CHECK(code) assert((code) == CUDA_SUCCESS) + +static CUfunction __Get{{ fn.fnname }}Kernel() { + static std::optional function = std::nullopt; + if (!function) { + CUmodule module; + CUfunction func; + __CUDA_CHECK(cuModuleLoadData(&module, __cubin)); + __CUDA_CHECK(cuModuleGetFunction(&func, module, "{{ fn.fnname }}")); +{% if fn.shmem > 49152 %} + int shared_optin, shared_static; + __CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, /* TODO: we assume the device id is 0 here, but this may not work on devices with more than one gpu */0)); + if (shared_optin >= 49152) { + __CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func)); + __CUDA_CHECK(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); + } +{% endif %} + function = func; + } + return *function; +} + +#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __numWarps, __numStages, __args, __kwargs) do { \ +const char *__signature[] = { "{{ fn.signature | join("\", \"") }}" }; \ +tvm::ffi::Map __meta = { \ +{% if fn.best_config != none %} +{% for k, v in fn.best_config.items() %} + { "{{ k }}", {{ v }} }, \ {% endfor %} - }; \ -static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{{ fn.fnname }}, "{{ fn.fnname }}"); \ -tvm::ffi::dim3 __gridDim = MakeGridDim(__grid, __meta); \ -tvm::ffi::dim3 __block({% if fn.num_warps != none %}{{ fn.num_warps }}{% else %}__numWarps{% endif %} * 32, 1, 1); \ -void *dummy = nullptr -{%- for ctype in fn.ctypes -%} - {%- if ctype == "CUdeviceptr" -%} - , *__arg{{ loop.index0 }}_ptr=__arg{{ loop.index0 }}.data_ptr() - {%- endif -%} -{%- endfor -%}; \ -void *__params[] = { -{%- for ctype in fn.ctypes -%} - {%- if ctype != none -%} - &__arg{{ loop.index0 }} - {%- if ctype == "CUdeviceptr" -%} - _ptr - {%- endif -%}, - {%- endif -%} -{%- endfor -%}&dummy, &dummy }; \ -TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __gridDim, __block, static_cast(__stream))); \ +{% endif %} +}; \ +for (size_t __i = 0, __size_args = __args.size(); __i < sizeof(__signature) / sizeof(const char *); ++__i) { \ + if (__i < __size_args) { \ + __meta.Set(__signature[__i], __args[__i]); \ + } else if (auto __val = __kwargs.Get(__signature[__i])) { \ + __meta.Set(__signature[__i], *__val); \ + } \ +} \ +CUfunction __function = __Get{{ fn.fnname }}Kernel(); \ +tvm::ffi::Tuple __gridDim = MakeGridDim(__grid, __meta); \ +void *dummy = nullptr; \ +{% for ctype in fn.ctypes %} +{% if ctype != none %} +{% if ctype == "CUdeviceptr" %} +void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast().data_ptr(); \ +{% else %} +{{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \ +{% endif %} +{% endif %} +{% endfor %} +void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \ +__CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast(__stream), __params, nullptr)); \ } while (false) {% endif %} {% endfor %} diff --git a/python/triton_tvm_ffi/templates/grid.h b/python/triton_tvm_ffi/templates/grid.h index b7e52e9..ea3ad13 100644 --- a/python/triton_tvm_ffi/templates/grid.h +++ b/python/triton_tvm_ffi/templates/grid.h @@ -6,19 +6,21 @@ #include template -inline tvm::ffi::dim3 +inline tvm::ffi::Tuple MakeGridDim(const T &grid, const tvm::ffi::Map &meta); template <> -inline tvm::ffi::dim3 MakeGridDim>( +inline tvm::ffi::Tuple +MakeGridDim>( const tvm::ffi::Tuple &grid, const tvm::ffi::Map &) { - return tvm::ffi::dim3(grid.get<0>(), grid.get<1>(), grid.get<2>()); + return grid; } template <> -inline tvm::ffi::dim3 MakeGridDim( +inline tvm::ffi::Tuple +MakeGridDim( const tvm::ffi::Function &grid, const tvm::ffi::Map &meta) { tvm::ffi::Tuple tuple =