From a727526794accf4b1cae3047d408d0d3f1da0f5c Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Thu, 12 Feb 2026 15:31:50 +0800 Subject: [PATCH] add length check to unwrap arguments Signed-off-by: Jinjie Liu --- examples/attention/attnbwd.cc | 9 ++++----- examples/attention/attnbwdpre.cc | 11 ++++++----- examples/attention/attnfwd.cc | 6 +++--- python/triton_tvm_ffi/templates/gendef.cc.j2 | 5 +++-- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/attention/attnbwd.cc b/examples/attention/attnbwd.cc index 394ff62..e888496 100644 --- a/examples/attention/attnbwd.cc +++ b/examples/attention/attnbwd.cc @@ -1,10 +1,9 @@ -#include "ATen/core/ATen_fwd.h" -#include "ATen/ops/empty.h" -#include "c10/core/Device.h" -#include "torch/headeronly/core/DeviceType.h" -#include "tvm/ffi/container/tensor.h" #include +#include #include +#include +#include +#include #include #include #include diff --git a/examples/attention/attnbwdpre.cc b/examples/attention/attnbwdpre.cc index b9d518b..2b1cb77 100644 --- a/examples/attention/attnbwdpre.cc +++ b/examples/attention/attnbwdpre.cc @@ -1,9 +1,9 @@ -#include "ATen/core/ATen_fwd.h" -#include "ATen/ops/empty.h" -#include "torch/headeronly/core/DeviceType.h" -#include "tvm/ffi/container/tensor.h" #include +#include #include +#include +#include +#include #include #include #include @@ -28,8 +28,9 @@ tvm::ffi::Tensor AttnBwdPreprocess(tvm::ffi::Tensor o, tvm::ffi::Tensor do_, tvm::ffi::Tensor::FromDLPack(at::toDLPack(deltaTorch)); tvm::ffi::Tuple grid(kNCtx / kPreBlock, kBatch * kNHead, 1); - tvm::ffi::Array args = {o, do_, delta, kBatch, kNHead, kNCtx}; + tvm::ffi::Array args = {o, do_, delta, kBatch, kNHead}; tvm::ffi::Map kwargs = { + {"N_CTX", kNCtx}, {"BLOCK_M", kPreBlock}, {"HEAD_DIM", kHeadDim}, }; diff --git a/examples/attention/attnfwd.cc b/examples/attention/attnfwd.cc index eec91cb..48c700d 100644 --- a/examples/attention/attnfwd.cc +++ b/examples/attention/attnfwd.cc @@ -1,8 +1,8 @@ -#include "ATen/core/ATen_fwd.h" -#include "ATen/ops/empty.h" -#include "tvm/ffi/container/tensor.h" #include +#include #include +#include +#include #include #include #include diff --git a/python/triton_tvm_ffi/templates/gendef.cc.j2 b/python/triton_tvm_ffi/templates/gendef.cc.j2 index fba40ca..a690a61 100644 --- a/python/triton_tvm_ffi/templates/gendef.cc.j2 +++ b/python/triton_tvm_ffi/templates/gendef.cc.j2 @@ -28,11 +28,12 @@ triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 } CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \ tvm::ffi::Tuple __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \ void *dummy = nullptr; \ +const size_t __args_len = __args.size(); \ {% for ctype in fn.ctypes %} {% if ctype == "CUdeviceptr" %} -void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast().data_ptr(); \ +void *__arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast().data_ptr() : __kwargs[__varname{{ loop.index0 }}].cast().data_ptr(); \ {% elif ctype != none %} -{{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \ +{{ ctype }} __arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<{{ ctype }}>() : __kwargs[__varname{{ loop.index0 }}].cast<{{ ctype }}>(); \ {% endif %} {% endfor %} void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \