diff --git a/include/type.h b/include/type.h index 3736104..68d10cf 100644 --- a/include/type.h +++ b/include/type.h @@ -9,7 +9,7 @@ namespace triton_tvm_ffi { // --------------- Definitions --------------- // -#define TYPE_TABLE(V) \ +#define TYPE_TABLE_NATIVE(V) \ V(I1, "i1", int8_t) \ V(I8, "i8", int8_t) \ V(I16, "i16", int16_t) \ @@ -23,7 +23,10 @@ namespace triton_tvm_ffi { V(FP16, "fp16", double) \ V(BF16, "bf16", double) \ V(FP32, "f32", double) \ - V(FP64, "fp64", double) \ + V(FP64, "fp64", double) + +#define TYPE_TABLE(V) \ + TYPE_TABLE_NATIVE(V) \ V(PTR, "*?", void *) \ V(CONSTEXPR, "constexpr", void) @@ -43,28 +46,6 @@ TYPE_TABLE(DEFINE_TYPE_TO_CTYPE) #undef DEFINE_TYPE_TO_CTYPE template using type_to_ctype_t = typename type_to_ctype::t; -template struct type_size { - static constexpr size_t value = 0; -}; -template -struct type_size>> { - static constexpr size_t value = sizeof(T); -}; -template constexpr size_t type_size_v = type_size::value; - -template struct max; -template constexpr size_t max_v = max::value; -template struct max { static constexpr size_t value = N; }; -template struct max { - static constexpr size_t value = N > max_v ? N : max_v; -}; - -static constexpr size_t kMaxOpaqueSize = max_v< -#define DEFINE_TYPE_SIZE(type, str, ctype) type_size_v, - TYPE_TABLE(DEFINE_TYPE_SIZE) -#undef DEFINE_TYPE_SIZE - 0>; - // --------------- Implementations --------------- // } // namespace triton_tvm_ffi diff --git a/src/utils.cc b/src/utils.cc index 42aea17..abbc612 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -11,66 +11,11 @@ #define CUDA_CHECK(code) \ do { \ - if ((code) != CUDA_SUCCESS) { \ + if (__builtin_expect((code) != CUDA_SUCCESS, 0)) { \ throw triton_tvm_ffi::CUDAException(code); \ } \ } while (false) -namespace { - -using namespace triton_tvm_ffi; - -// --------------- Definitions --------------- - -template struct ValueCast { - TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value); -}; - -template struct ValueCastSet { - TRITON_TVM_FFI_INLINE static bool apply(void *ptr, const TypedValue &value); -}; - -using GenericValueCastSet = ValueCastSet< -#define DEFINE_TYPE_CAST(type, str, ctype) Type::type, - TYPE_TABLE(DEFINE_TYPE_CAST) -#undef DEFINE_TYPE_CAST - Type::PTR>; - -// --------------- Implementations --------------- - -template -TRITON_TVM_FFI_INLINE bool ValueCast::apply(void *ptr, - const TypedValue &value) { - if (value.GetType() == T) { - if constexpr (T == Type::PTR) { - tvm::ffi::TensorView cvalue = - value.GetValue().cast(); - void **p = reinterpret_cast(ptr); - *p = cvalue.data_ptr(); - } else if constexpr (T == Type::CONSTEXPR) { -#ifdef NDEBUG - __builtin_unreachable(); -#else - throw NotImplementedException("CONSTEXPR for value casting"); -#endif - } else { - using ctype = type_to_ctype_t; - ctype cvalue = value.GetValue().cast(); - ctype *p = reinterpret_cast(ptr); - *p = cvalue; - } - return true; - } - return false; -} -template -TRITON_TVM_FFI_INLINE bool ValueCastSet::apply(void *ptr, - const TypedValue &value) { - return (ValueCast::apply(ptr, value) || ...); -} - -} // namespace - namespace triton_tvm_ffi { tvm::ffi::Map GetDeviceProperties(int device_id) { @@ -159,23 +104,40 @@ void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); } const int32_t kernelArgNum = kernelArgs.size(); - uint8_t *buffer = - reinterpret_cast(alloca(kMaxOpaqueSize * (kernelArgNum))); void **params = reinterpret_cast(alloca(sizeof(void *) * (kernelArgNum + 2))); size_t j = 0; for (size_t i = 0; i < kernelArgNum; ++i) { TypedValue value = kernelArgs[i].cast(); - if (value.GetType() != Type::CONSTEXPR) { -#ifndef NDEBUG - assert(j < kernelArgNum); -#endif - void *ptr = buffer + j * kMaxOpaqueSize; - params[j] = ptr; - if (!GenericValueCastSet::apply(ptr, value)) { - throw UnknownTypeException(value.GetType()); - } + switch (value.GetType()) { +#define CASE_STMT(type, str, ctype) \ + case Type::type: { \ + using cpptype = type_to_ctype_t; \ + params[j] = reinterpret_cast(alloca(sizeof(cpptype))); \ + *reinterpret_cast(params[j]) = \ + value.GetValue().cast(); \ + ++j; \ + break; \ + } + TYPE_TABLE_NATIVE(CASE_STMT) +#undef CASE_STMT + case Type::PTR: { + params[j] = reinterpret_cast(alloca(sizeof(void *))); + *reinterpret_cast(params[j]) = + value.GetValue().cast().data_ptr(); ++j; + break; + } + case Type::CONSTEXPR: { + break; + } + default: { +#ifdef NDEBUG + __builtin_unreachable(); +#else + throw NotImplementedException("CONSTEXPR for value casting"); +#endif + } } } // TODO: unwrap PyObject* from scratch pointers and assign to kernel args