support typedvalue

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-01-29 16:46:00 +08:00
parent fecf48e403
commit 781a5396cc
15 changed files with 367 additions and 89 deletions
+70 -21
View File
@@ -1,6 +1,8 @@
#include "exception.h"
#include <cassert>
#include "type.h"
#include "value.h"
#include <cuda.h>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/tvm_ffi.h>
@@ -11,6 +13,60 @@
} \
} while (false)
namespace {
using namespace triton_tvm_ffi;
// --------------- Definitions ---------------
template <Type T> struct ValueCast {
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
};
template <Type... Ts> struct ValueCastSet {
TRITON_TVM_FFI_INLINE static bool apply(void **ptr, const TypedValue &value);
};
using GenericValueCastSet = ValueCastSet<
#define DEFINE_TYPE_CAST(type, str, ctype) Type::type,
TYPE_TABLE(DEFINE_TYPE_CAST)
#undef DEFINE_TYPE_CAST
Type::PTR>;
// --------------- Implementations ---------------
template <Type T>
TRITON_TVM_FFI_INLINE bool ValueCast<T>::apply(void **addr,
const TypedValue &value) {
if (value.GetType() == T) {
if constexpr (T == Type::PTR) {
tvm::ffi::TensorView cvalue =
value.GetValue().cast<tvm::ffi::TensorView>();
void **ptr = reinterpret_cast<void **>(alloca(sizeof(void *)));
*ptr = cvalue.data_ptr();
*addr = ptr;
} else {
using ctype = type_to_ctype_t<T>;
ctype cvalue = value.GetValue().cast<ctype>();
ctype *ptr = reinterpret_cast<ctype *>(alloca(sizeof(ctype)));
*ptr = cvalue;
*addr = ptr;
}
return true;
} else {
return false;
}
}
template <Type... Ts>
TRITON_TVM_FFI_INLINE bool ValueCastSet<Ts...>::apply(void **ptr,
const TypedValue &value) {
return (ValueCast<Ts>::apply(ptr, value) || ...);
}
} // namespace
namespace triton_tvm_ffi {
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
tvm::ffi::cuda_api::DeviceHandle device;
CUDA_CHECK(cuDeviceGet(&device, device_id));
@@ -153,26 +209,17 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
const int32_t kernelArgNum = kernelArgs.size();
void **params =
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
size_t j = 0;
for (size_t i = 0; i < kernelArgNum; ++i) {
tvm::ffi::AnyView arg = kernelArgs[i];
if (auto val = arg.try_cast<tvm::ffi::TensorView>()) {
void **ptr = reinterpret_cast<void **>(alloca(sizeof(void *)));
*ptr = val->data_ptr();
params[i] = ptr;
} else if (auto val = arg.try_cast<int32_t>()) {
int32_t *ptr = reinterpret_cast<int32_t *>(alloca(sizeof(int32_t)));
*ptr = *val;
params[i] = ptr;
} else if (auto val = arg.try_cast<float>()) {
float *ptr = reinterpret_cast<float *>(alloca(sizeof(float)));
*ptr = *val;
params[i] = ptr;
} else {
assert(false && "unsupported kernel argument type");
TypedValue value = kernelArgs[i].cast<TypedValue>();
if (value.GetType() != Type::CONSTEXPR) {
if (!GenericValueCastSet::apply(&params[j++], value)) {
throw UnknownTypeException(value.GetType());
}
}
}
params[kernelArgNum] = &globalScratch;
params[kernelArgNum + 1] = &profileScratch;
params[j] = &globalScratch;
params[j + 1] = &profileScratch;
CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr));
}
// TODO: call `launchExitHook`
@@ -181,7 +228,9 @@ void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties)
.def_packed("triton_tvm_ffi.utils.launch", Launch)
.def("triton_tvm_ffi.utils.load_binary", LoadBinary);
.def("triton_tvm_ffi.get_device_properties", GetDeviceProperties)
.def_packed("triton_tvm_ffi.launch", Launch)
.def("triton_tvm_ffi.load_binary", LoadBinary);
}
} // namespace triton_tvm_ffi