mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
add length check to unwrap arguments
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -1,10 +1,9 @@
|
|||||||
#include "ATen/core/ATen_fwd.h"
|
|
||||||
#include "ATen/ops/empty.h"
|
|
||||||
#include "c10/core/Device.h"
|
|
||||||
#include "torch/headeronly/core/DeviceType.h"
|
|
||||||
#include "tvm/ffi/container/tensor.h"
|
|
||||||
#include <ATen/DLConvertor.h>
|
#include <ATen/DLConvertor.h>
|
||||||
|
#include <ATen/core/ATen_fwd.h>
|
||||||
#include <ATen/dlpack.h>
|
#include <ATen/dlpack.h>
|
||||||
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <torch/headeronly/core/DeviceType.h>
|
||||||
|
#include <tvm/ffi/container/tensor.h>
|
||||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
#include <tvm/ffi/function.h>
|
#include <tvm/ffi/function.h>
|
||||||
#include <tvm/ffi/tvm_ffi.h>
|
#include <tvm/ffi/tvm_ffi.h>
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
#include "ATen/core/ATen_fwd.h"
|
|
||||||
#include "ATen/ops/empty.h"
|
|
||||||
#include "torch/headeronly/core/DeviceType.h"
|
|
||||||
#include "tvm/ffi/container/tensor.h"
|
|
||||||
#include <ATen/DLConvertor.h>
|
#include <ATen/DLConvertor.h>
|
||||||
|
#include <ATen/core/ATen_fwd.h>
|
||||||
#include <ATen/dlpack.h>
|
#include <ATen/dlpack.h>
|
||||||
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <torch/headeronly/core/DeviceType.h>
|
||||||
|
#include <tvm/ffi/container/tensor.h>
|
||||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
#include <tvm/ffi/function.h>
|
#include <tvm/ffi/function.h>
|
||||||
#include <tvm/ffi/tvm_ffi.h>
|
#include <tvm/ffi/tvm_ffi.h>
|
||||||
@@ -28,8 +28,9 @@ tvm::ffi::Tensor AttnBwdPreprocess(tvm::ffi::Tensor o, tvm::ffi::Tensor do_,
|
|||||||
tvm::ffi::Tensor::FromDLPack(at::toDLPack(deltaTorch));
|
tvm::ffi::Tensor::FromDLPack(at::toDLPack(deltaTorch));
|
||||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid(kNCtx / kPreBlock,
|
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid(kNCtx / kPreBlock,
|
||||||
kBatch * kNHead, 1);
|
kBatch * kNHead, 1);
|
||||||
tvm::ffi::Array<tvm::ffi::Any> args = {o, do_, delta, kBatch, kNHead, kNCtx};
|
tvm::ffi::Array<tvm::ffi::Any> args = {o, do_, delta, kBatch, kNHead};
|
||||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
|
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
|
||||||
|
{"N_CTX", kNCtx},
|
||||||
{"BLOCK_M", kPreBlock},
|
{"BLOCK_M", kPreBlock},
|
||||||
{"HEAD_DIM", kHeadDim},
|
{"HEAD_DIM", kHeadDim},
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
#include "ATen/core/ATen_fwd.h"
|
|
||||||
#include "ATen/ops/empty.h"
|
|
||||||
#include "tvm/ffi/container/tensor.h"
|
|
||||||
#include <ATen/DLConvertor.h>
|
#include <ATen/DLConvertor.h>
|
||||||
|
#include <ATen/core/ATen_fwd.h>
|
||||||
#include <ATen/dlpack.h>
|
#include <ATen/dlpack.h>
|
||||||
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <tvm/ffi/container/tensor.h>
|
||||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
#include <tvm/ffi/function.h>
|
#include <tvm/ffi/function.h>
|
||||||
#include <tvm/ffi/tvm_ffi.h>
|
#include <tvm/ffi/tvm_ffi.h>
|
||||||
|
|||||||
@@ -28,11 +28,12 @@ triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }
|
|||||||
CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
|
CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
|
||||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
|
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
|
||||||
void *dummy = nullptr; \
|
void *dummy = nullptr; \
|
||||||
|
const size_t __args_len = __args.size(); \
|
||||||
{% for ctype in fn.ctypes %}
|
{% for ctype in fn.ctypes %}
|
||||||
{% if ctype == "CUdeviceptr" %}
|
{% if ctype == "CUdeviceptr" %}
|
||||||
void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
|
void *__arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr() : __kwargs[__varname{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
|
||||||
{% elif ctype != none %}
|
{% elif ctype != none %}
|
||||||
{{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \
|
{{ ctype }} __arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<{{ ctype }}>() : __kwargs[__varname{{ loop.index0 }}].cast<{{ ctype }}>(); \
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
|
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
|
||||||
|
|||||||
Reference in New Issue
Block a user