#include #include #include #include #ifndef MATMUL_KERNEL_STUB #define MATMUL_KERNEL_STUB(grid, device, stream, args, kwargs) #endif #ifndef MATMUL_NAME #define MATMUL_NAME "" #endif tvm::ffi::Tensor Matmul(tvm::ffi::Tensor a, tvm::ffi::Tensor b, tvm::ffi::String activation) { at::Tensor atorch = at::fromDLPack(a.ToDLPack()), btorch = at::fromDLPack(b.ToDLPack()); const int32_t M = atorch.size(0), K = atorch.size(1), N = btorch.size(1); at::Tensor ctorch = at::empty({M, N}, atorch.options()); tvm::ffi::Function grid = tvm::ffi::Function::FromTyped( [M, N](const tvm::ffi::Map &meta) -> tvm::ffi::Tuple { const int32_t BLOCK_SIZE_M = meta["BLOCK_SIZE_M"].cast(), BLOCK_SIZE_N = meta["BLOCK_SIZE_N"].cast(); return tvm::ffi::Tuple{ (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M * ((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N), 1, 1}; }); DLDevice device = a.device(); void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); tvm::ffi::Tensor c = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ctorch)); tvm::ffi::Array args = {a, b, c, M, N, K, atorch.stride(0), atorch.stride(1), btorch.stride(0), btorch.stride(1), ctorch.stride(0), ctorch.stride(1)}; tvm::ffi::Map kwargs = { {"ACTIVATION", activation}, }; MATMUL_KERNEL_STUB(grid, device.device_id, stream, args, kwargs); return c; } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def(MATMUL_NAME, Matmul); }