From 599957e156625574c93870e0d7ae68eddb66b167 Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Tue, 10 Feb 2026 17:01:28 +0800 Subject: [PATCH] support attention bwd Signed-off-by: Jinjie Liu --- CMakeLists.txt | 8 -- examples/attention/attention.py | 116 +++++++++++++++--- examples/attention/attnbwd.cc | 57 +++++++++ examples/attention/attnbwdpre.cc | 45 +++++++ .../attention/{attention.cc => attnfwd.cc} | 2 + python/triton_tvm_ffi/jit.py | 4 +- 6 files changed, 206 insertions(+), 26 deletions(-) delete mode 100644 CMakeLists.txt create mode 100644 examples/attention/attnbwd.cc create mode 100644 examples/attention/attnbwdpre.cc rename examples/attention/{attention.cc => attnfwd.cc} (97%) diff --git a/CMakeLists.txt b/CMakeLists.txt deleted file mode 100644 index 536b718..0000000 --- a/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -cmake_minimum_required(VERSION 3.18) - -project(${SKBUILD_PROJECT_NAME}) - -install( - DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include - DESTINATION ${CMAKE_INSTALL_PREFIX}/triton_tvm_ffi -) diff --git a/examples/attention/attention.py b/examples/attention/attention.py index 2d25b79..2b7a310 100644 --- a/examples/attention/attention.py +++ b/examples/attention/attention.py @@ -16,6 +16,7 @@ Extra Credits: import os from pathlib import Path import time +from typing import Sequence import torch import triton @@ -731,13 +732,41 @@ class _attention_triton(torch.autograd.Function): @triton_tvm_ffi.torch_wrap( [_attn_fwd], - Path(__file__).parent / "attention.cc", + Path(__file__).parent / "attnfwd.cc", ) def _attn_fwd_tvm_ffi( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, sm_scale: float ) -> torch.Tensor: ... +@triton_tvm_ffi.torch_wrap( + [_attn_bwd_preprocess], + Path(__file__).parent / "attnbwdpre.cc", +) +def _attn_bwd_preprocess_tvm_ffi( + o: torch.Tensor, + do: torch.Tensor, + mshape: Sequence[int], + head_dim: int, +): ... + + +@triton_tvm_ffi.torch_wrap( + [_attn_bwd], + Path(__file__).parent / "attnbwd.cc", +) +def _attn_bwd_tvm_ffi( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sm_scale: float, + do: torch.Tensor, + m: torch.Tensor, + delta: torch.Tensor, + HEAD_DIM: int, +): ... + + class _attention_tvm_ffi(_attention_triton): @staticmethod def forward(ctx, q, k, v, causal, sm_scale): @@ -752,6 +781,31 @@ class _attention_tvm_ffi(_attention_triton): ctx.causal = causal return o + @staticmethod + def backward(ctx, do): + q, k, v, o, M = ctx.saved_tensors + delta = _attn_bwd_preprocess_tvm_ffi( + o, + do, + M.shape, + ctx.HEAD_DIM, + ) + dq, dk, dv = _attn_bwd_tvm_ffi( + q, + k, + v, + ctx.sm_scale, + do, + M, + delta, + ctx.HEAD_DIM, + ) + dq = torch.from_dlpack(dq) + dk = torch.from_dlpack(dk) + dv = torch.from_dlpack(dv) + + return dq, dk, dv, None, None, None, None + def attn_torch(q, k, v, causal=False, sm_scale=1.0): M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE)) @@ -778,7 +832,7 @@ if __name__ == "__main__": N_CTX = 128 HEAD_DIM = 64 causal = True - mode = "fwd" + mode = "bwd" dtype = torch.float16 torch.manual_seed(20) q = ( @@ -811,25 +865,27 @@ if __name__ == "__main__": ref_dq, q.grad = q.grad.clone(), None tri_out = attn_triton(q, k, v, causal, sm_scale).half() tvm_ffi_out = attn_tvm_ffi(q, k, v, causal, sm_scale).half() + warmup = 5 round = 1000 if mode == "fwd": atol = 1e-2 torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0) torch.testing.assert_close(tvm_ffi_out, ref_out, atol=atol, rtol=0) - for _ in range(5): - attn_torch(q, k, v, causal, sm_scale) - attn_triton(q, k, v, causal, sm_scale) - attn_tvm_ffi(q, k, v, causal, sm_scale) - cp0 = time.perf_counter_ns() - for _ in range(round): - attn_torch(q, k, v, causal, sm_scale) - cp1 = time.perf_counter_ns() - for _ in range(round): - attn_triton(q, k, v, causal, sm_scale) - cp2 = time.perf_counter_ns() - for _ in range(round): - attn_tvm_ffi(q, k, v, causal, sm_scale) - cp3 = time.perf_counter_ns() + with torch.no_grad(): + for _ in range(warmup): + attn_torch(q, k, v, causal, sm_scale) + attn_triton(q, k, v, causal, sm_scale) + attn_tvm_ffi(q, k, v, causal, sm_scale) + cp0 = time.perf_counter_ns() + for _ in range(round): + attn_torch(q, k, v, causal, sm_scale) + cp1 = time.perf_counter_ns() + for _ in range(round): + attn_triton(q, k, v, causal, sm_scale) + cp2 = time.perf_counter_ns() + for _ in range(round): + attn_tvm_ffi(q, k, v, causal, sm_scale) + cp3 = time.perf_counter_ns() print( f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms" ) @@ -851,3 +907,31 @@ if __name__ == "__main__": torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol) torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol) torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol) + tvm_ffi_out.backward(dout) + tvm_ffi_dv, v.grad = v.grad.clone(), None + tvm_ffi_dk, k.grad = k.grad.clone(), None + tvm_ffi_dq, q.grad = q.grad.clone(), None + # compare + torch.testing.assert_close(tvm_ffi_out, ref_out, atol=1e-2, rtol=0) + rtol = 0.0 + torch.testing.assert_close(tvm_ffi_dv, ref_dv, atol=1e-2, rtol=rtol) + torch.testing.assert_close(tvm_ffi_dk, ref_dk, atol=1e-2, rtol=rtol) + torch.testing.assert_close(tvm_ffi_dq, ref_dq, atol=1e-2, rtol=rtol) + + for _ in range(warmup): + attn_torch(q, k, v, causal, sm_scale).backward(dout) + attn_triton(q, k, v, causal, sm_scale).backward(dout) + attn_tvm_ffi(q, k, v, causal, sm_scale).backward(dout) + cp0 = time.perf_counter_ns() + for _ in range(round): + attn_torch(q, k, v, causal, sm_scale).backward(dout) + cp1 = time.perf_counter_ns() + for _ in range(round): + attn_triton(q, k, v, causal, sm_scale).backward(dout) + cp2 = time.perf_counter_ns() + for _ in range(round): + attn_tvm_ffi(q, k, v, causal, sm_scale).backward(dout) + cp3 = time.perf_counter_ns() + print( + f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms" + ) diff --git a/examples/attention/attnbwd.cc b/examples/attention/attnbwd.cc new file mode 100644 index 0000000..394ff62 --- /dev/null +++ b/examples/attention/attnbwd.cc @@ -0,0 +1,57 @@ +#include "ATen/core/ATen_fwd.h" +#include "ATen/ops/empty.h" +#include "c10/core/Device.h" +#include "torch/headeronly/core/DeviceType.h" +#include "tvm/ffi/container/tensor.h" +#include +#include +#include +#include +#include + +#ifndef _ATTN_BWD_STUB +#define _ATTN_BWD_STUB(grid, device, stream, args, kwargs) +#endif + +#ifndef _ATTN_BWD_TVM_FFI_NAME +#define _ATTN_BWD_TVM_FFI_NAME "" +#endif + +tvm::ffi::Tuple +AttnBwd(tvm::ffi::Tensor q, tvm::ffi::Tensor k, tvm::ffi::Tensor v, + const double smScale, tvm::ffi::Tensor do_, tvm::ffi::Tensor m, + tvm::ffi::Tensor delta, const int32_t kHeadDim) { + tvm::ffi::ShapeView qshape = q.shape(), qstride = q.strides(); + const int32_t kBatch = qshape[0], kNHead = qshape[1], kNCtx = qshape[2], + kBlockN1 = 128; + const double kArgKScale = smScale / log(2); + at::Tensor qTorch = at::fromDLPack(q.ToDLPack()), + kTorch = at::fromDLPack(k.ToDLPack()), + vTorch = at::fromDLPack(v.ToDLPack()), + dqTorch = at::empty_like(qTorch), dkTorch = at::empty_like(kTorch), + dvTorch = at::empty_like(vTorch), + argKTorch = at::mul(kTorch, kArgKScale); + tvm::ffi::Tensor dq = tvm::ffi::Tensor::FromDLPack(at::toDLPack(dqTorch)), + dk = tvm::ffi::Tensor::FromDLPack(at::toDLPack(dkTorch)), + dv = tvm::ffi::Tensor::FromDLPack(at::toDLPack(dvTorch)), + argK = tvm::ffi::Tensor::FromDLPack(at::toDLPack(argKTorch)); + tvm::ffi::Tuple grid(kNCtx / kBlockN1, 1, + kBatch * kNHead); + tvm::ffi::Array args = { + q, argK, v, smScale, do_, dq, dk, dv, + m, delta, qstride[0], qstride[1], qstride[2], qstride[3], kNHead, kNCtx}; + tvm::ffi::Map kwargs = { + {"BLOCK_M1", 32}, {"BLOCK_N1", kBlockN1}, {"BLOCK_M2", 128}, + {"BLOCK_N2", 32}, {"BLK_SLICE_FACTOR", 2}, {"HEAD_DIM", kHeadDim}, + {"num_warps", 4}, {"num_stages", 5}, + }; + DLDevice device = q.device(); + void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); + _ATTN_BWD_STUB(grid, device.device_id, stream, args, kwargs); + return tvm::ffi::Tuple{dq, dk, dv}; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def(_ATTN_BWD_TVM_FFI_NAME, AttnBwd); +} diff --git a/examples/attention/attnbwdpre.cc b/examples/attention/attnbwdpre.cc new file mode 100644 index 0000000..b9d518b --- /dev/null +++ b/examples/attention/attnbwdpre.cc @@ -0,0 +1,45 @@ +#include "ATen/core/ATen_fwd.h" +#include "ATen/ops/empty.h" +#include "torch/headeronly/core/DeviceType.h" +#include "tvm/ffi/container/tensor.h" +#include +#include +#include +#include +#include + +#ifndef _ATTN_BWD_PREPROCESS_STUB +#define _ATTN_BWD_PREPROCESS_STUB(grid, device, stream, args, kwargs) +#endif + +#ifndef _ATTN_BWD_PREPROCESS_TVM_FFI_NAME +#define _ATTN_BWD_PREPROCESS_TVM_FFI_NAME "" +#endif + +tvm::ffi::Tensor AttnBwdPreprocess(tvm::ffi::Tensor o, tvm::ffi::Tensor do_, + tvm::ffi::Shape mshape, + const int32_t kHeadDim) { + const int32_t kBatch = mshape[0], kNHead = mshape[1], kNCtx = mshape[2], + kPreBlock = 128; + at::Tensor deltaTorch = at::empty(mshape, at::kFloat, std::nullopt, + at::Device(at::kCUDA, o.device().device_id), + std::nullopt, std::nullopt); + tvm::ffi::Tensor delta = + tvm::ffi::Tensor::FromDLPack(at::toDLPack(deltaTorch)); + tvm::ffi::Tuple grid(kNCtx / kPreBlock, + kBatch * kNHead, 1); + tvm::ffi::Array args = {o, do_, delta, kBatch, kNHead, kNCtx}; + tvm::ffi::Map kwargs = { + {"BLOCK_M", kPreBlock}, + {"HEAD_DIM", kHeadDim}, + }; + DLDevice device = o.device(); + void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); + _ATTN_BWD_PREPROCESS_STUB(grid, device.device_id, stream, args, kwargs); + return delta; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def(_ATTN_BWD_PREPROCESS_TVM_FFI_NAME, AttnBwdPreprocess); +} diff --git a/examples/attention/attention.cc b/examples/attention/attnfwd.cc similarity index 97% rename from examples/attention/attention.cc rename to examples/attention/attnfwd.cc index 1fe5073..eec91cb 100644 --- a/examples/attention/attention.cc +++ b/examples/attention/attnfwd.cc @@ -1,3 +1,5 @@ +#include "ATen/core/ATen_fwd.h" +#include "ATen/ops/empty.h" #include "tvm/ffi/container/tensor.h" #include #include diff --git a/python/triton_tvm_ffi/jit.py b/python/triton_tvm_ffi/jit.py index 31e07f3..be19cd4 100644 --- a/python/triton_tvm_ffi/jit.py +++ b/python/triton_tvm_ffi/jit.py @@ -47,8 +47,8 @@ class TVMFFIJITFunction(object): ): args: Iterator[Any] = map(self.canonicalize, args) kwargs: Dict[str, Any] = { - k: v for k, v in zip(self.signature, args) if v is not None - } | {k: self.canonicalize(v) for k, v in kwargs.items()} + k: self.canonicalize(v) for k, v in kwargs.items() + } kernel: CompiledKernel = self.fn[grid](*args, **kwargs) self.num_warps, _, self.shmem = kernel.packed_metadata self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]