From 875145b3331a5f614d9a9491ca4300ca11bb8f95 Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Sat, 31 Jan 2026 17:48:59 +0800 Subject: [PATCH] fix bugs on failure of mm example Signed-off-by: Jinjie Liu --- python/triton_tvm_ffi/driver.py | 38 ++++++++++++++++++--- python/triton_tvm_ffi/templates/launch.c.j2 | 12 +++---- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index 1e224ae..d761656 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -34,7 +34,35 @@ class TVMLauncher(object): ], ) launch: tvm_ffi.Function = mod.get_function("launch") - self.launch = launch + self.launch = ( + lambda grid_x, + grid_y, + grid_z, + stream, + function, + num_warps, + num_ctas, + shared_memory, + global_scratch, + profile_scratch, + *args: launch( + grid_x, + grid_y, + grid_z, + stream, + function, + num_warps, + num_ctas, + shared_memory, + global_scratch, + profile_scratch, + *( + arg + for arg, type in zip(args, self.signature) + if type != "constexpr" + ), + ) + ) else: self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl( [string_to_type(t) for t in self.signature], @@ -129,19 +157,19 @@ class TVMLauncher(object): @cached_property def codegen(self) -> str: - env: Final[jinja2.Environment] = jinja2.Environment( + env: jinja2.Environment = jinja2.Environment( loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"), trim_blocks=True, lstrip_blocks=True, ) - template = env.get_template("launch.c.j2") - signature = list( + template: jinja2.Template = env.get_template("launch.c.j2") + signature: List[int] = list( filter( lambda t: t != "void", map(lambda t: type_to_ctype(string_to_type(t)), self.signature), ) ) - html = template.render(signature=signature) + html: str = template.render(signature=signature) return html diff --git a/python/triton_tvm_ffi/templates/launch.c.j2 b/python/triton_tvm_ffi/templates/launch.c.j2 index c318749..53caf94 100644 --- a/python/triton_tvm_ffi/templates/launch.c.j2 +++ b/python/triton_tvm_ffi/templates/launch.c.j2 @@ -1,11 +1,13 @@ -#include +#include #include #include #ifdef __cplusplus extern "C" #endif -TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, int32_t num_args, TVMFFIAny *result) { + TVM_FFI_DLL_EXPORT void + __tvm_ffi_launch(void *handle, const TVMFFIAny *args, int32_t num_args, + TVMFFIAny *result) { int32_t gridX = args[0].v_int64; int32_t gridY = args[1].v_int64; int32_t gridZ = args[2].v_int64; @@ -22,8 +24,7 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in config.gridDimX = gridX * numCtas; config.gridDimY = gridY; config.gridDimZ = gridZ; - static constexpr int32_t kThreadsPerWarp = 32; - config.blockDimX = kThreadsPerWarp * numWarps; + config.blockDimX = 32 * numWarps; config.blockDimY = 1; config.blockDimZ = 1; config.sharedMemBytes = sharedMemory; @@ -56,8 +57,7 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in assert(false, "unsupported type yet {{ type }}"); {% endif %} {% endfor %} - void *foo = NULL, *bar = NULL; - void *params[] = { {% for type in signature %} &arg{{ loop.index0 }}, {% endfor %}&foo, &bar }; + void *params[] = { {% for type in signature %} &arg{{ loop.index0 }}, {% endfor %}&globalScratch, &profileScratch }; cuLaunchKernelEx(&config, function, params, NULL); } }