unify launch apis

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-31 11:29:32 +08:00
parent ac7497b2c8
commit e9576d265e
5 changed files with 59 additions and 89 deletions
+7 -14
View File
@@ -2,6 +2,7 @@
#define TRITON_TVM_FFI_LAUNCH_H_ #define TRITON_TVM_FFI_LAUNCH_H_
#include "type.h" #include "type.h"
#include <cuda.h>
#include <tvm/ffi/object.h> #include <tvm/ffi/object.h>
namespace triton_tvm_ffi { namespace triton_tvm_ffi {
@@ -13,13 +14,9 @@ public:
TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default; TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default;
TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default; TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default;
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function, uint64_t function, int32_t numWarps, int32_t numCtas,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata, int32_t sharedMemory, uint64_t globalScratch,
tvm::ffi::ObjectRef launchMetadata, uint64_t profileScratch,
tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook,
tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const; const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl", TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl",
TVMFFILauncherImplObj, tvm::ffi::Object); TVMFFILauncherImplObj, tvm::ffi::Object);
@@ -37,13 +34,9 @@ public:
using tvm::ffi::ObjectRef::ObjectRef; using tvm::ffi::ObjectRef::ObjectRef;
using tvm::ffi::ObjectRef::operator=; using tvm::ffi::ObjectRef::operator=;
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function, uint64_t function, int32_t numWarps, int32_t numCtas,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata, int32_t sharedMemory, uint64_t globalScratch,
tvm::ffi::ObjectRef launchMetadata, uint64_t profileScratch,
tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook,
tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const; const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl,
tvm::ffi::ObjectRef, tvm::ffi::ObjectRef,
+1 -1
View File
@@ -37,7 +37,7 @@ class TVMFFILauncherImpl(_ffi_Object):
if TYPE_CHECKING: if TYPE_CHECKING:
@staticmethod @staticmethod
def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ... def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ...
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: tuple[int, int, int], _7: Object, _8: Object, _9: Object, _10: Object, _11: Object, _12: Sequence[Any], /) -> None: ... def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: Sequence[Any], /) -> None: ...
# fmt: on # fmt: on
# tvm-ffi-stubgen(end) # tvm-ffi-stubgen(end)
+29 -35
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from functools import cached_property from functools import cached_property
import os import os
from typing import Any, Callable, Final, List, Sequence, Type from typing import Any, Final, List, Type
import jinja2 import jinja2
from triton.backends.nvidia.driver import CudaDriver from triton.backends.nvidia.driver import CudaDriver
@@ -24,19 +24,16 @@ class TVMLauncher(object):
self.profile_scratch_align: Final[int] = metadata.profile_scratch_align self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
self.launch_pdl: Final[bool] = metadata.launch_pdl self.launch_pdl: Final[bool] = metadata.launch_pdl
self.enable_jit: Final[bool] = ( if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}:
os.getenv("TRITON_TVM_FFI_ENABLE_JIT", None) is not None mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline(
)
if self.enable_jit:
mod = tvm_ffi.cpp.load_inline(
"launch", "launch",
cpp_sources=self.codegen, cpp_sources=[self.codegen],
extra_ldflags=["-Wl,--no-as-needed", "-lcuda"], extra_ldflags=["-Wl,--no-as-needed", "-lcuda"],
extra_include_paths=[ extra_include_paths=[
f"{tvm_ffi.cpp.extension._find_cuda_home()}/include" f"{tvm_ffi.cpp.extension._find_cuda_home()}/include"
], ],
) )
launch = mod.get_function("launch") launch: tvm_ffi.Function = mod.get_function("launch")
self.launch = launch self.launch = launch
else: else:
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl( self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
@@ -50,10 +47,9 @@ class TVMLauncher(object):
grid_z, grid_z,
stream, stream,
function, function,
kernel_metadata, num_warps,
launch_metadata, num_ctas,
launch_enter_hook, shared_memory,
launch_exit_hook,
global_scratch, global_scratch,
profile_scratch, profile_scratch,
*args: self.impl.launch( *args: self.impl.launch(
@@ -62,10 +58,9 @@ class TVMLauncher(object):
grid_z, grid_z,
stream, stream,
function, function,
kernel_metadata, num_warps,
launch_metadata, num_ctas,
launch_enter_hook, shared_memory,
launch_exit_hook,
global_scratch, global_scratch,
profile_scratch, profile_scratch,
args, args,
@@ -101,12 +96,21 @@ class TVMLauncher(object):
self.profile_scratch_align, self.profile_scratch_align,
_allocation._profile_allocator, _allocation._profile_allocator,
) )
assert not self.launch_cooperative_grid
assert not self.launch_pdl
if self.enable_jit: def canonicalize(obj: Any) -> int:
if obj is None:
return 0
elif isinstance(obj, int):
return obj
elif get_ptr := getattr(obj, "data_ptr", None):
return get_ptr()
else:
raise TypeError(f"cannot canonicalize object of type {type(obj)}")
(num_warps, num_ctas, shared_memory) = kernel_metadata (num_warps, num_ctas, shared_memory) = kernel_metadata
return self.launch( if launch_enter_hook:
launch_enter_hook(launch_metadata)
ret = self.launch(
gridX, gridX,
gridY, gridY,
gridZ, gridZ,
@@ -115,23 +119,13 @@ class TVMLauncher(object):
num_warps, num_warps,
num_ctas, num_ctas,
shared_memory, shared_memory,
canonicalize(global_scratch),
canonicalize(profile_scratch),
*args, *args,
) )
else: if launch_exit_hook:
return self.launch( launch_exit_hook(launch_metadata)
gridX, return ret
gridY,
gridZ,
stream,
function,
kernel_metadata,
launch_metadata,
launch_enter_hook,
launch_exit_hook,
global_scratch,
profile_scratch,
*args,
)
@cached_property @cached_property
def codegen(self) -> str: def codegen(self) -> str:
+4 -6
View File
@@ -14,10 +14,8 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
int32_t numWarps = args[5].v_int64; int32_t numWarps = args[5].v_int64;
int32_t numCtas = args[6].v_int64; int32_t numCtas = args[6].v_int64;
int32_t sharedMemory = args[7].v_int64; int32_t sharedMemory = args[7].v_int64;
// TODO: Implement the launch logic uint64_t globalScratch = args[8].v_uint64;
CUdeviceptr globalScratch = 0; uint64_t profileScratch = args[9].v_uint64;
// TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0;
if (gridX * gridY * gridZ > 0) { if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4]; CUlaunchAttribute launchAttr[4];
CUlaunchConfig config; CUlaunchConfig config;
@@ -51,9 +49,9 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
config.numAttrs = numAttrs; config.numAttrs = numAttrs;
{% for type in signature %} {% for type in signature %}
{% if type == "void *" %} {% if type == "void *" %}
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data; {{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
{% elif type == "int32_t" %} {% elif type == "int32_t" %}
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64; {{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
{% else %} {% else %}
assert(false, "unsupported type yet {{ type }}"); assert(false, "unsupported type yet {{ type }}");
{% endif %} {% endif %}
+7 -22
View File
@@ -1,6 +1,6 @@
#include "launch.h" #include "launch.h"
#include "macro.h" #include "macro.h"
#include <cuda.h> #include <cstdint>
#include <tvm/ffi/base_details.h> #include <tvm/ffi/base_details.h>
namespace triton_tvm_ffi { namespace triton_tvm_ffi {
@@ -14,19 +14,11 @@ TVMFFILauncherImplObj::TVMFFILauncherImplObj(
void TVMFFILauncherImplObj::Launch( void TVMFFILauncherImplObj::Launch(
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function, uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata, uint64_t globalScratch, uint64_t profileScratch,
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const { const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
CUstream cStream = reinterpret_cast<CUstream>(stream); CUstream cStream = reinterpret_cast<CUstream>(stream);
CUfunction cFunction = reinterpret_cast<CUfunction>(function); CUfunction cFunction = reinterpret_cast<CUfunction>(function);
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
// TODO: Implement the launch logic
CUdeviceptr globalScratch = 0;
// TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0;
if (gridX * gridY * gridZ > 0) { if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4]; CUlaunchAttribute launchAttr[4];
CUlaunchConfig config; CUlaunchConfig config;
@@ -41,8 +33,6 @@ void TVMFFILauncherImplObj::Launch(
config.hStream = cStream; config.hStream = cStream;
config.attrs = launchAttr; config.attrs = launchAttr;
int32_t numAttrs = 0; int32_t numAttrs = 0;
// TODO: check `launchPdl`
// TODO: check `launchCooperativeGrid`
if (numCtas != 1) { if (numCtas != 1) {
CUlaunchAttribute clusterAttr; CUlaunchAttribute clusterAttr;
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
@@ -109,7 +99,6 @@ void TVMFFILauncherImplObj::Launch(
params[j + 1] = &profileScratch; params[j + 1] = &profileScratch;
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr)); CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
} }
// TODO: call `launchExitHook`
} }
TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature, TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
@@ -120,15 +109,11 @@ TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
void TVMFFILauncherImpl::Launch( void TVMFFILauncherImpl::Launch(
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function, uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata, uint64_t globalScratch, uint64_t profileScratch,
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const { const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
get()->Launch(gridX, gridY, gridZ, stream, function, kernelMetadata, get()->Launch(gridX, gridY, gridZ, stream, function, numWarps, numCtas,
launchMetadata, launchEnterHook, launchExitHook, sharedMemory, globalScratch, profileScratch, kernelArgs);
globalScratchObject, profileScratchObject, kernelArgs);
} }
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {