Files
triton-tvm-ffi/python/triton_tvm_ffi/templates/launch.c.j2
T
2026-01-31 17:48:59 +08:00

64 lines
2.3 KiB
Django/Jinja

#include <assert.h>
#include <cuda.h>
#include <tvm/ffi/tvm_ffi.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_FFI_DLL_EXPORT void
__tvm_ffi_launch(void *handle, const TVMFFIAny *args, int32_t num_args,
TVMFFIAny *result) {
int32_t gridX = args[0].v_int64;
int32_t gridY = args[1].v_int64;
int32_t gridZ = args[2].v_int64;
CUstream stream = (CUstream)args[3].v_uint64;
CUfunction function = (CUfunction)args[4].v_uint64;
int32_t numWarps = args[5].v_int64;
int32_t numCtas = args[6].v_int64;
int32_t sharedMemory = args[7].v_int64;
uint64_t globalScratch = args[8].v_uint64;
uint64_t profileScratch = args[9].v_uint64;
if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4];
CUlaunchConfig config;
config.gridDimX = gridX * numCtas;
config.gridDimY = gridY;
config.gridDimZ = gridZ;
config.blockDimX = 32 * numWarps;
config.blockDimY = 1;
config.blockDimZ = 1;
config.sharedMemBytes = sharedMemory;
config.hStream = stream;
config.attrs = launchAttr;
int32_t numAttrs = 0;
// TODO: check `launchPdl`
// TODO: check `launchCooperativeGrid`
if (numCtas != 1) {
CUlaunchAttribute clusterAttr;
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
clusterAttr.value.clusterDim.x = numCtas;
clusterAttr.value.clusterDim.y = 1;
clusterAttr.value.clusterDim.z = 1;
launchAttr[numAttrs++] = clusterAttr;
CUlaunchAttribute clusterSchedulingAttr;
clusterSchedulingAttr.id =
CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference =
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
launchAttr[numAttrs++] = clusterSchedulingAttr;
}
config.numAttrs = numAttrs;
{% for type in signature %}
{% if type == "void *" %}
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
{% elif type == "int32_t" %}
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
{% else %}
assert(false, "unsupported type yet {{ type }}");
{% endif %}
{% endfor %}
void *params[] = { {% for type in signature %} &arg{{ loop.index0 }}, {% endfor %}&globalScratch, &profileScratch };
cuLaunchKernelEx(&config, function, params, NULL);
}
}