mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
support GetDeviceProperties
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
+45
-2
@@ -1,8 +1,51 @@
|
||||
#include "exception.h"
|
||||
#include <cuda.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
void hello() { std::cout << "Hello, world!\n"; }
|
||||
#define CUDA_CHECK(code) \
|
||||
do { \
|
||||
if ((code) != CUDA_SUCCESS) { \
|
||||
throw triton_tvm_ffi::CUDAException(code); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
|
||||
tvm::ffi::cuda_api::DeviceHandle device;
|
||||
CUDA_CHECK(cuDeviceGet(&device, device_id));
|
||||
int max_shared_mem = 0;
|
||||
int max_num_regs = 0;
|
||||
int multiprocessor_count = 0;
|
||||
int warp_size = 0;
|
||||
int sm_clock_rate = 0;
|
||||
int mem_clock_rate = 0;
|
||||
int mem_bus_width = 0;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||
CUDA_CHECK(
|
||||
cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate,
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||
return {{"max_shared_mem", max_shared_mem},
|
||||
{"max_num_regs", max_num_regs},
|
||||
{"multiprocessor_count", multiprocessor_count},
|
||||
{"warp_size", warp_size},
|
||||
{"sm_clock_rate", sm_clock_rate},
|
||||
{"mem_clock_rate", mem_clock_rate},
|
||||
{"mem_bus_width", mem_bus_width}};
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef().def("triton_tvm_ffi.utils.hello", hello);
|
||||
refl::GlobalDef().def("triton_tvm_ffi.utils.get_device_properties",
|
||||
GetDeviceProperties);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user