diff --git a/CMakeLists.txt b/CMakeLists.txt index b659ac9..85caa3a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,8 @@ else() project(triton-tvm-ffi) endif() +string(REPLACE "-" "_" TARGET_NAME "${PROJECT_NAME}") + if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) else(CMAKE_BUILD_TYPE STREQUAL "Release") diff --git a/include/exception.h b/include/exception.h index ee9d9c3..1685605 100644 --- a/include/exception.h +++ b/include/exception.h @@ -1,6 +1,7 @@ #ifndef TRITON_TVM_FFI_EXCEPTION_H_ #define TRITON_TVM_FFI_EXCEPTION_H_ +#include "type.h" #include #include @@ -12,7 +13,17 @@ public: const char *what() const noexcept override; private: - const CUresult code; + const CUresult code_; +}; + +class UnknownTypeException : public std::exception { +public: + UnknownTypeException(Type type); + UnknownTypeException(std::string_view type); + const char *what() const noexcept override; + +private: + const std::string message_; }; } // namespace triton_tvm_ffi diff --git a/include/macro.h b/include/macro.h new file mode 100644 index 0000000..e1417be --- /dev/null +++ b/include/macro.h @@ -0,0 +1,8 @@ +#ifndef TRITON_TVM_FFI_MACRO_H_ +#define TRITON_TVM_FFI_MACRO_H_ + +#if defined(__GNUC__) || defined(__clang__) +#define TRITON_TVM_FFI_INLINE __attribute__((always_inline)) inline +#endif + +#endif diff --git a/include/type.h b/include/type.h new file mode 100644 index 0000000..1e5c929 --- /dev/null +++ b/include/type.h @@ -0,0 +1,53 @@ +#ifndef TRITON_TVM_FFI_TYPE_H_ +#define TRITON_TVM_FFI_TYPE_H_ + +#include +#include +#include + +namespace triton_tvm_ffi { + +// --------------- Definitions --------------- // + +#define TYPE_TABLE(V) \ + V(I1, "i1", int8_t) \ + V(I8, "i8", int8_t) \ + V(I16, "i16", int16_t) \ + V(I32, "i32", int32_t) \ + V(I64, "i64", int64_t) \ + V(U1, "u1", uint8_t) \ + V(U8, "u8", uint8_t) \ + V(U16, "u16", uint16_t) \ + V(U32, "u32", uint32_t) \ + V(U64, "u64", uint64_t) \ + V(FP16, "fp16", double) \ + V(BF16, "bf16", double) \ + V(FP32, "f32", double) \ + V(FP64, "fp64", double) + +enum class Type : int64_t { +#define DEFINE_ENUM(type, str, ctype) type, + TYPE_TABLE(DEFINE_ENUM) +#undef DEFINE_ENUM + PTR, + CONSTEXPR, +}; + +const char *TypeToString(Type type); +tvm::ffi::Optional StringToType(tvm::ffi::String str); + +template struct type_to_ctype; +#define DEFINE_TYPE_TO_CTYPE(type, str, ctype) \ + template <> struct type_to_ctype { using t = ctype; }; +TYPE_TABLE(DEFINE_TYPE_TO_CTYPE) +#undef DEFINE_TYPE_TO_CTYPE +template <> struct type_to_ctype { using t = void *; }; +// TODO: check whether CUtensorMap* is correct +template <> struct type_to_ctype { using t = void; }; +template using type_to_ctype_t = typename type_to_ctype::t; + +// --------------- Implementations --------------- // + +} // namespace triton_tvm_ffi + +#endif diff --git a/include/value.h b/include/value.h new file mode 100644 index 0000000..b764d48 --- /dev/null +++ b/include/value.h @@ -0,0 +1,45 @@ +#ifndef TRITON_TVM_FFI_VALUE_H_ +#define TRITON_TVM_FFI_VALUE_H_ + +#include "macro.h" +#include "type.h" +#include +#include + +namespace triton_tvm_ffi { + +class TypedValueObj : public tvm::ffi::Object { +public: + TypedValueObj(Type type, const tvm::ffi::Any &value); + TypedValueObj(Type type, tvm::ffi::Any &&value); + TypedValueObj(const TypedValueObj &other) = default; + TypedValueObj(TypedValueObj &&other) = default; + TypedValueObj &operator=(const TypedValueObj &other) = default; + TypedValueObj &operator=(TypedValueObj &&other) = default; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TypedValue", TypedValueObj, + tvm::ffi::Object); + TRITON_TVM_FFI_INLINE Type GetType() const { return type_; } + TRITON_TVM_FFI_INLINE const tvm::ffi::Any &GetValue() const { return value_; } + +private: + Type type_; + tvm::ffi::Any value_; +}; + +class TypedValue : public tvm::ffi::ObjectRef { +public: + TypedValue(Type type, const tvm::ffi::Any &value); + TypedValue(Type type, tvm::ffi::Any &&value); + using tvm::ffi::ObjectRef::ObjectRef; + using tvm::ffi::ObjectRef::operator=; + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TypedValue, tvm::ffi::ObjectRef, + TypedValueObj); + TRITON_TVM_FFI_INLINE Type GetType() const { return get()->GetType(); } + TRITON_TVM_FFI_INLINE const tvm::ffi::Any &GetValue() const { + return get()->GetValue(); + } +}; + +} // namespace triton_tvm_ffi + +#endif diff --git a/python/triton_tvm_ffi/__init__.py b/python/triton_tvm_ffi/__init__.py index 03235a6..04e1469 100644 --- a/python/triton_tvm_ffi/__init__.py +++ b/python/triton_tvm_ffi/__init__.py @@ -1,3 +1,11 @@ -from . import utils - -__all__ = ["utils"] +# tvm-ffi-stubgen(begin): export/_ffi_api +# fmt: off +# isort: off +from ._ffi_api import * # noqa: F403 +from ._ffi_api import __all__ as _ffi_api__all__ +if "__all__" not in globals(): + __all__ = [] +__all__.extend(_ffi_api__all__) +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/triton_tvm_ffi/_ffi_api.py b/python/triton_tvm_ffi/_ffi_api.py new file mode 100644 index 0000000..310f727 --- /dev/null +++ b/python/triton_tvm_ffi/_ffi_api.py @@ -0,0 +1,54 @@ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import Object as _ffi_Object, init_ffi_api as _FFI_INIT_FUNC, register_object as _FFI_REG_OBJ +from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping + from tvm_ffi import Object + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) + +# tvm-ffi-stubgen(import-object): tvm_ffi.libinfo.load_lib_module;False;_FFI_LOAD_LIB +LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "triton_tvm_ffi") +# tvm-ffi-stubgen(begin): global/triton_tvm_ffi +# fmt: off +_FFI_INIT_FUNC("triton_tvm_ffi", __name__) +if TYPE_CHECKING: + def get_device_properties(_0: int, /) -> Mapping[str, int]: ... + def launch(*args: Any) -> Any: ... + def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... + def string_to_type(_0: str, /) -> int | None: ... + def type_to_string(_0: int, /) -> str: ... +# fmt: on +# tvm-ffi-stubgen(end) + + +# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ +# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object +@_FFI_REG_OBJ("triton_tvm_ffi.TypedValue") +class TypedValue(_ffi_Object): + # tvm-ffi-stubgen(begin): object/triton_tvm_ffi.TypedValue + # fmt: off + if TYPE_CHECKING: + @staticmethod + def __c_ffi_init__(_0: int, _1: Any, /) -> Object: ... + # fmt: on + # tvm-ffi-stubgen(end) + + +__all__ = [ + # tvm-ffi-stubgen(begin): __all__ + "LIB", + "TypedValue", + "get_device_properties", + "launch", + "load_binary", + "string_to_type", + "type_to_string", + # tvm-ffi-stubgen(end) +] diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index dababea..63fc124 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -1,9 +1,9 @@ from __future__ import annotations from ctypes import c_void_p -from typing import List, Mapping, Tuple, Type +from typing import Any, List, Mapping, Optional, Tuple, Type from triton.backends.nvidia.driver import CudaDriver -from .utils import get_device_properties, launch, load_binary +from . import TypedValue, get_device_properties, launch, load_binary, string_to_type class TVMFFIUtils(object): @@ -53,7 +53,7 @@ class TVMLauncher(object): def __init__(self, src: List[bool], metadata, *args, **kwargs) -> TVMLauncher: super().__init__(*args, **kwargs) - self.mask: List[bool] = [annotation != "constexpr" for annotation in src.signature.values()] + self.signature = src.signature.values() self.num_ctas = getattr(metadata, "num_ctas", 1) self.launch = launch self.global_scratch_size = metadata.global_scratch_size @@ -63,12 +63,6 @@ class TVMLauncher(object): self.launch_cooperative_grid = metadata.launch_cooperative_grid self.launch_pdl = metadata.launch_pdl - # We assume the global Triton allocator is not enabled: `_allocator` must be a NullAllocator. - # This module depends on NullAllocator behavior; ensure no other code replaces the allocator. - from triton.runtime._allocation import _allocator, NullAllocator - - assert isinstance(_allocator.get(), NullAllocator) - def __call__( self, gridX, @@ -102,8 +96,15 @@ class TVMLauncher(object): ) assert not self.launch_cooperative_grid assert not self.launch_pdl - assert len(self.mask) == len(args) - args = [arg for arg, m in zip(args, self.mask) if m] + assert len(self.signature) == len(args) + + def canonicalize(arg: Any, sig: str) -> TypedValue: + ty: Optional[int] = string_to_type(sig) + assert ty is not None, sig + return TypedValue(ty, arg) + + args = [canonicalize(arg, sig) for arg, sig in zip(args, self.signature)] + return launch( gridX, gridY, diff --git a/python/triton_tvm_ffi/utils/__init__.py b/python/triton_tvm_ffi/utils/__init__.py deleted file mode 100644 index 04e1469..0000000 --- a/python/triton_tvm_ffi/utils/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# tvm-ffi-stubgen(begin): export/_ffi_api -# fmt: off -# isort: off -from ._ffi_api import * # noqa: F403 -from ._ffi_api import __all__ as _ffi_api__all__ -if "__all__" not in globals(): - __all__ = [] -__all__.extend(_ffi_api__all__) -# isort: on -# fmt: on -# tvm-ffi-stubgen(end) diff --git a/python/triton_tvm_ffi/utils/_ffi_api.py b/python/triton_tvm_ffi/utils/_ffi_api.py deleted file mode 100644 index 351a47f..0000000 --- a/python/triton_tvm_ffi/utils/_ffi_api.py +++ /dev/null @@ -1,33 +0,0 @@ -# tvm-ffi-stubgen(begin): import-section -# fmt: off -# isort: off -from __future__ import annotations -from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC -from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from collections.abc import Mapping - from typing import Any -# isort: on -# fmt: on -# tvm-ffi-stubgen(end) -# tvm-ffi-stubgen(import-object): tvm_ffi.libinfo.load_lib_module;False;_FFI_LOAD_LIB -LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "utils") -# tvm-ffi-stubgen(begin): global/triton_tvm_ffi.utils -# fmt: off -_FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__) -if TYPE_CHECKING: - def get_device_properties(_0: int, /) -> Mapping[str, int]: ... - def launch(*args: Any) -> Any: ... - def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... -# fmt: on -# tvm-ffi-stubgen(end) - -__all__ = [ - # tvm-ffi-stubgen(begin): __all__ - "LIB", - "get_device_properties", - "launch", - "load_binary", - # tvm-ffi-stubgen(end) -] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4d6ba2d..aef7cc9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,34 +1,36 @@ add_library( - utils + ${TARGET_NAME} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/exception.cc + ${CMAKE_CURRENT_SOURCE_DIR}/type.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/value.cc ) target_include_directories( - utils + ${TARGET_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/include ) target_compile_options( - utils + ${TARGET_NAME} PRIVATE $<$:-O0 -g -DDEBUG> $<$:-O3 -DNDEBUG> ) target_link_libraries( - utils + ${TARGET_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver ) tvm_ffi_configure_target( - utils + ${TARGET_NAME} STUB_DIR "${CMAKE_SOURCE_DIR}/python" STUB_INIT ON ) install( - TARGETS utils + TARGETS ${TARGET_NAME} LIBRARY DESTINATION . ) -tvm_ffi_install(utils DESTINATION .) +tvm_ffi_install(${TARGET_NAME} DESTINATION .) diff --git a/src/exception.cc b/src/exception.cc index 110eeb7..a255add 100644 --- a/src/exception.cc +++ b/src/exception.cc @@ -2,12 +2,22 @@ namespace triton_tvm_ffi { -CUDAException::CUDAException(CUresult code) : code(code) {} +CUDAException::CUDAException(CUresult code) : code_(code) {} const char *CUDAException::what() const noexcept { const char *p = nullptr; - cuGetErrorString(code, &p); + cuGetErrorString(code_, &p); return p; } +UnknownTypeException::UnknownTypeException(Type type) + : message_("unknown type: " + std::string(TypeToString(type))) {} + +UnknownTypeException::UnknownTypeException(std::string_view type) + : message_("unknown type: " + std::string(type)) {} + +const char *UnknownTypeException::what() const noexcept { + return message_.c_str(); +} + } // namespace triton_tvm_ffi diff --git a/src/type.cc b/src/type.cc new file mode 100644 index 0000000..0e7a426 --- /dev/null +++ b/src/type.cc @@ -0,0 +1,52 @@ +#include "type.h" +#include "exception.h" +#include +#include +#include + +namespace triton_tvm_ffi { + +const char *TypeToString(Type type) { + switch (type) { +#define CASE_ENUM(type, str, ctype) \ + case Type::type: \ + return str; + TYPE_TABLE(CASE_ENUM) +#undef CASE_ENUM + case Type::PTR: + return "*?"; + case Type::CONSTEXPR: + return "constexpr"; + default: + throw UnknownTypeException(type); + } +} + +tvm::ffi::Optional StringToType(tvm::ffi::String name) { +#define IF_ENUM(type, str, ctype) \ + if (name == str) { \ + return Type::type; \ + } + TYPE_TABLE(IF_ENUM) +#undef IF_ENUM + if (name.starts_with("*")) { + return Type::PTR; + } + if (name == "constexpr") { + return Type::CONSTEXPR; + } + if (name.starts_with("tensordesc") || name == "nvTmaDesc") { + // TODO: + assert(false); + } + return std::nullopt; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("triton_tvm_ffi.type_to_string", TypeToString) + .def("triton_tvm_ffi.string_to_type", StringToType); +} + +} // namespace triton_tvm_ffi diff --git a/src/utils.cc b/src/utils.cc index 3713d84..ed75cc2 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,6 +1,8 @@ #include "exception.h" -#include +#include "type.h" +#include "value.h" #include +#include #include #include @@ -11,6 +13,60 @@ } \ } while (false) +namespace { + +using namespace triton_tvm_ffi; + +// --------------- Definitions --------------- + +template struct ValueCast { + TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value); +}; + +template struct ValueCastSet { + TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value); +}; + +using GenericValueCastSet = ValueCastSet< +#define DEFINE_TYPE_CAST(type, str, ctype) Type::type, + TYPE_TABLE(DEFINE_TYPE_CAST) +#undef DEFINE_TYPE_CAST + Type::PTR>; + +// --------------- Implementations --------------- + +template +TRITON_TVM_FFI_INLINE bool ValueCast::apply(void **addr, + const TypedValue &value) { + if (value.GetType() == T) { + if constexpr (T == Type::PTR) { + tvm::ffi::TensorView cvalue = + value.GetValue().cast(); + void **ptr = reinterpret_cast(alloca(sizeof(void *))); + *ptr = cvalue.data_ptr(); + *addr = ptr; + } else { + using ctype = type_to_ctype_t; + ctype cvalue = value.GetValue().cast(); + ctype *ptr = reinterpret_cast(alloca(sizeof(ctype))); + *ptr = cvalue; + *addr = ptr; + } + return true; + } else { + return false; + } +} +template +TRITON_TVM_FFI_INLINE bool ValueCastSet::apply(void **ptr, + const TypedValue &value) { + return (ValueCast::apply(ptr, value) || ...); +} + +} // namespace + +namespace triton_tvm_ffi { + tvm::ffi::Map GetDeviceProperties(int device_id) { tvm::ffi::cuda_api::DeviceHandle device; CUDA_CHECK(cuDeviceGet(&device, device_id)); @@ -153,26 +209,17 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { const int32_t kernelArgNum = kernelArgs.size(); void **params = reinterpret_cast(alloca(sizeof(void *) * (kernelArgNum + 2))); + size_t j = 0; for (size_t i = 0; i < kernelArgNum; ++i) { - tvm::ffi::AnyView arg = kernelArgs[i]; - if (auto val = arg.try_cast()) { - void **ptr = reinterpret_cast(alloca(sizeof(void *))); - *ptr = val->data_ptr(); - params[i] = ptr; - } else if (auto val = arg.try_cast()) { - int32_t *ptr = reinterpret_cast(alloca(sizeof(int32_t))); - *ptr = *val; - params[i] = ptr; - } else if (auto val = arg.try_cast()) { - float *ptr = reinterpret_cast(alloca(sizeof(float))); - *ptr = *val; - params[i] = ptr; - } else { - assert(false && "unsupported kernel argument type"); + TypedValue value = kernelArgs[i].cast(); + if (value.GetType() != Type::CONSTEXPR) { + if (!GenericValueCastSet::apply(¶ms[j++], value)) { + throw UnknownTypeException(value.GetType()); + } } } - params[kernelArgNum] = &globalScratch; - params[kernelArgNum + 1] = &profileScratch; + params[j] = &globalScratch; + params[j + 1] = &profileScratch; CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr)); } // TODO: call `launchExitHook` @@ -181,7 +228,9 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties) - .def_packed("triton_tvm_ffi.utils.launch", Launch) - .def("triton_tvm_ffi.utils.load_binary", LoadBinary); + .def("triton_tvm_ffi.get_device_properties", GetDeviceProperties) + .def_packed("triton_tvm_ffi.launch", Launch) + .def("triton_tvm_ffi.load_binary", LoadBinary); } + +} // namespace triton_tvm_ffi \ No newline at end of file diff --git a/src/value.cc b/src/value.cc new file mode 100644 index 0000000..92557ca --- /dev/null +++ b/src/value.cc @@ -0,0 +1,27 @@ +#include "value.h" +#include "type.h" +#include +#include + +namespace triton_tvm_ffi { + +TypedValueObj::TypedValueObj(Type type, const tvm::ffi::Any &value) + : type_(type), value_(value) {} + +TypedValueObj::TypedValueObj(Type type, tvm::ffi::Any &&value) + : type_(type), value_(std::move(value)) {} + +TypedValue::TypedValue(Type type, const tvm::ffi::Any &value) + : tvm::ffi::ObjectRef(tvm::ffi::make_object(type, value)) {} + +TypedValue::TypedValue(Type type, tvm::ffi::Any &&value) + : tvm::ffi::ObjectRef( + tvm::ffi::make_object(type, std::move(value))) {} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def( + refl::init()); +} + +} // namespace triton_tvm_ffi