From 1a01c9f2d81913ce0f623fa35d3da3a47e9980a2 Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Thu, 29 Jan 2026 18:23:02 +0800 Subject: [PATCH] put utils in C++ side Signed-off-by: Jinjie Liu --- include/exception.h | 9 +++ include/value.h | 8 +-- python/triton_tvm_ffi/_ffi_api.py | 9 --- python/triton_tvm_ffi/driver.py | 70 ++++-------------- python/triton_tvm_ffi/utils.py | 38 ++++++++++ src/exception.cc | 7 ++ src/utils.cc | 114 +++++++++++++++++------------- 7 files changed, 137 insertions(+), 118 deletions(-) create mode 100644 python/triton_tvm_ffi/utils.py diff --git a/include/exception.h b/include/exception.h index 1685605..cf118af 100644 --- a/include/exception.h +++ b/include/exception.h @@ -16,6 +16,15 @@ private: const CUresult code_; }; +class NotImplementedException : public std::exception { +public: + NotImplementedException(std::string_view name); + const char *what() const noexcept override; + +private: + const std::string message_; +}; + class UnknownTypeException : public std::exception { public: UnknownTypeException(Type type); diff --git a/include/value.h b/include/value.h index b764d48..15c338d 100644 --- a/include/value.h +++ b/include/value.h @@ -16,10 +16,10 @@ public: TypedValueObj(TypedValueObj &&other) = default; TypedValueObj &operator=(const TypedValueObj &other) = default; TypedValueObj &operator=(TypedValueObj &&other) = default; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TypedValue", TypedValueObj, - tvm::ffi::Object); TRITON_TVM_FFI_INLINE Type GetType() const { return type_; } TRITON_TVM_FFI_INLINE const tvm::ffi::Any &GetValue() const { return value_; } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TypedValue", TypedValueObj, + tvm::ffi::Object); private: Type type_; @@ -32,12 +32,12 @@ public: TypedValue(Type type, tvm::ffi::Any &&value); using tvm::ffi::ObjectRef::ObjectRef; using tvm::ffi::ObjectRef::operator=; - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TypedValue, tvm::ffi::ObjectRef, - TypedValueObj); TRITON_TVM_FFI_INLINE Type GetType() const { return get()->GetType(); } TRITON_TVM_FFI_INLINE const tvm::ffi::Any &GetValue() const { return get()->GetValue(); } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TypedValue, tvm::ffi::ObjectRef, + TypedValueObj); }; } // namespace triton_tvm_ffi diff --git a/python/triton_tvm_ffi/_ffi_api.py b/python/triton_tvm_ffi/_ffi_api.py index 310f727..794fc89 100644 --- a/python/triton_tvm_ffi/_ffi_api.py +++ b/python/triton_tvm_ffi/_ffi_api.py @@ -6,7 +6,6 @@ from tvm_ffi import Object as _ffi_Object, init_ffi_api as _FFI_INIT_FUNC, regis from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Mapping from tvm_ffi import Object from typing import Any # isort: on @@ -19,15 +18,11 @@ LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "triton_tvm_ffi") # fmt: off _FFI_INIT_FUNC("triton_tvm_ffi", __name__) if TYPE_CHECKING: - def get_device_properties(_0: int, /) -> Mapping[str, int]: ... - def launch(*args: Any) -> Any: ... - def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... def string_to_type(_0: str, /) -> int | None: ... def type_to_string(_0: int, /) -> str: ... # fmt: on # tvm-ffi-stubgen(end) - # tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ # tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object @_FFI_REG_OBJ("triton_tvm_ffi.TypedValue") @@ -40,14 +35,10 @@ class TypedValue(_ffi_Object): # fmt: on # tvm-ffi-stubgen(end) - __all__ = [ # tvm-ffi-stubgen(begin): __all__ "LIB", "TypedValue", - "get_device_properties", - "launch", - "load_binary", "string_to_type", "type_to_string", # tvm-ffi-stubgen(end) diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index 63fc124..22c29ff 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -1,67 +1,23 @@ from __future__ import annotations -from ctypes import c_void_p -from typing import Any, List, Mapping, Optional, Tuple, Type +from typing import Any, List, Optional, Type from triton.backends.nvidia.driver import CudaDriver -from . import TypedValue, get_device_properties, launch, load_binary, string_to_type - - -class TVMFFIUtils(object): - def __new__(cls: Type[TVMFFIUtils]) -> TVMFFIUtils: - if not hasattr(cls, "instance"): - cls.instance = super().__new__(cls) - return cls.instance - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def load_binary( - self, name: str, data: bytes, shared: int, device: int - ) -> Tuple[c_void_p, c_void_p, int, int, int]: - return load_binary(name, data, shared, device) - - def get_device_properties(self, device_id: int) -> Mapping[str, int]: - return get_device_properties(device_id) - - def cuOccupancyMaxActiveClusters(self, *args, **kwargs): - raise NotImplementedError( - '"cuOccupancyMaxActiveClusters isn\'t expected to be invoked"' - ) - - def set_printf_fifo_size(self, *args, **kwargs): - raise NotImplementedError( - '"set_printf_fifo_size" isn\'t expected to be invoked' - ) - - def fill_tma_descriptor(self, *args, **kwargs): - raise NotImplementedError( - '"fill_tma_descriptor" hasn\'t been supported for Hopper' - ) - - def launch(self, *args, **kwargs): - raise NotImplementedError( - '"launch" is introduced in triton after commit d2b3925410689155e0f6028e8554bba972989348, which is still not supported yed' - ) - - def build_signature_metadata(self, *args, **kwargs): - raise NotImplementedError( - '"launch" is introduced in triton after commit d2b3925410689155e0f6028e8554bba972989348, which is still not supported yed' - ) +from . import TypedValue, utils, string_to_type class TVMLauncher(object): def __init__(self, src: List[bool], metadata, *args, **kwargs) -> TVMLauncher: super().__init__(*args, **kwargs) - self.signature = src.signature.values() - self.num_ctas = getattr(metadata, "num_ctas", 1) - self.launch = launch - self.global_scratch_size = metadata.global_scratch_size - self.global_scratch_align = metadata.global_scratch_align - self.profile_scratch_size = metadata.profile_scratch_size - self.profile_scratch_align = metadata.profile_scratch_align - self.launch_cooperative_grid = metadata.launch_cooperative_grid - self.launch_pdl = metadata.launch_pdl + self.signature: List[str] = src.signature.values() + self.num_ctas: int = getattr(metadata, "num_ctas", 1) + self.launch = utils.launch + self.global_scratch_size: int = metadata.global_scratch_size + self.global_scratch_align: int = metadata.global_scratch_align + self.profile_scratch_size: int = metadata.profile_scratch_size + self.profile_scratch_align: int = metadata.profile_scratch_align + self.launch_cooperative_grid: bool = metadata.launch_cooperative_grid + self.launch_pdl: bool = metadata.launch_pdl def __call__( self, @@ -105,7 +61,7 @@ class TVMLauncher(object): args = [canonicalize(arg, sig) for arg, sig in zip(args, self.signature)] - return launch( + return self.launch( gridX, gridY, gridZ, @@ -126,7 +82,7 @@ class TVMLauncher(object): class TVMFFIDriver(CudaDriver): def __init__(self, *args, **kwargs) -> TVMFFIDriver: super().__init__(*args, **kwargs) - self.utils: TVMFFIUtils = TVMFFIUtils() + self.utils = utils self.launcher_cls: Type[TVMLauncher] = TVMLauncher diff --git a/python/triton_tvm_ffi/utils.py b/python/triton_tvm_ffi/utils.py new file mode 100644 index 0000000..f0b22ac --- /dev/null +++ b/python/triton_tvm_ffi/utils.py @@ -0,0 +1,38 @@ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) + +# tvm-ffi-stubgen(begin): global/triton_tvm_ffi.utils +# fmt: off +_FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__) +if TYPE_CHECKING: + def build_signature_metadata(*args: Any) -> Any: ... + def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ... + def fill_tma_descriptor(*args: Any) -> Any: ... + def get_device_properties(_0: int, /) -> Mapping[str, int]: ... + def launch(*args: Any) -> Any: ... + def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... + def set_printf_fifo_size(*args: Any) -> Any: ... +# fmt: on +# tvm-ffi-stubgen(end) + +__all__ = [ + # tvm-ffi-stubgen(begin): __all__ + "build_signature_metadata", + "cuOccupancyMaxActiveClusters", + "fill_tma_descriptor", + "get_device_properties", + "launch", + "load_binary", + "set_printf_fifo_size", + # tvm-ffi-stubgen(end) +] diff --git a/src/exception.cc b/src/exception.cc index a255add..09021c2 100644 --- a/src/exception.cc +++ b/src/exception.cc @@ -10,6 +10,13 @@ const char *CUDAException::what() const noexcept { return p; } +NotImplementedException::NotImplementedException(std::string_view name) + : message_("\"" + std::string(name) + "\" is not implemented") {} + +const char *NotImplementedException::what() const noexcept { + return message_.c_str(); +} + UnknownTypeException::UnknownTypeException(Type type) : message_("unknown type: " + std::string(TypeToString(type))) {} diff --git a/src/utils.cc b/src/utils.cc index ed75cc2..10f0132 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -2,8 +2,8 @@ #include "type.h" #include "value.h" #include -#include #include +#include #include #define CUDA_CHECK(code) \ @@ -101,49 +101,6 @@ tvm::ffi::Map GetDeviceProperties(int device_id) { {"mem_bus_width", memBusWidth}}; } -tvm::ffi::Tuple -LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data, - int32_t shared, CUdevice device) { - CUcontext pctx; - CUfunction fun; - CUmodule mod; - int32_t nRegs = 0; - int32_t nSpills = 0; - int32_t nMaxThreads = 0; - int32_t sharedOptin = 0; - CUDA_CHECK(cuCtxGetCurrent(&pctx)); - if (!pctx) { - CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); - CUDA_CHECK(cuCtxSetCurrent(pctx)); - } - CUDA_CHECK(cuModuleLoadData(&mod, data.data())); - CUDA_CHECK(cuModuleGetFunction(&fun, mod, name.data())); - CUDA_CHECK(cuFuncGetAttribute(&nRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); - CUDA_CHECK( - cuFuncGetAttribute(&nSpills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); - CUDA_CHECK(cuFuncGetAttribute(&nMaxThreads, - CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun)); - CUDA_CHECK(cuDeviceGetAttribute( - &sharedOptin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - device)); - static constexpr int64_t kExpectedMaxDynamicSharedMemory = 49152; - if (shared > kExpectedMaxDynamicSharedMemory && - sharedOptin > kExpectedMaxDynamicSharedMemory) { - CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); - int32_t sharedTotal = 0, sharedStatic = 0; - CUDA_CHECK(cuDeviceGetAttribute( - &sharedTotal, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, - device)); - CUDA_CHECK(cuFuncGetAttribute(&sharedStatic, - CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); - CUDA_CHECK( - cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - sharedOptin - sharedStatic)); - } - return tvm::ffi::Tuple{ - mod, fun, nRegs, nSpills, nMaxThreads}; -} - void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { CUtensorMap x; int32_t gridX = args[0].cast(); @@ -218,6 +175,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { } } } + // TODO: unwrap PyObject* from scratch pointers and assign to kernel args params[j] = &globalScratch; params[j + 1] = &profileScratch; CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr)); @@ -225,12 +183,72 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { // TODO: call `launchExitHook` } +tvm::ffi::Tuple +LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data, + int32_t shared, CUdevice device) { + CUcontext pctx; + CUfunction fun; + CUmodule mod; + int32_t nRegs = 0; + int32_t nSpills = 0; + int32_t nMaxThreads = 0; + int32_t sharedOptin = 0; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } + CUDA_CHECK(cuModuleLoadData(&mod, data.data())); + CUDA_CHECK(cuModuleGetFunction(&fun, mod, name.data())); + CUDA_CHECK(cuFuncGetAttribute(&nRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK( + cuFuncGetAttribute(&nSpills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + CUDA_CHECK(cuFuncGetAttribute(&nMaxThreads, + CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun)); + CUDA_CHECK(cuDeviceGetAttribute( + &sharedOptin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + static constexpr int64_t kExpectedMaxDynamicSharedMemory = 49152; + if (shared > kExpectedMaxDynamicSharedMemory && + sharedOptin > kExpectedMaxDynamicSharedMemory) { + CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int32_t sharedTotal = 0, sharedStatic = 0; + CUDA_CHECK(cuDeviceGetAttribute( + &sharedTotal, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device)); + CUDA_CHECK(cuFuncGetAttribute(&sharedStatic, + CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK( + cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedOptin - sharedStatic)); + } + return tvm::ffi::Tuple{ + mod, fun, nRegs, nSpills, nMaxThreads}; +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("triton_tvm_ffi.get_device_properties", GetDeviceProperties) - .def_packed("triton_tvm_ffi.launch", Launch) - .def("triton_tvm_ffi.load_binary", LoadBinary); + .def_packed("triton_tvm_ffi.utils.build_signature_metadata", + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + throw NotImplementedException("build_signature_metadata"); + }) + .def_packed("triton_tvm_ffi.utils.cuOccupancyMaxActiveClusters", + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + throw NotImplementedException( + "cuOccupancyMaxActiveClusters"); + }) + .def_packed("triton_tvm_ffi.utils.fill_tma_descriptor", + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + throw NotImplementedException("fill_tma_descriptor"); + }) + .def_packed("triton_tvm_ffi.utils.set_printf_fifo_size", + [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + throw NotImplementedException("set_printf_fifo_size"); + }) + .def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties) + .def_packed("triton_tvm_ffi.utils.launch", Launch) + .def("triton_tvm_ffi.utils.load_binary", LoadBinary); } -} // namespace triton_tvm_ffi \ No newline at end of file +} // namespace triton_tvm_ffi