mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
+1
-106
@@ -1,6 +1,5 @@
|
||||
#include "exception.h"
|
||||
#include "macro.h"
|
||||
#include "type.h"
|
||||
#include "value.h"
|
||||
#include <cuda.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/reflection/registry.h>
|
||||
@@ -9,13 +8,6 @@
|
||||
#include <cassert>
|
||||
#endif
|
||||
|
||||
#define CUDA_CHECK(code) \
|
||||
do { \
|
||||
if (__builtin_expect((code) != CUDA_SUCCESS, 0)) { \
|
||||
throw triton_tvm_ffi::CUDAException(code); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
|
||||
@@ -52,102 +44,6 @@ tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
|
||||
{"mem_bus_width", memBusWidth}};
|
||||
}
|
||||
|
||||
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function,
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
||||
tvm::ffi::ObjectRef launchMetadata,
|
||||
tvm::ffi::ObjectRef launchEnterHook,
|
||||
tvm::ffi::ObjectRef launchExitHook, bool launchCooperativeGrid,
|
||||
bool launchPdl, tvm::ffi::ObjectRef globalScratchObject,
|
||||
tvm::ffi::ObjectRef profileScratchObject,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) {
|
||||
CUstream cStream = reinterpret_cast<CUstream>(stream);
|
||||
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
|
||||
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
|
||||
// TODO: Implement the launch logic
|
||||
CUdeviceptr globalScratch = 0;
|
||||
// TODO: check `profileScratchObject`
|
||||
CUdeviceptr profileScratch = 0;
|
||||
if (gridX * gridY * gridZ > 0) {
|
||||
CUlaunchAttribute launchAttr[4];
|
||||
CUlaunchConfig config;
|
||||
config.gridDimX = gridX * numCtas;
|
||||
config.gridDimY = gridY;
|
||||
config.gridDimZ = gridZ;
|
||||
static constexpr int32_t kThreadsPerWarp = 32;
|
||||
config.blockDimX = kThreadsPerWarp * numWarps;
|
||||
config.blockDimY = 1;
|
||||
config.blockDimZ = 1;
|
||||
config.sharedMemBytes = sharedMemory;
|
||||
config.hStream = cStream;
|
||||
config.attrs = launchAttr;
|
||||
int32_t numAttrs = 0;
|
||||
// TODO: check `launchPdl`
|
||||
// TODO: check `launchCooperativeGrid`
|
||||
if (numCtas != 1) {
|
||||
CUlaunchAttribute clusterAttr;
|
||||
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
||||
clusterAttr.value.clusterDim.x = numCtas;
|
||||
clusterAttr.value.clusterDim.y = 1;
|
||||
clusterAttr.value.clusterDim.z = 1;
|
||||
launchAttr[numAttrs++] = clusterAttr;
|
||||
CUlaunchAttribute clusterSchedulingAttr;
|
||||
clusterSchedulingAttr.id =
|
||||
CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
||||
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference =
|
||||
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
||||
launchAttr[numAttrs++] = clusterSchedulingAttr;
|
||||
}
|
||||
config.numAttrs = numAttrs;
|
||||
if (numCtas == 16) {
|
||||
CUDA_CHECK(cuFuncSetAttribute(
|
||||
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
||||
}
|
||||
const int32_t kernelArgNum = kernelArgs.size();
|
||||
void **params =
|
||||
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < kernelArgNum; ++i) {
|
||||
TypedValue value = kernelArgs[i].cast<TypedValue>();
|
||||
switch (value.GetType()) {
|
||||
#define CASE_STMT(type, str, ctype) \
|
||||
case Type::type: { \
|
||||
using cpptype = type_to_ctype_t<Type::type>; \
|
||||
params[j] = reinterpret_cast<void *>(alloca(sizeof(cpptype))); \
|
||||
*reinterpret_cast<cpptype *>(params[j]) = \
|
||||
value.GetValue().cast<cpptype>(); \
|
||||
++j; \
|
||||
break; \
|
||||
}
|
||||
TYPE_TABLE_NATIVE(CASE_STMT)
|
||||
#undef CASE_STMT
|
||||
case Type::PTR: {
|
||||
params[j] = reinterpret_cast<void *>(alloca(sizeof(void *)));
|
||||
*reinterpret_cast<void **>(params[j]) =
|
||||
value.GetValue().cast<tvm::ffi::TensorView>().data_ptr();
|
||||
++j;
|
||||
break;
|
||||
}
|
||||
case Type::CONSTEXPR: {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
#ifdef NDEBUG
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw NotImplementedException("CONSTEXPR for value casting");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
|
||||
params[j] = &globalScratch;
|
||||
params[j + 1] = &profileScratch;
|
||||
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
|
||||
}
|
||||
// TODO: call `launchExitHook`
|
||||
}
|
||||
|
||||
tvm::ffi::Tuple<uint64_t, uint64_t, int32_t, int32_t, int32_t>
|
||||
LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data,
|
||||
int32_t shared, CUdevice device) {
|
||||
@@ -212,7 +108,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
throw NotImplementedException("set_printf_fifo_size");
|
||||
})
|
||||
.def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties)
|
||||
.def("triton_tvm_ffi.utils.launch", Launch)
|
||||
.def("triton_tvm_ffi.utils.load_binary", LoadBinary);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user