support GetDeviceProperties

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-28 02:16:10 +08:00
parent ab83dded12
commit 3a485166b4
8 changed files with 95 additions and 9 deletions
+45 -2
View File
@@ -1,8 +1,51 @@
#include "exception.h"
#include <cuda.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/tvm_ffi.h>
void hello() { std::cout << "Hello, world!\n"; }
#define CUDA_CHECK(code) \
do { \
if ((code) != CUDA_SUCCESS) { \
throw triton_tvm_ffi::CUDAException(code); \
} \
} while (false)
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
tvm::ffi::cuda_api::DeviceHandle device;
CUDA_CHECK(cuDeviceGet(&device, device_id));
int max_shared_mem = 0;
int max_num_regs = 0;
int multiprocessor_count = 0;
int warp_size = 0;
int sm_clock_rate = 0;
int mem_clock_rate = 0;
int mem_bus_width = 0;
CUDA_CHECK(cuDeviceGetAttribute(
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
&max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
CUDA_CHECK(cuDeviceGetAttribute(
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(
cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return {{"max_shared_mem", max_shared_mem},
{"max_num_regs", max_num_regs},
{"multiprocessor_count", multiprocessor_count},
{"warp_size", warp_size},
{"sm_clock_rate", sm_clock_rate},
{"mem_clock_rate", mem_clock_rate},
{"mem_bus_width", mem_bus_width}};
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("triton_tvm_ffi.utils.hello", hello);
refl::GlobalDef().def("triton_tvm_ffi.utils.get_device_properties",
GetDeviceProperties);
}