enable cjit launcher

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-31 10:36:41 +08:00
parent f0739b9dca
commit ac7497b2c8
7 changed files with 180 additions and 38 deletions
@@ -0,0 +1,65 @@
#include <cassert>
#include <cuda.h>
#include <tvm/ffi/tvm_ffi.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, int32_t num_args, TVMFFIAny *result) {
int32_t gridX = args[0].v_int64;
int32_t gridY = args[1].v_int64;
int32_t gridZ = args[2].v_int64;
CUstream stream = (CUstream)args[3].v_uint64;
CUfunction function = (CUfunction)args[4].v_uint64;
int32_t numWarps = args[5].v_int64;
int32_t numCtas = args[6].v_int64;
int32_t sharedMemory = args[7].v_int64;
// TODO: Implement the launch logic
CUdeviceptr globalScratch = 0;
// TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0;
if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4];
CUlaunchConfig config;
config.gridDimX = gridX * numCtas;
config.gridDimY = gridY;
config.gridDimZ = gridZ;
static constexpr int32_t kThreadsPerWarp = 32;
config.blockDimX = kThreadsPerWarp * numWarps;
config.blockDimY = 1;
config.blockDimZ = 1;
config.sharedMemBytes = sharedMemory;
config.hStream = stream;
config.attrs = launchAttr;
int32_t numAttrs = 0;
// TODO: check `launchPdl`
// TODO: check `launchCooperativeGrid`
if (numCtas != 1) {
CUlaunchAttribute clusterAttr;
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
clusterAttr.value.clusterDim.x = numCtas;
clusterAttr.value.clusterDim.y = 1;
clusterAttr.value.clusterDim.z = 1;
launchAttr[numAttrs++] = clusterAttr;
CUlaunchAttribute clusterSchedulingAttr;
clusterSchedulingAttr.id =
CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference =
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
launchAttr[numAttrs++] = clusterSchedulingAttr;
}
config.numAttrs = numAttrs;
{% for type in signature %}
{% if type == "void *" %}
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data;
{% elif type == "int32_t" %}
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64;
{% else %}
assert(false, "unsupported type yet {{ type }}");
{% endif %}
{% endfor %}
void *foo = NULL, *bar = NULL;
void *params[] = { {% for type in signature %} &arg{{ loop.index0 }}, {% endfor %}&foo, &bar };
cuLaunchKernelEx(&config, function, params, NULL);
}
}