mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
expand argument extractions with macro
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
+5
-24
@@ -9,7 +9,7 @@ namespace triton_tvm_ffi {
|
|||||||
|
|
||||||
// --------------- Definitions --------------- //
|
// --------------- Definitions --------------- //
|
||||||
|
|
||||||
#define TYPE_TABLE(V) \
|
#define TYPE_TABLE_NATIVE(V) \
|
||||||
V(I1, "i1", int8_t) \
|
V(I1, "i1", int8_t) \
|
||||||
V(I8, "i8", int8_t) \
|
V(I8, "i8", int8_t) \
|
||||||
V(I16, "i16", int16_t) \
|
V(I16, "i16", int16_t) \
|
||||||
@@ -23,7 +23,10 @@ namespace triton_tvm_ffi {
|
|||||||
V(FP16, "fp16", double) \
|
V(FP16, "fp16", double) \
|
||||||
V(BF16, "bf16", double) \
|
V(BF16, "bf16", double) \
|
||||||
V(FP32, "f32", double) \
|
V(FP32, "f32", double) \
|
||||||
V(FP64, "fp64", double) \
|
V(FP64, "fp64", double)
|
||||||
|
|
||||||
|
#define TYPE_TABLE(V) \
|
||||||
|
TYPE_TABLE_NATIVE(V) \
|
||||||
V(PTR, "*?", void *) \
|
V(PTR, "*?", void *) \
|
||||||
V(CONSTEXPR, "constexpr", void)
|
V(CONSTEXPR, "constexpr", void)
|
||||||
|
|
||||||
@@ -43,28 +46,6 @@ TYPE_TABLE(DEFINE_TYPE_TO_CTYPE)
|
|||||||
#undef DEFINE_TYPE_TO_CTYPE
|
#undef DEFINE_TYPE_TO_CTYPE
|
||||||
template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t;
|
template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t;
|
||||||
|
|
||||||
template <typename T, typename = void> struct type_size {
|
|
||||||
static constexpr size_t value = 0;
|
|
||||||
};
|
|
||||||
template <typename T>
|
|
||||||
struct type_size<T, std::enable_if_t<!std::is_void_v<decltype(sizeof(T))>>> {
|
|
||||||
static constexpr size_t value = sizeof(T);
|
|
||||||
};
|
|
||||||
template <typename T> constexpr size_t type_size_v = type_size<T>::value;
|
|
||||||
|
|
||||||
template <size_t... Ns> struct max;
|
|
||||||
template <size_t... Ns> constexpr size_t max_v = max<Ns...>::value;
|
|
||||||
template <size_t N> struct max<N> { static constexpr size_t value = N; };
|
|
||||||
template <size_t N, size_t... Ns> struct max<N, Ns...> {
|
|
||||||
static constexpr size_t value = N > max_v<Ns...> ? N : max_v<Ns...>;
|
|
||||||
};
|
|
||||||
|
|
||||||
static constexpr size_t kMaxOpaqueSize = max_v<
|
|
||||||
#define DEFINE_TYPE_SIZE(type, str, ctype) type_size_v<ctype>,
|
|
||||||
TYPE_TABLE(DEFINE_TYPE_SIZE)
|
|
||||||
#undef DEFINE_TYPE_SIZE
|
|
||||||
0>;
|
|
||||||
|
|
||||||
// --------------- Implementations --------------- //
|
// --------------- Implementations --------------- //
|
||||||
|
|
||||||
} // namespace triton_tvm_ffi
|
} // namespace triton_tvm_ffi
|
||||||
|
|||||||
+29
-67
@@ -11,66 +11,11 @@
|
|||||||
|
|
||||||
#define CUDA_CHECK(code) \
|
#define CUDA_CHECK(code) \
|
||||||
do { \
|
do { \
|
||||||
if ((code) != CUDA_SUCCESS) { \
|
if (__builtin_expect((code) != CUDA_SUCCESS, 0)) { \
|
||||||
throw triton_tvm_ffi::CUDAException(code); \
|
throw triton_tvm_ffi::CUDAException(code); \
|
||||||
} \
|
} \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using namespace triton_tvm_ffi;
|
|
||||||
|
|
||||||
// --------------- Definitions ---------------
|
|
||||||
|
|
||||||
template <Type T> struct ValueCast {
|
|
||||||
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <Type... Ts> struct ValueCastSet {
|
|
||||||
TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value);
|
|
||||||
};
|
|
||||||
|
|
||||||
using GenericValueCastSet = ValueCastSet<
|
|
||||||
#define DEFINE_TYPE_CAST(type, str, ctype) Type::type,
|
|
||||||
TYPE_TABLE(DEFINE_TYPE_CAST)
|
|
||||||
#undef DEFINE_TYPE_CAST
|
|
||||||
Type::PTR>;
|
|
||||||
|
|
||||||
// --------------- Implementations ---------------
|
|
||||||
|
|
||||||
template <Type T>
|
|
||||||
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void *ptr,
|
|
||||||
const TypedValue &value) {
|
|
||||||
if (value.GetType() == T) {
|
|
||||||
if constexpr (T == Type::PTR) {
|
|
||||||
tvm::ffi::TensorView cvalue =
|
|
||||||
value.GetValue().cast<tvm::ffi::TensorView>();
|
|
||||||
void **p = reinterpret_cast<void **>(ptr);
|
|
||||||
*p = cvalue.data_ptr();
|
|
||||||
} else if constexpr (T == Type::CONSTEXPR) {
|
|
||||||
#ifdef NDEBUG
|
|
||||||
__builtin_unreachable();
|
|
||||||
#else
|
|
||||||
throw NotImplementedException("CONSTEXPR for value casting");
|
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
using ctype = type_to_ctype_t<T>;
|
|
||||||
ctype cvalue = value.GetValue().cast<ctype>();
|
|
||||||
ctype *p = reinterpret_cast<ctype *>(ptr);
|
|
||||||
*p = cvalue;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
template <Type... Ts>
|
|
||||||
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void *ptr,
|
|
||||||
const TypedValue &value) {
|
|
||||||
return (ValueCast<Ts>::apply(ptr, value) || ...);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace triton_tvm_ffi {
|
namespace triton_tvm_ffi {
|
||||||
|
|
||||||
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
|
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
|
||||||
@@ -159,23 +104,40 @@ void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
|||||||
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
||||||
}
|
}
|
||||||
const int32_t kernelArgNum = kernelArgs.size();
|
const int32_t kernelArgNum = kernelArgs.size();
|
||||||
uint8_t *buffer =
|
|
||||||
reinterpret_cast<uint8_t *>(alloca(kMaxOpaqueSize * (kernelArgNum)));
|
|
||||||
void **params =
|
void **params =
|
||||||
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
|
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
|
||||||
size_t j = 0;
|
size_t j = 0;
|
||||||
for (size_t i = 0; i < kernelArgNum; ++i) {
|
for (size_t i = 0; i < kernelArgNum; ++i) {
|
||||||
TypedValue value = kernelArgs[i].cast<TypedValue>();
|
TypedValue value = kernelArgs[i].cast<TypedValue>();
|
||||||
if (value.GetType() != Type::CONSTEXPR) {
|
switch (value.GetType()) {
|
||||||
#ifndef NDEBUG
|
#define CASE_STMT(type, str, ctype) \
|
||||||
assert(j < kernelArgNum);
|
case Type::type: { \
|
||||||
#endif
|
using cpptype = type_to_ctype_t<Type::type>; \
|
||||||
void *ptr = buffer + j * kMaxOpaqueSize;
|
params[j] = reinterpret_cast<void *>(alloca(sizeof(cpptype))); \
|
||||||
params[j] = ptr;
|
*reinterpret_cast<cpptype *>(params[j]) = \
|
||||||
if (!GenericValueCastSet::apply(ptr, value)) {
|
value.GetValue().cast<cpptype>(); \
|
||||||
throw UnknownTypeException(value.GetType());
|
++j; \
|
||||||
}
|
break; \
|
||||||
|
}
|
||||||
|
TYPE_TABLE_NATIVE(CASE_STMT)
|
||||||
|
#undef CASE_STMT
|
||||||
|
case Type::PTR: {
|
||||||
|
params[j] = reinterpret_cast<void *>(alloca(sizeof(void *)));
|
||||||
|
*reinterpret_cast<void **>(params[j]) =
|
||||||
|
value.GetValue().cast<tvm::ffi::TensorView>().data_ptr();
|
||||||
++j;
|
++j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Type::CONSTEXPR: {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
#ifdef NDEBUG
|
||||||
|
__builtin_unreachable();
|
||||||
|
#else
|
||||||
|
throw NotImplementedException("CONSTEXPR for value casting");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
|
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
|
||||||
|
|||||||
Reference in New Issue
Block a user