mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
use templates to substitute parts of macros
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
39
include/triton_tvm_ffi/kernel.h
Normal file
39
include/triton_tvm_ffi/kernel.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#ifndef TRITON_TVM_FFI_KERNEL_H_
|
||||
#define TRITON_TVM_FFI_KERNEL_H_
|
||||
|
||||
#include "macro.h"
|
||||
#include <cstdint>
|
||||
#include <cuda.h>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <const char kFnName[], const char kCubin[], size_t kSMem>
|
||||
inline CUfunction GetKernel(int32_t device) {
|
||||
static std::unordered_map<int32_t, CUfunction> functions = {};
|
||||
if (functions.find(device) == functions.end()) {
|
||||
CUmodule module;
|
||||
CUfunction func;
|
||||
__CUDA_CHECK(cuModuleLoadData(&module, kCubin));
|
||||
__CUDA_CHECK(cuModuleGetFunction(&func, module, kFnName));
|
||||
if (kSMem > 49152) {
|
||||
int32_t shared_optin, shared_static;
|
||||
__CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
if (shared_optin >= kSMem) {
|
||||
__CUDA_CHECK(cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
|
||||
__CUDA_CHECK(cuFuncSetAttribute(
|
||||
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
}
|
||||
functions[device] = func;
|
||||
}
|
||||
return functions[device];
|
||||
};
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user