diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index d761656..686a1a8 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -56,11 +56,7 @@ class TVMLauncher(object): shared_memory, global_scratch, profile_scratch, - *( - arg - for arg, type in zip(args, self.signature) - if type != "constexpr" - ), + *args, ) ) else: @@ -164,10 +160,7 @@ class TVMLauncher(object): ) template: jinja2.Template = env.get_template("launch.c.j2") signature: List[int] = list( - filter( - lambda t: t != "void", map(lambda t: type_to_ctype(string_to_type(t)), self.signature), - ) ) html: str = template.render(signature=signature) return html diff --git a/python/triton_tvm_ffi/templates/launch.c.j2 b/python/triton_tvm_ffi/templates/launch.c.j2 index 53caf94..0b20889 100644 --- a/python/triton_tvm_ffi/templates/launch.c.j2 +++ b/python/triton_tvm_ffi/templates/launch.c.j2 @@ -53,11 +53,12 @@ extern "C" {{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data; {% elif type == "int32_t" %} {{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64; + {% elif type == "void" %} {% else %} assert(false, "unsupported type yet {{ type }}"); {% endif %} {% endfor %} - void *params[] = { {% for type in signature %} &arg{{ loop.index0 }}, {% endfor %}&globalScratch, &profileScratch }; + void *params[] = { {% for type in signature %} {% if type != "void" %} &arg{{ loop.index0 }}, {% endif %} {% endfor %}&globalScratch, &profileScratch }; cuLaunchKernelEx(&config, function, params, NULL); } }