add length check to unwrap arguments

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-12 15:31:50 +08:00
parent 599957e156
commit a727526794
4 changed files with 16 additions and 15 deletions

View File

@@ -1,9 +1,9 @@
#include "ATen/core/ATen_fwd.h"
#include "ATen/ops/empty.h"
#include "torch/headeronly/core/DeviceType.h"
#include "tvm/ffi/container/tensor.h"
#include <ATen/DLConvertor.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/dlpack.h>
#include <ATen/ops/empty.h>
#include <torch/headeronly/core/DeviceType.h>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/tvm_ffi.h>
@@ -28,8 +28,9 @@ tvm::ffi::Tensor AttnBwdPreprocess(tvm::ffi::Tensor o, tvm::ffi::Tensor do_,
tvm::ffi::Tensor::FromDLPack(at::toDLPack(deltaTorch));
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid(kNCtx / kPreBlock,
kBatch * kNHead, 1);
tvm::ffi::Array<tvm::ffi::Any> args = {o, do_, delta, kBatch, kNHead, kNCtx};
tvm::ffi::Array<tvm::ffi::Any> args = {o, do_, delta, kBatch, kNHead};
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
{"N_CTX", kNCtx},
{"BLOCK_M", kPreBlock},
{"HEAD_DIM", kHeadDim},
};