mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
fix bugs on illegal memory access on Release
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
+26
-6
@@ -4,6 +4,7 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <tvm/ffi/object.h>
|
#include <tvm/ffi/object.h>
|
||||||
#include <tvm/ffi/string.h>
|
#include <tvm/ffi/string.h>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
namespace triton_tvm_ffi {
|
namespace triton_tvm_ffi {
|
||||||
|
|
||||||
@@ -23,14 +24,14 @@ namespace triton_tvm_ffi {
|
|||||||
V(FP16, "fp16", double) \
|
V(FP16, "fp16", double) \
|
||||||
V(BF16, "bf16", double) \
|
V(BF16, "bf16", double) \
|
||||||
V(FP32, "f32", double) \
|
V(FP32, "f32", double) \
|
||||||
V(FP64, "fp64", double)
|
V(FP64, "fp64", double) \
|
||||||
|
V(PTR, "*?", void *) \
|
||||||
|
V(CONSTEXPR, "constexpr", void)
|
||||||
|
|
||||||
enum class Type : int64_t {
|
enum class Type : int64_t {
|
||||||
#define DEFINE_ENUM(type, str, ctype) type,
|
#define DEFINE_ENUM(type, str, ctype) type,
|
||||||
TYPE_TABLE(DEFINE_ENUM)
|
TYPE_TABLE(DEFINE_ENUM)
|
||||||
#undef DEFINE_ENUM
|
#undef DEFINE_ENUM
|
||||||
PTR,
|
|
||||||
CONSTEXPR,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const char *TypeToString(Type type);
|
const char *TypeToString(Type type);
|
||||||
@@ -41,11 +42,30 @@ template <Type T> struct type_to_ctype;
|
|||||||
template <> struct type_to_ctype<Type::type> { using t = ctype; };
|
template <> struct type_to_ctype<Type::type> { using t = ctype; };
|
||||||
TYPE_TABLE(DEFINE_TYPE_TO_CTYPE)
|
TYPE_TABLE(DEFINE_TYPE_TO_CTYPE)
|
||||||
#undef DEFINE_TYPE_TO_CTYPE
|
#undef DEFINE_TYPE_TO_CTYPE
|
||||||
template <> struct type_to_ctype<Type::PTR> { using t = void *; };
|
|
||||||
// TODO: check whether CUtensorMap* is correct
|
|
||||||
template <> struct type_to_ctype<Type::CONSTEXPR> { using t = void; };
|
|
||||||
template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t;
|
template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t;
|
||||||
|
|
||||||
|
template <typename T, typename = void> struct type_size {
|
||||||
|
static constexpr size_t value = 0;
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
struct type_size<T, std::enable_if_t<!std::is_void_v<decltype(sizeof(T))>>> {
|
||||||
|
static constexpr size_t value = sizeof(T);
|
||||||
|
};
|
||||||
|
template <typename T> constexpr size_t type_size_v = type_size<T>::value;
|
||||||
|
|
||||||
|
template <size_t... Ns> struct max;
|
||||||
|
template <size_t... Ns> constexpr size_t max_v = max<Ns...>::value;
|
||||||
|
template <size_t N> struct max<N> { static constexpr size_t value = N; };
|
||||||
|
template <size_t N, size_t... Ns> struct max<N, Ns...> {
|
||||||
|
static constexpr size_t value = N > max_v<Ns...> ? N : max_v<Ns...>;
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr size_t kMaxOpaqueSize = max_v<
|
||||||
|
#define DEFINE_TYPE_SIZE(type, str, ctype) type_size_v<ctype>,
|
||||||
|
TYPE_TABLE(DEFINE_TYPE_SIZE)
|
||||||
|
#undef DEFINE_TYPE_SIZE
|
||||||
|
0>;
|
||||||
|
|
||||||
// --------------- Implementations --------------- //
|
// --------------- Implementations --------------- //
|
||||||
|
|
||||||
} // namespace triton_tvm_ffi
|
} // namespace triton_tvm_ffi
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
# tvm-ffi-stubgen(end)
|
# tvm-ffi-stubgen(end)
|
||||||
|
|
||||||
|
|
||||||
# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ
|
# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ
|
||||||
# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
|
# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
|
||||||
@_FFI_REG_OBJ("triton_tvm_ffi.TypedValue")
|
@_FFI_REG_OBJ("triton_tvm_ffi.TypedValue")
|
||||||
@@ -35,6 +36,7 @@ class TypedValue(_ffi_Object):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
# tvm-ffi-stubgen(end)
|
# tvm-ffi-stubgen(end)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# tvm-ffi-stubgen(begin): __all__
|
# tvm-ffi-stubgen(begin): __all__
|
||||||
"LIB",
|
"LIB",
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, List, Optional, Type
|
from typing import Any, List, Optional, Type
|
||||||
from triton.backends.nvidia.driver import CudaDriver
|
from triton.backends.nvidia.driver import CudaDriver
|
||||||
|
from triton.runtime import _allocation
|
||||||
from . import TypedValue, utils, string_to_type
|
from . import TypedValue, utils, string_to_type
|
||||||
|
|
||||||
|
|
||||||
class TVMLauncher(object):
|
class TVMLauncher(object):
|
||||||
def __init__(self, src: List[bool], metadata, *args, **kwargs) -> TVMLauncher:
|
def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.signature: List[str] = src.signature.values()
|
self.signature: List[str] = src.signature.values()
|
||||||
@@ -32,8 +33,6 @@ class TVMLauncher(object):
|
|||||||
launch_exit_hook,
|
launch_exit_hook,
|
||||||
*args,
|
*args,
|
||||||
):
|
):
|
||||||
from triton.runtime import _allocation
|
|
||||||
|
|
||||||
def allocate_scratch(size, align, allocator):
|
def allocate_scratch(size, align, allocator):
|
||||||
if size > 0:
|
if size > 0:
|
||||||
grid_size = gridX * gridY * gridZ
|
grid_size = gridX * gridY * gridZ
|
||||||
|
|||||||
+1
-1
@@ -14,7 +14,7 @@ target_include_directories(
|
|||||||
target_compile_options(
|
target_compile_options(
|
||||||
${TARGET_NAME}
|
${TARGET_NAME}
|
||||||
PRIVATE
|
PRIVATE
|
||||||
$<$<CONFIG:Debug>:-O0 -g -DDEBUG>
|
$<$<CONFIG:Debug>:-O0 -g>
|
||||||
$<$<CONFIG:Release>:-O3 -DNDEBUG>
|
$<$<CONFIG:Release>:-O3 -DNDEBUG>
|
||||||
)
|
)
|
||||||
target_link_libraries(
|
target_link_libraries(
|
||||||
|
|||||||
+5
-4
@@ -11,18 +11,19 @@ const char *CUDAException::what() const noexcept {
|
|||||||
}
|
}
|
||||||
|
|
||||||
NotImplementedException::NotImplementedException(std::string_view name)
|
NotImplementedException::NotImplementedException(std::string_view name)
|
||||||
: message_("\"" + std::string(name) + "\" is not implemented") {}
|
: message_("[NotImplementedException]: \"" + std::string(name) + "\"") {}
|
||||||
|
|
||||||
const char *NotImplementedException::what() const noexcept {
|
const char *NotImplementedException::what() const noexcept {
|
||||||
return message_.c_str();
|
return message_.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
UnknownTypeException::UnknownTypeException(Type type)
|
UnknownTypeException::UnknownTypeException(Type type)
|
||||||
: message_("unknown type: " + std::string(TypeToString(type))) {}
|
: message_("[UnknownTypeException]: unknown type: \"" +
|
||||||
|
std::string(TypeToString(type)) + "\"") {}
|
||||||
|
|
||||||
UnknownTypeException::UnknownTypeException(std::string_view type)
|
UnknownTypeException::UnknownTypeException(std::string_view type)
|
||||||
: message_("unknown type: " + std::string(type)) {}
|
: message_("[UnknownTypeException]: unknown type: \"" + std::string(type) +
|
||||||
|
"\"") {}
|
||||||
const char *UnknownTypeException::what() const noexcept {
|
const char *UnknownTypeException::what() const noexcept {
|
||||||
return message_.c_str();
|
return message_.c_str();
|
||||||
}
|
}
|
||||||
|
|||||||
+8
-12
@@ -13,31 +13,27 @@ const char *TypeToString(Type type) {
|
|||||||
return str;
|
return str;
|
||||||
TYPE_TABLE(CASE_ENUM)
|
TYPE_TABLE(CASE_ENUM)
|
||||||
#undef CASE_ENUM
|
#undef CASE_ENUM
|
||||||
case Type::PTR:
|
|
||||||
return "*?";
|
|
||||||
case Type::CONSTEXPR:
|
|
||||||
return "constexpr";
|
|
||||||
default:
|
default:
|
||||||
throw UnknownTypeException(type);
|
throw UnknownTypeException(type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tvm::ffi::Optional<Type> StringToType(tvm::ffi::String name) {
|
tvm::ffi::Optional<Type> StringToType(tvm::ffi::String name) {
|
||||||
#define IF_ENUM(type, str, ctype) \
|
|
||||||
if (name == str) { \
|
|
||||||
return Type::type; \
|
|
||||||
}
|
|
||||||
TYPE_TABLE(IF_ENUM)
|
|
||||||
#undef IF_ENUM
|
|
||||||
if (name.starts_with("*")) {
|
if (name.starts_with("*")) {
|
||||||
return Type::PTR;
|
return Type::PTR;
|
||||||
}
|
}
|
||||||
if (name == "constexpr") {
|
if (name == "constexpr") {
|
||||||
return Type::CONSTEXPR;
|
return Type::CONSTEXPR;
|
||||||
}
|
}
|
||||||
|
#define IF_ENUM(type, str, ctype) \
|
||||||
|
if (name == str) { \
|
||||||
|
return Type::type; \
|
||||||
|
}
|
||||||
|
TYPE_TABLE(IF_ENUM)
|
||||||
|
#undef IF_ENUM
|
||||||
if (name.starts_with("tensordesc") || name == "nvTmaDesc") {
|
if (name.starts_with("tensordesc") || name == "nvTmaDesc") {
|
||||||
// TODO:
|
throw NotImplementedException(
|
||||||
assert(false);
|
"tensordesc and nvTmaDesc are not supported in triton-tvm-ffi yet.");
|
||||||
}
|
}
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|||||||
+27
-13
@@ -5,6 +5,9 @@
|
|||||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
#include <tvm/ffi/reflection/registry.h>
|
#include <tvm/ffi/reflection/registry.h>
|
||||||
#include <tvm/ffi/tvm_ffi.h>
|
#include <tvm/ffi/tvm_ffi.h>
|
||||||
|
#ifndef NDEBUG
|
||||||
|
#include <cassert>
|
||||||
|
#endif
|
||||||
|
|
||||||
#define CUDA_CHECK(code) \
|
#define CUDA_CHECK(code) \
|
||||||
do { \
|
do { \
|
||||||
@@ -20,11 +23,11 @@ using namespace triton_tvm_ffi;
|
|||||||
// --------------- Definitions ---------------
|
// --------------- Definitions ---------------
|
||||||
|
|
||||||
template <Type T> struct ValueCast {
|
template <Type T> struct ValueCast {
|
||||||
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
|
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <Type... Ts> struct ValueCastSet {
|
template <Type... Ts> struct ValueCastSet {
|
||||||
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
|
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
|
||||||
};
|
};
|
||||||
|
|
||||||
using GenericValueCastSet = ValueCastSet<
|
using GenericValueCastSet = ValueCastSet<
|
||||||
@@ -36,29 +39,32 @@ using GenericValueCastSet = ValueCastSet<
|
|||||||
// --------------- Implementations ---------------
|
// --------------- Implementations ---------------
|
||||||
|
|
||||||
template <Type T>
|
template <Type T>
|
||||||
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void **addr,
|
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void *ptr,
|
||||||
const TypedValue &value) {
|
const TypedValue &value) {
|
||||||
if (value.GetType() == T) {
|
if (value.GetType() == T) {
|
||||||
if constexpr (T == Type::PTR) {
|
if constexpr (T == Type::PTR) {
|
||||||
tvm::ffi::TensorView cvalue =
|
tvm::ffi::TensorView cvalue =
|
||||||
value.GetValue().cast<tvm::ffi::TensorView>();
|
value.GetValue().cast<tvm::ffi::TensorView>();
|
||||||
void **ptr = reinterpret_cast<void **>(alloca(sizeof(void *)));
|
void **p = reinterpret_cast<void **>(ptr);
|
||||||
*ptr = cvalue.data_ptr();
|
*p = cvalue.data_ptr();
|
||||||
*addr = ptr;
|
} else if constexpr (T == Type::CONSTEXPR) {
|
||||||
|
#ifdef NDEBUG
|
||||||
|
__builtin_unreachable();
|
||||||
|
#else
|
||||||
|
throw NotImplementedException("CONSTEXPR for value casting");
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
using ctype = type_to_ctype_t<T>;
|
using ctype = type_to_ctype_t<T>;
|
||||||
ctype cvalue = value.GetValue().cast<ctype>();
|
ctype cvalue = value.GetValue().cast<ctype>();
|
||||||
ctype *ptr = reinterpret_cast<ctype *>(alloca(sizeof(ctype)));
|
ctype *p = reinterpret_cast<ctype *>(ptr);
|
||||||
*ptr = cvalue;
|
*p = cvalue;
|
||||||
*addr = ptr;
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} else {
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
template <Type... Ts>
|
template <Type... Ts>
|
||||||
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void **ptr,
|
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void *ptr,
|
||||||
const TypedValue &value) {
|
const TypedValue &value) {
|
||||||
return (ValueCast<Ts>::apply(ptr, value) || ...);
|
return (ValueCast<Ts>::apply(ptr, value) || ...);
|
||||||
}
|
}
|
||||||
@@ -164,15 +170,23 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
|||||||
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
||||||
}
|
}
|
||||||
const int32_t kernelArgNum = kernelArgs.size();
|
const int32_t kernelArgNum = kernelArgs.size();
|
||||||
|
uint8_t *buffer =
|
||||||
|
reinterpret_cast<uint8_t *>(alloca(kMaxOpaqueSize * (kernelArgNum)));
|
||||||
void **params =
|
void **params =
|
||||||
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
|
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
|
||||||
size_t j = 0;
|
size_t j = 0;
|
||||||
for (size_t i = 0; i < kernelArgNum; ++i) {
|
for (size_t i = 0; i < kernelArgNum; ++i) {
|
||||||
TypedValue value = kernelArgs[i].cast<TypedValue>();
|
TypedValue value = kernelArgs[i].cast<TypedValue>();
|
||||||
if (value.GetType() != Type::CONSTEXPR) {
|
if (value.GetType() != Type::CONSTEXPR) {
|
||||||
if (!GenericValueCastSet::apply(¶ms[j++], value)) {
|
#ifndef NDEBUG
|
||||||
|
assert(j < kernelArgNum);
|
||||||
|
#endif
|
||||||
|
void *ptr = buffer + j * kMaxOpaqueSize;
|
||||||
|
params[j] = ptr;
|
||||||
|
if (!GenericValueCastSet::apply(ptr, value)) {
|
||||||
throw UnknownTypeException(value.GetType());
|
throw UnknownTypeException(value.GetType());
|
||||||
}
|
}
|
||||||
|
++j;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
|
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
|
||||||
|
|||||||
Reference in New Issue
Block a user