mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
support GetDeviceProperties
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -11,6 +11,7 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
|||||||
else(CMAKE_BUILD_TYPE STREQUAL "Release")
|
else(CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
find_package(CUDAToolkit REQUIRED)
|
||||||
find_package(Python COMPONENTS Interpreter REQUIRED)
|
find_package(Python COMPONENTS Interpreter REQUIRED)
|
||||||
|
|
||||||
execute_process(
|
execute_process(
|
||||||
|
|||||||
@@ -9,5 +9,5 @@ SKBUILD_BUILD_DIR="build" SKBUILD_CMAKE_BUILD_TYPE=Debug uv pip install --no-bui
|
|||||||
|
|
||||||
### Format
|
### Format
|
||||||
```bash
|
```bash
|
||||||
find python -name "*.py" | xargs ruff format
|
find python -name "*.py" | xargs ruff format && find include src -name "*.h" -o -name "*.cc" | xargs clang-format -i
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
#ifndef TRITON_TVM_FFI_EXCEPTION_H_
|
||||||
|
#define TRITON_TVM_FFI_EXCEPTION_H_
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <exception>
|
||||||
|
|
||||||
|
namespace triton_tvm_ffi {
|
||||||
|
|
||||||
|
class CUDAException : public std::exception {
|
||||||
|
public:
|
||||||
|
CUDAException(CUresult code);
|
||||||
|
const char *what() const noexcept override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const CUresult code;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace triton_tvm_ffi
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Type
|
from typing import Mapping, Type
|
||||||
from triton.backends.nvidia.driver import CudaDriver
|
from triton.backends.nvidia.driver import CudaDriver
|
||||||
|
from .utils import get_device_properties
|
||||||
|
|
||||||
|
|
||||||
class TVMFFIUtils(object):
|
class TVMFFIUtils(object):
|
||||||
@@ -19,8 +20,8 @@ class TVMFFIUtils(object):
|
|||||||
def load_binary(self, *args, **kwargs):
|
def load_binary(self, *args, **kwargs):
|
||||||
return self._utils.load_binary(*args, **kwargs)
|
return self._utils.load_binary(*args, **kwargs)
|
||||||
|
|
||||||
def get_device_properties(self, *args, **kwargs):
|
def get_device_properties(self, device_id: int) -> Mapping[str, int]:
|
||||||
return self._utils.get_device_properties(*args, **kwargs)
|
return get_device_properties(device_id)
|
||||||
|
|
||||||
def cuOccupancyMaxActiveClusters(self, *args, **kwargs):
|
def cuOccupancyMaxActiveClusters(self, *args, **kwargs):
|
||||||
return self._utils.cuOccupancyMaxActiveClusters(*args, **kwargs)
|
return self._utils.cuOccupancyMaxActiveClusters(*args, **kwargs)
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from __future__ import annotations
|
|||||||
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
|
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
|
||||||
from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB
|
from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Mapping
|
||||||
# isort: on
|
# isort: on
|
||||||
# fmt: on
|
# fmt: on
|
||||||
# tvm-ffi-stubgen(end)
|
# tvm-ffi-stubgen(end)
|
||||||
@@ -14,13 +16,13 @@ LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "utils")
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
_FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__)
|
_FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__)
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
def hello() -> None: ...
|
def get_device_properties(_0: int, /) -> Mapping[str, int]: ...
|
||||||
# fmt: on
|
# fmt: on
|
||||||
# tvm-ffi-stubgen(end)
|
# tvm-ffi-stubgen(end)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# tvm-ffi-stubgen(begin): __all__
|
# tvm-ffi-stubgen(begin): __all__
|
||||||
"LIB",
|
"LIB",
|
||||||
"hello",
|
"get_device_properties",
|
||||||
# tvm-ffi-stubgen(end)
|
# tvm-ffi-stubgen(end)
|
||||||
]
|
]
|
||||||
|
|||||||
+7
-1
@@ -1,6 +1,7 @@
|
|||||||
add_library(
|
add_library(
|
||||||
utils
|
utils
|
||||||
SHARED
|
SHARED
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/exception.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,7 +15,12 @@ target_compile_options(
|
|||||||
$<$<CONFIG:Debug>:-O0 -g -DDEBUG>
|
$<$<CONFIG:Debug>:-O0 -g -DDEBUG>
|
||||||
$<$<CONFIG:Release>:-O3 -DNDEBUG>
|
$<$<CONFIG:Release>:-O3 -DNDEBUG>
|
||||||
)
|
)
|
||||||
|
target_link_libraries(
|
||||||
|
utils
|
||||||
|
PRIVATE
|
||||||
|
CUDA::cudart
|
||||||
|
CUDA::cuda_driver
|
||||||
|
)
|
||||||
tvm_ffi_configure_target(
|
tvm_ffi_configure_target(
|
||||||
utils
|
utils
|
||||||
STUB_DIR "${CMAKE_SOURCE_DIR}/python"
|
STUB_DIR "${CMAKE_SOURCE_DIR}/python"
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
#include "exception.h"
|
||||||
|
|
||||||
|
namespace triton_tvm_ffi {
|
||||||
|
|
||||||
|
CUDAException::CUDAException(CUresult code) : code(code) {}
|
||||||
|
|
||||||
|
const char *CUDAException::what() const noexcept {
|
||||||
|
const char *p = nullptr;
|
||||||
|
cuGetErrorString(code, &p);
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace triton_tvm_ffi
|
||||||
+45
-2
@@ -1,8 +1,51 @@
|
|||||||
|
#include "exception.h"
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
#include <tvm/ffi/tvm_ffi.h>
|
#include <tvm/ffi/tvm_ffi.h>
|
||||||
|
|
||||||
void hello() { std::cout << "Hello, world!\n"; }
|
#define CUDA_CHECK(code) \
|
||||||
|
do { \
|
||||||
|
if ((code) != CUDA_SUCCESS) { \
|
||||||
|
throw triton_tvm_ffi::CUDAException(code); \
|
||||||
|
} \
|
||||||
|
} while (false)
|
||||||
|
|
||||||
|
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
|
||||||
|
tvm::ffi::cuda_api::DeviceHandle device;
|
||||||
|
CUDA_CHECK(cuDeviceGet(&device, device_id));
|
||||||
|
int max_shared_mem = 0;
|
||||||
|
int max_num_regs = 0;
|
||||||
|
int multiprocessor_count = 0;
|
||||||
|
int warp_size = 0;
|
||||||
|
int sm_clock_rate = 0;
|
||||||
|
int mem_clock_rate = 0;
|
||||||
|
int mem_bus_width = 0;
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(
|
||||||
|
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||||
|
device));
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(
|
||||||
|
&max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(
|
||||||
|
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||||
|
CUDA_CHECK(
|
||||||
|
cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
|
||||||
|
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(
|
||||||
|
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||||
|
CUDA_CHECK(cuDeviceGetAttribute(
|
||||||
|
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||||
|
return {{"max_shared_mem", max_shared_mem},
|
||||||
|
{"max_num_regs", max_num_regs},
|
||||||
|
{"multiprocessor_count", multiprocessor_count},
|
||||||
|
{"warp_size", warp_size},
|
||||||
|
{"sm_clock_rate", sm_clock_rate},
|
||||||
|
{"mem_clock_rate", mem_clock_rate},
|
||||||
|
{"mem_bus_width", mem_bus_width}};
|
||||||
|
}
|
||||||
|
|
||||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||||
namespace refl = tvm::ffi::reflection;
|
namespace refl = tvm::ffi::reflection;
|
||||||
refl::GlobalDef().def("triton_tvm_ffi.utils.hello", hello);
|
refl::GlobalDef().def("triton_tvm_ffi.utils.get_device_properties",
|
||||||
|
GetDeviceProperties);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user