support GetDeviceProperties

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-28 02:16:10 +08:00
parent ab83dded12
commit 3a485166b4
8 changed files with 95 additions and 9 deletions
+1
View File
@@ -11,6 +11,7 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug")
else(CMAKE_BUILD_TYPE STREQUAL "Release")
endif()
find_package(CUDAToolkit REQUIRED)
find_package(Python COMPONENTS Interpreter REQUIRED)
execute_process(
+1 -1
View File
@@ -9,5 +9,5 @@ SKBUILD_BUILD_DIR="build" SKBUILD_CMAKE_BUILD_TYPE=Debug uv pip install --no-bui
### Format
```bash
find python -name "*.py" | xargs ruff format
find python -name "*.py" | xargs ruff format && find include src -name "*.h" -o -name "*.cc" | xargs clang-format -i
```
+20
View File
@@ -0,0 +1,20 @@
#ifndef TRITON_TVM_FFI_EXCEPTION_H_
#define TRITON_TVM_FFI_EXCEPTION_H_
#include <cuda.h>
#include <exception>
namespace triton_tvm_ffi {
class CUDAException : public std::exception {
public:
CUDAException(CUresult code);
const char *what() const noexcept override;
private:
const CUresult code;
};
} // namespace triton_tvm_ffi
#endif
+4 -3
View File
@@ -1,7 +1,8 @@
from __future__ import annotations
from typing import Type
from typing import Mapping, Type
from triton.backends.nvidia.driver import CudaDriver
from .utils import get_device_properties
class TVMFFIUtils(object):
@@ -19,8 +20,8 @@ class TVMFFIUtils(object):
def load_binary(self, *args, **kwargs):
return self._utils.load_binary(*args, **kwargs)
def get_device_properties(self, *args, **kwargs):
return self._utils.get_device_properties(*args, **kwargs)
def get_device_properties(self, device_id: int) -> Mapping[str, int]:
return get_device_properties(device_id)
def cuOccupancyMaxActiveClusters(self, *args, **kwargs):
return self._utils.cuOccupancyMaxActiveClusters(*args, **kwargs)
+4 -2
View File
@@ -5,6 +5,8 @@ from __future__ import annotations
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
@@ -14,13 +16,13 @@ LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "utils")
# fmt: off
_FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__)
if TYPE_CHECKING:
def hello() -> None: ...
def get_device_properties(_0: int, /) -> Mapping[str, int]: ...
# fmt: on
# tvm-ffi-stubgen(end)
__all__ = [
# tvm-ffi-stubgen(begin): __all__
"LIB",
"hello",
"get_device_properties",
# tvm-ffi-stubgen(end)
]
+7 -1
View File
@@ -1,6 +1,7 @@
add_library(
utils
SHARED
${CMAKE_CURRENT_SOURCE_DIR}/exception.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils.cc
)
@@ -14,7 +15,12 @@ target_compile_options(
$<$<CONFIG:Debug>:-O0 -g -DDEBUG>
$<$<CONFIG:Release>:-O3 -DNDEBUG>
)
target_link_libraries(
utils
PRIVATE
CUDA::cudart
CUDA::cuda_driver
)
tvm_ffi_configure_target(
utils
STUB_DIR "${CMAKE_SOURCE_DIR}/python"
+13
View File
@@ -0,0 +1,13 @@
#include "exception.h"
namespace triton_tvm_ffi {
CUDAException::CUDAException(CUresult code) : code(code) {}
const char *CUDAException::what() const noexcept {
const char *p = nullptr;
cuGetErrorString(code, &p);
return p;
}
} // namespace triton_tvm_ffi
+45 -2
View File
@@ -1,8 +1,51 @@
#include "exception.h"
#include <cuda.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/tvm_ffi.h>
void hello() { std::cout << "Hello, world!\n"; }
#define CUDA_CHECK(code) \
do { \
if ((code) != CUDA_SUCCESS) { \
throw triton_tvm_ffi::CUDAException(code); \
} \
} while (false)
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
tvm::ffi::cuda_api::DeviceHandle device;
CUDA_CHECK(cuDeviceGet(&device, device_id));
int max_shared_mem = 0;
int max_num_regs = 0;
int multiprocessor_count = 0;
int warp_size = 0;
int sm_clock_rate = 0;
int mem_clock_rate = 0;
int mem_bus_width = 0;
CUDA_CHECK(cuDeviceGetAttribute(
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
&max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
CUDA_CHECK(cuDeviceGetAttribute(
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(
cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return {{"max_shared_mem", max_shared_mem},
{"max_num_regs", max_num_regs},
{"multiprocessor_count", multiprocessor_count},
{"warp_size", warp_size},
{"sm_clock_rate", sm_clock_rate},
{"mem_clock_rate", mem_clock_rate},
{"mem_bus_width", mem_bus_width}};
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("triton_tvm_ffi.utils.hello", hello);
refl::GlobalDef().def("triton_tvm_ffi.utils.get_device_properties",
GetDeviceProperties);
}