From b79d880adf279c5334c03b71a0db4b9c4a5ed1ee Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Wed, 28 Jan 2026 14:30:43 +0800 Subject: [PATCH] remove CudaUtils dependencies Signed-off-by: Jinjie Liu --- python/triton_tvm_ffi/driver.py | 26 ++++++++----- python/triton_tvm_ffi/utils/_ffi_api.py | 2 + src/utils.cc | 49 ++++++++++++++++++++++++- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index d9c035d..88954ab 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Mapping, Type +from ctypes import c_void_p +from typing import Mapping, Tuple, Type from triton.backends.nvidia.driver import CudaDriver -from .utils import get_device_properties +from .utils import get_device_properties, load_binary class TVMFFIUtils(object): @@ -13,12 +14,11 @@ class TVMFFIUtils(object): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - from triton.backends.nvidia.driver import CudaUtils - self._utils: CudaUtils = CudaUtils() - - def load_binary(self, *args, **kwargs): - return self._utils.load_binary(*args, **kwargs) + def load_binary( + self, name: str, data: bytes, shared: int, device: int + ) -> Tuple[c_void_p, c_void_p, int, int, int]: + return load_binary(name, data, shared, device) def get_device_properties(self, device_id: int) -> Mapping[str, int]: return get_device_properties(device_id) @@ -34,13 +34,19 @@ class TVMFFIUtils(object): ) def fill_tma_descriptor(self, *args, **kwargs): - return self._utils.fill_tma_descriptor(*args, **kwargs) + raise NotImplementedError( + '"fill_tma_descriptor" hasn\'t been supported for Hopper' + ) def launch(self, *args, **kwargs): - return self._utils.launch(*args, **kwargs) + raise NotImplementedError( + '"launch" is introduced in triton after commit d2b3925410689155e0f6028e8554bba972989348, which is still not supported yed' + ) def build_signature_metadata(self, *args, **kwargs): - return self._utils.build_signature_metadata(*args, **kwargs) + raise NotImplementedError( + '"launch" is introduced in triton after commit d2b3925410689155e0f6028e8554bba972989348, which is still not supported yed' + ) class TVMFFIDriver(CudaDriver): diff --git a/python/triton_tvm_ffi/utils/_ffi_api.py b/python/triton_tvm_ffi/utils/_ffi_api.py index 5dbcc8f..79b64d3 100644 --- a/python/triton_tvm_ffi/utils/_ffi_api.py +++ b/python/triton_tvm_ffi/utils/_ffi_api.py @@ -17,6 +17,7 @@ LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "utils") _FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__) if TYPE_CHECKING: def get_device_properties(_0: int, /) -> Mapping[str, int]: ... + def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... # fmt: on # tvm-ffi-stubgen(end) @@ -24,5 +25,6 @@ __all__ = [ # tvm-ffi-stubgen(begin): __all__ "LIB", "get_device_properties", + "load_binary", # tvm-ffi-stubgen(end) ] diff --git a/src/utils.cc b/src/utils.cc index 0466a19..104f377 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,6 +1,7 @@ #include "exception.h" #include #include +#include #include #define CUDA_CHECK(code) \ @@ -44,8 +45,52 @@ tvm::ffi::Map GetDeviceProperties(int device_id) { {"mem_bus_width", mem_bus_width}}; } +tvm::ffi::Tuple +LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data, + int32_t shared, CUdevice device) { + CUcontext pctx; + CUfunction fun; + CUmodule mod; + int32_t nRegs = 0; + int32_t nSpills = 0; + int32_t nMaxThreads = 0; + int32_t sharedOptin = 0; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } + CUDA_CHECK(cuModuleLoadData(&mod, data.data())); + CUDA_CHECK(cuModuleGetFunction(&fun, mod, name.data())); + CUDA_CHECK(cuFuncGetAttribute(&nRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK( + cuFuncGetAttribute(&nSpills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + CUDA_CHECK(cuFuncGetAttribute(&nMaxThreads, + CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun)); + CUDA_CHECK(cuDeviceGetAttribute( + &sharedOptin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + static constexpr int64_t kExpectedMaxDynamicSharedMemory = 49152; + if (shared > kExpectedMaxDynamicSharedMemory && + sharedOptin > kExpectedMaxDynamicSharedMemory) { + CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int32_t sharedTotal = 0, sharedStatic = 0; + CUDA_CHECK(cuDeviceGetAttribute( + &sharedTotal, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device)); + CUDA_CHECK(cuFuncGetAttribute(&sharedStatic, + CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK( + cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedOptin - sharedStatic)); + } + return tvm::ffi::Tuple{ + mod, fun, nRegs, nSpills, nMaxThreads}; +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("triton_tvm_ffi.utils.get_device_properties", - GetDeviceProperties); + refl::GlobalDef() + .def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties) + .def("triton_tvm_ffi.utils.load_binary", LoadBinary); }