put typedvalues initialization into cpp

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-30 01:38:58 +08:00
parent bdc9c03b75
commit a953cbe7cc
10 changed files with 92 additions and 47 deletions
+9
View File
@@ -25,6 +25,15 @@ private:
const std::string message_; const std::string message_;
}; };
class UnmatchedArgumentException : public std::exception {
public:
UnmatchedArgumentException(std::string_view name, size_t len, size_t expect);
const char *what() const noexcept override;
private:
const std::string message_;
};
class UnknownTypeException : public std::exception { class UnknownTypeException : public std::exception {
public: public:
UnknownTypeException(Type type); UnknownTypeException(Type type);
+2 -3
View File
@@ -2,8 +2,7 @@
#define TRITON_TVM_FFI_TYPE_H_ #define TRITON_TVM_FFI_TYPE_H_
#include <cstdint> #include <cstdint>
#include <tvm/ffi/object.h> #include <tvm/ffi/tvm_ffi.h>
#include <tvm/ffi/string.h>
#include <type_traits> #include <type_traits>
namespace triton_tvm_ffi { namespace triton_tvm_ffi {
@@ -35,7 +34,7 @@ enum class Type : int64_t {
}; };
const char *TypeToString(Type type); const char *TypeToString(Type type);
tvm::ffi::Optional<Type> StringToType(tvm::ffi::String str); tvm::ffi::Optional<Type> StringToType(const tvm::ffi::String &name);
template <Type T> struct type_to_ctype; template <Type T> struct type_to_ctype;
#define DEFINE_TYPE_TO_CTYPE(type, str, ctype) \ #define DEFINE_TYPE_TO_CTYPE(type, str, ctype) \
+7
View File
@@ -40,6 +40,13 @@ public:
TypedValueObj); TypedValueObj);
}; };
tvm::ffi::Optional<TypedValue> MakeTypedValue(const tvm::ffi::String &type,
const tvm::ffi::Any &value);
tvm::ffi::Array<TypedValue>
MakeTypedValues(const tvm::ffi::Array<tvm::ffi::String> &types,
const tvm::ffi::Array<tvm::ffi::Any> &values);
} // namespace triton_tvm_ffi } // namespace triton_tvm_ffi
#endif #endif
+5
View File
@@ -6,6 +6,7 @@ from tvm_ffi import Object as _ffi_Object, init_ffi_api as _FFI_INIT_FUNC, regis
from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
from tvm_ffi import Object from tvm_ffi import Object
from typing import Any from typing import Any
# isort: on # isort: on
@@ -33,6 +34,10 @@ class TypedValue(_ffi_Object):
if TYPE_CHECKING: if TYPE_CHECKING:
@staticmethod @staticmethod
def __c_ffi_init__(_0: int, _1: Any, /) -> Object: ... def __c_ffi_init__(_0: int, _1: Any, /) -> Object: ...
@staticmethod
def make_typed_value(_0: str, _1: Any, /) -> TypedValue | None: ...
@staticmethod
def make_typed_values(_0: Sequence[str], _1: Sequence[Any], /) -> Sequence[TypedValue]: ...
# fmt: on # fmt: on
# tvm-ffi-stubgen(end) # tvm-ffi-stubgen(end)
+4 -10
View File
@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Type from typing import Any, List, Optional, Sequence, Type
from triton.backends.nvidia.driver import CudaDriver from triton.backends.nvidia.driver import CudaDriver
from triton.runtime import _allocation from triton.runtime import _allocation
from . import TypedValue, utils, string_to_type from . import TypedValue, utils, string_to_type
@@ -10,7 +10,7 @@ class TVMLauncher(object):
def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher: def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.signature: List[str] = src.signature.values() self.signature: List[str] = [*src.signature.values()]
self.num_ctas: int = getattr(metadata, "num_ctas", 1) self.num_ctas: int = getattr(metadata, "num_ctas", 1)
self.launch = utils.launch self.launch = utils.launch
self.global_scratch_size: int = metadata.global_scratch_size self.global_scratch_size: int = metadata.global_scratch_size
@@ -51,14 +51,8 @@ class TVMLauncher(object):
) )
assert not self.launch_cooperative_grid assert not self.launch_cooperative_grid
assert not self.launch_pdl assert not self.launch_pdl
assert len(self.signature) == len(args)
def canonicalize(arg: Any, sig: str) -> TypedValue: args: Sequence[TypedValue] = TypedValue.make_typed_values(self.signature, args)
ty: Optional[int] = string_to_type(sig)
assert ty is not None, sig
return TypedValue(ty, arg)
args = [canonicalize(arg, sig) for arg, sig in zip(args, self.signature)]
return self.launch( return self.launch(
gridX, gridX,
@@ -74,7 +68,7 @@ class TVMLauncher(object):
self.launch_pdl, self.launch_pdl,
global_scratch, global_scratch,
profile_scratch, profile_scratch,
*args, args,
) )
+3 -2
View File
@@ -5,7 +5,8 @@ from __future__ import annotations
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from tvm_ffi import Object
from typing import Any from typing import Any
# isort: on # isort: on
# fmt: on # fmt: on
@@ -19,7 +20,7 @@ if TYPE_CHECKING:
def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ... def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ...
def fill_tma_descriptor(*args: Any) -> Any: ... def fill_tma_descriptor(*args: Any) -> Any: ...
def get_device_properties(_0: int, /) -> Mapping[str, int]: ... def get_device_properties(_0: int, /) -> Mapping[str, int]: ...
def launch(*args: Any) -> Any: ... def launch(_0: int, _1: int, _2: int, _3: int, _4: int, _5: tuple[int, int, int], _6: Object, _7: Object, _8: Object, _9: bool, _10: bool, _11: Object, _12: Object, _13: Sequence[Any], /) -> None: ...
def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ... def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ...
def set_printf_fifo_size(*args: Any) -> Any: ... def set_printf_fifo_size(*args: Any) -> Any: ...
# fmt: on # fmt: on
+11
View File
@@ -17,6 +17,17 @@ const char *NotImplementedException::what() const noexcept {
return message_.c_str(); return message_.c_str();
} }
UnmatchedArgumentException::UnmatchedArgumentException(std::string_view name,
size_t len,
size_t expect)
: message_("[UnmatchedArgumentException]: argument \"" + std::string(name) +
"\" has length " + std::to_string(len) + ", but expected " +
std::to_string(expect)) {}
const char *UnmatchedArgumentException::what() const noexcept {
return message_.c_str();
}
UnknownTypeException::UnknownTypeException(Type type) UnknownTypeException::UnknownTypeException(Type type)
: message_("[UnknownTypeException]: unknown type: \"" + : message_("[UnknownTypeException]: unknown type: \"" +
std::string(TypeToString(type)) + "\"") {} std::string(TypeToString(type)) + "\"") {}
+1 -1
View File
@@ -18,7 +18,7 @@ const char *TypeToString(Type type) {
} }
} }
tvm::ffi::Optional<Type> StringToType(tvm::ffi::String name) { tvm::ffi::Optional<Type> StringToType(const tvm::ffi::String &name) {
if (name.starts_with("*")) { if (name.starts_with("*")) {
return Type::PTR; return Type::PTR;
} }
+18 -29
View File
@@ -107,30 +107,19 @@ tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
{"mem_bus_width", memBusWidth}}; {"mem_bus_width", memBusWidth}};
} }
void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
CUtensorMap x; uint64_t function,
int32_t gridX = args[0].cast<int32_t>(); tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
int32_t gridY = args[1].cast<int32_t>(); tvm::ffi::ObjectRef launchMetadata,
int32_t gridZ = args[2].cast<int32_t>(); tvm::ffi::ObjectRef launchEnterHook,
CUstream stream = reinterpret_cast<CUstream>(args[3].cast<uint64_t>()); tvm::ffi::ObjectRef launchExitHook, bool launchCooperativeGrid,
CUfunction function = reinterpret_cast<CUfunction>(args[4].cast<uint64_t>()); bool launchPdl, tvm::ffi::ObjectRef globalScratchObject,
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata = tvm::ffi::ObjectRef profileScratchObject,
args[5].cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(); const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) {
int32_t numWarps = kernelMetadata.get<0>(); CUstream cStream = reinterpret_cast<CUstream>(stream);
int32_t numCtas = kernelMetadata.get<1>(); CUfunction cFunction = reinterpret_cast<CUfunction>(function);
int32_t sharedMemory = kernelMetadata.get<2>(); auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
tvm::ffi::ObjectRef launchMetadata = args[6].cast<tvm::ffi::ObjectRef>(); // TODO: Implement the launch logic
tvm::ffi::ObjectRef launchEnterHook = args[7].cast<tvm::ffi::ObjectRef>();
tvm::ffi::ObjectRef launchExitHook = args[8].cast<tvm::ffi::ObjectRef>();
bool launchCooperativeGrid = args[9].cast<bool>();
bool launchPdl = args[10].cast<bool>();
tvm::ffi::ObjectRef globalScratchObject =
args[11].cast<tvm::ffi::ObjectRef>();
tvm::ffi::ObjectRef profileScratchObject =
args[12].cast<tvm::ffi::ObjectRef>();
tvm::ffi::PackedArgs kernelArgs = args.Slice(13);
// TODO: call `launchEnterHook`
// TODO: check `globalScratchObject`
CUdeviceptr globalScratch = 0; CUdeviceptr globalScratch = 0;
// TODO: check `profileScratchObject` // TODO: check `profileScratchObject`
CUdeviceptr profileScratch = 0; CUdeviceptr profileScratch = 0;
@@ -145,10 +134,10 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
config.blockDimY = 1; config.blockDimY = 1;
config.blockDimZ = 1; config.blockDimZ = 1;
config.sharedMemBytes = sharedMemory; config.sharedMemBytes = sharedMemory;
config.hStream = stream; config.hStream = cStream;
config.attrs = launchAttr; config.attrs = launchAttr;
int32_t numAttrs = 0; int32_t numAttrs = 0;
// TODO: check `launchPdf` // TODO: check `launchPdl`
// TODO: check `launchCooperativeGrid` // TODO: check `launchCooperativeGrid`
if (numCtas != 1) { if (numCtas != 1) {
CUlaunchAttribute clusterAttr; CUlaunchAttribute clusterAttr;
@@ -167,7 +156,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
config.numAttrs = numAttrs; config.numAttrs = numAttrs;
if (numCtas == 16) { if (numCtas == 16) {
CUDA_CHECK(cuFuncSetAttribute( CUDA_CHECK(cuFuncSetAttribute(
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
} }
const int32_t kernelArgNum = kernelArgs.size(); const int32_t kernelArgNum = kernelArgs.size();
uint8_t *buffer = uint8_t *buffer =
@@ -192,7 +181,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args // TODO: unwrap PyObject* from scratch pointers and assign to kernel args
params[j] = &globalScratch; params[j] = &globalScratch;
params[j + 1] = &profileScratch; params[j + 1] = &profileScratch;
CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr)); CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
} }
// TODO: call `launchExitHook` // TODO: call `launchExitHook`
} }
@@ -261,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
throw NotImplementedException("set_printf_fifo_size"); throw NotImplementedException("set_printf_fifo_size");
}) })
.def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties) .def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties)
.def_packed("triton_tvm_ffi.utils.launch", Launch) .def("triton_tvm_ffi.utils.launch", Launch)
.def("triton_tvm_ffi.utils.load_binary", LoadBinary); .def("triton_tvm_ffi.utils.load_binary", LoadBinary);
} }
+32 -2
View File
@@ -1,4 +1,5 @@
#include "value.h" #include "value.h"
#include "exception.h"
#include "type.h" #include "type.h"
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/tvm_ffi.h> #include <tvm/ffi/tvm_ffi.h>
@@ -18,10 +19,39 @@ TypedValue::TypedValue(Type type, tvm::ffi::Any &&value)
: tvm::ffi::ObjectRef( : tvm::ffi::ObjectRef(
tvm::ffi::make_object<TypedValueObj>(type, std::move(value))) {} tvm::ffi::make_object<TypedValueObj>(type, std::move(value))) {}
tvm::ffi::Optional<TypedValue> MakeTypedValue(const tvm::ffi::String &type,
const tvm::ffi::Any &value) {
tvm::ffi::Optional<Type> typeOpt = StringToType(type);
if (!typeOpt.has_value()) {
throw UnknownTypeException(type.data());
}
return TypedValue(*typeOpt, value);
}
tvm::ffi::Array<TypedValue>
MakeTypedValues(const tvm::ffi::Array<tvm::ffi::String> &types,
const tvm::ffi::Array<tvm::ffi::Any> &values) {
const size_t n = types.size();
if (const size_t m = values.size(); m != n) {
throw UnmatchedArgumentException("values", m, n);
}
tvm::ffi::Array<TypedValue> rets;
for (size_t i = 0; i < n; ++i) {
tvm::ffi::Optional<TypedValue> val = MakeTypedValue(types[i], values[i]);
if (!val.has_value()) {
throw UnknownTypeException(types[i].data());
}
rets.emplace_back(std::move(*val));
}
return rets;
}
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TypedValueObj>().def( refl::ObjectDef<TypedValueObj>()
refl::init<Type, const tvm::ffi::Any &>()); .def(refl::init<Type, const tvm::ffi::Any &>())
.def_static("make_typed_value", MakeTypedValue)
.def_static("make_typed_values", MakeTypedValues);
} }
} // namespace triton_tvm_ffi } // namespace triton_tvm_ffi