mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
fix bugs on illegal memory access on Release
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
+26
-6
@@ -4,6 +4,7 @@
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/object.h>
|
||||
#include <tvm/ffi/string.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
@@ -23,14 +24,14 @@ namespace triton_tvm_ffi {
|
||||
V(FP16, "fp16", double) \
|
||||
V(BF16, "bf16", double) \
|
||||
V(FP32, "f32", double) \
|
||||
V(FP64, "fp64", double)
|
||||
V(FP64, "fp64", double) \
|
||||
V(PTR, "*?", void *) \
|
||||
V(CONSTEXPR, "constexpr", void)
|
||||
|
||||
enum class Type : int64_t {
|
||||
#define DEFINE_ENUM(type, str, ctype) type,
|
||||
TYPE_TABLE(DEFINE_ENUM)
|
||||
#undef DEFINE_ENUM
|
||||
PTR,
|
||||
CONSTEXPR,
|
||||
};
|
||||
|
||||
const char *TypeToString(Type type);
|
||||
@@ -41,11 +42,30 @@ template <Type T> struct type_to_ctype;
|
||||
template <> struct type_to_ctype<Type::type> { using t = ctype; };
|
||||
TYPE_TABLE(DEFINE_TYPE_TO_CTYPE)
|
||||
#undef DEFINE_TYPE_TO_CTYPE
|
||||
template <> struct type_to_ctype<Type::PTR> { using t = void *; };
|
||||
// TODO: check whether CUtensorMap* is correct
|
||||
template <> struct type_to_ctype<Type::CONSTEXPR> { using t = void; };
|
||||
template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t;
|
||||
|
||||
template <typename T, typename = void> struct type_size {
|
||||
static constexpr size_t value = 0;
|
||||
};
|
||||
template <typename T>
|
||||
struct type_size<T, std::enable_if_t<!std::is_void_v<decltype(sizeof(T))>>> {
|
||||
static constexpr size_t value = sizeof(T);
|
||||
};
|
||||
template <typename T> constexpr size_t type_size_v = type_size<T>::value;
|
||||
|
||||
template <size_t... Ns> struct max;
|
||||
template <size_t... Ns> constexpr size_t max_v = max<Ns...>::value;
|
||||
template <size_t N> struct max<N> { static constexpr size_t value = N; };
|
||||
template <size_t N, size_t... Ns> struct max<N, Ns...> {
|
||||
static constexpr size_t value = N > max_v<Ns...> ? N : max_v<Ns...>;
|
||||
};
|
||||
|
||||
static constexpr size_t kMaxOpaqueSize = max_v<
|
||||
#define DEFINE_TYPE_SIZE(type, str, ctype) type_size_v<ctype>,
|
||||
TYPE_TABLE(DEFINE_TYPE_SIZE)
|
||||
#undef DEFINE_TYPE_SIZE
|
||||
0>;
|
||||
|
||||
// --------------- Implementations --------------- //
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
|
||||
|
||||
# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ
|
||||
# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
|
||||
@_FFI_REG_OBJ("triton_tvm_ffi.TypedValue")
|
||||
@@ -35,6 +36,7 @@ class TypedValue(_ffi_Object):
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# tvm-ffi-stubgen(begin): __all__
|
||||
"LIB",
|
||||
|
||||
@@ -2,11 +2,12 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Type
|
||||
from triton.backends.nvidia.driver import CudaDriver
|
||||
from triton.runtime import _allocation
|
||||
from . import TypedValue, utils, string_to_type
|
||||
|
||||
|
||||
class TVMLauncher(object):
|
||||
def __init__(self, src: List[bool], metadata, *args, **kwargs) -> TVMLauncher:
|
||||
def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.signature: List[str] = src.signature.values()
|
||||
@@ -32,8 +33,6 @@ class TVMLauncher(object):
|
||||
launch_exit_hook,
|
||||
*args,
|
||||
):
|
||||
from triton.runtime import _allocation
|
||||
|
||||
def allocate_scratch(size, align, allocator):
|
||||
if size > 0:
|
||||
grid_size = gridX * gridY * gridZ
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@ target_include_directories(
|
||||
target_compile_options(
|
||||
${TARGET_NAME}
|
||||
PRIVATE
|
||||
$<$<CONFIG:Debug>:-O0 -g -DDEBUG>
|
||||
$<$<CONFIG:Debug>:-O0 -g>
|
||||
$<$<CONFIG:Release>:-O3 -DNDEBUG>
|
||||
)
|
||||
target_link_libraries(
|
||||
|
||||
+5
-4
@@ -11,18 +11,19 @@ const char *CUDAException::what() const noexcept {
|
||||
}
|
||||
|
||||
NotImplementedException::NotImplementedException(std::string_view name)
|
||||
: message_("\"" + std::string(name) + "\" is not implemented") {}
|
||||
: message_("[NotImplementedException]: \"" + std::string(name) + "\"") {}
|
||||
|
||||
const char *NotImplementedException::what() const noexcept {
|
||||
return message_.c_str();
|
||||
}
|
||||
|
||||
UnknownTypeException::UnknownTypeException(Type type)
|
||||
: message_("unknown type: " + std::string(TypeToString(type))) {}
|
||||
: message_("[UnknownTypeException]: unknown type: \"" +
|
||||
std::string(TypeToString(type)) + "\"") {}
|
||||
|
||||
UnknownTypeException::UnknownTypeException(std::string_view type)
|
||||
: message_("unknown type: " + std::string(type)) {}
|
||||
|
||||
: message_("[UnknownTypeException]: unknown type: \"" + std::string(type) +
|
||||
"\"") {}
|
||||
const char *UnknownTypeException::what() const noexcept {
|
||||
return message_.c_str();
|
||||
}
|
||||
|
||||
+8
-12
@@ -13,31 +13,27 @@ const char *TypeToString(Type type) {
|
||||
return str;
|
||||
TYPE_TABLE(CASE_ENUM)
|
||||
#undef CASE_ENUM
|
||||
case Type::PTR:
|
||||
return "*?";
|
||||
case Type::CONSTEXPR:
|
||||
return "constexpr";
|
||||
default:
|
||||
throw UnknownTypeException(type);
|
||||
}
|
||||
}
|
||||
|
||||
tvm::ffi::Optional<Type> StringToType(tvm::ffi::String name) {
|
||||
#define IF_ENUM(type, str, ctype) \
|
||||
if (name == str) { \
|
||||
return Type::type; \
|
||||
}
|
||||
TYPE_TABLE(IF_ENUM)
|
||||
#undef IF_ENUM
|
||||
if (name.starts_with("*")) {
|
||||
return Type::PTR;
|
||||
}
|
||||
if (name == "constexpr") {
|
||||
return Type::CONSTEXPR;
|
||||
}
|
||||
#define IF_ENUM(type, str, ctype) \
|
||||
if (name == str) { \
|
||||
return Type::type; \
|
||||
}
|
||||
TYPE_TABLE(IF_ENUM)
|
||||
#undef IF_ENUM
|
||||
if (name.starts_with("tensordesc") || name == "nvTmaDesc") {
|
||||
// TODO:
|
||||
assert(false);
|
||||
throw NotImplementedException(
|
||||
"tensordesc and nvTmaDesc are not supported in triton-tvm-ffi yet.");
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
+27
-13
@@ -5,6 +5,9 @@
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/reflection/registry.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
#ifndef NDEBUG
|
||||
#include <cassert>
|
||||
#endif
|
||||
|
||||
#define CUDA_CHECK(code) \
|
||||
do { \
|
||||
@@ -20,11 +23,11 @@ using namespace triton_tvm_ffi;
|
||||
// --------------- Definitions ---------------
|
||||
|
||||
template <Type T> struct ValueCast {
|
||||
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
|
||||
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
|
||||
};
|
||||
|
||||
template <Type... Ts> struct ValueCastSet {
|
||||
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
|
||||
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
|
||||
};
|
||||
|
||||
using GenericValueCastSet = ValueCastSet<
|
||||
@@ -36,29 +39,32 @@ using GenericValueCastSet = ValueCastSet<
|
||||
// --------------- Implementations ---------------
|
||||
|
||||
template <Type T>
|
||||
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void **addr,
|
||||
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void *ptr,
|
||||
const TypedValue &value) {
|
||||
if (value.GetType() == T) {
|
||||
if constexpr (T == Type::PTR) {
|
||||
tvm::ffi::TensorView cvalue =
|
||||
value.GetValue().cast<tvm::ffi::TensorView>();
|
||||
void **ptr = reinterpret_cast<void **>(alloca(sizeof(void *)));
|
||||
*ptr = cvalue.data_ptr();
|
||||
*addr = ptr;
|
||||
void **p = reinterpret_cast<void **>(ptr);
|
||||
*p = cvalue.data_ptr();
|
||||
} else if constexpr (T == Type::CONSTEXPR) {
|
||||
#ifdef NDEBUG
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw NotImplementedException("CONSTEXPR for value casting");
|
||||
#endif
|
||||
} else {
|
||||
using ctype = type_to_ctype_t<T>;
|
||||
ctype cvalue = value.GetValue().cast<ctype>();
|
||||
ctype *ptr = reinterpret_cast<ctype *>(alloca(sizeof(ctype)));
|
||||
*ptr = cvalue;
|
||||
*addr = ptr;
|
||||
ctype *p = reinterpret_cast<ctype *>(ptr);
|
||||
*p = cvalue;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
template <Type... Ts>
|
||||
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void **ptr,
|
||||
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void *ptr,
|
||||
const TypedValue &value) {
|
||||
return (ValueCast<Ts>::apply(ptr, value) || ...);
|
||||
}
|
||||
@@ -164,15 +170,23 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
||||
}
|
||||
const int32_t kernelArgNum = kernelArgs.size();
|
||||
uint8_t *buffer =
|
||||
reinterpret_cast<uint8_t *>(alloca(kMaxOpaqueSize * (kernelArgNum)));
|
||||
void **params =
|
||||
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < kernelArgNum; ++i) {
|
||||
TypedValue value = kernelArgs[i].cast<TypedValue>();
|
||||
if (value.GetType() != Type::CONSTEXPR) {
|
||||
if (!GenericValueCastSet::apply(¶ms[j++], value)) {
|
||||
#ifndef NDEBUG
|
||||
assert(j < kernelArgNum);
|
||||
#endif
|
||||
void *ptr = buffer + j * kMaxOpaqueSize;
|
||||
params[j] = ptr;
|
||||
if (!GenericValueCastSet::apply(ptr, value)) {
|
||||
throw UnknownTypeException(value.GetType());
|
||||
}
|
||||
++j;
|
||||
}
|
||||
}
|
||||
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
|
||||
|
||||
Reference in New Issue
Block a user