unify launch apis

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-31 11:29:32 +08:00
parent ac7497b2c8
commit e9576d265e
5 changed files with 59 additions and 89 deletions
+7 -14
View File
@@ -2,6 +2,7 @@
#define TRITON_TVM_FFI_LAUNCH_H_
#include "type.h"
#include <cuda.h>
#include <tvm/ffi/object.h>
namespace triton_tvm_ffi {
@@ -13,13 +14,9 @@ public:
TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default;
TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default;
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata,
tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook,
tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
uint64_t function, int32_t numWarps, int32_t numCtas,
int32_t sharedMemory, uint64_t globalScratch,
uint64_t profileScratch,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl",
TVMFFILauncherImplObj, tvm::ffi::Object);
@@ -37,13 +34,9 @@ public:
using tvm::ffi::ObjectRef::ObjectRef;
using tvm::ffi::ObjectRef::operator=;
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata,
tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook,
tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
uint64_t function, int32_t numWarps, int32_t numCtas,
int32_t sharedMemory, uint64_t globalScratch,
uint64_t profileScratch,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl,
tvm::ffi::ObjectRef,
+1 -1
View File
@@ -37,7 +37,7 @@ class TVMFFILauncherImpl(_ffi_Object):
if TYPE_CHECKING:
@staticmethod
def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ...
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: tuple[int, int, int], _7: Object, _8: Object, _9: Object, _10: Object, _11: Object, _12: Sequence[Any], /) -> None: ...
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: Sequence[Any], /) -> None: ...
# fmt: on
# tvm-ffi-stubgen(end)
+40 -46
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from functools import cached_property
import os
from typing import Any, Callable, Final, List, Sequence, Type
from typing import Any, Final, List, Type
import jinja2
from triton.backends.nvidia.driver import CudaDriver
@@ -24,19 +24,16 @@ class TVMLauncher(object):
self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
self.launch_pdl: Final[bool] = metadata.launch_pdl
self.enable_jit: Final[bool] = (
os.getenv("TRITON_TVM_FFI_ENABLE_JIT", None) is not None
)
if self.enable_jit:
mod = tvm_ffi.cpp.load_inline(
if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}:
mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline(
"launch",
cpp_sources=self.codegen,
cpp_sources=[self.codegen],
extra_ldflags=["-Wl,--no-as-needed", "-lcuda"],
extra_include_paths=[
f"{tvm_ffi.cpp.extension._find_cuda_home()}/include"
],
)
launch = mod.get_function("launch")
launch: tvm_ffi.Function = mod.get_function("launch")
self.launch = launch
else:
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
@@ -50,10 +47,9 @@ class TVMLauncher(object):
grid_z,
stream,
function,
kernel_metadata,
launch_metadata,
launch_enter_hook,
launch_exit_hook,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
*args: self.impl.launch(
@@ -62,10 +58,9 @@ class TVMLauncher(object):
grid_z,
stream,
function,
kernel_metadata,
launch_metadata,
launch_enter_hook,
launch_exit_hook,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
args,
@@ -101,37 +96,36 @@ class TVMLauncher(object):
self.profile_scratch_align,
_allocation._profile_allocator,
)
assert not self.launch_cooperative_grid
assert not self.launch_pdl
if self.enable_jit:
(num_warps, num_ctas, shared_memory) = kernel_metadata
return self.launch(
gridX,
gridY,
gridZ,
stream,
function,
num_warps,
num_ctas,
shared_memory,
*args,
)
else:
return self.launch(
gridX,
gridY,
gridZ,
stream,
function,
kernel_metadata,
launch_metadata,
launch_enter_hook,
launch_exit_hook,
global_scratch,
profile_scratch,
*args,
)
def canonicalize(obj: Any) -> int:
if obj is None:
return 0
elif isinstance(obj, int):
return obj
elif get_ptr := getattr(obj, "data_ptr", None):
return get_ptr()
else:
raise TypeError(f"cannot canonicalize object of type {type(obj)}")
(num_warps, num_ctas, shared_memory) = kernel_metadata
if launch_enter_hook:
launch_enter_hook(launch_metadata)
ret = self.launch(
gridX,
gridY,
gridZ,
stream,
function,
num_warps,
num_ctas,
shared_memory,
canonicalize(global_scratch),
canonicalize(profile_scratch),
*args,
)
if launch_exit_hook:
launch_exit_hook(launch_metadata)
return ret
@cached_property
def codegen(self) -> str:
+4 -6
View File
@@ -14,10 +14,8 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
int32_t numWarps = args[5].v_int64;
int32_t numCtas = args[6].v_int64;
int32_t sharedMemory = args[7].v_int64;
// TODO: Implement the launch logic
CUdeviceptr globalScratch = 0;
// TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0;
uint64_t globalScratch = args[8].v_uint64;
uint64_t profileScratch = args[9].v_uint64;
if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4];
CUlaunchConfig config;
@@ -51,9 +49,9 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
config.numAttrs = numAttrs;
{% for type in signature %}
{% if type == "void *" %}
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data;
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
{% elif type == "int32_t" %}
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64;
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
{% else %}
assert(false, "unsupported type yet {{ type }}");
{% endif %}
+7 -22
View File
@@ -1,6 +1,6 @@
#include "launch.h"
#include "macro.h"
#include <cuda.h>
#include <cstdint>
#include <tvm/ffi/base_details.h>
namespace triton_tvm_ffi {
@@ -14,19 +14,11 @@ TVMFFILauncherImplObj::TVMFFILauncherImplObj(
void TVMFFILauncherImplObj::Launch(
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
uint64_t globalScratch, uint64_t profileScratch,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
CUstream cStream = reinterpret_cast<CUstream>(stream);
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
// TODO: Implement the launch logic
CUdeviceptr globalScratch = 0;
// TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0;
if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4];
CUlaunchConfig config;
@@ -41,8 +33,6 @@ void TVMFFILauncherImplObj::Launch(
config.hStream = cStream;
config.attrs = launchAttr;
int32_t numAttrs = 0;
// TODO: check `launchPdl`
// TODO: check `launchCooperativeGrid`
if (numCtas != 1) {
CUlaunchAttribute clusterAttr;
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
@@ -109,7 +99,6 @@ void TVMFFILauncherImplObj::Launch(
params[j + 1] = &profileScratch;
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
}
// TODO: call `launchExitHook`
}
TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
@@ -120,15 +109,11 @@ TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
void TVMFFILauncherImpl::Launch(
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
uint64_t globalScratch, uint64_t profileScratch,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
get()->Launch(gridX, gridY, gridZ, stream, function, kernelMetadata,
launchMetadata, launchEnterHook, launchExitHook,
globalScratchObject, profileScratchObject, kernelArgs);
get()->Launch(gridX, gridY, gridZ, stream, function, numWarps, numCtas,
sharedMemory, globalScratch, profileScratch, kernelArgs);
}
TVM_FFI_STATIC_INIT_BLOCK() {