mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
@@ -16,6 +16,7 @@ Extra Credits:
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -731,13 +732,41 @@ class _attention_triton(torch.autograd.Function):
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[_attn_fwd],
|
||||
Path(__file__).parent / "attention.cc",
|
||||
Path(__file__).parent / "attnfwd.cc",
|
||||
)
|
||||
def _attn_fwd_tvm_ffi(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, sm_scale: float
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[_attn_bwd_preprocess],
|
||||
Path(__file__).parent / "attnbwdpre.cc",
|
||||
)
|
||||
def _attn_bwd_preprocess_tvm_ffi(
|
||||
o: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
mshape: Sequence[int],
|
||||
head_dim: int,
|
||||
): ...
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[_attn_bwd],
|
||||
Path(__file__).parent / "attnbwd.cc",
|
||||
)
|
||||
def _attn_bwd_tvm_ffi(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
sm_scale: float,
|
||||
do: torch.Tensor,
|
||||
m: torch.Tensor,
|
||||
delta: torch.Tensor,
|
||||
HEAD_DIM: int,
|
||||
): ...
|
||||
|
||||
|
||||
class _attention_tvm_ffi(_attention_triton):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
@@ -752,6 +781,31 @@ class _attention_tvm_ffi(_attention_triton):
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, M = ctx.saved_tensors
|
||||
delta = _attn_bwd_preprocess_tvm_ffi(
|
||||
o,
|
||||
do,
|
||||
M.shape,
|
||||
ctx.HEAD_DIM,
|
||||
)
|
||||
dq, dk, dv = _attn_bwd_tvm_ffi(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
ctx.sm_scale,
|
||||
do,
|
||||
M,
|
||||
delta,
|
||||
ctx.HEAD_DIM,
|
||||
)
|
||||
dq = torch.from_dlpack(dq)
|
||||
dk = torch.from_dlpack(dk)
|
||||
dv = torch.from_dlpack(dv)
|
||||
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
def attn_torch(q, k, v, causal=False, sm_scale=1.0):
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
|
||||
@@ -778,7 +832,7 @@ if __name__ == "__main__":
|
||||
N_CTX = 128
|
||||
HEAD_DIM = 64
|
||||
causal = True
|
||||
mode = "fwd"
|
||||
mode = "bwd"
|
||||
dtype = torch.float16
|
||||
torch.manual_seed(20)
|
||||
q = (
|
||||
@@ -811,25 +865,27 @@ if __name__ == "__main__":
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
tri_out = attn_triton(q, k, v, causal, sm_scale).half()
|
||||
tvm_ffi_out = attn_tvm_ffi(q, k, v, causal, sm_scale).half()
|
||||
warmup = 5
|
||||
round = 1000
|
||||
if mode == "fwd":
|
||||
atol = 1e-2
|
||||
torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(tvm_ffi_out, ref_out, atol=atol, rtol=0)
|
||||
for _ in range(5):
|
||||
attn_torch(q, k, v, causal, sm_scale)
|
||||
attn_triton(q, k, v, causal, sm_scale)
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_torch(q, k, v, causal, sm_scale)
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_triton(q, k, v, causal, sm_scale)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
cp3 = time.perf_counter_ns()
|
||||
with torch.no_grad():
|
||||
for _ in range(warmup):
|
||||
attn_torch(q, k, v, causal, sm_scale)
|
||||
attn_triton(q, k, v, causal, sm_scale)
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_torch(q, k, v, causal, sm_scale)
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_triton(q, k, v, causal, sm_scale)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
@@ -851,3 +907,31 @@ if __name__ == "__main__":
|
||||
torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol)
|
||||
tvm_ffi_out.backward(dout)
|
||||
tvm_ffi_dv, v.grad = v.grad.clone(), None
|
||||
tvm_ffi_dk, k.grad = k.grad.clone(), None
|
||||
tvm_ffi_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
torch.testing.assert_close(tvm_ffi_out, ref_out, atol=1e-2, rtol=0)
|
||||
rtol = 0.0
|
||||
torch.testing.assert_close(tvm_ffi_dv, ref_dv, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tvm_ffi_dk, ref_dk, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tvm_ffi_dq, ref_dq, atol=1e-2, rtol=rtol)
|
||||
|
||||
for _ in range(warmup):
|
||||
attn_torch(q, k, v, causal, sm_scale).backward(dout)
|
||||
attn_triton(q, k, v, causal, sm_scale).backward(dout)
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale).backward(dout)
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_torch(q, k, v, causal, sm_scale).backward(dout)
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_triton(q, k, v, causal, sm_scale).backward(dout)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale).backward(dout)
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user