diff --git a/include/type.h b/include/type.h index 1e5c929..176712b 100644 --- a/include/type.h +++ b/include/type.h @@ -4,6 +4,7 @@ #include #include #include +#include namespace triton_tvm_ffi { @@ -23,14 +24,14 @@ namespace triton_tvm_ffi { V(FP16, "fp16", double) \ V(BF16, "bf16", double) \ V(FP32, "f32", double) \ - V(FP64, "fp64", double) + V(FP64, "fp64", double) \ + V(PTR, "*?", void *) \ + V(CONSTEXPR, "constexpr", void) enum class Type : int64_t { #define DEFINE_ENUM(type, str, ctype) type, TYPE_TABLE(DEFINE_ENUM) #undef DEFINE_ENUM - PTR, - CONSTEXPR, }; const char *TypeToString(Type type); @@ -41,11 +42,30 @@ template struct type_to_ctype; template <> struct type_to_ctype { using t = ctype; }; TYPE_TABLE(DEFINE_TYPE_TO_CTYPE) #undef DEFINE_TYPE_TO_CTYPE -template <> struct type_to_ctype { using t = void *; }; -// TODO: check whether CUtensorMap* is correct -template <> struct type_to_ctype { using t = void; }; template using type_to_ctype_t = typename type_to_ctype::t; +template struct type_size { + static constexpr size_t value = 0; +}; +template +struct type_size>> { + static constexpr size_t value = sizeof(T); +}; +template constexpr size_t type_size_v = type_size::value; + +template struct max; +template constexpr size_t max_v = max::value; +template struct max { static constexpr size_t value = N; }; +template struct max { + static constexpr size_t value = N > max_v ? N : max_v; +}; + +static constexpr size_t kMaxOpaqueSize = max_v< +#define DEFINE_TYPE_SIZE(type, str, ctype) type_size_v, + TYPE_TABLE(DEFINE_TYPE_SIZE) +#undef DEFINE_TYPE_SIZE + 0>; + // --------------- Implementations --------------- // } // namespace triton_tvm_ffi diff --git a/python/triton_tvm_ffi/_ffi_api.py b/python/triton_tvm_ffi/_ffi_api.py index 794fc89..9e1b326 100644 --- a/python/triton_tvm_ffi/_ffi_api.py +++ b/python/triton_tvm_ffi/_ffi_api.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: # fmt: on # tvm-ffi-stubgen(end) + # tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ # tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object @_FFI_REG_OBJ("triton_tvm_ffi.TypedValue") @@ -35,6 +36,7 @@ class TypedValue(_ffi_Object): # fmt: on # tvm-ffi-stubgen(end) + __all__ = [ # tvm-ffi-stubgen(begin): __all__ "LIB", diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index 22c29ff..b6dcea0 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -2,11 +2,12 @@ from __future__ import annotations from typing import Any, List, Optional, Type from triton.backends.nvidia.driver import CudaDriver +from triton.runtime import _allocation from . import TypedValue, utils, string_to_type class TVMLauncher(object): - def __init__(self, src: List[bool], metadata, *args, **kwargs) -> TVMLauncher: + def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher: super().__init__(*args, **kwargs) self.signature: List[str] = src.signature.values() @@ -32,8 +33,6 @@ class TVMLauncher(object): launch_exit_hook, *args, ): - from triton.runtime import _allocation - def allocate_scratch(size, align, allocator): if size > 0: grid_size = gridX * gridY * gridZ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index aef7cc9..e0f2423 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,7 +14,7 @@ target_include_directories( target_compile_options( ${TARGET_NAME} PRIVATE - $<$:-O0 -g -DDEBUG> + $<$:-O0 -g> $<$:-O3 -DNDEBUG> ) target_link_libraries( diff --git a/src/exception.cc b/src/exception.cc index 09021c2..1bb5757 100644 --- a/src/exception.cc +++ b/src/exception.cc @@ -11,18 +11,19 @@ const char *CUDAException::what() const noexcept { } NotImplementedException::NotImplementedException(std::string_view name) - : message_("\"" + std::string(name) + "\" is not implemented") {} + : message_("[NotImplementedException]: \"" + std::string(name) + "\"") {} const char *NotImplementedException::what() const noexcept { return message_.c_str(); } UnknownTypeException::UnknownTypeException(Type type) - : message_("unknown type: " + std::string(TypeToString(type))) {} + : message_("[UnknownTypeException]: unknown type: \"" + + std::string(TypeToString(type)) + "\"") {} UnknownTypeException::UnknownTypeException(std::string_view type) - : message_("unknown type: " + std::string(type)) {} - + : message_("[UnknownTypeException]: unknown type: \"" + std::string(type) + + "\"") {} const char *UnknownTypeException::what() const noexcept { return message_.c_str(); } diff --git a/src/type.cc b/src/type.cc index 0e7a426..9836b5d 100644 --- a/src/type.cc +++ b/src/type.cc @@ -13,31 +13,27 @@ const char *TypeToString(Type type) { return str; TYPE_TABLE(CASE_ENUM) #undef CASE_ENUM - case Type::PTR: - return "*?"; - case Type::CONSTEXPR: - return "constexpr"; default: throw UnknownTypeException(type); } } tvm::ffi::Optional StringToType(tvm::ffi::String name) { -#define IF_ENUM(type, str, ctype) \ - if (name == str) { \ - return Type::type; \ - } - TYPE_TABLE(IF_ENUM) -#undef IF_ENUM if (name.starts_with("*")) { return Type::PTR; } if (name == "constexpr") { return Type::CONSTEXPR; } +#define IF_ENUM(type, str, ctype) \ + if (name == str) { \ + return Type::type; \ + } + TYPE_TABLE(IF_ENUM) +#undef IF_ENUM if (name.starts_with("tensordesc") || name == "nvTmaDesc") { - // TODO: - assert(false); + throw NotImplementedException( + "tensordesc and nvTmaDesc are not supported in triton-tvm-ffi yet."); } return std::nullopt; } diff --git a/src/utils.cc b/src/utils.cc index 10f0132..b04bc9c 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -5,6 +5,9 @@ #include #include #include +#ifndef NDEBUG +#include +#endif #define CUDA_CHECK(code) \ do { \ @@ -20,11 +23,11 @@ using namespace triton_tvm_ffi; // --------------- Definitions --------------- template struct ValueCast { - TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value); + TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value); }; template struct ValueCastSet { - TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value); + TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value); }; using GenericValueCastSet = ValueCastSet< @@ -36,29 +39,32 @@ using GenericValueCastSet = ValueCastSet< // --------------- Implementations --------------- template -TRITON_TVM_FFI_INLINE bool ValueCast::apply(void **addr, +TRITON_TVM_FFI_INLINE bool ValueCast::apply(void *ptr, const TypedValue &value) { if (value.GetType() == T) { if constexpr (T == Type::PTR) { tvm::ffi::TensorView cvalue = value.GetValue().cast(); - void **ptr = reinterpret_cast(alloca(sizeof(void *))); - *ptr = cvalue.data_ptr(); - *addr = ptr; + void **p = reinterpret_cast(ptr); + *p = cvalue.data_ptr(); + } else if constexpr (T == Type::CONSTEXPR) { +#ifdef NDEBUG + __builtin_unreachable(); +#else + throw NotImplementedException("CONSTEXPR for value casting"); +#endif } else { using ctype = type_to_ctype_t; ctype cvalue = value.GetValue().cast(); - ctype *ptr = reinterpret_cast(alloca(sizeof(ctype))); - *ptr = cvalue; - *addr = ptr; + ctype *p = reinterpret_cast(ptr); + *p = cvalue; } return true; - } else { - return false; } + return false; } template -TRITON_TVM_FFI_INLINE bool ValueCastSet::apply(void **ptr, +TRITON_TVM_FFI_INLINE bool ValueCastSet::apply(void *ptr, const TypedValue &value) { return (ValueCast::apply(ptr, value) || ...); } @@ -164,15 +170,23 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); } const int32_t kernelArgNum = kernelArgs.size(); + uint8_t *buffer = + reinterpret_cast(alloca(kMaxOpaqueSize * (kernelArgNum))); void **params = reinterpret_cast(alloca(sizeof(void *) * (kernelArgNum + 2))); size_t j = 0; for (size_t i = 0; i < kernelArgNum; ++i) { TypedValue value = kernelArgs[i].cast(); if (value.GetType() != Type::CONSTEXPR) { - if (!GenericValueCastSet::apply(¶ms[j++], value)) { +#ifndef NDEBUG + assert(j < kernelArgNum); +#endif + void *ptr = buffer + j * kMaxOpaqueSize; + params[j] = ptr; + if (!GenericValueCastSet::apply(ptr, value)) { throw UnknownTypeException(value.GetType()); } + ++j; } } // TODO: unwrap PyObject* from scratch pointers and assign to kernel args