From 3a485166b4c9e636216821832ad0bf33acc9d6d1 Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Wed, 28 Jan 2026 02:16:10 +0800 Subject: [PATCH] support GetDeviceProperties Signed-off-by: Jinjie Liu --- CMakeLists.txt | 1 + README.md | 2 +- include/exception.h | 20 +++++++++++ python/triton_tvm_ffi/driver.py | 7 ++-- python/triton_tvm_ffi/utils/_ffi_api.py | 6 ++-- src/CMakeLists.txt | 8 ++++- src/exception.cc | 13 +++++++ src/utils.cc | 47 +++++++++++++++++++++++-- 8 files changed, 95 insertions(+), 9 deletions(-) create mode 100644 include/exception.h create mode 100644 src/exception.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 508829b..b659ac9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,7 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug") else(CMAKE_BUILD_TYPE STREQUAL "Release") endif() +find_package(CUDAToolkit REQUIRED) find_package(Python COMPONENTS Interpreter REQUIRED) execute_process( diff --git a/README.md b/README.md index 251643c..c7e70e7 100644 --- a/README.md +++ b/README.md @@ -9,5 +9,5 @@ SKBUILD_BUILD_DIR="build" SKBUILD_CMAKE_BUILD_TYPE=Debug uv pip install --no-bui ### Format ```bash -find python -name "*.py" | xargs ruff format +find python -name "*.py" | xargs ruff format && find include src -name "*.h" -o -name "*.cc" | xargs clang-format -i ``` diff --git a/include/exception.h b/include/exception.h new file mode 100644 index 0000000..ee9d9c3 --- /dev/null +++ b/include/exception.h @@ -0,0 +1,20 @@ +#ifndef TRITON_TVM_FFI_EXCEPTION_H_ +#define TRITON_TVM_FFI_EXCEPTION_H_ + +#include +#include + +namespace triton_tvm_ffi { + +class CUDAException : public std::exception { +public: + CUDAException(CUresult code); + const char *what() const noexcept override; + +private: + const CUresult code; +}; + +} // namespace triton_tvm_ffi + +#endif diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index eea20b6..2c49b66 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import Type +from typing import Mapping, Type from triton.backends.nvidia.driver import CudaDriver +from .utils import get_device_properties class TVMFFIUtils(object): @@ -19,8 +20,8 @@ class TVMFFIUtils(object): def load_binary(self, *args, **kwargs): return self._utils.load_binary(*args, **kwargs) - def get_device_properties(self, *args, **kwargs): - return self._utils.get_device_properties(*args, **kwargs) + def get_device_properties(self, device_id: int) -> Mapping[str, int]: + return get_device_properties(device_id) def cuOccupancyMaxActiveClusters(self, *args, **kwargs): return self._utils.cuOccupancyMaxActiveClusters(*args, **kwargs) diff --git a/python/triton_tvm_ffi/utils/_ffi_api.py b/python/triton_tvm_ffi/utils/_ffi_api.py index d2f4994..5dbcc8f 100644 --- a/python/triton_tvm_ffi/utils/_ffi_api.py +++ b/python/triton_tvm_ffi/utils/_ffi_api.py @@ -5,6 +5,8 @@ from __future__ import annotations from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping # isort: on # fmt: on # tvm-ffi-stubgen(end) @@ -14,13 +16,13 @@ LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "utils") # fmt: off _FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__) if TYPE_CHECKING: - def hello() -> None: ... + def get_device_properties(_0: int, /) -> Mapping[str, int]: ... # fmt: on # tvm-ffi-stubgen(end) __all__ = [ # tvm-ffi-stubgen(begin): __all__ "LIB", - "hello", + "get_device_properties", # tvm-ffi-stubgen(end) ] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 81ef820..4d6ba2d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,6 +1,7 @@ add_library( utils SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/exception.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc ) @@ -14,7 +15,12 @@ target_compile_options( $<$:-O0 -g -DDEBUG> $<$:-O3 -DNDEBUG> ) - +target_link_libraries( + utils + PRIVATE + CUDA::cudart + CUDA::cuda_driver +) tvm_ffi_configure_target( utils STUB_DIR "${CMAKE_SOURCE_DIR}/python" diff --git a/src/exception.cc b/src/exception.cc new file mode 100644 index 0000000..110eeb7 --- /dev/null +++ b/src/exception.cc @@ -0,0 +1,13 @@ +#include "exception.h" + +namespace triton_tvm_ffi { + +CUDAException::CUDAException(CUresult code) : code(code) {} + +const char *CUDAException::what() const noexcept { + const char *p = nullptr; + cuGetErrorString(code, &p); + return p; +} + +} // namespace triton_tvm_ffi diff --git a/src/utils.cc b/src/utils.cc index af5f6db..0466a19 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,8 +1,51 @@ +#include "exception.h" +#include +#include #include -void hello() { std::cout << "Hello, world!\n"; } +#define CUDA_CHECK(code) \ + do { \ + if ((code) != CUDA_SUCCESS) { \ + throw triton_tvm_ffi::CUDAException(code); \ + } \ + } while (false) + +tvm::ffi::Map GetDeviceProperties(int device_id) { + tvm::ffi::cuda_api::DeviceHandle device; + CUDA_CHECK(cuDeviceGet(&device, device_id)); + int max_shared_mem = 0; + int max_num_regs = 0; + int multiprocessor_count = 0; + int warp_size = 0; + int sm_clock_rate = 0; + int mem_clock_rate = 0; + int mem_bus_width = 0; + CUDA_CHECK(cuDeviceGetAttribute( + &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + CUDA_CHECK(cuDeviceGetAttribute( + &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); + CUDA_CHECK(cuDeviceGetAttribute( + &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + CUDA_CHECK( + cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); + CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, + CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + CUDA_CHECK(cuDeviceGetAttribute( + &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + CUDA_CHECK(cuDeviceGetAttribute( + &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + return {{"max_shared_mem", max_shared_mem}, + {"max_num_regs", max_num_regs}, + {"multiprocessor_count", multiprocessor_count}, + {"warp_size", warp_size}, + {"sm_clock_rate", sm_clock_rate}, + {"mem_clock_rate", mem_clock_rate}, + {"mem_bus_width", mem_bus_width}}; +} TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("triton_tvm_ffi.utils.hello", hello); + refl::GlobalDef().def("triton_tvm_ffi.utils.get_device_properties", + GetDeviceProperties); }