fix bugs on illegal memory access on Release

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-30 00:15:47 +08:00
parent 1a01c9f2d8
commit bdc9c03b75
7 changed files with 71 additions and 39 deletions
+26 -6
View File
@@ -4,6 +4,7 @@
#include <cstdint>
#include <tvm/ffi/object.h>
#include <tvm/ffi/string.h>
#include <type_traits>
namespace triton_tvm_ffi {
@@ -23,14 +24,14 @@ namespace triton_tvm_ffi {
V(FP16, "fp16", double) \
V(BF16, "bf16", double) \
V(FP32, "f32", double) \
V(FP64, "fp64", double)
V(FP64, "fp64", double) \
V(PTR, "*?", void *) \
V(CONSTEXPR, "constexpr", void)
enum class Type : int64_t {
#define DEFINE_ENUM(type, str, ctype) type,
TYPE_TABLE(DEFINE_ENUM)
#undef DEFINE_ENUM
PTR,
CONSTEXPR,
};
const char *TypeToString(Type type);
@@ -41,11 +42,30 @@ template <Type T> struct type_to_ctype;
template <> struct type_to_ctype<Type::type> { using t = ctype; };
TYPE_TABLE(DEFINE_TYPE_TO_CTYPE)
#undef DEFINE_TYPE_TO_CTYPE
template <> struct type_to_ctype<Type::PTR> { using t = void *; };
// TODO: check whether CUtensorMap* is correct
template <> struct type_to_ctype<Type::CONSTEXPR> { using t = void; };
template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t;
template <typename T, typename = void> struct type_size {
static constexpr size_t value = 0;
};
template <typename T>
struct type_size<T, std::enable_if_t<!std::is_void_v<decltype(sizeof(T))>>> {
static constexpr size_t value = sizeof(T);
};
template <typename T> constexpr size_t type_size_v = type_size<T>::value;
template <size_t... Ns> struct max;
template <size_t... Ns> constexpr size_t max_v = max<Ns...>::value;
template <size_t N> struct max<N> { static constexpr size_t value = N; };
template <size_t N, size_t... Ns> struct max<N, Ns...> {
static constexpr size_t value = N > max_v<Ns...> ? N : max_v<Ns...>;
};
static constexpr size_t kMaxOpaqueSize = max_v<
#define DEFINE_TYPE_SIZE(type, str, ctype) type_size_v<ctype>,
TYPE_TABLE(DEFINE_TYPE_SIZE)
#undef DEFINE_TYPE_SIZE
0>;
// --------------- Implementations --------------- //
} // namespace triton_tvm_ffi
+2
View File
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
# fmt: on
# tvm-ffi-stubgen(end)
# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ
# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
@_FFI_REG_OBJ("triton_tvm_ffi.TypedValue")
@@ -35,6 +36,7 @@ class TypedValue(_ffi_Object):
# fmt: on
# tvm-ffi-stubgen(end)
__all__ = [
# tvm-ffi-stubgen(begin): __all__
"LIB",
+2 -3
View File
@@ -2,11 +2,12 @@ from __future__ import annotations
from typing import Any, List, Optional, Type
from triton.backends.nvidia.driver import CudaDriver
from triton.runtime import _allocation
from . import TypedValue, utils, string_to_type
class TVMLauncher(object):
def __init__(self, src: List[bool], metadata, *args, **kwargs) -> TVMLauncher:
def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher:
super().__init__(*args, **kwargs)
self.signature: List[str] = src.signature.values()
@@ -32,8 +33,6 @@ class TVMLauncher(object):
launch_exit_hook,
*args,
):
from triton.runtime import _allocation
def allocate_scratch(size, align, allocator):
if size > 0:
grid_size = gridX * gridY * gridZ
+1 -1
View File
@@ -14,7 +14,7 @@ target_include_directories(
target_compile_options(
${TARGET_NAME}
PRIVATE
$<$<CONFIG:Debug>:-O0 -g -DDEBUG>
$<$<CONFIG:Debug>:-O0 -g>
$<$<CONFIG:Release>:-O3 -DNDEBUG>
)
target_link_libraries(
+5 -4
View File
@@ -11,18 +11,19 @@ const char *CUDAException::what() const noexcept {
}
NotImplementedException::NotImplementedException(std::string_view name)
: message_("\"" + std::string(name) + "\" is not implemented") {}
: message_("[NotImplementedException]: \"" + std::string(name) + "\"") {}
const char *NotImplementedException::what() const noexcept {
return message_.c_str();
}
UnknownTypeException::UnknownTypeException(Type type)
: message_("unknown type: " + std::string(TypeToString(type))) {}
: message_("[UnknownTypeException]: unknown type: \"" +
std::string(TypeToString(type)) + "\"") {}
UnknownTypeException::UnknownTypeException(std::string_view type)
: message_("unknown type: " + std::string(type)) {}
: message_("[UnknownTypeException]: unknown type: \"" + std::string(type) +
"\"") {}
const char *UnknownTypeException::what() const noexcept {
return message_.c_str();
}
+8 -12
View File
@@ -13,31 +13,27 @@ const char *TypeToString(Type type) {
return str;
TYPE_TABLE(CASE_ENUM)
#undef CASE_ENUM
case Type::PTR:
return "*?";
case Type::CONSTEXPR:
return "constexpr";
default:
throw UnknownTypeException(type);
}
}
tvm::ffi::Optional<Type> StringToType(tvm::ffi::String name) {
#define IF_ENUM(type, str, ctype) \
if (name == str) { \
return Type::type; \
}
TYPE_TABLE(IF_ENUM)
#undef IF_ENUM
if (name.starts_with("*")) {
return Type::PTR;
}
if (name == "constexpr") {
return Type::CONSTEXPR;
}
#define IF_ENUM(type, str, ctype) \
if (name == str) { \
return Type::type; \
}
TYPE_TABLE(IF_ENUM)
#undef IF_ENUM
if (name.starts_with("tensordesc") || name == "nvTmaDesc") {
// TODO:
assert(false);
throw NotImplementedException(
"tensordesc and nvTmaDesc are not supported in triton-tvm-ffi yet.");
}
return std::nullopt;
}
+27 -13
View File
@@ -5,6 +5,9 @@
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef NDEBUG
#include <cassert>
#endif
#define CUDA_CHECK(code) \
do { \
@@ -20,11 +23,11 @@ using namespace triton_tvm_ffi;
// --------------- Definitions ---------------
template <Type T> struct ValueCast {
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
};
template <Type... Ts> struct ValueCastSet {
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
};
using GenericValueCastSet = ValueCastSet<
@@ -36,29 +39,32 @@ using GenericValueCastSet = ValueCastSet<
// --------------- Implementations ---------------
template <Type T>
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void **addr,
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void *ptr,
const TypedValue &value) {
if (value.GetType() == T) {
if constexpr (T == Type::PTR) {
tvm::ffi::TensorView cvalue =
value.GetValue().cast<tvm::ffi::TensorView>();
void **ptr = reinterpret_cast<void **>(alloca(sizeof(void *)));
*ptr = cvalue.data_ptr();
*addr = ptr;
void **p = reinterpret_cast<void **>(ptr);
*p = cvalue.data_ptr();
} else if constexpr (T == Type::CONSTEXPR) {
#ifdef NDEBUG
__builtin_unreachable();
#else
throw NotImplementedException("CONSTEXPR for value casting");
#endif
} else {
using ctype = type_to_ctype_t<T>;
ctype cvalue = value.GetValue().cast<ctype>();
ctype *ptr = reinterpret_cast<ctype *>(alloca(sizeof(ctype)));
*ptr = cvalue;
*addr = ptr;
ctype *p = reinterpret_cast<ctype *>(ptr);
*p = cvalue;
}
return true;
} else {
return false;
}
return false;
}
template <Type... Ts>
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void **ptr,
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void *ptr,
const TypedValue &value) {
return (ValueCast<Ts>::apply(ptr, value) || ...);
}
@@ -164,15 +170,23 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
}
const int32_t kernelArgNum = kernelArgs.size();
uint8_t *buffer =
reinterpret_cast<uint8_t *>(alloca(kMaxOpaqueSize * (kernelArgNum)));
void **params =
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
size_t j = 0;
for (size_t i = 0; i < kernelArgNum; ++i) {
TypedValue value = kernelArgs[i].cast<TypedValue>();
if (value.GetType() != Type::CONSTEXPR) {
if (!GenericValueCastSet::apply(&params[j++], value)) {
#ifndef NDEBUG
assert(j < kernelArgNum);
#endif
void *ptr = buffer + j * kMaxOpaqueSize;
params[j] = ptr;
if (!GenericValueCastSet::apply(ptr, value)) {
throw UnknownTypeException(value.GetType());
}
++j;
}
}
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args