mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
use templates to substitute parts of macros
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
@@ -2,9 +2,10 @@
|
||||
#define TRITON_TVM_FFI_GRID_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/extra/cuda/base.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <typename T>
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim(const T &grid,
|
||||
@@ -28,4 +29,6 @@ MakeGridDim<tvm::ffi::Function>(
|
||||
return MakeGridDim(tuple, meta);
|
||||
}
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
|
||||
39
include/triton_tvm_ffi/kernel.h
Normal file
39
include/triton_tvm_ffi/kernel.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#ifndef TRITON_TVM_FFI_KERNEL_H_
|
||||
#define TRITON_TVM_FFI_KERNEL_H_
|
||||
|
||||
#include "macro.h"
|
||||
#include <cstdint>
|
||||
#include <cuda.h>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <const char kFnName[], const char kCubin[], size_t kSMem>
|
||||
inline CUfunction GetKernel(int32_t device) {
|
||||
static std::unordered_map<int32_t, CUfunction> functions = {};
|
||||
if (functions.find(device) == functions.end()) {
|
||||
CUmodule module;
|
||||
CUfunction func;
|
||||
__CUDA_CHECK(cuModuleLoadData(&module, kCubin));
|
||||
__CUDA_CHECK(cuModuleGetFunction(&func, module, kFnName));
|
||||
if (kSMem > 49152) {
|
||||
int32_t shared_optin, shared_static;
|
||||
__CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
if (shared_optin >= kSMem) {
|
||||
__CUDA_CHECK(cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
|
||||
__CUDA_CHECK(cuFuncSetAttribute(
|
||||
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
}
|
||||
functions[device] = func;
|
||||
}
|
||||
return functions[device];
|
||||
};
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
22
include/triton_tvm_ffi/macro.h
Normal file
22
include/triton_tvm_ffi/macro.h
Normal file
@@ -0,0 +1,22 @@
|
||||
#ifndef TRITON_TVM_FFI_MACRO_H_
|
||||
#define TRITON_TVM_FFI_MACRO_H_
|
||||
|
||||
#include <cuda.h>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#define __CUDA_CHECK(__code) \
|
||||
do { \
|
||||
if ((__code) != CUDA_SUCCESS) { \
|
||||
const char *errorName = nullptr, *errorStr = nullptr; \
|
||||
cuGetErrorName((__code), &errorName); \
|
||||
cuGetErrorString((__code), &errorStr); \
|
||||
std::ostringstream __oss; \
|
||||
__oss << "[" << errorName << "] " << errorStr << ", at " << __FILE__ \
|
||||
<< ":" << __LINE__; \
|
||||
throw std::runtime_error(__oss.str()); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
52
include/triton_tvm_ffi/meta.h
Normal file
52
include/triton_tvm_ffi/meta.h
Normal file
@@ -0,0 +1,52 @@
|
||||
#ifndef TRITON_TVM_FFI_META_H_
|
||||
#define TRITON_TVM_FFI_META_H_
|
||||
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <const char... Ks[]> struct FillMetaImpl {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
|
||||
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs);
|
||||
};
|
||||
|
||||
template <> struct FillMetaImpl<> {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
|
||||
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {}
|
||||
};
|
||||
|
||||
template <const char K[], const char... Ks[]> struct FillMetaImpl<K, Ks...> {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
|
||||
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
|
||||
if (argsBegin != argsEnd) {
|
||||
meta.Set(K, *argsBegin++);
|
||||
} else if (auto val = kwargs.Get(K)) {
|
||||
meta.Set(K, *val);
|
||||
}
|
||||
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
|
||||
}
|
||||
};
|
||||
|
||||
template <const char... Ks[]> struct FillMeta {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &args,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator argsBegin = args.begin();
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator argsEnd = args.end();
|
||||
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user