unify launch apis

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-31 11:29:32 +08:00
parent ac7497b2c8
commit e9576d265e
5 changed files with 59 additions and 89 deletions
+40 -46
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from functools import cached_property
import os
from typing import Any, Callable, Final, List, Sequence, Type
from typing import Any, Final, List, Type
import jinja2
from triton.backends.nvidia.driver import CudaDriver
@@ -24,19 +24,16 @@ class TVMLauncher(object):
self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
self.launch_pdl: Final[bool] = metadata.launch_pdl
self.enable_jit: Final[bool] = (
os.getenv("TRITON_TVM_FFI_ENABLE_JIT", None) is not None
)
if self.enable_jit:
mod = tvm_ffi.cpp.load_inline(
if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}:
mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline(
"launch",
cpp_sources=self.codegen,
cpp_sources=[self.codegen],
extra_ldflags=["-Wl,--no-as-needed", "-lcuda"],
extra_include_paths=[
f"{tvm_ffi.cpp.extension._find_cuda_home()}/include"
],
)
launch = mod.get_function("launch")
launch: tvm_ffi.Function = mod.get_function("launch")
self.launch = launch
else:
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
@@ -50,10 +47,9 @@ class TVMLauncher(object):
grid_z,
stream,
function,
kernel_metadata,
launch_metadata,
launch_enter_hook,
launch_exit_hook,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
*args: self.impl.launch(
@@ -62,10 +58,9 @@ class TVMLauncher(object):
grid_z,
stream,
function,
kernel_metadata,
launch_metadata,
launch_enter_hook,
launch_exit_hook,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
args,
@@ -101,37 +96,36 @@ class TVMLauncher(object):
self.profile_scratch_align,
_allocation._profile_allocator,
)
assert not self.launch_cooperative_grid
assert not self.launch_pdl
if self.enable_jit:
(num_warps, num_ctas, shared_memory) = kernel_metadata
return self.launch(
gridX,
gridY,
gridZ,
stream,
function,
num_warps,
num_ctas,
shared_memory,
*args,
)
else:
return self.launch(
gridX,
gridY,
gridZ,
stream,
function,
kernel_metadata,
launch_metadata,
launch_enter_hook,
launch_exit_hook,
global_scratch,
profile_scratch,
*args,
)
def canonicalize(obj: Any) -> int:
if obj is None:
return 0
elif isinstance(obj, int):
return obj
elif get_ptr := getattr(obj, "data_ptr", None):
return get_ptr()
else:
raise TypeError(f"cannot canonicalize object of type {type(obj)}")
(num_warps, num_ctas, shared_memory) = kernel_metadata
if launch_enter_hook:
launch_enter_hook(launch_metadata)
ret = self.launch(
gridX,
gridY,
gridZ,
stream,
function,
num_warps,
num_ctas,
shared_memory,
canonicalize(global_scratch),
canonicalize(profile_scratch),
*args,
)
if launch_exit_hook:
launch_exit_hook(launch_metadata)
return ret
@cached_property
def codegen(self) -> str: