enable lambda function for grid descriptor

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-05 15:59:22 +08:00
parent 8b8aa6cb84
commit f6c7a48c1b
6 changed files with 104 additions and 57 deletions

View File

@@ -0,0 +1,42 @@
#include <cuda.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
{% include "grid.h" %}
#define {{ name | upper }}_NAME "{{ uniquename }}"
{% for fn in fns %}
{% if fn.ctypes is none %}
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
{% else %}
TVM_FFI_EMBED_CUBIN(triton_{{ fn.fnname }});
#define {{ fn.fnname | upper}}_STUB(__grid, __stream, __numWarps, __numStages{% for ctype in fn.ctypes %}, {{ "__arg" ~ loop.index0 }}{% endfor %}) do { \
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
{% for name in fn.signature %}
{ "{{ name }}", __arg{{ loop.index0 }} }, \
{% endfor %}
}; \
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{{ fn.fnname }}, "{{ fn.fnname }}"); \
tvm::ffi::dim3 __gridDim = MakeGridDim(__grid, __meta); \
tvm::ffi::dim3 __block({% if fn.num_warps != none %}{{ fn.num_warps }}{% else %}__numWarps{% endif %} * 32, 1, 1); \
void *dummy = nullptr
{%- for ctype in fn.ctypes -%}
{%- if ctype == "CUdeviceptr" -%}
, *__arg{{ loop.index0 }}_ptr=__arg{{ loop.index0 }}.data_ptr()
{%- endif -%}
{%- endfor -%}; \
void *__params[] = {
{%- for ctype in fn.ctypes -%}
{%- if ctype != none -%}
&__arg{{ loop.index0 }}
{%- if ctype == "CUdeviceptr" -%}
_ptr
{%- endif -%},
{%- endif -%}
{%- endfor -%}&dummy, &dummy }; \
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __gridDim, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \
} while (false)
{% endif %}
{% endfor %}
{{ code }}

View File

@@ -0,0 +1,29 @@
#ifndef TRITON_TVM_FFI_GRID_H
#define TRITON_TVM_FFI_GRID_H
#include <cstdint>
#include <tvm/ffi/extra/cuda/base.h>
#include <tvm/ffi/tvm_ffi.h>
template <typename T>
inline tvm::ffi::dim3
MakeGridDim(const T &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta);
template <>
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
const tvm::ffi::Tuple<int32_t, int32_t, int32_t> &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &) {
return tvm::ffi::dim3(grid.get<0>(), grid.get<1>(), grid.get<2>());
}
template <>
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Function>(
const tvm::ffi::Function &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta) {
tvm::ffi::Tuple<int32_t, int32_t, int32_t> tuple =
grid(meta).cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
return MakeGridDim(tuple, meta);
}
#endif