#include #include #include #ifdef __cplusplus extern "C" #endif TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, int32_t num_args, TVMFFIAny *result) { int32_t gridX = args[0].v_int64; int32_t gridY = args[1].v_int64; int32_t gridZ = args[2].v_int64; CUstream stream = (CUstream)args[3].v_uint64; CUfunction function = (CUfunction)args[4].v_uint64; int32_t numWarps = args[5].v_int64; int32_t numCtas = args[6].v_int64; int32_t sharedMemory = args[7].v_int64; uint64_t globalScratch = args[8].v_uint64; uint64_t profileScratch = args[9].v_uint64; if (gridX * gridY * gridZ > 0) { CUlaunchAttribute launchAttr[4]; CUlaunchConfig config; config.gridDimX = gridX * numCtas; config.gridDimY = gridY; config.gridDimZ = gridZ; config.blockDimX = 32 * numWarps; config.blockDimY = 1; config.blockDimZ = 1; config.sharedMemBytes = sharedMemory; config.hStream = stream; config.attrs = launchAttr; int32_t numAttrs = 0; // TODO: check `launchPdl` // TODO: check `launchCooperativeGrid` if (numCtas != 1) { CUlaunchAttribute clusterAttr; clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; clusterAttr.value.clusterDim.x = numCtas; clusterAttr.value.clusterDim.y = 1; clusterAttr.value.clusterDim.z = 1; launchAttr[numAttrs++] = clusterAttr; CUlaunchAttribute clusterSchedulingAttr; clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; launchAttr[numAttrs++] = clusterSchedulingAttr; } config.numAttrs = numAttrs; {% for type in signature %} {% if type == "void *" %} {{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data; {% elif type == "int32_t" %} {{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64; {% else %} assert(false, "unsupported type yet {{ type }}"); {% endif %} {% endfor %} void *params[] = { {% for type in signature %} &arg{{ loop.index0 }}, {% endfor %}&globalScratch, &profileScratch }; cuLaunchKernelEx(&config, function, params, NULL); } }