use templates to substitute parts of macros

Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
jinjieliu
2026-02-08 22:24:12 +08:00
parent 1c4f13c8f0
commit 213e4fc060
9 changed files with 139 additions and 49 deletions

View File

@@ -2,9 +2,10 @@
#define TRITON_TVM_FFI_GRID_H_
#include <cstdint>
#include <tvm/ffi/extra/cuda/base.h>
#include <tvm/ffi/tvm_ffi.h>
namespace triton_tvm_ffi {
template <typename T>
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
MakeGridDim(const T &grid,
@@ -28,4 +29,6 @@ MakeGridDim<tvm::ffi::Function>(
return MakeGridDim(tuple, meta);
}
} // namespace triton_tvm_ffi
#endif

View File

@@ -0,0 +1,39 @@
#ifndef TRITON_TVM_FFI_KERNEL_H_
#define TRITON_TVM_FFI_KERNEL_H_
#include "macro.h"
#include <cstdint>
#include <cuda.h>
#include <unordered_map>
namespace triton_tvm_ffi {
template <const char kFnName[], const char kCubin[], size_t kSMem>
inline CUfunction GetKernel(int32_t device) {
static std::unordered_map<int32_t, CUfunction> functions = {};
if (functions.find(device) == functions.end()) {
CUmodule module;
CUfunction func;
__CUDA_CHECK(cuModuleLoadData(&module, kCubin));
__CUDA_CHECK(cuModuleGetFunction(&func, module, kFnName));
if (kSMem > 49152) {
int32_t shared_optin, shared_static;
__CUDA_CHECK(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared_optin >= kSMem) {
__CUDA_CHECK(cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
__CUDA_CHECK(cuFuncSetAttribute(
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
}
functions[device] = func;
}
return functions[device];
};
} // namespace triton_tvm_ffi
#endif

View File

@@ -0,0 +1,22 @@
#ifndef TRITON_TVM_FFI_MACRO_H_
#define TRITON_TVM_FFI_MACRO_H_
#include <cuda.h>
#include <sstream>
#include <stdexcept>
#include <string>
#define __CUDA_CHECK(__code) \
do { \
if ((__code) != CUDA_SUCCESS) { \
const char *errorName = nullptr, *errorStr = nullptr; \
cuGetErrorName((__code), &errorName); \
cuGetErrorString((__code), &errorStr); \
std::ostringstream __oss; \
__oss << "[" << errorName << "] " << errorStr << ", at " << __FILE__ \
<< ":" << __LINE__; \
throw std::runtime_error(__oss.str()); \
} \
} while (false)
#endif

View File

@@ -0,0 +1,52 @@
#ifndef TRITON_TVM_FFI_META_H_
#define TRITON_TVM_FFI_META_H_
#include <tvm/ffi/tvm_ffi.h>
namespace triton_tvm_ffi {
template <const char... Ks[]> struct FillMetaImpl {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs);
};
template <> struct FillMetaImpl<> {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {}
};
template <const char K[], const char... Ks[]> struct FillMetaImpl<K, Ks...> {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
if (argsBegin != argsEnd) {
meta.Set(K, *argsBegin++);
} else if (auto val = kwargs.Get(K)) {
meta.Set(K, *val);
}
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
}
};
template <const char... Ks[]> struct FillMeta {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
const tvm::ffi::Array<tvm::ffi::Any> &args,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
tvm::ffi::Array<tvm::ffi::Any>::iterator argsBegin = args.begin();
tvm::ffi::Array<tvm::ffi::Any>::iterator argsEnd = args.end();
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
}
};
} // namespace triton_tvm_ffi
#endif