mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
+7
-14
@@ -2,6 +2,7 @@
|
||||
#define TRITON_TVM_FFI_LAUNCH_H_
|
||||
|
||||
#include "type.h"
|
||||
#include <cuda.h>
|
||||
#include <tvm/ffi/object.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
@@ -13,13 +14,9 @@ public:
|
||||
TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default;
|
||||
TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default;
|
||||
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function,
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
||||
tvm::ffi::ObjectRef launchMetadata,
|
||||
tvm::ffi::ObjectRef launchEnterHook,
|
||||
tvm::ffi::ObjectRef launchExitHook,
|
||||
tvm::ffi::ObjectRef globalScratchObject,
|
||||
tvm::ffi::ObjectRef profileScratchObject,
|
||||
uint64_t function, int32_t numWarps, int32_t numCtas,
|
||||
int32_t sharedMemory, uint64_t globalScratch,
|
||||
uint64_t profileScratch,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
|
||||
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl",
|
||||
TVMFFILauncherImplObj, tvm::ffi::Object);
|
||||
@@ -37,13 +34,9 @@ public:
|
||||
using tvm::ffi::ObjectRef::ObjectRef;
|
||||
using tvm::ffi::ObjectRef::operator=;
|
||||
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function,
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
||||
tvm::ffi::ObjectRef launchMetadata,
|
||||
tvm::ffi::ObjectRef launchEnterHook,
|
||||
tvm::ffi::ObjectRef launchExitHook,
|
||||
tvm::ffi::ObjectRef globalScratchObject,
|
||||
tvm::ffi::ObjectRef profileScratchObject,
|
||||
uint64_t function, int32_t numWarps, int32_t numCtas,
|
||||
int32_t sharedMemory, uint64_t globalScratch,
|
||||
uint64_t profileScratch,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
|
||||
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl,
|
||||
tvm::ffi::ObjectRef,
|
||||
|
||||
@@ -37,7 +37,7 @@ class TVMFFILauncherImpl(_ffi_Object):
|
||||
if TYPE_CHECKING:
|
||||
@staticmethod
|
||||
def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ...
|
||||
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: tuple[int, int, int], _7: Object, _8: Object, _9: Object, _10: Object, _11: Object, _12: Sequence[Any], /) -> None: ...
|
||||
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: Sequence[Any], /) -> None: ...
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import os
|
||||
from typing import Any, Callable, Final, List, Sequence, Type
|
||||
from typing import Any, Final, List, Type
|
||||
|
||||
import jinja2
|
||||
from triton.backends.nvidia.driver import CudaDriver
|
||||
@@ -24,19 +24,16 @@ class TVMLauncher(object):
|
||||
self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
|
||||
self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
|
||||
self.launch_pdl: Final[bool] = metadata.launch_pdl
|
||||
self.enable_jit: Final[bool] = (
|
||||
os.getenv("TRITON_TVM_FFI_ENABLE_JIT", None) is not None
|
||||
)
|
||||
if self.enable_jit:
|
||||
mod = tvm_ffi.cpp.load_inline(
|
||||
if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}:
|
||||
mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline(
|
||||
"launch",
|
||||
cpp_sources=self.codegen,
|
||||
cpp_sources=[self.codegen],
|
||||
extra_ldflags=["-Wl,--no-as-needed", "-lcuda"],
|
||||
extra_include_paths=[
|
||||
f"{tvm_ffi.cpp.extension._find_cuda_home()}/include"
|
||||
],
|
||||
)
|
||||
launch = mod.get_function("launch")
|
||||
launch: tvm_ffi.Function = mod.get_function("launch")
|
||||
self.launch = launch
|
||||
else:
|
||||
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
|
||||
@@ -50,10 +47,9 @@ class TVMLauncher(object):
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
kernel_metadata,
|
||||
launch_metadata,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*args: self.impl.launch(
|
||||
@@ -62,10 +58,9 @@ class TVMLauncher(object):
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
kernel_metadata,
|
||||
launch_metadata,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
args,
|
||||
@@ -101,37 +96,36 @@ class TVMLauncher(object):
|
||||
self.profile_scratch_align,
|
||||
_allocation._profile_allocator,
|
||||
)
|
||||
assert not self.launch_cooperative_grid
|
||||
assert not self.launch_pdl
|
||||
|
||||
if self.enable_jit:
|
||||
(num_warps, num_ctas, shared_memory) = kernel_metadata
|
||||
return self.launch(
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
*args,
|
||||
)
|
||||
else:
|
||||
return self.launch(
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
stream,
|
||||
function,
|
||||
kernel_metadata,
|
||||
launch_metadata,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*args,
|
||||
)
|
||||
def canonicalize(obj: Any) -> int:
|
||||
if obj is None:
|
||||
return 0
|
||||
elif isinstance(obj, int):
|
||||
return obj
|
||||
elif get_ptr := getattr(obj, "data_ptr", None):
|
||||
return get_ptr()
|
||||
else:
|
||||
raise TypeError(f"cannot canonicalize object of type {type(obj)}")
|
||||
|
||||
(num_warps, num_ctas, shared_memory) = kernel_metadata
|
||||
if launch_enter_hook:
|
||||
launch_enter_hook(launch_metadata)
|
||||
ret = self.launch(
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
canonicalize(global_scratch),
|
||||
canonicalize(profile_scratch),
|
||||
*args,
|
||||
)
|
||||
if launch_exit_hook:
|
||||
launch_exit_hook(launch_metadata)
|
||||
return ret
|
||||
|
||||
@cached_property
|
||||
def codegen(self) -> str:
|
||||
|
||||
@@ -14,10 +14,8 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
|
||||
int32_t numWarps = args[5].v_int64;
|
||||
int32_t numCtas = args[6].v_int64;
|
||||
int32_t sharedMemory = args[7].v_int64;
|
||||
// TODO: Implement the launch logic
|
||||
CUdeviceptr globalScratch = 0;
|
||||
// TODO: check `profileScratchObject`
|
||||
CUdeviceptr profileScratch = 0;
|
||||
uint64_t globalScratch = args[8].v_uint64;
|
||||
uint64_t profileScratch = args[9].v_uint64;
|
||||
if (gridX * gridY * gridZ > 0) {
|
||||
CUlaunchAttribute launchAttr[4];
|
||||
CUlaunchConfig config;
|
||||
@@ -51,9 +49,9 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
|
||||
config.numAttrs = numAttrs;
|
||||
{% for type in signature %}
|
||||
{% if type == "void *" %}
|
||||
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
||||
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
||||
{% elif type == "int32_t" %}
|
||||
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64;
|
||||
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
|
||||
{% else %}
|
||||
assert(false, "unsupported type yet {{ type }}");
|
||||
{% endif %}
|
||||
|
||||
+7
-22
@@ -1,6 +1,6 @@
|
||||
#include "launch.h"
|
||||
#include "macro.h"
|
||||
#include <cuda.h>
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/base_details.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
@@ -14,19 +14,11 @@ TVMFFILauncherImplObj::TVMFFILauncherImplObj(
|
||||
|
||||
void TVMFFILauncherImplObj::Launch(
|
||||
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function,
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
||||
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
|
||||
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
|
||||
tvm::ffi::ObjectRef profileScratchObject,
|
||||
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
|
||||
uint64_t globalScratch, uint64_t profileScratch,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
|
||||
CUstream cStream = reinterpret_cast<CUstream>(stream);
|
||||
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
|
||||
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
|
||||
// TODO: Implement the launch logic
|
||||
CUdeviceptr globalScratch = 0;
|
||||
// TODO: check `profileScratchObject`
|
||||
CUdeviceptr profileScratch = 0;
|
||||
if (gridX * gridY * gridZ > 0) {
|
||||
CUlaunchAttribute launchAttr[4];
|
||||
CUlaunchConfig config;
|
||||
@@ -41,8 +33,6 @@ void TVMFFILauncherImplObj::Launch(
|
||||
config.hStream = cStream;
|
||||
config.attrs = launchAttr;
|
||||
int32_t numAttrs = 0;
|
||||
// TODO: check `launchPdl`
|
||||
// TODO: check `launchCooperativeGrid`
|
||||
if (numCtas != 1) {
|
||||
CUlaunchAttribute clusterAttr;
|
||||
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
||||
@@ -109,7 +99,6 @@ void TVMFFILauncherImplObj::Launch(
|
||||
params[j + 1] = &profileScratch;
|
||||
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
|
||||
}
|
||||
// TODO: call `launchExitHook`
|
||||
}
|
||||
|
||||
TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
|
||||
@@ -120,15 +109,11 @@ TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
|
||||
|
||||
void TVMFFILauncherImpl::Launch(
|
||||
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function,
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
||||
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
|
||||
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
|
||||
tvm::ffi::ObjectRef profileScratchObject,
|
||||
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
|
||||
uint64_t globalScratch, uint64_t profileScratch,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
|
||||
get()->Launch(gridX, gridY, gridZ, stream, function, kernelMetadata,
|
||||
launchMetadata, launchEnterHook, launchExitHook,
|
||||
globalScratchObject, profileScratchObject, kernelArgs);
|
||||
get()->Launch(gridX, gridY, gridZ, stream, function, numWarps, numCtas,
|
||||
sharedMemory, globalScratch, profileScratch, kernelArgs);
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
|
||||
Reference in New Issue
Block a user