use templates to substitute parts of macros

Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
jinjieliu
2026-02-08 22:24:12 +08:00
parent 1c4f13c8f0
commit 213e4fc060
9 changed files with 139 additions and 49 deletions

View File

@@ -1,11 +1,11 @@
#include "tvm/ffi/function.h"
#include <ATen/DLConvertor.h> #include <ATen/DLConvertor.h>
#include <ATen/dlpack.h> #include <ATen/dlpack.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h> #include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/tvm_ffi.h> #include <tvm/ffi/tvm_ffi.h>
#ifndef ADD_KERNEL_STUB #ifndef ADD_KERNEL_STUB
#define ADD_KERNEL_STUB(grid, stream, args, kwargs) #define ADD_KERNEL_STUB(grid, device, stream, args, kwargs)
#endif #endif
#ifndef ADD_NAME #ifndef ADD_NAME
@@ -27,7 +27,7 @@ tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
tvm::ffi::Array<tvm::ffi::Any> args = {x, y, output, numel, 1024}; tvm::ffi::Array<tvm::ffi::Any> args = {x, y, output, numel, 1024};
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {}; tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {};
ADD_KERNEL_STUB(grid, stream, args, kwargs); ADD_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
return output; return output;
} }

View File

@@ -4,7 +4,7 @@
#include <tvm/ffi/tvm_ffi.h> #include <tvm/ffi/tvm_ffi.h>
#ifndef MATMUL_KERNEL_STUB #ifndef MATMUL_KERNEL_STUB
#define MATMUL_KERNEL_STUB(grid, stream, args, kwargs) #define MATMUL_KERNEL_STUB(grid, device, stream, args, kwargs)
#endif #endif
#ifndef MATMUL_NAME #ifndef MATMUL_NAME
@@ -45,7 +45,7 @@ tvm::ffi::Tensor Matmul(tvm::ffi::Tensor a, tvm::ffi::Tensor b,
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = { tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
{"ACTIVATION", activation}, {"ACTIVATION", activation},
}; };
MATMUL_KERNEL_STUB(grid, stream, args, kwargs); MATMUL_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
return c; return c;
} }

View File

@@ -4,7 +4,7 @@
#include <tvm/ffi/tvm_ffi.h> #include <tvm/ffi/tvm_ffi.h>
#ifndef SOFTMAX_KERNEL_STUB #ifndef SOFTMAX_KERNEL_STUB
#define SOFTMAX_KERNEL_STUB(grid, stream, args, kwargs) #define SOFTMAX_KERNEL_STUB(grid, device, stream, args, kwargs)
#endif #endif
#ifndef SOFTMAX_NAME #ifndef SOFTMAX_NAME
@@ -24,7 +24,7 @@ tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) {
tvm::ffi::Array<tvm::ffi::Any> args = {y, x, xStride, yStride, tvm::ffi::Array<tvm::ffi::Any> args = {y, x, xStride, yStride,
nRows, nCols, BLOCK_SIZE}; nRows, nCols, BLOCK_SIZE};
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {}; tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {};
SOFTMAX_KERNEL_STUB(grid, stream, args, kwargs); SOFTMAX_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
return y; return y;
} }

View File

@@ -2,9 +2,10 @@
#define TRITON_TVM_FFI_GRID_H_ #define TRITON_TVM_FFI_GRID_H_
#include <cstdint> #include <cstdint>
#include <tvm/ffi/extra/cuda/base.h>
#include <tvm/ffi/tvm_ffi.h> #include <tvm/ffi/tvm_ffi.h>
namespace triton_tvm_ffi {
template <typename T> template <typename T>
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t> inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
MakeGridDim(const T &grid, MakeGridDim(const T &grid,
@@ -28,4 +29,6 @@ MakeGridDim<tvm::ffi::Function>(
return MakeGridDim(tuple, meta); return MakeGridDim(tuple, meta);
} }
} // namespace triton_tvm_ffi
#endif #endif

View File

@@ -0,0 +1,39 @@
#ifndef TRITON_TVM_FFI_KERNEL_H_
#define TRITON_TVM_FFI_KERNEL_H_
#include "macro.h"
#include <cstdint>
#include <cuda.h>
#include <unordered_map>
namespace triton_tvm_ffi {
template <const char kFnName[], const char kCubin[], size_t kSMem>
inline CUfunction GetKernel(int32_t device) {
static std::unordered_map<int32_t, CUfunction> functions = {};
if (functions.find(device) == functions.end()) {
CUmodule module;
CUfunction func;
__CUDA_CHECK(cuModuleLoadData(&module, kCubin));
__CUDA_CHECK(cuModuleGetFunction(&func, module, kFnName));
if (kSMem > 49152) {
int32_t shared_optin, shared_static;
__CUDA_CHECK(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared_optin >= kSMem) {
__CUDA_CHECK(cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
__CUDA_CHECK(cuFuncSetAttribute(
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
}
functions[device] = func;
}
return functions[device];
};
} // namespace triton_tvm_ffi
#endif

View File

@@ -0,0 +1,22 @@
#ifndef TRITON_TVM_FFI_MACRO_H_
#define TRITON_TVM_FFI_MACRO_H_
#include <cuda.h>
#include <sstream>
#include <stdexcept>
#include <string>
#define __CUDA_CHECK(__code) \
do { \
if ((__code) != CUDA_SUCCESS) { \
const char *errorName = nullptr, *errorStr = nullptr; \
cuGetErrorName((__code), &errorName); \
cuGetErrorString((__code), &errorStr); \
std::ostringstream __oss; \
__oss << "[" << errorName << "] " << errorStr << ", at " << __FILE__ \
<< ":" << __LINE__; \
throw std::runtime_error(__oss.str()); \
} \
} while (false)
#endif

View File

@@ -0,0 +1,52 @@
#ifndef TRITON_TVM_FFI_META_H_
#define TRITON_TVM_FFI_META_H_
#include <tvm/ffi/tvm_ffi.h>
namespace triton_tvm_ffi {
template <const char... Ks[]> struct FillMetaImpl {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs);
};
template <> struct FillMetaImpl<> {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {}
};
template <const char K[], const char... Ks[]> struct FillMetaImpl<K, Ks...> {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
if (argsBegin != argsEnd) {
meta.Set(K, *argsBegin++);
} else if (auto val = kwargs.Get(K)) {
meta.Set(K, *val);
}
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
}
};
template <const char... Ks[]> struct FillMeta {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
const tvm::ffi::Array<tvm::ffi::Any> &args,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
tvm::ffi::Array<tvm::ffi::Any>::iterator argsBegin = args.begin();
tvm::ffi::Array<tvm::ffi::Any>::iterator argsEnd = args.end();
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
}
};
} // namespace triton_tvm_ffi
#endif

View File

@@ -40,7 +40,8 @@ class TVMFFIJITFunction(object):
grid: Union[ grid: Union[
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int] Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
], ],
_: int, _device: int,
_stream: int,
args: Sequence[Any], args: Sequence[Any],
kwargs: Mapping[str, Any], kwargs: Mapping[str, Any],
): ):

View File

@@ -1,41 +1,19 @@
#include <cassert>
#include <cuda.h> #include <cuda.h>
#include <optional>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h> #include <tvm/ffi/function.h>
#include "triton_tvm_ffi/grid.h" #include "triton_tvm_ffi/grid.h"
#include "triton_tvm_ffi/kernel.h"
#include "triton_tvm_ffi/macro.h"
#include "triton_tvm_ffi/meta.h"
#define {{ name | upper }}_NAME "{{ uniquename }}" #define {{ name | upper }}_NAME "{{ uniquename }}"
{% for fn in fns %} {% for fn in fns %}
{% if fn.ctypes is none %} {% if fn.ctypes is none %}
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}") #define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
{% else %} {% else %}
static const char __cubin[] = "{{ fn.kernel_cstr }}"; static constexpr char __fnname_{{ fn.fnname }}[] = "{{ fn.fnname }}";
static constexpr char __cubin_{{ fn.fnname }}[] = "{{ fn.kernel_cstr }}";
#define __CUDA_CHECK(code) assert((code) == CUDA_SUCCESS) #define {{ fn.fnname | upper }}_STUB(__grid, __device, __stream, __args, __kwargs) do { \
static CUfunction __Get{{ fn.fnname }}Kernel() {
static std::optional<CUfunction> function = std::nullopt;
if (!function) {
CUmodule module;
CUfunction func;
__CUDA_CHECK(cuModuleLoadData(&module, __cubin));
__CUDA_CHECK(cuModuleGetFunction(&func, module, "{{ fn.fnname }}"));
{% if fn.shmem > 49152 %}
int shared_optin, shared_static;
__CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, /* TODO: we assume the device id is 0 here, but this may not work on devices with more than one gpu */0));
if (shared_optin >= 49152) {
__CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
__CUDA_CHECK(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
}
{% endif %}
function = func;
}
return *function;
}
#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __args, __kwargs) do { \
const char *__signature[] = { "{{ fn.signature | join("\", \"") }}" }; \
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \ tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
{% if fn.best_config != none %} {% if fn.best_config != none %}
{% for k, v in fn.best_config.items() %} {% for k, v in fn.best_config.items() %}
@@ -43,24 +21,19 @@ tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
{% endfor %} {% endfor %}
{% endif %} {% endif %}
}; \ }; \
for (size_t __i = 0, __size_args = __args.size(); __i < sizeof(__signature) / sizeof(const char *); ++__i) { \ {% for type in fn.signature %}
if (__i < __size_args) { \ static constexpr char __varname{{ loop.index0 }}[] = "{{ type }}"; \
__meta.Set(__signature[__i], __args[__i]); \ {% endfor %}
} else if (auto __val = __kwargs.Get(__signature[__i])) { \ triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }}{% if not loop.last %}, {% endif %}{% endfor %}>::apply(__meta, __args, __kwargs); \
__meta.Set(__signature[__i], *__val); \ CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
} \ tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
} \
CUfunction __function = __Get{{ fn.fnname }}Kernel(); \
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = MakeGridDim(__grid, __meta); \
void *dummy = nullptr; \ void *dummy = nullptr; \
{% for ctype in fn.ctypes %} {% for ctype in fn.ctypes %}
{% if ctype != none %}
{% if ctype == "CUdeviceptr" %} {% if ctype == "CUdeviceptr" %}
void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \ void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
{% else %} {% elif ctype != none %}
{{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \ {{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \
{% endif %} {% endif %}
{% endif %}
{% endfor %} {% endfor %}
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \ void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
__CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast<CUstream>(__stream), __params, nullptr)); \ __CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast<CUstream>(__stream), __params, nullptr)); \