implement launch with cpp

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-30 15:05:05 +08:00
parent 37a8f4a5be
commit 524cf83708
10 changed files with 250 additions and 243 deletions
+55
View File
@@ -0,0 +1,55 @@
#ifndef TRITON_TVM_FFI_LAUNCH_H_
#define TRITON_TVM_FFI_LAUNCH_H_
#include "type.h"
#include <tvm/ffi/object.h>
namespace triton_tvm_ffi {
class TVMFFILauncherImplObj : public tvm::ffi::Object {
public:
TVMFFILauncherImplObj(const tvm::ffi::Array<Type> &signature,
bool launchCooperativeGrid, bool launchAsync);
TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default;
TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default;
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata,
tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook,
tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl",
TVMFFILauncherImplObj, tvm::ffi::Object);
private:
tvm::ffi::Array<Type> signature_;
const bool launchCooperativeGrid_;
const bool launchAsync_;
};
class TVMFFILauncherImpl : public tvm::ffi::ObjectRef {
public:
TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
bool launchCooperativeGrid, bool launchAsync);
using tvm::ffi::ObjectRef::ObjectRef;
using tvm::ffi::ObjectRef::operator=;
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata,
tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook,
tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl,
tvm::ffi::ObjectRef,
TVMFFILauncherImplObj);
};
} // namespace triton_tvm_ffi
#endif
+11
View File
@@ -1,8 +1,19 @@
#ifndef TRITON_TVM_FFI_MACRO_H_ #ifndef TRITON_TVM_FFI_MACRO_H_
#define TRITON_TVM_FFI_MACRO_H_ #define TRITON_TVM_FFI_MACRO_H_
#include "exception.h"
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#define TRITON_TVM_FFI_INLINE __attribute__((always_inline)) inline #define TRITON_TVM_FFI_INLINE __attribute__((always_inline)) inline
#endif #endif
#define UNLIKELY(cond) __builtin_expect((cond), 0)
#define CUDA_CHECK(code) \
do { \
if (UNLIKELY((code) != CUDA_SUCCESS)) { \
throw triton_tvm_ffi::CUDAException(code); \
} \
} while (false)
#endif #endif
-52
View File
@@ -1,52 +0,0 @@
#ifndef TRITON_TVM_FFI_VALUE_H_
#define TRITON_TVM_FFI_VALUE_H_
#include "macro.h"
#include "type.h"
#include <tvm/ffi/any.h>
#include <tvm/ffi/object.h>
namespace triton_tvm_ffi {
class TypedValueObj : public tvm::ffi::Object {
public:
TypedValueObj(Type type, const tvm::ffi::Any &value);
TypedValueObj(Type type, tvm::ffi::Any &&value);
TypedValueObj(const TypedValueObj &other) = default;
TypedValueObj(TypedValueObj &&other) = default;
TypedValueObj &operator=(const TypedValueObj &other) = default;
TypedValueObj &operator=(TypedValueObj &&other) = default;
TRITON_TVM_FFI_INLINE Type GetType() const { return type_; }
TRITON_TVM_FFI_INLINE const tvm::ffi::Any &GetValue() const { return value_; }
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TypedValue", TypedValueObj,
tvm::ffi::Object);
private:
Type type_;
tvm::ffi::Any value_;
};
class TypedValue : public tvm::ffi::ObjectRef {
public:
TypedValue(Type type, const tvm::ffi::Any &value);
TypedValue(Type type, tvm::ffi::Any &&value);
using tvm::ffi::ObjectRef::ObjectRef;
using tvm::ffi::ObjectRef::operator=;
TRITON_TVM_FFI_INLINE Type GetType() const { return get()->GetType(); }
TRITON_TVM_FFI_INLINE const tvm::ffi::Any &GetValue() const {
return get()->GetValue();
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TypedValue, tvm::ffi::ObjectRef,
TypedValueObj);
};
tvm::ffi::Optional<TypedValue> MakeTypedValue(const tvm::ffi::String &type,
const tvm::ffi::Any &value);
tvm::ffi::Array<TypedValue>
MakeTypedValues(const tvm::ffi::Array<tvm::ffi::String> &types,
const tvm::ffi::Array<tvm::ffi::Any> &values);
} // namespace triton_tvm_ffi
#endif
+8 -9
View File
@@ -27,17 +27,16 @@ if TYPE_CHECKING:
# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ # tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ
# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object # tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
@_FFI_REG_OBJ("triton_tvm_ffi.TypedValue") @_FFI_REG_OBJ("triton_tvm_ffi.TVMFFILauncherImpl")
class TypedValue(_ffi_Object): class TVMFFILauncherImpl(_ffi_Object):
# tvm-ffi-stubgen(begin): object/triton_tvm_ffi.TypedValue """FFI binding for `triton_tvm_ffi.TVMFFILauncherImpl`."""
# tvm-ffi-stubgen(begin): object/triton_tvm_ffi.TVMFFILauncherImpl
# fmt: off # fmt: off
if TYPE_CHECKING: if TYPE_CHECKING:
@staticmethod @staticmethod
def __c_ffi_init__(_0: int, _1: Any, /) -> Object: ... def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ...
@staticmethod def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: tuple[int, int, int], _7: Object, _8: Object, _9: Object, _10: Object, _11: Object, _12: Sequence[Any], /) -> None: ...
def make_typed_value(_0: str, _1: Any, /) -> TypedValue | None: ...
@staticmethod
def make_typed_values(_0: Sequence[str], _1: Sequence[Any], /) -> Sequence[TypedValue]: ...
# fmt: on # fmt: on
# tvm-ffi-stubgen(end) # tvm-ffi-stubgen(end)
@@ -45,7 +44,7 @@ class TypedValue(_ffi_Object):
__all__ = [ __all__ = [
# tvm-ffi-stubgen(begin): __all__ # tvm-ffi-stubgen(begin): __all__
"LIB", "LIB",
"TypedValue", "TVMFFILauncherImpl",
"string_to_type", "string_to_type",
"type_to_string", "type_to_string",
# tvm-ffi-stubgen(end) # tvm-ffi-stubgen(end)
+32 -14
View File
@@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Sequence, Type from typing import Any, Callable, Final, List, Sequence, Type, Union
from triton.backends.nvidia.driver import CudaDriver from triton.backends.nvidia.driver import CudaDriver
from triton.runtime import _allocation from triton.runtime import _allocation
from . import TypedValue, utils, string_to_type from . import TVMFFILauncherImpl, utils, string_to_type
class TVMLauncher(object): class TVMLauncher(object):
@@ -11,14 +11,34 @@ class TVMLauncher(object):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.signature: List[str] = [*src.signature.values()] self.signature: List[str] = [*src.signature.values()]
self.num_ctas: int = getattr(metadata, "num_ctas", 1) self.num_ctas: Final[int] = getattr(metadata, "num_ctas", 1)
self.launch = utils.launch self.global_scratch_size: Final[int] = metadata.global_scratch_size
self.global_scratch_size: int = metadata.global_scratch_size self.global_scratch_align: Final[int] = metadata.global_scratch_align
self.global_scratch_align: int = metadata.global_scratch_align self.profile_scratch_size: Final[int] = metadata.profile_scratch_size
self.profile_scratch_size: int = metadata.profile_scratch_size self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
self.profile_scratch_align: int = metadata.profile_scratch_align self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
self.launch_cooperative_grid: bool = metadata.launch_cooperative_grid self.launch_pdl: Final[bool] = metadata.launch_pdl
self.launch_pdl: bool = metadata.launch_pdl self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
[string_to_type(t) for t in self.signature],
self.launch_cooperative_grid,
self.launch_pdl,
)
self.launch: Callable[
[
int,
int,
int,
int,
int,
tuple[int, int, int],
object,
object,
object,
object,
object,
Sequence[Union[Any]],
]
] = self.impl.launch
def __call__( def __call__(
self, self,
@@ -52,9 +72,9 @@ class TVMLauncher(object):
assert not self.launch_cooperative_grid assert not self.launch_cooperative_grid
assert not self.launch_pdl assert not self.launch_pdl
args: Sequence[TypedValue] = TypedValue.make_typed_values(self.signature, args) # args: Sequence[TypedValue] = TypedValue.make_typed_values(self.signature, args)
return self.launch( return self.impl.launch(
gridX, gridX,
gridY, gridY,
gridZ, gridZ,
@@ -64,8 +84,6 @@ class TVMLauncher(object):
launch_metadata, launch_metadata,
launch_enter_hook, launch_enter_hook,
launch_exit_hook, launch_exit_hook,
self.launch_cooperative_grid,
self.launch_pdl,
global_scratch, global_scratch,
profile_scratch, profile_scratch,
args, args,
+1 -4
View File
@@ -5,8 +5,7 @@ from __future__ import annotations
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Mapping, Sequence from collections.abc import Mapping
from tvm_ffi import Object
from typing import Any from typing import Any
# isort: on # isort: on
# fmt: on # fmt: on
@@ -20,7 +19,6 @@ if TYPE_CHECKING:
def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ... def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ...
def fill_tma_descriptor(*args: Any) -> Any: ... def fill_tma_descriptor(*args: Any) -> Any: ...
def get_device_properties(_0: int, /) -> Mapping[str, int]: ... def get_device_properties(_0: int, /) -> Mapping[str, int]: ...
def launch(_0: int, _1: int, _2: int, _3: int, _4: int, _5: tuple[int, int, int], _6: Object, _7: Object, _8: Object, _9: bool, _10: bool, _11: Object, _12: Object, _13: Sequence[Any], /) -> None: ...
def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ...
def set_printf_fifo_size(*args: Any) -> Any: ... def set_printf_fifo_size(*args: Any) -> Any: ...
# fmt: on # fmt: on
@@ -32,7 +30,6 @@ __all__ = [
"cuOccupancyMaxActiveClusters", "cuOccupancyMaxActiveClusters",
"fill_tma_descriptor", "fill_tma_descriptor",
"get_device_properties", "get_device_properties",
"launch",
"load_binary", "load_binary",
"set_printf_fifo_size", "set_printf_fifo_size",
# tvm-ffi-stubgen(end) # tvm-ffi-stubgen(end)
+1 -1
View File
@@ -2,9 +2,9 @@ add_library(
${TARGET_NAME} ${TARGET_NAME}
SHARED SHARED
${CMAKE_CURRENT_SOURCE_DIR}/exception.cc ${CMAKE_CURRENT_SOURCE_DIR}/exception.cc
${CMAKE_CURRENT_SOURCE_DIR}/launch.cc
${CMAKE_CURRENT_SOURCE_DIR}/type.cc ${CMAKE_CURRENT_SOURCE_DIR}/type.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/value.cc
) )
target_include_directories( target_include_directories(
+141
View File
@@ -0,0 +1,141 @@
#include "launch.h"
#include "macro.h"
#include <cuda.h>
#include <tvm/ffi/base_details.h>
namespace triton_tvm_ffi {
TVMFFILauncherImplObj::TVMFFILauncherImplObj(
const tvm::ffi::Array<Type> &signature, bool launchCooperativeGrid,
bool launchAsync)
: signature_(std::move(signature)),
launchCooperativeGrid_(launchCooperativeGrid), launchAsync_(launchAsync) {
}
void TVMFFILauncherImplObj::Launch(
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
CUstream cStream = reinterpret_cast<CUstream>(stream);
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
// TODO: Implement the launch logic
CUdeviceptr globalScratch = 0;
// TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0;
if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4];
CUlaunchConfig config;
config.gridDimX = gridX * numCtas;
config.gridDimY = gridY;
config.gridDimZ = gridZ;
static constexpr int32_t kThreadsPerWarp = 32;
config.blockDimX = kThreadsPerWarp * numWarps;
config.blockDimY = 1;
config.blockDimZ = 1;
config.sharedMemBytes = sharedMemory;
config.hStream = cStream;
config.attrs = launchAttr;
int32_t numAttrs = 0;
// TODO: check `launchPdl`
// TODO: check `launchCooperativeGrid`
if (numCtas != 1) {
CUlaunchAttribute clusterAttr;
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
clusterAttr.value.clusterDim.x = numCtas;
clusterAttr.value.clusterDim.y = 1;
clusterAttr.value.clusterDim.z = 1;
launchAttr[numAttrs++] = clusterAttr;
CUlaunchAttribute clusterSchedulingAttr;
clusterSchedulingAttr.id =
CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference =
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
launchAttr[numAttrs++] = clusterSchedulingAttr;
}
config.numAttrs = numAttrs;
if (numCtas == 16) {
CUDA_CHECK(cuFuncSetAttribute(
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
}
const int32_t kernelArgNum = kernelArgs.size();
void **params =
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
size_t j = 0;
#ifndef NDEBUG
if (kernelArgNum != signature_.size()) {
throw UnmatchedArgumentException("kernelArgs", kernelArgNum,
signature_.size());
}
#endif
for (size_t i = 0; i < kernelArgNum; ++i) {
tvm::ffi::Any value = kernelArgs[i];
switch (signature_[i]) {
#define CASE_STMT(type, str, ctype) \
case Type::type: { \
using cpptype = type_to_ctype_t<Type::type>; \
params[j] = reinterpret_cast<void *>(alloca(sizeof(cpptype))); \
*reinterpret_cast<cpptype *>(params[j]) = value.cast<cpptype>(); \
++j; \
break; \
}
TYPE_TABLE_NATIVE(CASE_STMT)
#undef CASE_STMT
case Type::PTR: {
params[j] = reinterpret_cast<void *>(alloca(sizeof(void *)));
*reinterpret_cast<void **>(params[j]) =
value.cast<tvm::ffi::TensorView>().data_ptr();
++j;
break;
}
case Type::CONSTEXPR: {
break;
}
default: {
#ifdef NDEBUG
__builtin_unreachable();
#else
throw NotImplementedException("CONSTEXPR for value casting");
#endif
}
}
}
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
params[j] = &globalScratch;
params[j + 1] = &profileScratch;
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
}
// TODO: call `launchExitHook`
}
TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
bool launchCooperativeGrid,
bool launchAsync)
: tvm::ffi::ObjectRef(tvm::ffi::make_object<TVMFFILauncherImplObj>(
std::move(signature), launchCooperativeGrid, launchAsync)) {}
void TVMFFILauncherImpl::Launch(
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
get()->Launch(gridX, gridY, gridZ, stream, function, kernelMetadata,
launchMetadata, launchEnterHook, launchExitHook,
globalScratchObject, profileScratchObject, kernelArgs);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TVMFFILauncherImplObj>()
.def(refl::init<const tvm::ffi::Array<Type> &, bool, bool>())
.def("launch", &TVMFFILauncherImplObj::Launch);
}
} // namespace triton_tvm_ffi
+1 -106
View File
@@ -1,6 +1,5 @@
#include "exception.h" #include "macro.h"
#include "type.h" #include "type.h"
#include "value.h"
#include <cuda.h> #include <cuda.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h> #include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
@@ -9,13 +8,6 @@
#include <cassert> #include <cassert>
#endif #endif
#define CUDA_CHECK(code) \
do { \
if (__builtin_expect((code) != CUDA_SUCCESS, 0)) { \
throw triton_tvm_ffi::CUDAException(code); \
} \
} while (false)
namespace triton_tvm_ffi { namespace triton_tvm_ffi {
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) { tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
@@ -52,102 +44,6 @@ tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
{"mem_bus_width", memBusWidth}}; {"mem_bus_width", memBusWidth}};
} }
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
tvm::ffi::ObjectRef launchMetadata,
tvm::ffi::ObjectRef launchEnterHook,
tvm::ffi::ObjectRef launchExitHook, bool launchCooperativeGrid,
bool launchPdl, tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::ObjectRef profileScratchObject,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) {
CUstream cStream = reinterpret_cast<CUstream>(stream);
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
// TODO: Implement the launch logic
CUdeviceptr globalScratch = 0;
// TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0;
if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4];
CUlaunchConfig config;
config.gridDimX = gridX * numCtas;
config.gridDimY = gridY;
config.gridDimZ = gridZ;
static constexpr int32_t kThreadsPerWarp = 32;
config.blockDimX = kThreadsPerWarp * numWarps;
config.blockDimY = 1;
config.blockDimZ = 1;
config.sharedMemBytes = sharedMemory;
config.hStream = cStream;
config.attrs = launchAttr;
int32_t numAttrs = 0;
// TODO: check `launchPdl`
// TODO: check `launchCooperativeGrid`
if (numCtas != 1) {
CUlaunchAttribute clusterAttr;
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
clusterAttr.value.clusterDim.x = numCtas;
clusterAttr.value.clusterDim.y = 1;
clusterAttr.value.clusterDim.z = 1;
launchAttr[numAttrs++] = clusterAttr;
CUlaunchAttribute clusterSchedulingAttr;
clusterSchedulingAttr.id =
CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference =
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
launchAttr[numAttrs++] = clusterSchedulingAttr;
}
config.numAttrs = numAttrs;
if (numCtas == 16) {
CUDA_CHECK(cuFuncSetAttribute(
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
}
const int32_t kernelArgNum = kernelArgs.size();
void **params =
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
size_t j = 0;
for (size_t i = 0; i < kernelArgNum; ++i) {
TypedValue value = kernelArgs[i].cast<TypedValue>();
switch (value.GetType()) {
#define CASE_STMT(type, str, ctype) \
case Type::type: { \
using cpptype = type_to_ctype_t<Type::type>; \
params[j] = reinterpret_cast<void *>(alloca(sizeof(cpptype))); \
*reinterpret_cast<cpptype *>(params[j]) = \
value.GetValue().cast<cpptype>(); \
++j; \
break; \
}
TYPE_TABLE_NATIVE(CASE_STMT)
#undef CASE_STMT
case Type::PTR: {
params[j] = reinterpret_cast<void *>(alloca(sizeof(void *)));
*reinterpret_cast<void **>(params[j]) =
value.GetValue().cast<tvm::ffi::TensorView>().data_ptr();
++j;
break;
}
case Type::CONSTEXPR: {
break;
}
default: {
#ifdef NDEBUG
__builtin_unreachable();
#else
throw NotImplementedException("CONSTEXPR for value casting");
#endif
}
}
}
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
params[j] = &globalScratch;
params[j + 1] = &profileScratch;
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
}
// TODO: call `launchExitHook`
}
tvm::ffi::Tuple<uint64_t, uint64_t, int32_t, int32_t, int32_t> tvm::ffi::Tuple<uint64_t, uint64_t, int32_t, int32_t, int32_t>
LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data, LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data,
int32_t shared, CUdevice device) { int32_t shared, CUdevice device) {
@@ -212,7 +108,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
throw NotImplementedException("set_printf_fifo_size"); throw NotImplementedException("set_printf_fifo_size");
}) })
.def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties) .def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties)
.def("triton_tvm_ffi.utils.launch", Launch)
.def("triton_tvm_ffi.utils.load_binary", LoadBinary); .def("triton_tvm_ffi.utils.load_binary", LoadBinary);
} }
-57
View File
@@ -1,57 +0,0 @@
#include "value.h"
#include "exception.h"
#include "type.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/tvm_ffi.h>
namespace triton_tvm_ffi {
TypedValueObj::TypedValueObj(Type type, const tvm::ffi::Any &value)
: type_(type), value_(value) {}
TypedValueObj::TypedValueObj(Type type, tvm::ffi::Any &&value)
: type_(type), value_(std::move(value)) {}
TypedValue::TypedValue(Type type, const tvm::ffi::Any &value)
: tvm::ffi::ObjectRef(tvm::ffi::make_object<TypedValueObj>(type, value)) {}
TypedValue::TypedValue(Type type, tvm::ffi::Any &&value)
: tvm::ffi::ObjectRef(
tvm::ffi::make_object<TypedValueObj>(type, std::move(value))) {}
tvm::ffi::Optional<TypedValue> MakeTypedValue(const tvm::ffi::String &type,
const tvm::ffi::Any &value) {
tvm::ffi::Optional<Type> typeOpt = StringToType(type);
if (!typeOpt.has_value()) {
throw UnknownTypeException(type.data());
}
return TypedValue(*typeOpt, value);
}
tvm::ffi::Array<TypedValue>
MakeTypedValues(const tvm::ffi::Array<tvm::ffi::String> &types,
const tvm::ffi::Array<tvm::ffi::Any> &values) {
const size_t n = types.size();
if (const size_t m = values.size(); m != n) {
throw UnmatchedArgumentException("values", m, n);
}
tvm::ffi::Array<TypedValue> rets;
for (size_t i = 0; i < n; ++i) {
tvm::ffi::Optional<TypedValue> val = MakeTypedValue(types[i], values[i]);
if (!val.has_value()) {
throw UnknownTypeException(types[i].data());
}
rets.emplace_back(std::move(*val));
}
return rets;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TypedValueObj>()
.def(refl::init<Type, const tvm::ffi::Any &>())
.def_static("make_typed_value", MakeTypedValue)
.def_static("make_typed_values", MakeTypedValues);
}
} // namespace triton_tvm_ffi