mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
put void branch in template
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -56,11 +56,7 @@ class TVMLauncher(object):
|
|||||||
shared_memory,
|
shared_memory,
|
||||||
global_scratch,
|
global_scratch,
|
||||||
profile_scratch,
|
profile_scratch,
|
||||||
*(
|
*args,
|
||||||
arg
|
|
||||||
for arg, type in zip(args, self.signature)
|
|
||||||
if type != "constexpr"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -164,11 +160,8 @@ class TVMLauncher(object):
|
|||||||
)
|
)
|
||||||
template: jinja2.Template = env.get_template("launch.c.j2")
|
template: jinja2.Template = env.get_template("launch.c.j2")
|
||||||
signature: List[int] = list(
|
signature: List[int] = list(
|
||||||
filter(
|
|
||||||
lambda t: t != "void",
|
|
||||||
map(lambda t: type_to_ctype(string_to_type(t)), self.signature),
|
map(lambda t: type_to_ctype(string_to_type(t)), self.signature),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
html: str = template.render(signature=signature)
|
html: str = template.render(signature=signature)
|
||||||
return html
|
return html
|
||||||
|
|
||||||
|
|||||||
@@ -53,11 +53,12 @@ extern "C"
|
|||||||
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
||||||
{% elif type == "int32_t" %}
|
{% elif type == "int32_t" %}
|
||||||
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
|
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
|
||||||
|
{% elif type == "void" %}
|
||||||
{% else %}
|
{% else %}
|
||||||
assert(false, "unsupported type yet {{ type }}");
|
assert(false, "unsupported type yet {{ type }}");
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
void *params[] = { {% for type in signature %} &arg{{ loop.index0 }}, {% endfor %}&globalScratch, &profileScratch };
|
void *params[] = { {% for type in signature %} {% if type != "void" %} &arg{{ loop.index0 }}, {% endif %} {% endfor %}&globalScratch, &profileScratch };
|
||||||
cuLaunchKernelEx(&config, function, params, NULL);
|
cuLaunchKernelEx(&config, function, params, NULL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user