mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
+7
-14
@@ -2,6 +2,7 @@
|
|||||||
#define TRITON_TVM_FFI_LAUNCH_H_
|
#define TRITON_TVM_FFI_LAUNCH_H_
|
||||||
|
|
||||||
#include "type.h"
|
#include "type.h"
|
||||||
|
#include <cuda.h>
|
||||||
#include <tvm/ffi/object.h>
|
#include <tvm/ffi/object.h>
|
||||||
|
|
||||||
namespace triton_tvm_ffi {
|
namespace triton_tvm_ffi {
|
||||||
@@ -13,13 +14,9 @@ public:
|
|||||||
TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default;
|
TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default;
|
||||||
TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default;
|
TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default;
|
||||||
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||||
uint64_t function,
|
uint64_t function, int32_t numWarps, int32_t numCtas,
|
||||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
int32_t sharedMemory, uint64_t globalScratch,
|
||||||
tvm::ffi::ObjectRef launchMetadata,
|
uint64_t profileScratch,
|
||||||
tvm::ffi::ObjectRef launchEnterHook,
|
|
||||||
tvm::ffi::ObjectRef launchExitHook,
|
|
||||||
tvm::ffi::ObjectRef globalScratchObject,
|
|
||||||
tvm::ffi::ObjectRef profileScratchObject,
|
|
||||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
|
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
|
||||||
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl",
|
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl",
|
||||||
TVMFFILauncherImplObj, tvm::ffi::Object);
|
TVMFFILauncherImplObj, tvm::ffi::Object);
|
||||||
@@ -37,13 +34,9 @@ public:
|
|||||||
using tvm::ffi::ObjectRef::ObjectRef;
|
using tvm::ffi::ObjectRef::ObjectRef;
|
||||||
using tvm::ffi::ObjectRef::operator=;
|
using tvm::ffi::ObjectRef::operator=;
|
||||||
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||||
uint64_t function,
|
uint64_t function, int32_t numWarps, int32_t numCtas,
|
||||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
int32_t sharedMemory, uint64_t globalScratch,
|
||||||
tvm::ffi::ObjectRef launchMetadata,
|
uint64_t profileScratch,
|
||||||
tvm::ffi::ObjectRef launchEnterHook,
|
|
||||||
tvm::ffi::ObjectRef launchExitHook,
|
|
||||||
tvm::ffi::ObjectRef globalScratchObject,
|
|
||||||
tvm::ffi::ObjectRef profileScratchObject,
|
|
||||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
|
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
|
||||||
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl,
|
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl,
|
||||||
tvm::ffi::ObjectRef,
|
tvm::ffi::ObjectRef,
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class TVMFFILauncherImpl(_ffi_Object):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ...
|
def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ...
|
||||||
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: tuple[int, int, int], _7: Object, _8: Object, _9: Object, _10: Object, _11: Object, _12: Sequence[Any], /) -> None: ...
|
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: Sequence[Any], /) -> None: ...
|
||||||
# fmt: on
|
# fmt: on
|
||||||
# tvm-ffi-stubgen(end)
|
# tvm-ffi-stubgen(end)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Final, List, Sequence, Type
|
from typing import Any, Final, List, Type
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
from triton.backends.nvidia.driver import CudaDriver
|
from triton.backends.nvidia.driver import CudaDriver
|
||||||
@@ -24,19 +24,16 @@ class TVMLauncher(object):
|
|||||||
self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
|
self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
|
||||||
self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
|
self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
|
||||||
self.launch_pdl: Final[bool] = metadata.launch_pdl
|
self.launch_pdl: Final[bool] = metadata.launch_pdl
|
||||||
self.enable_jit: Final[bool] = (
|
if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}:
|
||||||
os.getenv("TRITON_TVM_FFI_ENABLE_JIT", None) is not None
|
mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline(
|
||||||
)
|
|
||||||
if self.enable_jit:
|
|
||||||
mod = tvm_ffi.cpp.load_inline(
|
|
||||||
"launch",
|
"launch",
|
||||||
cpp_sources=self.codegen,
|
cpp_sources=[self.codegen],
|
||||||
extra_ldflags=["-Wl,--no-as-needed", "-lcuda"],
|
extra_ldflags=["-Wl,--no-as-needed", "-lcuda"],
|
||||||
extra_include_paths=[
|
extra_include_paths=[
|
||||||
f"{tvm_ffi.cpp.extension._find_cuda_home()}/include"
|
f"{tvm_ffi.cpp.extension._find_cuda_home()}/include"
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
launch = mod.get_function("launch")
|
launch: tvm_ffi.Function = mod.get_function("launch")
|
||||||
self.launch = launch
|
self.launch = launch
|
||||||
else:
|
else:
|
||||||
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
|
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
|
||||||
@@ -50,10 +47,9 @@ class TVMLauncher(object):
|
|||||||
grid_z,
|
grid_z,
|
||||||
stream,
|
stream,
|
||||||
function,
|
function,
|
||||||
kernel_metadata,
|
num_warps,
|
||||||
launch_metadata,
|
num_ctas,
|
||||||
launch_enter_hook,
|
shared_memory,
|
||||||
launch_exit_hook,
|
|
||||||
global_scratch,
|
global_scratch,
|
||||||
profile_scratch,
|
profile_scratch,
|
||||||
*args: self.impl.launch(
|
*args: self.impl.launch(
|
||||||
@@ -62,10 +58,9 @@ class TVMLauncher(object):
|
|||||||
grid_z,
|
grid_z,
|
||||||
stream,
|
stream,
|
||||||
function,
|
function,
|
||||||
kernel_metadata,
|
num_warps,
|
||||||
launch_metadata,
|
num_ctas,
|
||||||
launch_enter_hook,
|
shared_memory,
|
||||||
launch_exit_hook,
|
|
||||||
global_scratch,
|
global_scratch,
|
||||||
profile_scratch,
|
profile_scratch,
|
||||||
args,
|
args,
|
||||||
@@ -101,37 +96,36 @@ class TVMLauncher(object):
|
|||||||
self.profile_scratch_align,
|
self.profile_scratch_align,
|
||||||
_allocation._profile_allocator,
|
_allocation._profile_allocator,
|
||||||
)
|
)
|
||||||
assert not self.launch_cooperative_grid
|
|
||||||
assert not self.launch_pdl
|
|
||||||
|
|
||||||
if self.enable_jit:
|
def canonicalize(obj: Any) -> int:
|
||||||
(num_warps, num_ctas, shared_memory) = kernel_metadata
|
if obj is None:
|
||||||
return self.launch(
|
return 0
|
||||||
gridX,
|
elif isinstance(obj, int):
|
||||||
gridY,
|
return obj
|
||||||
gridZ,
|
elif get_ptr := getattr(obj, "data_ptr", None):
|
||||||
stream,
|
return get_ptr()
|
||||||
function,
|
else:
|
||||||
num_warps,
|
raise TypeError(f"cannot canonicalize object of type {type(obj)}")
|
||||||
num_ctas,
|
|
||||||
shared_memory,
|
(num_warps, num_ctas, shared_memory) = kernel_metadata
|
||||||
*args,
|
if launch_enter_hook:
|
||||||
)
|
launch_enter_hook(launch_metadata)
|
||||||
else:
|
ret = self.launch(
|
||||||
return self.launch(
|
gridX,
|
||||||
gridX,
|
gridY,
|
||||||
gridY,
|
gridZ,
|
||||||
gridZ,
|
stream,
|
||||||
stream,
|
function,
|
||||||
function,
|
num_warps,
|
||||||
kernel_metadata,
|
num_ctas,
|
||||||
launch_metadata,
|
shared_memory,
|
||||||
launch_enter_hook,
|
canonicalize(global_scratch),
|
||||||
launch_exit_hook,
|
canonicalize(profile_scratch),
|
||||||
global_scratch,
|
*args,
|
||||||
profile_scratch,
|
)
|
||||||
*args,
|
if launch_exit_hook:
|
||||||
)
|
launch_exit_hook(launch_metadata)
|
||||||
|
return ret
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def codegen(self) -> str:
|
def codegen(self) -> str:
|
||||||
|
|||||||
@@ -14,10 +14,8 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
|
|||||||
int32_t numWarps = args[5].v_int64;
|
int32_t numWarps = args[5].v_int64;
|
||||||
int32_t numCtas = args[6].v_int64;
|
int32_t numCtas = args[6].v_int64;
|
||||||
int32_t sharedMemory = args[7].v_int64;
|
int32_t sharedMemory = args[7].v_int64;
|
||||||
// TODO: Implement the launch logic
|
uint64_t globalScratch = args[8].v_uint64;
|
||||||
CUdeviceptr globalScratch = 0;
|
uint64_t profileScratch = args[9].v_uint64;
|
||||||
// TODO: check `profileScratchObject`
|
|
||||||
CUdeviceptr profileScratch = 0;
|
|
||||||
if (gridX * gridY * gridZ > 0) {
|
if (gridX * gridY * gridZ > 0) {
|
||||||
CUlaunchAttribute launchAttr[4];
|
CUlaunchAttribute launchAttr[4];
|
||||||
CUlaunchConfig config;
|
CUlaunchConfig config;
|
||||||
@@ -51,9 +49,9 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
|
|||||||
config.numAttrs = numAttrs;
|
config.numAttrs = numAttrs;
|
||||||
{% for type in signature %}
|
{% for type in signature %}
|
||||||
{% if type == "void *" %}
|
{% if type == "void *" %}
|
||||||
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
||||||
{% elif type == "int32_t" %}
|
{% elif type == "int32_t" %}
|
||||||
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64;
|
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
|
||||||
{% else %}
|
{% else %}
|
||||||
assert(false, "unsupported type yet {{ type }}");
|
assert(false, "unsupported type yet {{ type }}");
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|||||||
+7
-22
@@ -1,6 +1,6 @@
|
|||||||
#include "launch.h"
|
#include "launch.h"
|
||||||
#include "macro.h"
|
#include "macro.h"
|
||||||
#include <cuda.h>
|
#include <cstdint>
|
||||||
#include <tvm/ffi/base_details.h>
|
#include <tvm/ffi/base_details.h>
|
||||||
|
|
||||||
namespace triton_tvm_ffi {
|
namespace triton_tvm_ffi {
|
||||||
@@ -14,19 +14,11 @@ TVMFFILauncherImplObj::TVMFFILauncherImplObj(
|
|||||||
|
|
||||||
void TVMFFILauncherImplObj::Launch(
|
void TVMFFILauncherImplObj::Launch(
|
||||||
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||||
uint64_t function,
|
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
|
||||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
uint64_t globalScratch, uint64_t profileScratch,
|
||||||
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
|
|
||||||
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
|
|
||||||
tvm::ffi::ObjectRef profileScratchObject,
|
|
||||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
|
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
|
||||||
CUstream cStream = reinterpret_cast<CUstream>(stream);
|
CUstream cStream = reinterpret_cast<CUstream>(stream);
|
||||||
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
|
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
|
||||||
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
|
|
||||||
// TODO: Implement the launch logic
|
|
||||||
CUdeviceptr globalScratch = 0;
|
|
||||||
// TODO: check `profileScratchObject`
|
|
||||||
CUdeviceptr profileScratch = 0;
|
|
||||||
if (gridX * gridY * gridZ > 0) {
|
if (gridX * gridY * gridZ > 0) {
|
||||||
CUlaunchAttribute launchAttr[4];
|
CUlaunchAttribute launchAttr[4];
|
||||||
CUlaunchConfig config;
|
CUlaunchConfig config;
|
||||||
@@ -41,8 +33,6 @@ void TVMFFILauncherImplObj::Launch(
|
|||||||
config.hStream = cStream;
|
config.hStream = cStream;
|
||||||
config.attrs = launchAttr;
|
config.attrs = launchAttr;
|
||||||
int32_t numAttrs = 0;
|
int32_t numAttrs = 0;
|
||||||
// TODO: check `launchPdl`
|
|
||||||
// TODO: check `launchCooperativeGrid`
|
|
||||||
if (numCtas != 1) {
|
if (numCtas != 1) {
|
||||||
CUlaunchAttribute clusterAttr;
|
CUlaunchAttribute clusterAttr;
|
||||||
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
||||||
@@ -109,7 +99,6 @@ void TVMFFILauncherImplObj::Launch(
|
|||||||
params[j + 1] = &profileScratch;
|
params[j + 1] = &profileScratch;
|
||||||
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
|
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
|
||||||
}
|
}
|
||||||
// TODO: call `launchExitHook`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
|
TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
|
||||||
@@ -120,15 +109,11 @@ TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
|
|||||||
|
|
||||||
void TVMFFILauncherImpl::Launch(
|
void TVMFFILauncherImpl::Launch(
|
||||||
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||||
uint64_t function,
|
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
|
||||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
uint64_t globalScratch, uint64_t profileScratch,
|
||||||
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
|
|
||||||
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
|
|
||||||
tvm::ffi::ObjectRef profileScratchObject,
|
|
||||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
|
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
|
||||||
get()->Launch(gridX, gridY, gridZ, stream, function, kernelMetadata,
|
get()->Launch(gridX, gridY, gridZ, stream, function, numWarps, numCtas,
|
||||||
launchMetadata, launchEnterHook, launchExitHook,
|
sharedMemory, globalScratch, profileScratch, kernelArgs);
|
||||||
globalScratchObject, profileScratchObject, kernelArgs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||||
|
|||||||
Reference in New Issue
Block a user