From 213e4fc060f8ed42e6553e4781caf53236dd9319 Mon Sep 17 00:00:00 2001 From: jinjieliu Date: Sun, 8 Feb 2026 22:24:12 +0800 Subject: [PATCH] use templates to substitute parts of macros Signed-off-by: jinjieliu --- examples/add/add.cc | 6 +-- examples/mm/mm.cc | 4 +- examples/softmax/softmax.cc | 4 +- include/triton_tvm_ffi/grid.h | 5 +- include/triton_tvm_ffi/kernel.h | 39 ++++++++++++++ include/triton_tvm_ffi/macro.h | 22 ++++++++ include/triton_tvm_ffi/meta.h | 52 +++++++++++++++++++ python/triton_tvm_ffi/jit.py | 3 +- python/triton_tvm_ffi/templates/gendef.cc.j2 | 53 +++++--------------- 9 files changed, 139 insertions(+), 49 deletions(-) create mode 100644 include/triton_tvm_ffi/kernel.h create mode 100644 include/triton_tvm_ffi/macro.h create mode 100644 include/triton_tvm_ffi/meta.h diff --git a/examples/add/add.cc b/examples/add/add.cc index bddda09..7a05646 100644 --- a/examples/add/add.cc +++ b/examples/add/add.cc @@ -1,11 +1,11 @@ -#include "tvm/ffi/function.h" #include #include #include +#include #include #ifndef ADD_KERNEL_STUB -#define ADD_KERNEL_STUB(grid, stream, args, kwargs) +#define ADD_KERNEL_STUB(grid, device, stream, args, kwargs) #endif #ifndef ADD_NAME @@ -27,7 +27,7 @@ tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); tvm::ffi::Array args = {x, y, output, numel, 1024}; tvm::ffi::Map kwargs = {}; - ADD_KERNEL_STUB(grid, stream, args, kwargs); + ADD_KERNEL_STUB(grid, device.device_id, stream, args, kwargs); return output; } diff --git a/examples/mm/mm.cc b/examples/mm/mm.cc index c4b84a9..ed23ca8 100644 --- a/examples/mm/mm.cc +++ b/examples/mm/mm.cc @@ -4,7 +4,7 @@ #include #ifndef MATMUL_KERNEL_STUB -#define MATMUL_KERNEL_STUB(grid, stream, args, kwargs) +#define MATMUL_KERNEL_STUB(grid, device, stream, args, kwargs) #endif #ifndef MATMUL_NAME @@ -45,7 +45,7 @@ tvm::ffi::Tensor Matmul(tvm::ffi::Tensor a, tvm::ffi::Tensor b, tvm::ffi::Map kwargs = { {"ACTIVATION", activation}, }; - MATMUL_KERNEL_STUB(grid, stream, args, kwargs); + MATMUL_KERNEL_STUB(grid, device.device_id, stream, args, kwargs); return c; } diff --git a/examples/softmax/softmax.cc b/examples/softmax/softmax.cc index a3c5709..a11e361 100644 --- a/examples/softmax/softmax.cc +++ b/examples/softmax/softmax.cc @@ -4,7 +4,7 @@ #include #ifndef SOFTMAX_KERNEL_STUB -#define SOFTMAX_KERNEL_STUB(grid, stream, args, kwargs) +#define SOFTMAX_KERNEL_STUB(grid, device, stream, args, kwargs) #endif #ifndef SOFTMAX_NAME @@ -24,7 +24,7 @@ tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) { tvm::ffi::Array args = {y, x, xStride, yStride, nRows, nCols, BLOCK_SIZE}; tvm::ffi::Map kwargs = {}; - SOFTMAX_KERNEL_STUB(grid, stream, args, kwargs); + SOFTMAX_KERNEL_STUB(grid, device.device_id, stream, args, kwargs); return y; } diff --git a/include/triton_tvm_ffi/grid.h b/include/triton_tvm_ffi/grid.h index 3480158..0b528db 100644 --- a/include/triton_tvm_ffi/grid.h +++ b/include/triton_tvm_ffi/grid.h @@ -2,9 +2,10 @@ #define TRITON_TVM_FFI_GRID_H_ #include -#include #include +namespace triton_tvm_ffi { + template inline tvm::ffi::Tuple MakeGridDim(const T &grid, @@ -28,4 +29,6 @@ MakeGridDim( return MakeGridDim(tuple, meta); } +} // namespace triton_tvm_ffi + #endif diff --git a/include/triton_tvm_ffi/kernel.h b/include/triton_tvm_ffi/kernel.h new file mode 100644 index 0000000..baa08d7 --- /dev/null +++ b/include/triton_tvm_ffi/kernel.h @@ -0,0 +1,39 @@ +#ifndef TRITON_TVM_FFI_KERNEL_H_ +#define TRITON_TVM_FFI_KERNEL_H_ + +#include "macro.h" +#include +#include +#include + +namespace triton_tvm_ffi { + +template +inline CUfunction GetKernel(int32_t device) { + static std::unordered_map functions = {}; + if (functions.find(device) == functions.end()) { + CUmodule module; + CUfunction func; + __CUDA_CHECK(cuModuleLoadData(&module, kCubin)); + __CUDA_CHECK(cuModuleGetFunction(&func, module, kFnName)); + if (kSMem > 49152) { + int32_t shared_optin, shared_static; + __CUDA_CHECK(cuDeviceGetAttribute( + &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + if (shared_optin >= kSMem) { + __CUDA_CHECK(cuFuncGetAttribute( + &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func)); + __CUDA_CHECK(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static)); + } + } + functions[device] = func; + } + return functions[device]; +}; + +} // namespace triton_tvm_ffi + +#endif diff --git a/include/triton_tvm_ffi/macro.h b/include/triton_tvm_ffi/macro.h new file mode 100644 index 0000000..36432e9 --- /dev/null +++ b/include/triton_tvm_ffi/macro.h @@ -0,0 +1,22 @@ +#ifndef TRITON_TVM_FFI_MACRO_H_ +#define TRITON_TVM_FFI_MACRO_H_ + +#include +#include +#include +#include + +#define __CUDA_CHECK(__code) \ + do { \ + if ((__code) != CUDA_SUCCESS) { \ + const char *errorName = nullptr, *errorStr = nullptr; \ + cuGetErrorName((__code), &errorName); \ + cuGetErrorString((__code), &errorStr); \ + std::ostringstream __oss; \ + __oss << "[" << errorName << "] " << errorStr << ", at " << __FILE__ \ + << ":" << __LINE__; \ + throw std::runtime_error(__oss.str()); \ + } \ + } while (false) + +#endif diff --git a/include/triton_tvm_ffi/meta.h b/include/triton_tvm_ffi/meta.h new file mode 100644 index 0000000..c877a83 --- /dev/null +++ b/include/triton_tvm_ffi/meta.h @@ -0,0 +1,52 @@ +#ifndef TRITON_TVM_FFI_META_H_ +#define TRITON_TVM_FFI_META_H_ + +#include + +namespace triton_tvm_ffi { + +template struct FillMetaImpl { + static inline void + apply(tvm::ffi::Map &meta, + tvm::ffi::Array::iterator &argsBegin, + const tvm::ffi::Array::iterator &argsEnd, + const tvm::ffi::Map &kwargs); +}; + +template <> struct FillMetaImpl<> { + static inline void + apply(tvm::ffi::Map &meta, + tvm::ffi::Array::iterator &argsBegin, + const tvm::ffi::Array::iterator &argsEnd, + const tvm::ffi::Map &kwargs) {} +}; + +template struct FillMetaImpl { + static inline void + apply(tvm::ffi::Map &meta, + tvm::ffi::Array::iterator &argsBegin, + const tvm::ffi::Array::iterator &argsEnd, + const tvm::ffi::Map &kwargs) { + if (argsBegin != argsEnd) { + meta.Set(K, *argsBegin++); + } else if (auto val = kwargs.Get(K)) { + meta.Set(K, *val); + } + FillMetaImpl::apply(meta, argsBegin, argsEnd, kwargs); + } +}; + +template struct FillMeta { + static inline void + apply(tvm::ffi::Map &meta, + const tvm::ffi::Array &args, + const tvm::ffi::Map &kwargs) { + tvm::ffi::Array::iterator argsBegin = args.begin(); + tvm::ffi::Array::iterator argsEnd = args.end(); + FillMetaImpl::apply(meta, argsBegin, argsEnd, kwargs); + } +}; + +} // namespace triton_tvm_ffi + +#endif diff --git a/python/triton_tvm_ffi/jit.py b/python/triton_tvm_ffi/jit.py index cd6b115..31e07f3 100644 --- a/python/triton_tvm_ffi/jit.py +++ b/python/triton_tvm_ffi/jit.py @@ -40,7 +40,8 @@ class TVMFFIJITFunction(object): grid: Union[ Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int] ], - _: int, + _device: int, + _stream: int, args: Sequence[Any], kwargs: Mapping[str, Any], ): diff --git a/python/triton_tvm_ffi/templates/gendef.cc.j2 b/python/triton_tvm_ffi/templates/gendef.cc.j2 index 09c4e6e..fba40ca 100644 --- a/python/triton_tvm_ffi/templates/gendef.cc.j2 +++ b/python/triton_tvm_ffi/templates/gendef.cc.j2 @@ -1,41 +1,19 @@ -#include #include -#include -#include #include #include "triton_tvm_ffi/grid.h" +#include "triton_tvm_ffi/kernel.h" +#include "triton_tvm_ffi/macro.h" +#include "triton_tvm_ffi/meta.h" #define {{ name | upper }}_NAME "{{ uniquename }}" {% for fn in fns %} {% if fn.ctypes is none %} #define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}") {% else %} -static const char __cubin[] = "{{ fn.kernel_cstr }}"; +static constexpr char __fnname_{{ fn.fnname }}[] = "{{ fn.fnname }}"; +static constexpr char __cubin_{{ fn.fnname }}[] = "{{ fn.kernel_cstr }}"; -#define __CUDA_CHECK(code) assert((code) == CUDA_SUCCESS) - -static CUfunction __Get{{ fn.fnname }}Kernel() { - static std::optional function = std::nullopt; - if (!function) { - CUmodule module; - CUfunction func; - __CUDA_CHECK(cuModuleLoadData(&module, __cubin)); - __CUDA_CHECK(cuModuleGetFunction(&func, module, "{{ fn.fnname }}")); -{% if fn.shmem > 49152 %} - int shared_optin, shared_static; - __CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, /* TODO: we assume the device id is 0 here, but this may not work on devices with more than one gpu */0)); - if (shared_optin >= 49152) { - __CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func)); - __CUDA_CHECK(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); - } -{% endif %} - function = func; - } - return *function; -} - -#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __args, __kwargs) do { \ -const char *__signature[] = { "{{ fn.signature | join("\", \"") }}" }; \ +#define {{ fn.fnname | upper }}_STUB(__grid, __device, __stream, __args, __kwargs) do { \ tvm::ffi::Map __meta = { \ {% if fn.best_config != none %} {% for k, v in fn.best_config.items() %} @@ -43,24 +21,19 @@ tvm::ffi::Map __meta = { \ {% endfor %} {% endif %} }; \ -for (size_t __i = 0, __size_args = __args.size(); __i < sizeof(__signature) / sizeof(const char *); ++__i) { \ - if (__i < __size_args) { \ - __meta.Set(__signature[__i], __args[__i]); \ - } else if (auto __val = __kwargs.Get(__signature[__i])) { \ - __meta.Set(__signature[__i], *__val); \ - } \ -} \ -CUfunction __function = __Get{{ fn.fnname }}Kernel(); \ -tvm::ffi::Tuple __gridDim = MakeGridDim(__grid, __meta); \ +{% for type in fn.signature %} +static constexpr char __varname{{ loop.index0 }}[] = "{{ type }}"; \ +{% endfor %} +triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }}{% if not loop.last %}, {% endif %}{% endfor %}>::apply(__meta, __args, __kwargs); \ +CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \ +tvm::ffi::Tuple __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \ void *dummy = nullptr; \ {% for ctype in fn.ctypes %} -{% if ctype != none %} {% if ctype == "CUdeviceptr" %} void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast().data_ptr(); \ -{% else %} +{% elif ctype != none %} {{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \ {% endif %} -{% endif %} {% endfor %} void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \ __CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast(__stream), __params, nullptr)); \