diff --git a/include/exception.h b/include/exception.h index cf118af..685f861 100644 --- a/include/exception.h +++ b/include/exception.h @@ -25,6 +25,15 @@ private: const std::string message_; }; +class UnmatchedArgumentException : public std::exception { +public: + UnmatchedArgumentException(std::string_view name, size_t len, size_t expect); + const char *what() const noexcept override; + +private: + const std::string message_; +}; + class UnknownTypeException : public std::exception { public: UnknownTypeException(Type type); diff --git a/include/type.h b/include/type.h index 176712b..3736104 100644 --- a/include/type.h +++ b/include/type.h @@ -2,8 +2,7 @@ #define TRITON_TVM_FFI_TYPE_H_ #include -#include -#include +#include #include namespace triton_tvm_ffi { @@ -35,7 +34,7 @@ enum class Type : int64_t { }; const char *TypeToString(Type type); -tvm::ffi::Optional StringToType(tvm::ffi::String str); +tvm::ffi::Optional StringToType(const tvm::ffi::String &name); template struct type_to_ctype; #define DEFINE_TYPE_TO_CTYPE(type, str, ctype) \ diff --git a/include/value.h b/include/value.h index 15c338d..9f894d8 100644 --- a/include/value.h +++ b/include/value.h @@ -40,6 +40,13 @@ public: TypedValueObj); }; +tvm::ffi::Optional MakeTypedValue(const tvm::ffi::String &type, + const tvm::ffi::Any &value); + +tvm::ffi::Array +MakeTypedValues(const tvm::ffi::Array &types, + const tvm::ffi::Array &values); + } // namespace triton_tvm_ffi #endif diff --git a/python/triton_tvm_ffi/_ffi_api.py b/python/triton_tvm_ffi/_ffi_api.py index 9e1b326..b7c6fbe 100644 --- a/python/triton_tvm_ffi/_ffi_api.py +++ b/python/triton_tvm_ffi/_ffi_api.py @@ -6,6 +6,7 @@ from tvm_ffi import Object as _ffi_Object, init_ffi_api as _FFI_INIT_FUNC, regis from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import Sequence from tvm_ffi import Object from typing import Any # isort: on @@ -33,6 +34,10 @@ class TypedValue(_ffi_Object): if TYPE_CHECKING: @staticmethod def __c_ffi_init__(_0: int, _1: Any, /) -> Object: ... + @staticmethod + def make_typed_value(_0: str, _1: Any, /) -> TypedValue | None: ... + @staticmethod + def make_typed_values(_0: Sequence[str], _1: Sequence[Any], /) -> Sequence[TypedValue]: ... # fmt: on # tvm-ffi-stubgen(end) diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index b6dcea0..45f2e80 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Type +from typing import Any, List, Optional, Sequence, Type from triton.backends.nvidia.driver import CudaDriver from triton.runtime import _allocation from . import TypedValue, utils, string_to_type @@ -10,7 +10,7 @@ class TVMLauncher(object): def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher: super().__init__(*args, **kwargs) - self.signature: List[str] = src.signature.values() + self.signature: List[str] = [*src.signature.values()] self.num_ctas: int = getattr(metadata, "num_ctas", 1) self.launch = utils.launch self.global_scratch_size: int = metadata.global_scratch_size @@ -51,14 +51,8 @@ class TVMLauncher(object): ) assert not self.launch_cooperative_grid assert not self.launch_pdl - assert len(self.signature) == len(args) - def canonicalize(arg: Any, sig: str) -> TypedValue: - ty: Optional[int] = string_to_type(sig) - assert ty is not None, sig - return TypedValue(ty, arg) - - args = [canonicalize(arg, sig) for arg, sig in zip(args, self.signature)] + args: Sequence[TypedValue] = TypedValue.make_typed_values(self.signature, args) return self.launch( gridX, @@ -74,7 +68,7 @@ class TVMLauncher(object): self.launch_pdl, global_scratch, profile_scratch, - *args, + args, ) diff --git a/python/triton_tvm_ffi/utils.py b/python/triton_tvm_ffi/utils.py index f0b22ac..e0c19a9 100644 --- a/python/triton_tvm_ffi/utils.py +++ b/python/triton_tvm_ffi/utils.py @@ -5,7 +5,8 @@ from __future__ import annotations from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Mapping, Sequence + from tvm_ffi import Object from typing import Any # isort: on # fmt: on @@ -19,7 +20,7 @@ if TYPE_CHECKING: def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ... def fill_tma_descriptor(*args: Any) -> Any: ... def get_device_properties(_0: int, /) -> Mapping[str, int]: ... - def launch(*args: Any) -> Any: ... + def launch(_0: int, _1: int, _2: int, _3: int, _4: int, _5: tuple[int, int, int], _6: Object, _7: Object, _8: Object, _9: bool, _10: bool, _11: Object, _12: Object, _13: Sequence[Any], /) -> None: ... def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... def set_printf_fifo_size(*args: Any) -> Any: ... # fmt: on diff --git a/src/exception.cc b/src/exception.cc index 1bb5757..fc40e93 100644 --- a/src/exception.cc +++ b/src/exception.cc @@ -17,6 +17,17 @@ const char *NotImplementedException::what() const noexcept { return message_.c_str(); } +UnmatchedArgumentException::UnmatchedArgumentException(std::string_view name, + size_t len, + size_t expect) + : message_("[UnmatchedArgumentException]: argument \"" + std::string(name) + + "\" has length " + std::to_string(len) + ", but expected " + + std::to_string(expect)) {} + +const char *UnmatchedArgumentException::what() const noexcept { + return message_.c_str(); +} + UnknownTypeException::UnknownTypeException(Type type) : message_("[UnknownTypeException]: unknown type: \"" + std::string(TypeToString(type)) + "\"") {} diff --git a/src/type.cc b/src/type.cc index 9836b5d..c5762ce 100644 --- a/src/type.cc +++ b/src/type.cc @@ -18,7 +18,7 @@ const char *TypeToString(Type type) { } } -tvm::ffi::Optional StringToType(tvm::ffi::String name) { +tvm::ffi::Optional StringToType(const tvm::ffi::String &name) { if (name.starts_with("*")) { return Type::PTR; } diff --git a/src/utils.cc b/src/utils.cc index b04bc9c..42aea17 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -107,30 +107,19 @@ tvm::ffi::Map GetDeviceProperties(int device_id) { {"mem_bus_width", memBusWidth}}; } -void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { - CUtensorMap x; - int32_t gridX = args[0].cast(); - int32_t gridY = args[1].cast(); - int32_t gridZ = args[2].cast(); - CUstream stream = reinterpret_cast(args[3].cast()); - CUfunction function = reinterpret_cast(args[4].cast()); - tvm::ffi::Tuple kernelMetadata = - args[5].cast>(); - int32_t numWarps = kernelMetadata.get<0>(); - int32_t numCtas = kernelMetadata.get<1>(); - int32_t sharedMemory = kernelMetadata.get<2>(); - tvm::ffi::ObjectRef launchMetadata = args[6].cast(); - tvm::ffi::ObjectRef launchEnterHook = args[7].cast(); - tvm::ffi::ObjectRef launchExitHook = args[8].cast(); - bool launchCooperativeGrid = args[9].cast(); - bool launchPdl = args[10].cast(); - tvm::ffi::ObjectRef globalScratchObject = - args[11].cast(); - tvm::ffi::ObjectRef profileScratchObject = - args[12].cast(); - tvm::ffi::PackedArgs kernelArgs = args.Slice(13); - // TODO: call `launchEnterHook` - // TODO: check `globalScratchObject` +void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, + uint64_t function, + tvm::ffi::Tuple kernelMetadata, + tvm::ffi::ObjectRef launchMetadata, + tvm::ffi::ObjectRef launchEnterHook, + tvm::ffi::ObjectRef launchExitHook, bool launchCooperativeGrid, + bool launchPdl, tvm::ffi::ObjectRef globalScratchObject, + tvm::ffi::ObjectRef profileScratchObject, + const tvm::ffi::Array &kernelArgs) { + CUstream cStream = reinterpret_cast(stream); + CUfunction cFunction = reinterpret_cast(function); + auto [numWarps, numCtas, sharedMemory] = kernelMetadata; + // TODO: Implement the launch logic CUdeviceptr globalScratch = 0; // TODO: check `profileScratchObject` CUdeviceptr profileScratch = 0; @@ -145,10 +134,10 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { config.blockDimY = 1; config.blockDimZ = 1; config.sharedMemBytes = sharedMemory; - config.hStream = stream; + config.hStream = cStream; config.attrs = launchAttr; int32_t numAttrs = 0; - // TODO: check `launchPdf` + // TODO: check `launchPdl` // TODO: check `launchCooperativeGrid` if (numCtas != 1) { CUlaunchAttribute clusterAttr; @@ -167,7 +156,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { config.numAttrs = numAttrs; if (numCtas == 16) { CUDA_CHECK(cuFuncSetAttribute( - function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); + cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); } const int32_t kernelArgNum = kernelArgs.size(); uint8_t *buffer = @@ -192,7 +181,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { // TODO: unwrap PyObject* from scratch pointers and assign to kernel args params[j] = &globalScratch; params[j + 1] = &profileScratch; - CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr)); + CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr)); } // TODO: call `launchExitHook` } @@ -261,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { throw NotImplementedException("set_printf_fifo_size"); }) .def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties) - .def_packed("triton_tvm_ffi.utils.launch", Launch) + .def("triton_tvm_ffi.utils.launch", Launch) .def("triton_tvm_ffi.utils.load_binary", LoadBinary); } diff --git a/src/value.cc b/src/value.cc index 92557ca..c400707 100644 --- a/src/value.cc +++ b/src/value.cc @@ -1,4 +1,5 @@ #include "value.h" +#include "exception.h" #include "type.h" #include #include @@ -18,10 +19,39 @@ TypedValue::TypedValue(Type type, tvm::ffi::Any &&value) : tvm::ffi::ObjectRef( tvm::ffi::make_object(type, std::move(value))) {} +tvm::ffi::Optional MakeTypedValue(const tvm::ffi::String &type, + const tvm::ffi::Any &value) { + tvm::ffi::Optional typeOpt = StringToType(type); + if (!typeOpt.has_value()) { + throw UnknownTypeException(type.data()); + } + return TypedValue(*typeOpt, value); +} + +tvm::ffi::Array +MakeTypedValues(const tvm::ffi::Array &types, + const tvm::ffi::Array &values) { + const size_t n = types.size(); + if (const size_t m = values.size(); m != n) { + throw UnmatchedArgumentException("values", m, n); + } + tvm::ffi::Array rets; + for (size_t i = 0; i < n; ++i) { + tvm::ffi::Optional val = MakeTypedValue(types[i], values[i]); + if (!val.has_value()) { + throw UnknownTypeException(types[i].data()); + } + rets.emplace_back(std::move(*val)); + } + return rets; +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def( - refl::init()); + refl::ObjectDef() + .def(refl::init()) + .def_static("make_typed_value", MakeTypedValue) + .def_static("make_typed_values", MakeTypedValues); } } // namespace triton_tvm_ffi