mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
46 lines
2.2 KiB
Django/Jinja
46 lines
2.2 KiB
Django/Jinja
#include <cuda.h>
|
|
#include <tvm/ffi/function.h>
|
|
#include "triton_tvm_ffi/grid.h"
|
|
#include "triton_tvm_ffi/kernel.h"
|
|
#include "triton_tvm_ffi/macro.h"
|
|
#include "triton_tvm_ffi/meta.h"
|
|
|
|
#define {{ name | upper }}_NAME "{{ uniquename }}"
|
|
{% for fn in fns %}
|
|
{% if fn.ctypes is none %}
|
|
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
|
|
{% else %}
|
|
static constexpr char __fnname_{{ fn.fnname }}[] = "{{ fn.fnname }}";
|
|
static constexpr char __cubin_{{ fn.fnname }}[] = "{{ fn.kernel_cstr }}";
|
|
|
|
#define {{ fn.fnname | upper }}_STUB(__grid, __device, __stream, __args, __kwargs) do { \
|
|
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
|
{% if fn.best_config != none %}
|
|
{% for k, v in fn.best_config.items() %}
|
|
{ "{{ k }}", {{ v }} }, \
|
|
{% endfor %}
|
|
{% endif %}
|
|
}; \
|
|
{% for type in fn.signature %}
|
|
static constexpr char __varname{{ loop.index0 }}[] = "{{ type }}"; \
|
|
{% endfor %}
|
|
triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }}{% if not loop.last %}, {% endif %}{% endfor %}>::apply(__meta, __args, __kwargs); \
|
|
CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
|
|
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
|
|
void *dummy = nullptr; \
|
|
const size_t __args_len = __args.size(); \
|
|
{% for ctype in fn.ctypes %}
|
|
{% if ctype == "CUdeviceptr" %}
|
|
void *__arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr() : __kwargs[__varname{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
|
|
{% elif ctype != none %}
|
|
{{ ctype }} __arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<{{ ctype }}>() : __kwargs[__varname{{ loop.index0 }}].cast<{{ ctype }}>(); \
|
|
{% endif %}
|
|
{% endfor %}
|
|
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
|
|
__CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast<CUstream>(__stream), __params, nullptr)); \
|
|
} while (false)
|
|
{% endif %}
|
|
{% endfor %}
|
|
|
|
{{ code }}
|