diff --git a/include/launch.h b/include/launch.h new file mode 100644 index 0000000..04b6e77 --- /dev/null +++ b/include/launch.h @@ -0,0 +1,55 @@ +#ifndef TRITON_TVM_FFI_LAUNCH_H_ +#define TRITON_TVM_FFI_LAUNCH_H_ + +#include "type.h" +#include + +namespace triton_tvm_ffi { + +class TVMFFILauncherImplObj : public tvm::ffi::Object { +public: + TVMFFILauncherImplObj(const tvm::ffi::Array &signature, + bool launchCooperativeGrid, bool launchAsync); + TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default; + TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default; + void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, + uint64_t function, + tvm::ffi::Tuple kernelMetadata, + tvm::ffi::ObjectRef launchMetadata, + tvm::ffi::ObjectRef launchEnterHook, + tvm::ffi::ObjectRef launchExitHook, + tvm::ffi::ObjectRef globalScratchObject, + tvm::ffi::ObjectRef profileScratchObject, + const tvm::ffi::Array &kernelArgs) const; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl", + TVMFFILauncherImplObj, tvm::ffi::Object); + +private: + tvm::ffi::Array signature_; + const bool launchCooperativeGrid_; + const bool launchAsync_; +}; + +class TVMFFILauncherImpl : public tvm::ffi::ObjectRef { +public: + TVMFFILauncherImpl(tvm::ffi::Array signature, + bool launchCooperativeGrid, bool launchAsync); + using tvm::ffi::ObjectRef::ObjectRef; + using tvm::ffi::ObjectRef::operator=; + void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, + uint64_t function, + tvm::ffi::Tuple kernelMetadata, + tvm::ffi::ObjectRef launchMetadata, + tvm::ffi::ObjectRef launchEnterHook, + tvm::ffi::ObjectRef launchExitHook, + tvm::ffi::ObjectRef globalScratchObject, + tvm::ffi::ObjectRef profileScratchObject, + const tvm::ffi::Array &kernelArgs) const; + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl, + tvm::ffi::ObjectRef, + TVMFFILauncherImplObj); +}; + +} // namespace triton_tvm_ffi + +#endif diff --git a/include/macro.h b/include/macro.h index e1417be..3129758 100644 --- a/include/macro.h +++ b/include/macro.h @@ -1,8 +1,19 @@ #ifndef TRITON_TVM_FFI_MACRO_H_ #define TRITON_TVM_FFI_MACRO_H_ +#include "exception.h" + #if defined(__GNUC__) || defined(__clang__) #define TRITON_TVM_FFI_INLINE __attribute__((always_inline)) inline #endif +#define UNLIKELY(cond) __builtin_expect((cond), 0) + +#define CUDA_CHECK(code) \ + do { \ + if (UNLIKELY((code) != CUDA_SUCCESS)) { \ + throw triton_tvm_ffi::CUDAException(code); \ + } \ + } while (false) + #endif diff --git a/include/value.h b/include/value.h deleted file mode 100644 index 9f894d8..0000000 --- a/include/value.h +++ /dev/null @@ -1,52 +0,0 @@ -#ifndef TRITON_TVM_FFI_VALUE_H_ -#define TRITON_TVM_FFI_VALUE_H_ - -#include "macro.h" -#include "type.h" -#include -#include - -namespace triton_tvm_ffi { - -class TypedValueObj : public tvm::ffi::Object { -public: - TypedValueObj(Type type, const tvm::ffi::Any &value); - TypedValueObj(Type type, tvm::ffi::Any &&value); - TypedValueObj(const TypedValueObj &other) = default; - TypedValueObj(TypedValueObj &&other) = default; - TypedValueObj &operator=(const TypedValueObj &other) = default; - TypedValueObj &operator=(TypedValueObj &&other) = default; - TRITON_TVM_FFI_INLINE Type GetType() const { return type_; } - TRITON_TVM_FFI_INLINE const tvm::ffi::Any &GetValue() const { return value_; } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TypedValue", TypedValueObj, - tvm::ffi::Object); - -private: - Type type_; - tvm::ffi::Any value_; -}; - -class TypedValue : public tvm::ffi::ObjectRef { -public: - TypedValue(Type type, const tvm::ffi::Any &value); - TypedValue(Type type, tvm::ffi::Any &&value); - using tvm::ffi::ObjectRef::ObjectRef; - using tvm::ffi::ObjectRef::operator=; - TRITON_TVM_FFI_INLINE Type GetType() const { return get()->GetType(); } - TRITON_TVM_FFI_INLINE const tvm::ffi::Any &GetValue() const { - return get()->GetValue(); - } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TypedValue, tvm::ffi::ObjectRef, - TypedValueObj); -}; - -tvm::ffi::Optional MakeTypedValue(const tvm::ffi::String &type, - const tvm::ffi::Any &value); - -tvm::ffi::Array -MakeTypedValues(const tvm::ffi::Array &types, - const tvm::ffi::Array &values); - -} // namespace triton_tvm_ffi - -#endif diff --git a/python/triton_tvm_ffi/_ffi_api.py b/python/triton_tvm_ffi/_ffi_api.py index b7c6fbe..344a060 100644 --- a/python/triton_tvm_ffi/_ffi_api.py +++ b/python/triton_tvm_ffi/_ffi_api.py @@ -27,17 +27,16 @@ if TYPE_CHECKING: # tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ # tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object -@_FFI_REG_OBJ("triton_tvm_ffi.TypedValue") -class TypedValue(_ffi_Object): - # tvm-ffi-stubgen(begin): object/triton_tvm_ffi.TypedValue +@_FFI_REG_OBJ("triton_tvm_ffi.TVMFFILauncherImpl") +class TVMFFILauncherImpl(_ffi_Object): + """FFI binding for `triton_tvm_ffi.TVMFFILauncherImpl`.""" + + # tvm-ffi-stubgen(begin): object/triton_tvm_ffi.TVMFFILauncherImpl # fmt: off if TYPE_CHECKING: @staticmethod - def __c_ffi_init__(_0: int, _1: Any, /) -> Object: ... - @staticmethod - def make_typed_value(_0: str, _1: Any, /) -> TypedValue | None: ... - @staticmethod - def make_typed_values(_0: Sequence[str], _1: Sequence[Any], /) -> Sequence[TypedValue]: ... + def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ... + def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: tuple[int, int, int], _7: Object, _8: Object, _9: Object, _10: Object, _11: Object, _12: Sequence[Any], /) -> None: ... # fmt: on # tvm-ffi-stubgen(end) @@ -45,7 +44,7 @@ class TypedValue(_ffi_Object): __all__ = [ # tvm-ffi-stubgen(begin): __all__ "LIB", - "TypedValue", + "TVMFFILauncherImpl", "string_to_type", "type_to_string", # tvm-ffi-stubgen(end) diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index 45f2e80..1619624 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Any, List, Optional, Sequence, Type +from typing import Any, Callable, Final, List, Sequence, Type, Union from triton.backends.nvidia.driver import CudaDriver from triton.runtime import _allocation -from . import TypedValue, utils, string_to_type +from . import TVMFFILauncherImpl, utils, string_to_type class TVMLauncher(object): @@ -11,14 +11,34 @@ class TVMLauncher(object): super().__init__(*args, **kwargs) self.signature: List[str] = [*src.signature.values()] - self.num_ctas: int = getattr(metadata, "num_ctas", 1) - self.launch = utils.launch - self.global_scratch_size: int = metadata.global_scratch_size - self.global_scratch_align: int = metadata.global_scratch_align - self.profile_scratch_size: int = metadata.profile_scratch_size - self.profile_scratch_align: int = metadata.profile_scratch_align - self.launch_cooperative_grid: bool = metadata.launch_cooperative_grid - self.launch_pdl: bool = metadata.launch_pdl + self.num_ctas: Final[int] = getattr(metadata, "num_ctas", 1) + self.global_scratch_size: Final[int] = metadata.global_scratch_size + self.global_scratch_align: Final[int] = metadata.global_scratch_align + self.profile_scratch_size: Final[int] = metadata.profile_scratch_size + self.profile_scratch_align: Final[int] = metadata.profile_scratch_align + self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid + self.launch_pdl: Final[bool] = metadata.launch_pdl + self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl( + [string_to_type(t) for t in self.signature], + self.launch_cooperative_grid, + self.launch_pdl, + ) + self.launch: Callable[ + [ + int, + int, + int, + int, + int, + tuple[int, int, int], + object, + object, + object, + object, + object, + Sequence[Union[Any]], + ] + ] = self.impl.launch def __call__( self, @@ -52,9 +72,9 @@ class TVMLauncher(object): assert not self.launch_cooperative_grid assert not self.launch_pdl - args: Sequence[TypedValue] = TypedValue.make_typed_values(self.signature, args) + # args: Sequence[TypedValue] = TypedValue.make_typed_values(self.signature, args) - return self.launch( + return self.impl.launch( gridX, gridY, gridZ, @@ -64,8 +84,6 @@ class TVMLauncher(object): launch_metadata, launch_enter_hook, launch_exit_hook, - self.launch_cooperative_grid, - self.launch_pdl, global_scratch, profile_scratch, args, diff --git a/python/triton_tvm_ffi/utils.py b/python/triton_tvm_ffi/utils.py index e0c19a9..5b23376 100644 --- a/python/triton_tvm_ffi/utils.py +++ b/python/triton_tvm_ffi/utils.py @@ -5,8 +5,7 @@ from __future__ import annotations from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Mapping, Sequence - from tvm_ffi import Object + from collections.abc import Mapping from typing import Any # isort: on # fmt: on @@ -20,7 +19,6 @@ if TYPE_CHECKING: def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ... def fill_tma_descriptor(*args: Any) -> Any: ... def get_device_properties(_0: int, /) -> Mapping[str, int]: ... - def launch(_0: int, _1: int, _2: int, _3: int, _4: int, _5: tuple[int, int, int], _6: Object, _7: Object, _8: Object, _9: bool, _10: bool, _11: Object, _12: Object, _13: Sequence[Any], /) -> None: ... def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... def set_printf_fifo_size(*args: Any) -> Any: ... # fmt: on @@ -32,7 +30,6 @@ __all__ = [ "cuOccupancyMaxActiveClusters", "fill_tma_descriptor", "get_device_properties", - "launch", "load_binary", "set_printf_fifo_size", # tvm-ffi-stubgen(end) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e0f2423..39d2384 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,9 +2,9 @@ add_library( ${TARGET_NAME} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/exception.cc + ${CMAKE_CURRENT_SOURCE_DIR}/launch.cc ${CMAKE_CURRENT_SOURCE_DIR}/type.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc - ${CMAKE_CURRENT_SOURCE_DIR}/value.cc ) target_include_directories( diff --git a/src/launch.cc b/src/launch.cc new file mode 100644 index 0000000..b54015b --- /dev/null +++ b/src/launch.cc @@ -0,0 +1,141 @@ +#include "launch.h" +#include "macro.h" +#include +#include + +namespace triton_tvm_ffi { + +TVMFFILauncherImplObj::TVMFFILauncherImplObj( + const tvm::ffi::Array &signature, bool launchCooperativeGrid, + bool launchAsync) + : signature_(std::move(signature)), + launchCooperativeGrid_(launchCooperativeGrid), launchAsync_(launchAsync) { +} + +void TVMFFILauncherImplObj::Launch( + int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, + uint64_t function, + tvm::ffi::Tuple kernelMetadata, + tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook, + tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject, + tvm::ffi::ObjectRef profileScratchObject, + const tvm::ffi::Array &kernelArgs) const { + CUstream cStream = reinterpret_cast(stream); + CUfunction cFunction = reinterpret_cast(function); + auto [numWarps, numCtas, sharedMemory] = kernelMetadata; + // TODO: Implement the launch logic + CUdeviceptr globalScratch = 0; + // TODO: check `profileScratchObject` + CUdeviceptr profileScratch = 0; + if (gridX * gridY * gridZ > 0) { + CUlaunchAttribute launchAttr[4]; + CUlaunchConfig config; + config.gridDimX = gridX * numCtas; + config.gridDimY = gridY; + config.gridDimZ = gridZ; + static constexpr int32_t kThreadsPerWarp = 32; + config.blockDimX = kThreadsPerWarp * numWarps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = sharedMemory; + config.hStream = cStream; + config.attrs = launchAttr; + int32_t numAttrs = 0; + // TODO: check `launchPdl` + // TODO: check `launchCooperativeGrid` + if (numCtas != 1) { + CUlaunchAttribute clusterAttr; + clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + clusterAttr.value.clusterDim.x = numCtas; + clusterAttr.value.clusterDim.y = 1; + clusterAttr.value.clusterDim.z = 1; + launchAttr[numAttrs++] = clusterAttr; + CUlaunchAttribute clusterSchedulingAttr; + clusterSchedulingAttr.id = + CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = + CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + launchAttr[numAttrs++] = clusterSchedulingAttr; + } + config.numAttrs = numAttrs; + if (numCtas == 16) { + CUDA_CHECK(cuFuncSetAttribute( + cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); + } + const int32_t kernelArgNum = kernelArgs.size(); + void **params = + reinterpret_cast(alloca(sizeof(void *) * (kernelArgNum + 2))); + size_t j = 0; +#ifndef NDEBUG + if (kernelArgNum != signature_.size()) { + throw UnmatchedArgumentException("kernelArgs", kernelArgNum, + signature_.size()); + } +#endif + for (size_t i = 0; i < kernelArgNum; ++i) { + tvm::ffi::Any value = kernelArgs[i]; + switch (signature_[i]) { +#define CASE_STMT(type, str, ctype) \ + case Type::type: { \ + using cpptype = type_to_ctype_t; \ + params[j] = reinterpret_cast(alloca(sizeof(cpptype))); \ + *reinterpret_cast(params[j]) = value.cast(); \ + ++j; \ + break; \ + } + TYPE_TABLE_NATIVE(CASE_STMT) +#undef CASE_STMT + case Type::PTR: { + params[j] = reinterpret_cast(alloca(sizeof(void *))); + *reinterpret_cast(params[j]) = + value.cast().data_ptr(); + ++j; + break; + } + case Type::CONSTEXPR: { + break; + } + default: { +#ifdef NDEBUG + __builtin_unreachable(); +#else + throw NotImplementedException("CONSTEXPR for value casting"); +#endif + } + } + } + // TODO: unwrap PyObject* from scratch pointers and assign to kernel args + params[j] = &globalScratch; + params[j + 1] = &profileScratch; + CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr)); + } + // TODO: call `launchExitHook` +} + +TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array signature, + bool launchCooperativeGrid, + bool launchAsync) + : tvm::ffi::ObjectRef(tvm::ffi::make_object( + std::move(signature), launchCooperativeGrid, launchAsync)) {} + +void TVMFFILauncherImpl::Launch( + int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, + uint64_t function, + tvm::ffi::Tuple kernelMetadata, + tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook, + tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject, + tvm::ffi::ObjectRef profileScratchObject, + const tvm::ffi::Array &kernelArgs) const { + get()->Launch(gridX, gridY, gridZ, stream, function, kernelMetadata, + launchMetadata, launchEnterHook, launchExitHook, + globalScratchObject, profileScratchObject, kernelArgs); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def(refl::init &, bool, bool>()) + .def("launch", &TVMFFILauncherImplObj::Launch); +} + +} // namespace triton_tvm_ffi diff --git a/src/utils.cc b/src/utils.cc index abbc612..65aab79 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,6 +1,5 @@ -#include "exception.h" +#include "macro.h" #include "type.h" -#include "value.h" #include #include #include @@ -9,13 +8,6 @@ #include #endif -#define CUDA_CHECK(code) \ - do { \ - if (__builtin_expect((code) != CUDA_SUCCESS, 0)) { \ - throw triton_tvm_ffi::CUDAException(code); \ - } \ - } while (false) - namespace triton_tvm_ffi { tvm::ffi::Map GetDeviceProperties(int device_id) { @@ -52,102 +44,6 @@ tvm::ffi::Map GetDeviceProperties(int device_id) { {"mem_bus_width", memBusWidth}}; } -void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, - uint64_t function, - tvm::ffi::Tuple kernelMetadata, - tvm::ffi::ObjectRef launchMetadata, - tvm::ffi::ObjectRef launchEnterHook, - tvm::ffi::ObjectRef launchExitHook, bool launchCooperativeGrid, - bool launchPdl, tvm::ffi::ObjectRef globalScratchObject, - tvm::ffi::ObjectRef profileScratchObject, - const tvm::ffi::Array &kernelArgs) { - CUstream cStream = reinterpret_cast(stream); - CUfunction cFunction = reinterpret_cast(function); - auto [numWarps, numCtas, sharedMemory] = kernelMetadata; - // TODO: Implement the launch logic - CUdeviceptr globalScratch = 0; - // TODO: check `profileScratchObject` - CUdeviceptr profileScratch = 0; - if (gridX * gridY * gridZ > 0) { - CUlaunchAttribute launchAttr[4]; - CUlaunchConfig config; - config.gridDimX = gridX * numCtas; - config.gridDimY = gridY; - config.gridDimZ = gridZ; - static constexpr int32_t kThreadsPerWarp = 32; - config.blockDimX = kThreadsPerWarp * numWarps; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = sharedMemory; - config.hStream = cStream; - config.attrs = launchAttr; - int32_t numAttrs = 0; - // TODO: check `launchPdl` - // TODO: check `launchCooperativeGrid` - if (numCtas != 1) { - CUlaunchAttribute clusterAttr; - clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - clusterAttr.value.clusterDim.x = numCtas; - clusterAttr.value.clusterDim.y = 1; - clusterAttr.value.clusterDim.z = 1; - launchAttr[numAttrs++] = clusterAttr; - CUlaunchAttribute clusterSchedulingAttr; - clusterSchedulingAttr.id = - CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; - clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = - CU_CLUSTER_SCHEDULING_POLICY_SPREAD; - launchAttr[numAttrs++] = clusterSchedulingAttr; - } - config.numAttrs = numAttrs; - if (numCtas == 16) { - CUDA_CHECK(cuFuncSetAttribute( - cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); - } - const int32_t kernelArgNum = kernelArgs.size(); - void **params = - reinterpret_cast(alloca(sizeof(void *) * (kernelArgNum + 2))); - size_t j = 0; - for (size_t i = 0; i < kernelArgNum; ++i) { - TypedValue value = kernelArgs[i].cast(); - switch (value.GetType()) { -#define CASE_STMT(type, str, ctype) \ - case Type::type: { \ - using cpptype = type_to_ctype_t; \ - params[j] = reinterpret_cast(alloca(sizeof(cpptype))); \ - *reinterpret_cast(params[j]) = \ - value.GetValue().cast(); \ - ++j; \ - break; \ - } - TYPE_TABLE_NATIVE(CASE_STMT) -#undef CASE_STMT - case Type::PTR: { - params[j] = reinterpret_cast(alloca(sizeof(void *))); - *reinterpret_cast(params[j]) = - value.GetValue().cast().data_ptr(); - ++j; - break; - } - case Type::CONSTEXPR: { - break; - } - default: { -#ifdef NDEBUG - __builtin_unreachable(); -#else - throw NotImplementedException("CONSTEXPR for value casting"); -#endif - } - } - } - // TODO: unwrap PyObject* from scratch pointers and assign to kernel args - params[j] = &globalScratch; - params[j + 1] = &profileScratch; - CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr)); - } - // TODO: call `launchExitHook` -} - tvm::ffi::Tuple LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data, int32_t shared, CUdevice device) { @@ -212,7 +108,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { throw NotImplementedException("set_printf_fifo_size"); }) .def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties) - .def("triton_tvm_ffi.utils.launch", Launch) .def("triton_tvm_ffi.utils.load_binary", LoadBinary); } diff --git a/src/value.cc b/src/value.cc deleted file mode 100644 index c400707..0000000 --- a/src/value.cc +++ /dev/null @@ -1,57 +0,0 @@ -#include "value.h" -#include "exception.h" -#include "type.h" -#include -#include - -namespace triton_tvm_ffi { - -TypedValueObj::TypedValueObj(Type type, const tvm::ffi::Any &value) - : type_(type), value_(value) {} - -TypedValueObj::TypedValueObj(Type type, tvm::ffi::Any &&value) - : type_(type), value_(std::move(value)) {} - -TypedValue::TypedValue(Type type, const tvm::ffi::Any &value) - : tvm::ffi::ObjectRef(tvm::ffi::make_object(type, value)) {} - -TypedValue::TypedValue(Type type, tvm::ffi::Any &&value) - : tvm::ffi::ObjectRef( - tvm::ffi::make_object(type, std::move(value))) {} - -tvm::ffi::Optional MakeTypedValue(const tvm::ffi::String &type, - const tvm::ffi::Any &value) { - tvm::ffi::Optional typeOpt = StringToType(type); - if (!typeOpt.has_value()) { - throw UnknownTypeException(type.data()); - } - return TypedValue(*typeOpt, value); -} - -tvm::ffi::Array -MakeTypedValues(const tvm::ffi::Array &types, - const tvm::ffi::Array &values) { - const size_t n = types.size(); - if (const size_t m = values.size(); m != n) { - throw UnmatchedArgumentException("values", m, n); - } - tvm::ffi::Array rets; - for (size_t i = 0; i < n; ++i) { - tvm::ffi::Optional val = MakeTypedValue(types[i], values[i]); - if (!val.has_value()) { - throw UnknownTypeException(types[i].data()); - } - rets.emplace_back(std::move(*val)); - } - return rets; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def(refl::init()) - .def_static("make_typed_value", MakeTypedValue) - .def_static("make_typed_values", MakeTypedValues); -} - -} // namespace triton_tvm_ffi