add length check to unwrap arguments

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-12 15:31:50 +08:00
parent 599957e156
commit a727526794
4 changed files with 16 additions and 15 deletions

View File

@@ -28,11 +28,12 @@ triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }
CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
void *dummy = nullptr; \
const size_t __args_len = __args.size(); \
{% for ctype in fn.ctypes %}
{% if ctype == "CUdeviceptr" %}
void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
void *__arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr() : __kwargs[__varname{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
{% elif ctype != none %}
{{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \
{{ ctype }} __arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<{{ ctype }}>() : __kwargs[__varname{{ loop.index0 }}].cast<{{ ctype }}>(); \
{% endif %}
{% endfor %}
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \