mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
put typedvalues initialization into cpp
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -25,6 +25,15 @@ private:
|
||||
const std::string message_;
|
||||
};
|
||||
|
||||
class UnmatchedArgumentException : public std::exception {
|
||||
public:
|
||||
UnmatchedArgumentException(std::string_view name, size_t len, size_t expect);
|
||||
const char *what() const noexcept override;
|
||||
|
||||
private:
|
||||
const std::string message_;
|
||||
};
|
||||
|
||||
class UnknownTypeException : public std::exception {
|
||||
public:
|
||||
UnknownTypeException(Type type);
|
||||
|
||||
+2
-3
@@ -2,8 +2,7 @@
|
||||
#define TRITON_TVM_FFI_TYPE_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/object.h>
|
||||
#include <tvm/ffi/string.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
@@ -35,7 +34,7 @@ enum class Type : int64_t {
|
||||
};
|
||||
|
||||
const char *TypeToString(Type type);
|
||||
tvm::ffi::Optional<Type> StringToType(tvm::ffi::String str);
|
||||
tvm::ffi::Optional<Type> StringToType(const tvm::ffi::String &name);
|
||||
|
||||
template <Type T> struct type_to_ctype;
|
||||
#define DEFINE_TYPE_TO_CTYPE(type, str, ctype) \
|
||||
|
||||
@@ -40,6 +40,13 @@ public:
|
||||
TypedValueObj);
|
||||
};
|
||||
|
||||
tvm::ffi::Optional<TypedValue> MakeTypedValue(const tvm::ffi::String &type,
|
||||
const tvm::ffi::Any &value);
|
||||
|
||||
tvm::ffi::Array<TypedValue>
|
||||
MakeTypedValues(const tvm::ffi::Array<tvm::ffi::String> &types,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &values);
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
|
||||
@@ -6,6 +6,7 @@ from tvm_ffi import Object as _ffi_Object, init_ffi_api as _FFI_INIT_FUNC, regis
|
||||
from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from tvm_ffi import Object
|
||||
from typing import Any
|
||||
# isort: on
|
||||
@@ -33,6 +34,10 @@ class TypedValue(_ffi_Object):
|
||||
if TYPE_CHECKING:
|
||||
@staticmethod
|
||||
def __c_ffi_init__(_0: int, _1: Any, /) -> Object: ...
|
||||
@staticmethod
|
||||
def make_typed_value(_0: str, _1: Any, /) -> TypedValue | None: ...
|
||||
@staticmethod
|
||||
def make_typed_values(_0: Sequence[str], _1: Sequence[Any], /) -> Sequence[TypedValue]: ...
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import Any, List, Optional, Sequence, Type
|
||||
from triton.backends.nvidia.driver import CudaDriver
|
||||
from triton.runtime import _allocation
|
||||
from . import TypedValue, utils, string_to_type
|
||||
@@ -10,7 +10,7 @@ class TVMLauncher(object):
|
||||
def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.signature: List[str] = src.signature.values()
|
||||
self.signature: List[str] = [*src.signature.values()]
|
||||
self.num_ctas: int = getattr(metadata, "num_ctas", 1)
|
||||
self.launch = utils.launch
|
||||
self.global_scratch_size: int = metadata.global_scratch_size
|
||||
@@ -51,14 +51,8 @@ class TVMLauncher(object):
|
||||
)
|
||||
assert not self.launch_cooperative_grid
|
||||
assert not self.launch_pdl
|
||||
assert len(self.signature) == len(args)
|
||||
|
||||
def canonicalize(arg: Any, sig: str) -> TypedValue:
|
||||
ty: Optional[int] = string_to_type(sig)
|
||||
assert ty is not None, sig
|
||||
return TypedValue(ty, arg)
|
||||
|
||||
args = [canonicalize(arg, sig) for arg, sig in zip(args, self.signature)]
|
||||
args: Sequence[TypedValue] = TypedValue.make_typed_values(self.signature, args)
|
||||
|
||||
return self.launch(
|
||||
gridX,
|
||||
@@ -74,7 +68,7 @@ class TVMLauncher(object):
|
||||
self.launch_pdl,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*args,
|
||||
args,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ from __future__ import annotations
|
||||
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from tvm_ffi import Object
|
||||
from typing import Any
|
||||
# isort: on
|
||||
# fmt: on
|
||||
@@ -19,7 +20,7 @@ if TYPE_CHECKING:
|
||||
def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ...
|
||||
def fill_tma_descriptor(*args: Any) -> Any: ...
|
||||
def get_device_properties(_0: int, /) -> Mapping[str, int]: ...
|
||||
def launch(*args: Any) -> Any: ...
|
||||
def launch(_0: int, _1: int, _2: int, _3: int, _4: int, _5: tuple[int, int, int], _6: Object, _7: Object, _8: Object, _9: bool, _10: bool, _11: Object, _12: Object, _13: Sequence[Any], /) -> None: ...
|
||||
def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ...
|
||||
def set_printf_fifo_size(*args: Any) -> Any: ...
|
||||
# fmt: on
|
||||
|
||||
@@ -17,6 +17,17 @@ const char *NotImplementedException::what() const noexcept {
|
||||
return message_.c_str();
|
||||
}
|
||||
|
||||
UnmatchedArgumentException::UnmatchedArgumentException(std::string_view name,
|
||||
size_t len,
|
||||
size_t expect)
|
||||
: message_("[UnmatchedArgumentException]: argument \"" + std::string(name) +
|
||||
"\" has length " + std::to_string(len) + ", but expected " +
|
||||
std::to_string(expect)) {}
|
||||
|
||||
const char *UnmatchedArgumentException::what() const noexcept {
|
||||
return message_.c_str();
|
||||
}
|
||||
|
||||
UnknownTypeException::UnknownTypeException(Type type)
|
||||
: message_("[UnknownTypeException]: unknown type: \"" +
|
||||
std::string(TypeToString(type)) + "\"") {}
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ const char *TypeToString(Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
tvm::ffi::Optional<Type> StringToType(tvm::ffi::String name) {
|
||||
tvm::ffi::Optional<Type> StringToType(const tvm::ffi::String &name) {
|
||||
if (name.starts_with("*")) {
|
||||
return Type::PTR;
|
||||
}
|
||||
|
||||
+18
-29
@@ -107,30 +107,19 @@ tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
|
||||
{"mem_bus_width", memBusWidth}};
|
||||
}
|
||||
|
||||
void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
CUtensorMap x;
|
||||
int32_t gridX = args[0].cast<int32_t>();
|
||||
int32_t gridY = args[1].cast<int32_t>();
|
||||
int32_t gridZ = args[2].cast<int32_t>();
|
||||
CUstream stream = reinterpret_cast<CUstream>(args[3].cast<uint64_t>());
|
||||
CUfunction function = reinterpret_cast<CUfunction>(args[4].cast<uint64_t>());
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata =
|
||||
args[5].cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
|
||||
int32_t numWarps = kernelMetadata.get<0>();
|
||||
int32_t numCtas = kernelMetadata.get<1>();
|
||||
int32_t sharedMemory = kernelMetadata.get<2>();
|
||||
tvm::ffi::ObjectRef launchMetadata = args[6].cast<tvm::ffi::ObjectRef>();
|
||||
tvm::ffi::ObjectRef launchEnterHook = args[7].cast<tvm::ffi::ObjectRef>();
|
||||
tvm::ffi::ObjectRef launchExitHook = args[8].cast<tvm::ffi::ObjectRef>();
|
||||
bool launchCooperativeGrid = args[9].cast<bool>();
|
||||
bool launchPdl = args[10].cast<bool>();
|
||||
tvm::ffi::ObjectRef globalScratchObject =
|
||||
args[11].cast<tvm::ffi::ObjectRef>();
|
||||
tvm::ffi::ObjectRef profileScratchObject =
|
||||
args[12].cast<tvm::ffi::ObjectRef>();
|
||||
tvm::ffi::PackedArgs kernelArgs = args.Slice(13);
|
||||
// TODO: call `launchEnterHook`
|
||||
// TODO: check `globalScratchObject`
|
||||
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function,
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> kernelMetadata,
|
||||
tvm::ffi::ObjectRef launchMetadata,
|
||||
tvm::ffi::ObjectRef launchEnterHook,
|
||||
tvm::ffi::ObjectRef launchExitHook, bool launchCooperativeGrid,
|
||||
bool launchPdl, tvm::ffi::ObjectRef globalScratchObject,
|
||||
tvm::ffi::ObjectRef profileScratchObject,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) {
|
||||
CUstream cStream = reinterpret_cast<CUstream>(stream);
|
||||
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
|
||||
auto [numWarps, numCtas, sharedMemory] = kernelMetadata;
|
||||
// TODO: Implement the launch logic
|
||||
CUdeviceptr globalScratch = 0;
|
||||
// TODO: check `profileScratchObject`
|
||||
CUdeviceptr profileScratch = 0;
|
||||
@@ -145,10 +134,10 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
config.blockDimY = 1;
|
||||
config.blockDimZ = 1;
|
||||
config.sharedMemBytes = sharedMemory;
|
||||
config.hStream = stream;
|
||||
config.hStream = cStream;
|
||||
config.attrs = launchAttr;
|
||||
int32_t numAttrs = 0;
|
||||
// TODO: check `launchPdf`
|
||||
// TODO: check `launchPdl`
|
||||
// TODO: check `launchCooperativeGrid`
|
||||
if (numCtas != 1) {
|
||||
CUlaunchAttribute clusterAttr;
|
||||
@@ -167,7 +156,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
config.numAttrs = numAttrs;
|
||||
if (numCtas == 16) {
|
||||
CUDA_CHECK(cuFuncSetAttribute(
|
||||
function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
||||
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
||||
}
|
||||
const int32_t kernelArgNum = kernelArgs.size();
|
||||
uint8_t *buffer =
|
||||
@@ -192,7 +181,7 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
|
||||
params[j] = &globalScratch;
|
||||
params[j + 1] = &profileScratch;
|
||||
CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr));
|
||||
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
|
||||
}
|
||||
// TODO: call `launchExitHook`
|
||||
}
|
||||
@@ -261,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
throw NotImplementedException("set_printf_fifo_size");
|
||||
})
|
||||
.def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties)
|
||||
.def_packed("triton_tvm_ffi.utils.launch", Launch)
|
||||
.def("triton_tvm_ffi.utils.launch", Launch)
|
||||
.def("triton_tvm_ffi.utils.load_binary", LoadBinary);
|
||||
}
|
||||
|
||||
|
||||
+32
-2
@@ -1,4 +1,5 @@
|
||||
#include "value.h"
|
||||
#include "exception.h"
|
||||
#include "type.h"
|
||||
#include <tvm/ffi/reflection/registry.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
@@ -18,10 +19,39 @@ TypedValue::TypedValue(Type type, tvm::ffi::Any &&value)
|
||||
: tvm::ffi::ObjectRef(
|
||||
tvm::ffi::make_object<TypedValueObj>(type, std::move(value))) {}
|
||||
|
||||
tvm::ffi::Optional<TypedValue> MakeTypedValue(const tvm::ffi::String &type,
|
||||
const tvm::ffi::Any &value) {
|
||||
tvm::ffi::Optional<Type> typeOpt = StringToType(type);
|
||||
if (!typeOpt.has_value()) {
|
||||
throw UnknownTypeException(type.data());
|
||||
}
|
||||
return TypedValue(*typeOpt, value);
|
||||
}
|
||||
|
||||
tvm::ffi::Array<TypedValue>
|
||||
MakeTypedValues(const tvm::ffi::Array<tvm::ffi::String> &types,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &values) {
|
||||
const size_t n = types.size();
|
||||
if (const size_t m = values.size(); m != n) {
|
||||
throw UnmatchedArgumentException("values", m, n);
|
||||
}
|
||||
tvm::ffi::Array<TypedValue> rets;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
tvm::ffi::Optional<TypedValue> val = MakeTypedValue(types[i], values[i]);
|
||||
if (!val.has_value()) {
|
||||
throw UnknownTypeException(types[i].data());
|
||||
}
|
||||
rets.emplace_back(std::move(*val));
|
||||
}
|
||||
return rets;
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::ObjectDef<TypedValueObj>().def(
|
||||
refl::init<Type, const tvm::ffi::Any &>());
|
||||
refl::ObjectDef<TypedValueObj>()
|
||||
.def(refl::init<Type, const tvm::ffi::Any &>())
|
||||
.def_static("make_typed_value", MakeTypedValue)
|
||||
.def_static("make_typed_values", MakeTypedValues);
|
||||
}
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
Reference in New Issue
Block a user