mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
fix bugs on failure of mm example
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -34,7 +34,35 @@ class TVMLauncher(object):
|
||||
],
|
||||
)
|
||||
launch: tvm_ffi.Function = mod.get_function("launch")
|
||||
self.launch = launch
|
||||
self.launch = (
|
||||
lambda grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*args: launch(
|
||||
grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*(
|
||||
arg
|
||||
for arg, type in zip(args, self.signature)
|
||||
if type != "constexpr"
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
|
||||
[string_to_type(t) for t in self.signature],
|
||||
@@ -129,19 +157,19 @@ class TVMLauncher(object):
|
||||
|
||||
@cached_property
|
||||
def codegen(self) -> str:
|
||||
env: Final[jinja2.Environment] = jinja2.Environment(
|
||||
env: jinja2.Environment = jinja2.Environment(
|
||||
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
template = env.get_template("launch.c.j2")
|
||||
signature = list(
|
||||
template: jinja2.Template = env.get_template("launch.c.j2")
|
||||
signature: List[int] = list(
|
||||
filter(
|
||||
lambda t: t != "void",
|
||||
map(lambda t: type_to_ctype(string_to_type(t)), self.signature),
|
||||
)
|
||||
)
|
||||
html = template.render(signature=signature)
|
||||
html: str = template.render(signature=signature)
|
||||
return html
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user