From fecf48e403642672c202f8bbe7ef5907b015ac15 Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Thu, 29 Jan 2026 01:54:16 +0800 Subject: [PATCH] implement launch Signed-off-by: Jinjie Liu --- python/triton_tvm_ffi/driver.py | 78 +++++++++++++- python/triton_tvm_ffi/utils/_ffi_api.py | 3 + src/utils.cc | 137 ++++++++++++++++++++---- 3 files changed, 193 insertions(+), 25 deletions(-) diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index 88954ab..dababea 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -1,9 +1,9 @@ from __future__ import annotations from ctypes import c_void_p -from typing import Mapping, Tuple, Type +from typing import List, Mapping, Tuple, Type from triton.backends.nvidia.driver import CudaDriver -from .utils import get_device_properties, load_binary +from .utils import get_device_properties, launch, load_binary class TVMFFIUtils(object): @@ -49,10 +49,84 @@ class TVMFFIUtils(object): ) +class TVMLauncher(object): + def __init__(self, src: List[bool], metadata, *args, **kwargs) -> TVMLauncher: + super().__init__(*args, **kwargs) + + self.mask: List[bool] = [annotation != "constexpr" for annotation in src.signature.values()] + self.num_ctas = getattr(metadata, "num_ctas", 1) + self.launch = launch + self.global_scratch_size = metadata.global_scratch_size + self.global_scratch_align = metadata.global_scratch_align + self.profile_scratch_size = metadata.profile_scratch_size + self.profile_scratch_align = metadata.profile_scratch_align + self.launch_cooperative_grid = metadata.launch_cooperative_grid + self.launch_pdl = metadata.launch_pdl + + # We assume the global Triton allocator is not enabled: `_allocator` must be a NullAllocator. + # This module depends on NullAllocator behavior; ensure no other code replaces the allocator. + from triton.runtime._allocation import _allocator, NullAllocator + + assert isinstance(_allocator.get(), NullAllocator) + + def __call__( + self, + gridX, + gridY, + gridZ, + stream, + function, + kernel_metadata, + launch_metadata, + launch_enter_hook, + launch_exit_hook, + *args, + ): + from triton.runtime import _allocation + + def allocate_scratch(size, align, allocator): + if size > 0: + grid_size = gridX * gridY * gridZ + alloc_size = grid_size * self.num_ctas * size + alloc_fn = allocator.get() + return alloc_fn(alloc_size, align, stream) + return None + + global_scratch = allocate_scratch( + self.global_scratch_size, self.global_scratch_align, _allocation._allocator + ) + profile_scratch = allocate_scratch( + self.profile_scratch_size, + self.profile_scratch_align, + _allocation._profile_allocator, + ) + assert not self.launch_cooperative_grid + assert not self.launch_pdl + assert len(self.mask) == len(args) + args = [arg for arg, m in zip(args, self.mask) if m] + return launch( + gridX, + gridY, + gridZ, + stream, + function, + kernel_metadata, + launch_metadata, + launch_enter_hook, + launch_exit_hook, + self.launch_cooperative_grid, + self.launch_pdl, + global_scratch, + profile_scratch, + *args, + ) + + class TVMFFIDriver(CudaDriver): def __init__(self, *args, **kwargs) -> TVMFFIDriver: super().__init__(*args, **kwargs) self.utils: TVMFFIUtils = TVMFFIUtils() + self.launcher_cls: Type[TVMLauncher] = TVMLauncher del CudaDriver diff --git a/python/triton_tvm_ffi/utils/_ffi_api.py b/python/triton_tvm_ffi/utils/_ffi_api.py index 79b64d3..351a47f 100644 --- a/python/triton_tvm_ffi/utils/_ffi_api.py +++ b/python/triton_tvm_ffi/utils/_ffi_api.py @@ -7,6 +7,7 @@ from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB from typing import TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Mapping + from typing import Any # isort: on # fmt: on # tvm-ffi-stubgen(end) @@ -17,6 +18,7 @@ LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "utils") _FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__) if TYPE_CHECKING: def get_device_properties(_0: int, /) -> Mapping[str, int]: ... + def launch(*args: Any) -> Any: ... def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... # fmt: on # tvm-ffi-stubgen(end) @@ -25,6 +27,7 @@ __all__ = [ # tvm-ffi-stubgen(begin): __all__ "LIB", "get_device_properties", + "launch", "load_binary", # tvm-ffi-stubgen(end) ] diff --git a/src/utils.cc b/src/utils.cc index 104f377..3713d84 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,7 +1,7 @@ #include "exception.h" +#include #include #include -#include #include #define CUDA_CHECK(code) \ @@ -14,35 +14,35 @@ tvm::ffi::Map GetDeviceProperties(int device_id) { tvm::ffi::cuda_api::DeviceHandle device; CUDA_CHECK(cuDeviceGet(&device, device_id)); - int max_shared_mem = 0; - int max_num_regs = 0; - int multiprocessor_count = 0; - int warp_size = 0; - int sm_clock_rate = 0; - int mem_clock_rate = 0; - int mem_bus_width = 0; + int maxSharedMem = 0; + int maxNumRegs = 0; + int multiprocessorCount = 0; + int warpSize = 0; + int smClockRate = 0; + int memClockRate = 0; + int memBusWidth = 0; CUDA_CHECK(cuDeviceGetAttribute( - &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + &maxSharedMem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); CUDA_CHECK(cuDeviceGetAttribute( - &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); + &maxNumRegs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); CUDA_CHECK(cuDeviceGetAttribute( - &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + &multiprocessorCount, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); CUDA_CHECK( - cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); - CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, - CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + cuDeviceGetAttribute(&warpSize, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); + CUDA_CHECK(cuDeviceGetAttribute(&smClockRate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, + device)); CUDA_CHECK(cuDeviceGetAttribute( - &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + &memClockRate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); CUDA_CHECK(cuDeviceGetAttribute( - &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); - return {{"max_shared_mem", max_shared_mem}, - {"max_num_regs", max_num_regs}, - {"multiprocessor_count", multiprocessor_count}, - {"warp_size", warp_size}, - {"sm_clock_rate", sm_clock_rate}, - {"mem_clock_rate", mem_clock_rate}, - {"mem_bus_width", mem_bus_width}}; + &memBusWidth, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + return {{"max_shared_mem", maxSharedMem}, + {"max_num_regs", maxNumRegs}, + {"multiprocessor_count", multiprocessorCount}, + {"warpSize", warpSize}, + {"sm_clock_rate", smClockRate}, + {"mem_clock_rate", memClockRate}, + {"mem_bus_width", memBusWidth}}; } tvm::ffi::Tuple @@ -88,9 +88,100 @@ LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data, mod, fun, nRegs, nSpills, nMaxThreads}; } +void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { + CUtensorMap x; + int32_t gridX = args[0].cast(); + int32_t gridY = args[1].cast(); + int32_t gridZ = args[2].cast(); + CUstream stream = reinterpret_cast(args[3].cast()); + CUfunction function = reinterpret_cast(args[4].cast()); + tvm::ffi::Tuple kernelMetadata = + args[5].cast>(); + int32_t numWarps = kernelMetadata.get<0>(); + int32_t numCtas = kernelMetadata.get<1>(); + int32_t sharedMemory = kernelMetadata.get<2>(); + tvm::ffi::ObjectRef launchMetadata = args[6].cast(); + tvm::ffi::ObjectRef launchEnterHook = args[7].cast(); + tvm::ffi::ObjectRef launchExitHook = args[8].cast(); + bool launchCooperativeGrid = args[9].cast(); + bool launchPdl = args[10].cast(); + tvm::ffi::ObjectRef globalScratchObject = + args[11].cast(); + tvm::ffi::ObjectRef profileScratchObject = + args[12].cast(); + tvm::ffi::PackedArgs kernelArgs = args.Slice(13); + // TODO: call `launchEnterHook` + // TODO: check `globalScratchObject` + CUdeviceptr globalScratch = 0; + // TODO: check `profileScratchObject` + CUdeviceptr profileScratch = 0; + if (gridX * gridY * gridZ > 0) { + CUlaunchAttribute launchAttr[4]; + CUlaunchConfig config; + config.gridDimX = gridX * numCtas; + config.gridDimY = gridY; + config.gridDimZ = gridZ; + static constexpr int32_t kThreadsPerWarp = 32; + config.blockDimX = kThreadsPerWarp * numWarps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = sharedMemory; + config.hStream = stream; + config.attrs = launchAttr; + int32_t numAttrs = 0; + // TODO: check `launchPdf` + // TODO: check `launchCooperativeGrid` + if (numCtas != 1) { + CUlaunchAttribute clusterAttr; + clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + clusterAttr.value.clusterDim.x = numCtas; + clusterAttr.value.clusterDim.y = 1; + clusterAttr.value.clusterDim.z = 1; + launchAttr[numAttrs++] = clusterAttr; + CUlaunchAttribute clusterSchedulingAttr; + clusterSchedulingAttr.id = + CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = + CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + launchAttr[numAttrs++] = clusterSchedulingAttr; + } + config.numAttrs = numAttrs; + if (numCtas == 16) { + CUDA_CHECK(cuFuncSetAttribute( + function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); + } + const int32_t kernelArgNum = kernelArgs.size(); + void **params = + reinterpret_cast(alloca(sizeof(void *) * (kernelArgNum + 2))); + for (size_t i = 0; i < kernelArgNum; ++i) { + tvm::ffi::AnyView arg = kernelArgs[i]; + if (auto val = arg.try_cast()) { + void **ptr = reinterpret_cast(alloca(sizeof(void *))); + *ptr = val->data_ptr(); + params[i] = ptr; + } else if (auto val = arg.try_cast()) { + int32_t *ptr = reinterpret_cast(alloca(sizeof(int32_t))); + *ptr = *val; + params[i] = ptr; + } else if (auto val = arg.try_cast()) { + float *ptr = reinterpret_cast(alloca(sizeof(float))); + *ptr = *val; + params[i] = ptr; + } else { + assert(false && "unsupported kernel argument type"); + } + } + params[kernelArgNum] = &globalScratch; + params[kernelArgNum + 1] = &profileScratch; + CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr)); + } + // TODO: call `launchExitHook` +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties) + .def_packed("triton_tvm_ffi.utils.launch", Launch) .def("triton_tvm_ffi.utils.load_binary", LoadBinary); }