Files
triton-tvm-ffi/python/triton_tvm_ffi/templates/gendef.cc.j2
2026-02-12 15:31:50 +08:00

46 lines
2.2 KiB
Django/Jinja

#include <cuda.h>
#include <tvm/ffi/function.h>
#include "triton_tvm_ffi/grid.h"
#include "triton_tvm_ffi/kernel.h"
#include "triton_tvm_ffi/macro.h"
#include "triton_tvm_ffi/meta.h"
#define {{ name | upper }}_NAME "{{ uniquename }}"
{% for fn in fns %}
{% if fn.ctypes is none %}
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
{% else %}
static constexpr char __fnname_{{ fn.fnname }}[] = "{{ fn.fnname }}";
static constexpr char __cubin_{{ fn.fnname }}[] = "{{ fn.kernel_cstr }}";
#define {{ fn.fnname | upper }}_STUB(__grid, __device, __stream, __args, __kwargs) do { \
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
{% if fn.best_config != none %}
{% for k, v in fn.best_config.items() %}
{ "{{ k }}", {{ v }} }, \
{% endfor %}
{% endif %}
}; \
{% for type in fn.signature %}
static constexpr char __varname{{ loop.index0 }}[] = "{{ type }}"; \
{% endfor %}
triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }}{% if not loop.last %}, {% endif %}{% endfor %}>::apply(__meta, __args, __kwargs); \
CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
void *dummy = nullptr; \
const size_t __args_len = __args.size(); \
{% for ctype in fn.ctypes %}
{% if ctype == "CUdeviceptr" %}
void *__arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr() : __kwargs[__varname{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
{% elif ctype != none %}
{{ ctype }} __arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<{{ ctype }}>() : __kwargs[__varname{{ loop.index0 }}].cast<{{ ctype }}>(); \
{% endif %}
{% endfor %}
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
__CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast<CUstream>(__stream), __params, nullptr)); \
} while (false)
{% endif %}
{% endfor %}
{{ code }}