#include #include #include {% include "grid.h" %} #define {{ name | upper }}_NAME "{{ uniquename }}" {% for fn in fns %} {% if fn.ctypes is none %} #define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}") {% else %} TVM_FFI_EMBED_CUBIN(triton_{{ fn.fnname }}); #define {{ fn.fnname | upper}}_STUB(__grid, __stream, __numWarps, __numStages{% for ctype in fn.ctypes %}, {{ "__arg" ~ loop.index0 }}{% endfor %}) do { \ const tvm::ffi::Map __meta = { \ {% for name in fn.signature %} { "{{ name }}", __arg{{ loop.index0 }} }, \ {% endfor %} }; \ static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{{ fn.fnname }}, "{{ fn.fnname }}"); \ tvm::ffi::dim3 __gridDim = MakeGridDim(__grid, __meta); \ tvm::ffi::dim3 __block({% if fn.num_warps != none %}{{ fn.num_warps }}{% else %}__numWarps{% endif %} * 32, 1, 1); \ void *dummy = nullptr {%- for ctype in fn.ctypes -%} {%- if ctype == "CUdeviceptr" -%} , *__arg{{ loop.index0 }}_ptr=__arg{{ loop.index0 }}.data_ptr() {%- endif -%} {%- endfor -%}; \ void *__params[] = { {%- for ctype in fn.ctypes -%} {%- if ctype != none -%} &__arg{{ loop.index0 }} {%- if ctype == "CUdeviceptr" -%} _ptr {%- endif -%}, {%- endif -%} {%- endfor -%}&dummy, &dummy }; \ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __gridDim, __block, static_cast(__stream))); \ } while (false) {% endif %} {% endfor %} {{ code }}