fix bugs on illegal memory access on Release

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-30 00:15:47 +08:00
parent 1a01c9f2d8
commit bdc9c03b75
7 changed files with 71 additions and 39 deletions
+27 -13
View File
@@ -5,6 +5,9 @@
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef NDEBUG
#include <cassert>
#endif
#define CUDA_CHECK(code) \
do { \
@@ -20,11 +23,11 @@ using namespace triton_tvm_ffi;
// --------------- Definitions ---------------
template <Type T> struct ValueCast {
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
};
template <Type... Ts> struct ValueCastSet {
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
};
using GenericValueCastSet = ValueCastSet<
@@ -36,29 +39,32 @@ using GenericValueCastSet = ValueCastSet<
// --------------- Implementations ---------------
template <Type T>
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void **addr,
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void *ptr,
const TypedValue &value) {
if (value.GetType() == T) {
if constexpr (T == Type::PTR) {
tvm::ffi::TensorView cvalue =
value.GetValue().cast<tvm::ffi::TensorView>();
void **ptr = reinterpret_cast<void **>(alloca(sizeof(void *)));
*ptr = cvalue.data_ptr();
*addr = ptr;
void **p = reinterpret_cast<void **>(ptr);
*p = cvalue.data_ptr();
} else if constexpr (T == Type::CONSTEXPR) {
#ifdef NDEBUG
__builtin_unreachable();
#else
throw NotImplementedException("CONSTEXPR for value casting");
#endif
} else {
using ctype = type_to_ctype_t<T>;
ctype cvalue = value.GetValue().cast<ctype>();
ctype *ptr = reinterpret_cast<ctype *>(alloca(sizeof(ctype)));
*ptr = cvalue;
*addr = ptr;
ctype *p = reinterpret_cast<ctype *>(ptr);
*p = cvalue;
}
return true;
} else {
return false;
}
return false;
}
template <Type... Ts>
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void **ptr,
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void *ptr,
const TypedValue &value) {
return (ValueCast<Ts>::apply(ptr, value) || ...);
}
@@ -164,15 +170,23 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
}
const int32_t kernelArgNum = kernelArgs.size();
uint8_t *buffer =
reinterpret_cast<uint8_t *>(alloca(kMaxOpaqueSize * (kernelArgNum)));
void **params =
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
size_t j = 0;
for (size_t i = 0; i < kernelArgNum; ++i) {
TypedValue value = kernelArgs[i].cast<TypedValue>();
if (value.GetType() != Type::CONSTEXPR) {
if (!GenericValueCastSet::apply(&params[j++], value)) {
#ifndef NDEBUG
assert(j < kernelArgNum);
#endif
void *ptr = buffer + j * kMaxOpaqueSize;
params[j] = ptr;
if (!GenericValueCastSet::apply(ptr, value)) {
throw UnknownTypeException(value.GetType());
}
++j;
}
}
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args