fix bugs on failure of mm example

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-31 17:48:59 +08:00
parent e9576d265e
commit 875145b333
2 changed files with 39 additions and 11 deletions
+33 -5
View File
@@ -34,7 +34,35 @@ class TVMLauncher(object):
],
)
launch: tvm_ffi.Function = mod.get_function("launch")
self.launch = launch
self.launch = (
lambda grid_x,
grid_y,
grid_z,
stream,
function,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
*args: launch(
grid_x,
grid_y,
grid_z,
stream,
function,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
*(
arg
for arg, type in zip(args, self.signature)
if type != "constexpr"
),
)
)
else:
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
[string_to_type(t) for t in self.signature],
@@ -129,19 +157,19 @@ class TVMLauncher(object):
@cached_property
def codegen(self) -> str:
env: Final[jinja2.Environment] = jinja2.Environment(
env: jinja2.Environment = jinja2.Environment(
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
trim_blocks=True,
lstrip_blocks=True,
)
template = env.get_template("launch.c.j2")
signature = list(
template: jinja2.Template = env.get_template("launch.c.j2")
signature: List[int] = list(
filter(
lambda t: t != "void",
map(lambda t: type_to_ctype(string_to_type(t)), self.signature),
)
)
html = template.render(signature=signature)
html: str = template.render(signature=signature)
return html