fix bugs on illegal memory access on Release

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-30 00:15:47 +08:00
parent 1a01c9f2d8
commit bdc9c03b75
7 changed files with 71 additions and 39 deletions
+26 -6
View File
@@ -4,6 +4,7 @@
#include <cstdint> #include <cstdint>
#include <tvm/ffi/object.h> #include <tvm/ffi/object.h>
#include <tvm/ffi/string.h> #include <tvm/ffi/string.h>
#include <type_traits>
namespace triton_tvm_ffi { namespace triton_tvm_ffi {
@@ -23,14 +24,14 @@ namespace triton_tvm_ffi {
V(FP16, "fp16", double) \ V(FP16, "fp16", double) \
V(BF16, "bf16", double) \ V(BF16, "bf16", double) \
V(FP32, "f32", double) \ V(FP32, "f32", double) \
V(FP64, "fp64", double) V(FP64, "fp64", double) \
V(PTR, "*?", void *) \
V(CONSTEXPR, "constexpr", void)
enum class Type : int64_t { enum class Type : int64_t {
#define DEFINE_ENUM(type, str, ctype) type, #define DEFINE_ENUM(type, str, ctype) type,
TYPE_TABLE(DEFINE_ENUM) TYPE_TABLE(DEFINE_ENUM)
#undef DEFINE_ENUM #undef DEFINE_ENUM
PTR,
CONSTEXPR,
}; };
const char *TypeToString(Type type); const char *TypeToString(Type type);
@@ -41,11 +42,30 @@ template <Type T> struct type_to_ctype;
template <> struct type_to_ctype<Type::type> { using t = ctype; }; template <> struct type_to_ctype<Type::type> { using t = ctype; };
TYPE_TABLE(DEFINE_TYPE_TO_CTYPE) TYPE_TABLE(DEFINE_TYPE_TO_CTYPE)
#undef DEFINE_TYPE_TO_CTYPE #undef DEFINE_TYPE_TO_CTYPE
template <> struct type_to_ctype<Type::PTR> { using t = void *; };
// TODO: check whether CUtensorMap* is correct
template <> struct type_to_ctype<Type::CONSTEXPR> { using t = void; };
template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t; template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t;
template <typename T, typename = void> struct type_size {
static constexpr size_t value = 0;
};
template <typename T>
struct type_size<T, std::enable_if_t<!std::is_void_v<decltype(sizeof(T))>>> {
static constexpr size_t value = sizeof(T);
};
template <typename T> constexpr size_t type_size_v = type_size<T>::value;
template <size_t... Ns> struct max;
template <size_t... Ns> constexpr size_t max_v = max<Ns...>::value;
template <size_t N> struct max<N> { static constexpr size_t value = N; };
template <size_t N, size_t... Ns> struct max<N, Ns...> {
static constexpr size_t value = N > max_v<Ns...> ? N : max_v<Ns...>;
};
static constexpr size_t kMaxOpaqueSize = max_v<
#define DEFINE_TYPE_SIZE(type, str, ctype) type_size_v<ctype>,
TYPE_TABLE(DEFINE_TYPE_SIZE)
#undef DEFINE_TYPE_SIZE
0>;
// --------------- Implementations --------------- // // --------------- Implementations --------------- //
} // namespace triton_tvm_ffi } // namespace triton_tvm_ffi
+2
View File
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
# fmt: on # fmt: on
# tvm-ffi-stubgen(end) # tvm-ffi-stubgen(end)
# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ # tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ
# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object # tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
@_FFI_REG_OBJ("triton_tvm_ffi.TypedValue") @_FFI_REG_OBJ("triton_tvm_ffi.TypedValue")
@@ -35,6 +36,7 @@ class TypedValue(_ffi_Object):
# fmt: on # fmt: on
# tvm-ffi-stubgen(end) # tvm-ffi-stubgen(end)
__all__ = [ __all__ = [
# tvm-ffi-stubgen(begin): __all__ # tvm-ffi-stubgen(begin): __all__
"LIB", "LIB",
+2 -3
View File
@@ -2,11 +2,12 @@ from __future__ import annotations
from typing import Any, List, Optional, Type from typing import Any, List, Optional, Type
from triton.backends.nvidia.driver import CudaDriver from triton.backends.nvidia.driver import CudaDriver
from triton.runtime import _allocation
from . import TypedValue, utils, string_to_type from . import TypedValue, utils, string_to_type
class TVMLauncher(object): class TVMLauncher(object):
def __init__(self, src: List[bool], metadata, *args, **kwargs) -> TVMLauncher: def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.signature: List[str] = src.signature.values() self.signature: List[str] = src.signature.values()
@@ -32,8 +33,6 @@ class TVMLauncher(object):
launch_exit_hook, launch_exit_hook,
*args, *args,
): ):
from triton.runtime import _allocation
def allocate_scratch(size, align, allocator): def allocate_scratch(size, align, allocator):
if size > 0: if size > 0:
grid_size = gridX * gridY * gridZ grid_size = gridX * gridY * gridZ
+1 -1
View File
@@ -14,7 +14,7 @@ target_include_directories(
target_compile_options( target_compile_options(
${TARGET_NAME} ${TARGET_NAME}
PRIVATE PRIVATE
$<$<CONFIG:Debug>:-O0 -g -DDEBUG> $<$<CONFIG:Debug>:-O0 -g>
$<$<CONFIG:Release>:-O3 -DNDEBUG> $<$<CONFIG:Release>:-O3 -DNDEBUG>
) )
target_link_libraries( target_link_libraries(
+5 -4
View File
@@ -11,18 +11,19 @@ const char *CUDAException::what() const noexcept {
} }
NotImplementedException::NotImplementedException(std::string_view name) NotImplementedException::NotImplementedException(std::string_view name)
: message_("\"" + std::string(name) + "\" is not implemented") {} : message_("[NotImplementedException]: \"" + std::string(name) + "\"") {}
const char *NotImplementedException::what() const noexcept { const char *NotImplementedException::what() const noexcept {
return message_.c_str(); return message_.c_str();
} }
UnknownTypeException::UnknownTypeException(Type type) UnknownTypeException::UnknownTypeException(Type type)
: message_("unknown type: " + std::string(TypeToString(type))) {} : message_("[UnknownTypeException]: unknown type: \"" +
std::string(TypeToString(type)) + "\"") {}
UnknownTypeException::UnknownTypeException(std::string_view type) UnknownTypeException::UnknownTypeException(std::string_view type)
: message_("unknown type: " + std::string(type)) {} : message_("[UnknownTypeException]: unknown type: \"" + std::string(type) +
"\"") {}
const char *UnknownTypeException::what() const noexcept { const char *UnknownTypeException::what() const noexcept {
return message_.c_str(); return message_.c_str();
} }
+8 -12
View File
@@ -13,31 +13,27 @@ const char *TypeToString(Type type) {
return str; return str;
TYPE_TABLE(CASE_ENUM) TYPE_TABLE(CASE_ENUM)
#undef CASE_ENUM #undef CASE_ENUM
case Type::PTR:
return "*?";
case Type::CONSTEXPR:
return "constexpr";
default: default:
throw UnknownTypeException(type); throw UnknownTypeException(type);
} }
} }
tvm::ffi::Optional<Type> StringToType(tvm::ffi::String name) { tvm::ffi::Optional<Type> StringToType(tvm::ffi::String name) {
#define IF_ENUM(type, str, ctype) \
if (name == str) { \
return Type::type; \
}
TYPE_TABLE(IF_ENUM)
#undef IF_ENUM
if (name.starts_with("*")) { if (name.starts_with("*")) {
return Type::PTR; return Type::PTR;
} }
if (name == "constexpr") { if (name == "constexpr") {
return Type::CONSTEXPR; return Type::CONSTEXPR;
} }
#define IF_ENUM(type, str, ctype) \
if (name == str) { \
return Type::type; \
}
TYPE_TABLE(IF_ENUM)
#undef IF_ENUM
if (name.starts_with("tensordesc") || name == "nvTmaDesc") { if (name.starts_with("tensordesc") || name == "nvTmaDesc") {
// TODO: throw NotImplementedException(
assert(false); "tensordesc and nvTmaDesc are not supported in triton-tvm-ffi yet.");
} }
return std::nullopt; return std::nullopt;
} }
+27 -13
View File
@@ -5,6 +5,9 @@
#include <tvm/ffi/extra/cuda/cubin_launcher.h> #include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/tvm_ffi.h> #include <tvm/ffi/tvm_ffi.h>
#ifndef NDEBUG
#include <cassert>
#endif
#define CUDA_CHECK(code) \ #define CUDA_CHECK(code) \
do { \ do { \
@@ -20,11 +23,11 @@ using namespace triton_tvm_ffi;
// --------------- Definitions --------------- // --------------- Definitions ---------------
template <Type T> struct ValueCast { template <Type T> struct ValueCast {
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value); TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
}; };
template <Type... Ts> struct ValueCastSet { template <Type... Ts> struct ValueCastSet {
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value); TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
}; };
using GenericValueCastSet = ValueCastSet< using GenericValueCastSet = ValueCastSet<
@@ -36,29 +39,32 @@ using GenericValueCastSet = ValueCastSet<
// --------------- Implementations --------------- // --------------- Implementations ---------------
template <Type T> template <Type T>
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void **addr, TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void *ptr,
const TypedValue &value) { const TypedValue &value) {
if (value.GetType() == T) { if (value.GetType() == T) {
if constexpr (T == Type::PTR) { if constexpr (T == Type::PTR) {
tvm::ffi::TensorView cvalue = tvm::ffi::TensorView cvalue =
value.GetValue().cast<tvm::ffi::TensorView>(); value.GetValue().cast<tvm::ffi::TensorView>();
void **ptr = reinterpret_cast<void **>(alloca(sizeof(void *))); void **p = reinterpret_cast<void **>(ptr);
*ptr = cvalue.data_ptr(); *p = cvalue.data_ptr();
*addr = ptr; } else if constexpr (T == Type::CONSTEXPR) {
#ifdef NDEBUG
__builtin_unreachable();
#else
throw NotImplementedException("CONSTEXPR for value casting");
#endif
} else { } else {
using ctype = type_to_ctype_t<T>; using ctype = type_to_ctype_t<T>;
ctype cvalue = value.GetValue().cast<ctype>(); ctype cvalue = value.GetValue().cast<ctype>();
ctype *ptr = reinterpret_cast<ctype *>(alloca(sizeof(ctype))); ctype *p = reinterpret_cast<ctype *>(ptr);
*ptr = cvalue; *p = cvalue;
*addr = ptr;
} }
return true; return true;
} else { }
return false; return false;
} }
}
template <Type... Ts> template <Type... Ts>
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void **ptr, TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void *ptr,
const TypedValue &value) { const TypedValue &value) {
return (ValueCast<Ts>::apply(ptr, value) || ...); return (ValueCast<Ts>::apply(ptr, value) || ...);
} }
@@ -164,15 +170,23 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
} }
const int32_t kernelArgNum = kernelArgs.size(); const int32_t kernelArgNum = kernelArgs.size();
uint8_t *buffer =
reinterpret_cast<uint8_t *>(alloca(kMaxOpaqueSize * (kernelArgNum)));
void **params = void **params =
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2))); reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
size_t j = 0; size_t j = 0;
for (size_t i = 0; i < kernelArgNum; ++i) { for (size_t i = 0; i < kernelArgNum; ++i) {
TypedValue value = kernelArgs[i].cast<TypedValue>(); TypedValue value = kernelArgs[i].cast<TypedValue>();
if (value.GetType() != Type::CONSTEXPR) { if (value.GetType() != Type::CONSTEXPR) {
if (!GenericValueCastSet::apply(&params[j++], value)) { #ifndef NDEBUG
assert(j < kernelArgNum);
#endif
void *ptr = buffer + j * kMaxOpaqueSize;
params[j] = ptr;
if (!GenericValueCastSet::apply(ptr, value)) {
throw UnknownTypeException(value.GetType()); throw UnknownTypeException(value.GetType());
} }
++j;
} }
} }
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args // TODO: unwrap PyObject* from scratch pointers and assign to kernel args