put typedvalues initialization into cpp

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-30 01:38:58 +08:00
parent bdc9c03b75
commit a953cbe7cc
10 changed files with 92 additions and 47 deletions
+18 -29
View File
@@ -107,30 +107,19 @@ tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
{"mem_bus_width", memBusWidth}};
}
void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
CUtensorMap x;
int32_t gridX = args[0].cast<int32_t>();
int32_t gridY = args[1].cast<int32_t>();
int32_t gridZ = args[2].cast<int32_t>();
CUstream stream = reinterpret_cast<CUstream>(args[3].cast<uint64_t>());
CUfunction function = reinterpret_cast<CUfunction>(args[4].cast<uint64_t>());
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata =
args[5].cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
int32_t numWarps = kernelMetadata.get<0>();
int32_t numCtas = kernelMetadata.get<1>();
int32_t sharedMemory = kernelMetadata.get<2>();
tvm::ffi::ObjectRef launchMetadata = args[6].cast<tvm::ffi::ObjectRef>();
tvm::ffi::ObjectRef launchEnterHook = args[7].cast<tvm::ffi::ObjectRef>();
tvm::ffi::ObjectRef launchExitHook = args[8].cast<tvm::ffi::ObjectRef>();
bool launchCooperativeGrid = args[9].cast<bool>();
bool launchPdl = args[10].cast<bool>();
tvm::ffi::ObjectRef globalScratchObject =
args[11].cast<tvm::ffi::ObjectRef>();
tvm::ffi::ObjectRef profileScratchObject =
args[12].cast<tvm::ffi::ObjectRef>();
tvm::ffi::PackedArgs kernelArgs = args.Slice(13);
// TODO: call `launchEnterHook`
// TODO: check `globalScratchObject`
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata,
tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook, bool launchCooperativeGrid,
bool launchPdl, tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) {
CUstream cStream = reinterpret_cast<CUstream>(stream);
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
// TODO: Implement the launch logic
CUdeviceptr globalScratch = 0;
// TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0;
@@ -145,10 +134,10 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
config.blockDimY = 1;
config.blockDimZ = 1;
config.sharedMemBytes = sharedMemory;
config.hStream = stream;
config.hStream = cStream;
config.attrs = launchAttr;
int32_t numAttrs = 0;
// TODO: check `launchPdf`
// TODO: check `launchPdl`
// TODO: check `launchCooperativeGrid`
if (numCtas != 1) {
CUlaunchAttribute clusterAttr;
@@ -167,7 +156,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
config.numAttrs = numAttrs;
if (numCtas == 16) {
CUDA_CHECK(cuFuncSetAttribute(
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
}
const int32_t kernelArgNum = kernelArgs.size();
uint8_t *buffer =
@@ -192,7 +181,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
params[j] = &globalScratch;
params[j + 1] = &profileScratch;
CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr));
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
}
// TODO: call `launchExitHook`
}
@@ -261,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
throw NotImplementedException("set_printf_fifo_size");
})
.def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties)
.def_packed("triton_tvm_ffi.utils.launch", Launch)
.def("triton_tvm_ffi.utils.launch", Launch)
.def("triton_tvm_ffi.utils.load_binary", LoadBinary);
}