mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
use templates to substitute parts of macros
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
@@ -1,41 +1,19 @@
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <optional>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
#include "triton_tvm_ffi/grid.h"
|
||||
#include "triton_tvm_ffi/kernel.h"
|
||||
#include "triton_tvm_ffi/macro.h"
|
||||
#include "triton_tvm_ffi/meta.h"
|
||||
|
||||
#define {{ name | upper }}_NAME "{{ uniquename }}"
|
||||
{% for fn in fns %}
|
||||
{% if fn.ctypes is none %}
|
||||
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
|
||||
{% else %}
|
||||
static const char __cubin[] = "{{ fn.kernel_cstr }}";
|
||||
static constexpr char __fnname_{{ fn.fnname }}[] = "{{ fn.fnname }}";
|
||||
static constexpr char __cubin_{{ fn.fnname }}[] = "{{ fn.kernel_cstr }}";
|
||||
|
||||
#define __CUDA_CHECK(code) assert((code) == CUDA_SUCCESS)
|
||||
|
||||
static CUfunction __Get{{ fn.fnname }}Kernel() {
|
||||
static std::optional<CUfunction> function = std::nullopt;
|
||||
if (!function) {
|
||||
CUmodule module;
|
||||
CUfunction func;
|
||||
__CUDA_CHECK(cuModuleLoadData(&module, __cubin));
|
||||
__CUDA_CHECK(cuModuleGetFunction(&func, module, "{{ fn.fnname }}"));
|
||||
{% if fn.shmem > 49152 %}
|
||||
int shared_optin, shared_static;
|
||||
__CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, /* TODO: we assume the device id is 0 here, but this may not work on devices with more than one gpu */0));
|
||||
if (shared_optin >= 49152) {
|
||||
__CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
|
||||
__CUDA_CHECK(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
|
||||
}
|
||||
{% endif %}
|
||||
function = func;
|
||||
}
|
||||
return *function;
|
||||
}
|
||||
|
||||
#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __args, __kwargs) do { \
|
||||
const char *__signature[] = { "{{ fn.signature | join("\", \"") }}" }; \
|
||||
#define {{ fn.fnname | upper }}_STUB(__grid, __device, __stream, __args, __kwargs) do { \
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||
{% if fn.best_config != none %}
|
||||
{% for k, v in fn.best_config.items() %}
|
||||
@@ -43,24 +21,19 @@ tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
}; \
|
||||
for (size_t __i = 0, __size_args = __args.size(); __i < sizeof(__signature) / sizeof(const char *); ++__i) { \
|
||||
if (__i < __size_args) { \
|
||||
__meta.Set(__signature[__i], __args[__i]); \
|
||||
} else if (auto __val = __kwargs.Get(__signature[__i])) { \
|
||||
__meta.Set(__signature[__i], *__val); \
|
||||
} \
|
||||
} \
|
||||
CUfunction __function = __Get{{ fn.fnname }}Kernel(); \
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = MakeGridDim(__grid, __meta); \
|
||||
{% for type in fn.signature %}
|
||||
static constexpr char __varname{{ loop.index0 }}[] = "{{ type }}"; \
|
||||
{% endfor %}
|
||||
triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }}{% if not loop.last %}, {% endif %}{% endfor %}>::apply(__meta, __args, __kwargs); \
|
||||
CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
|
||||
void *dummy = nullptr; \
|
||||
{% for ctype in fn.ctypes %}
|
||||
{% if ctype != none %}
|
||||
{% if ctype == "CUdeviceptr" %}
|
||||
void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
|
||||
{% else %}
|
||||
{% elif ctype != none %}
|
||||
{{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
|
||||
__CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast<CUstream>(__stream), __params, nullptr)); \
|
||||
|
||||
Reference in New Issue
Block a user