#include #include #include "triton_tvm_ffi/grid.h" #include "triton_tvm_ffi/kernel.h" #include "triton_tvm_ffi/macro.h" #include "triton_tvm_ffi/meta.h" #define {{ name | upper }}_NAME "{{ uniquename }}" {% for fn in fns %} {% if fn.ctypes is none %} #define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}") {% else %} static constexpr char __fnname_{{ fn.fnname }}[] = "{{ fn.fnname }}"; static constexpr char __cubin_{{ fn.fnname }}[] = "{{ fn.kernel_cstr }}"; #define {{ fn.fnname | upper }}_STUB(__grid, __device, __stream, __args, __kwargs) do { \ tvm::ffi::Map __meta = { \ {% if fn.best_config != none %} {% for k, v in fn.best_config.items() %} { "{{ k }}", {{ v }} }, \ {% endfor %} {% endif %} }; \ {% for type in fn.signature %} static constexpr char __varname{{ loop.index0 }}[] = "{{ type }}"; \ {% endfor %} triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }}{% if not loop.last %}, {% endif %}{% endfor %}>::apply(__meta, __args, __kwargs); \ CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \ tvm::ffi::Tuple __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \ void *dummy = nullptr; \ const size_t __args_len = __args.size(); \ {% for ctype in fn.ctypes %} {% if ctype == "CUdeviceptr" %} void *__arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast().data_ptr() : __kwargs[__varname{{ loop.index0 }}].cast().data_ptr(); \ {% elif ctype != none %} {{ ctype }} __arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<{{ ctype }}>() : __kwargs[__varname{{ loop.index0 }}].cast<{{ ctype }}>(); \ {% endif %} {% endfor %} void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \ __CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast(__stream), __params, nullptr)); \ } while (false) {% endif %} {% endfor %} {{ code }}