from __future__ import annotations from functools import cached_property import os from typing import Any, Final, List, Type import jinja2 from triton.backends.nvidia.driver import CudaDriver from triton.runtime import _allocation import tvm_ffi from . import TVMFFILauncherImpl, utils, string_to_type, type_to_ctype class TVMLauncher(object): def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher: super().__init__(*args, **kwargs) self.signature: List[str] = [*src.signature.values()] self.num_ctas: Final[int] = getattr(metadata, "num_ctas", 1) self.global_scratch_size: Final[int] = metadata.global_scratch_size self.global_scratch_align: Final[int] = metadata.global_scratch_align self.profile_scratch_size: Final[int] = metadata.profile_scratch_size self.profile_scratch_align: Final[int] = metadata.profile_scratch_align self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid self.launch_pdl: Final[bool] = metadata.launch_pdl if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}: mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline( "launch", cpp_sources=[self.codegen], extra_ldflags=["-Wl,--no-as-needed", "-lcuda"], extra_include_paths=[ f"{tvm_ffi.cpp.extension._find_cuda_home()}/include" ], ) launch: tvm_ffi.Function = mod.get_function("launch") self.launch = launch else: self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl( [string_to_type(t) for t in self.signature], self.launch_cooperative_grid, self.launch_pdl, ) self.launch = ( lambda grid_x, grid_y, grid_z, stream, function, num_warps, num_ctas, shared_memory, global_scratch, profile_scratch, *args: self.impl.launch( grid_x, grid_y, grid_z, stream, function, num_warps, num_ctas, shared_memory, global_scratch, profile_scratch, args, ) ) def __call__( self, gridX, gridY, gridZ, stream, function, kernel_metadata, launch_metadata, launch_enter_hook, launch_exit_hook, *args, ): def allocate_scratch(size, align, allocator): if size > 0: grid_size = gridX * gridY * gridZ alloc_size = grid_size * self.num_ctas * size alloc_fn = allocator.get() return alloc_fn(alloc_size, align, stream) return None global_scratch = allocate_scratch( self.global_scratch_size, self.global_scratch_align, _allocation._allocator ) profile_scratch = allocate_scratch( self.profile_scratch_size, self.profile_scratch_align, _allocation._profile_allocator, ) def canonicalize(obj: Any) -> int: if obj is None: return 0 elif isinstance(obj, int): return obj elif get_ptr := getattr(obj, "data_ptr", None): return get_ptr() else: raise TypeError(f"cannot canonicalize object of type {type(obj)}") (num_warps, num_ctas, shared_memory) = kernel_metadata if launch_enter_hook: launch_enter_hook(launch_metadata) ret = self.launch( gridX, gridY, gridZ, stream, function, num_warps, num_ctas, shared_memory, canonicalize(global_scratch), canonicalize(profile_scratch), *args, ) if launch_exit_hook: launch_exit_hook(launch_metadata) return ret @cached_property def codegen(self) -> str: env: Final[jinja2.Environment] = jinja2.Environment( loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"), trim_blocks=True, lstrip_blocks=True, ) template = env.get_template("launch.c.j2") signature = list( filter( lambda t: t != "void", map(lambda t: type_to_ctype(string_to_type(t)), self.signature), ) ) html = template.render(signature=signature) return html class TVMFFIDriver(CudaDriver): def __init__(self, *args, **kwargs) -> TVMFFIDriver: super().__init__(*args, **kwargs) self.utils = utils self.launcher_cls: Type[TVMLauncher] = TVMLauncher del CudaDriver