mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
@@ -14,10 +14,8 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
|
||||
int32_t numWarps = args[5].v_int64;
|
||||
int32_t numCtas = args[6].v_int64;
|
||||
int32_t sharedMemory = args[7].v_int64;
|
||||
// TODO: Implement the launch logic
|
||||
CUdeviceptr globalScratch = 0;
|
||||
// TODO: check `profileScratchObject`
|
||||
CUdeviceptr profileScratch = 0;
|
||||
uint64_t globalScratch = args[8].v_uint64;
|
||||
uint64_t profileScratch = args[9].v_uint64;
|
||||
if (gridX * gridY * gridZ > 0) {
|
||||
CUlaunchAttribute launchAttr[4];
|
||||
CUlaunchConfig config;
|
||||
@@ -51,9 +49,9 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
|
||||
config.numAttrs = numAttrs;
|
||||
{% for type in signature %}
|
||||
{% if type == "void *" %}
|
||||
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
||||
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
||||
{% elif type == "int32_t" %}
|
||||
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64;
|
||||
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
|
||||
{% else %}
|
||||
assert(false, "unsupported type yet {{ type }}");
|
||||
{% endif %}
|
||||
|
||||
Reference in New Issue
Block a user