mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
put typedvalues initialization into cpp
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
+18
-29
@@ -107,30 +107,19 @@ tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
|
||||
{"mem_bus_width", memBusWidth}};
|
||||
}
|
||||
|
||||
void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
CUtensorMap x;
|
||||
int32_t gridX = args[0].cast<int32_t>();
|
||||
int32_t gridY = args[1].cast<int32_t>();
|
||||
int32_t gridZ = args[2].cast<int32_t>();
|
||||
CUstream stream = reinterpret_cast<CUstream>(args[3].cast<uint64_t>());
|
||||
CUfunction function = reinterpret_cast<CUfunction>(args[4].cast<uint64_t>());
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata =
|
||||
args[5].cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
|
||||
int32_t numWarps = kernelMetadata.get<0>();
|
||||
int32_t numCtas = kernelMetadata.get<1>();
|
||||
int32_t sharedMemory = kernelMetadata.get<2>();
|
||||
tvm::ffi::ObjectRef launchMetadata = args[6].cast<tvm::ffi::ObjectRef>();
|
||||
tvm::ffi::ObjectRef launchEnterHook = args[7].cast<tvm::ffi::ObjectRef>();
|
||||
tvm::ffi::ObjectRef launchExitHook = args[8].cast<tvm::ffi::ObjectRef>();
|
||||
bool launchCooperativeGrid = args[9].cast<bool>();
|
||||
bool launchPdl = args[10].cast<bool>();
|
||||
tvm::ffi::ObjectRef globalScratchObject =
|
||||
args[11].cast<tvm::ffi::ObjectRef>();
|
||||
tvm::ffi::ObjectRef profileScratchObject =
|
||||
args[12].cast<tvm::ffi::ObjectRef>();
|
||||
tvm::ffi::PackedArgs kernelArgs = args.Slice(13);
|
||||
// TODO: call `launchEnterHook`
|
||||
// TODO: check `globalScratchObject`
|
||||
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function,
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
||||
tvm::ffi::ObjectRef launchMetadata,
|
||||
tvm::ffi::ObjectRef launchEnterHook,
|
||||
tvm::ffi::ObjectRef launchExitHook, bool launchCooperativeGrid,
|
||||
bool launchPdl, tvm::ffi::ObjectRef globalScratchObject,
|
||||
tvm::ffi::ObjectRef profileScratchObject,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) {
|
||||
CUstream cStream = reinterpret_cast<CUstream>(stream);
|
||||
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
|
||||
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
|
||||
// TODO: Implement the launch logic
|
||||
CUdeviceptr globalScratch = 0;
|
||||
// TODO: check `profileScratchObject`
|
||||
CUdeviceptr profileScratch = 0;
|
||||
@@ -145,10 +134,10 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
config.blockDimY = 1;
|
||||
config.blockDimZ = 1;
|
||||
config.sharedMemBytes = sharedMemory;
|
||||
config.hStream = stream;
|
||||
config.hStream = cStream;
|
||||
config.attrs = launchAttr;
|
||||
int32_t numAttrs = 0;
|
||||
// TODO: check `launchPdf`
|
||||
// TODO: check `launchPdl`
|
||||
// TODO: check `launchCooperativeGrid`
|
||||
if (numCtas != 1) {
|
||||
CUlaunchAttribute clusterAttr;
|
||||
@@ -167,7 +156,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
config.numAttrs = numAttrs;
|
||||
if (numCtas == 16) {
|
||||
CUDA_CHECK(cuFuncSetAttribute(
|
||||
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
||||
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
||||
}
|
||||
const int32_t kernelArgNum = kernelArgs.size();
|
||||
uint8_t *buffer =
|
||||
@@ -192,7 +181,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
|
||||
params[j] = &globalScratch;
|
||||
params[j + 1] = &profileScratch;
|
||||
CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr));
|
||||
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
|
||||
}
|
||||
// TODO: call `launchExitHook`
|
||||
}
|
||||
@@ -261,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
throw NotImplementedException("set_printf_fifo_size");
|
||||
})
|
||||
.def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties)
|
||||
.def_packed("triton_tvm_ffi.utils.launch", Launch)
|
||||
.def("triton_tvm_ffi.utils.launch", Launch)
|
||||
.def("triton_tvm_ffi.utils.load_binary", LoadBinary);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user