unify launch apis

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-31 11:29:32 +08:00
parent ac7497b2c8
commit e9576d265e
5 changed files with 59 additions and 89 deletions
+4 -6
View File
@@ -14,10 +14,8 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
int32_t numWarps = args[5].v_int64;
int32_t numCtas = args[6].v_int64;
int32_t sharedMemory = args[7].v_int64;
// TODO: Implement the launch logic
CUdeviceptr globalScratch = 0;
// TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0;
uint64_t globalScratch = args[8].v_uint64;
uint64_t profileScratch = args[9].v_uint64;
if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4];
CUlaunchConfig config;
@@ -51,9 +49,9 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
config.numAttrs = numAttrs;
{% for type in signature %}
{% if type == "void *" %}
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data;
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
{% elif type == "int32_t" %}
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64;
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
{% else %}
assert(false, "unsupported type yet {{ type }}");
{% endif %}