mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
50
examples/attention/attention.cc
Normal file
50
examples/attention/attention.cc
Normal file
@@ -0,0 +1,50 @@
|
||||
#include "tvm/ffi/container/tensor.h"
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef _ATTN_FWD_STUB
|
||||
#define _ATTN_FWD_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef _ATTN_FWD_TVM_FFI_NAME
|
||||
#define _ATTN_FWD_TVM_FFI_NAME ""
|
||||
#endif
|
||||
|
||||
tvm::ffi::Tuple<tvm::ffi::Tensor, tvm::ffi::Tensor>
|
||||
AttnFwd(tvm::ffi::Tensor q, tvm::ffi::Tensor k, tvm::ffi::Tensor v, bool casual,
|
||||
float smScale) {
|
||||
const tvm::ffi::ShapeView &qshape = q.shape(), &kshape = k.shape(),
|
||||
&vshape = v.shape();
|
||||
const int32_t kB = qshape[0], kH = qshape[1], kN = qshape[2], kQ = qshape[3],
|
||||
kK = kshape[3], kV = vshape[3], stage = casual ? 3 : 1;
|
||||
at::Tensor qTorch = at::fromDLPack(q.ToDLPack()),
|
||||
oTorch = at::empty_like(qTorch),
|
||||
mTorch =
|
||||
at::empty({kB, kH, kN}, qTorch.options().dtype(at::kFloat));
|
||||
tvm::ffi::Tensor o = tvm::ffi::Tensor::FromDLPack(at::toDLPack(oTorch)),
|
||||
m = tvm::ffi::Tensor::FromDLPack(at::toDLPack(mTorch));
|
||||
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
|
||||
[kB, kH, kN](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
|
||||
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
|
||||
const int32_t kBlockM = meta["BLOCK_M"].cast<int32_t>();
|
||||
return tvm::ffi::Tuple<int32_t, int32_t, int32_t>(
|
||||
(kN + kBlockM - 1) / kBlockM, kB * kH, 1);
|
||||
});
|
||||
tvm::ffi::Array<tvm::ffi::Any> args = {smScale, m, kB, kH, q, k, v, o, kN};
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
|
||||
{"HEAD_DIM", kK},
|
||||
{"STAGE", stage},
|
||||
};
|
||||
DLDevice device = q.device();
|
||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||
_ATTN_FWD_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return tvm::ffi::Tuple{m, o};
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef().def(_ATTN_FWD_TVM_FFI_NAME, AttnFwd);
|
||||
}
|
||||
853
examples/attention/attention.py
Normal file
853
examples/attention/attention.py
Normal file
@@ -0,0 +1,853 @@
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
|
||||
|
||||
Credits: OpenAI kernel team
|
||||
|
||||
Extra Credits:
|
||||
|
||||
* Original flash attention paper (https://arxiv.org/abs/2205.14135)
|
||||
* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.tools.tensor_descriptor import TensorDescriptor
|
||||
import triton_tvm_ffi
|
||||
|
||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q, #
|
||||
desc_k,
|
||||
desc_v, #
|
||||
offset_y,
|
||||
dtype: tl.constexpr,
|
||||
start_m,
|
||||
qk_scale, #
|
||||
BLOCK_M: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, #
|
||||
STAGE: tl.constexpr,
|
||||
offs_m: tl.constexpr,
|
||||
offs_n: tl.constexpr, #
|
||||
N_CTX: tl.constexpr,
|
||||
):
|
||||
# range of values handled by this stage
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
# causal = False
|
||||
else:
|
||||
lo, hi = 0, N_CTX
|
||||
offsetk_y = offset_y + lo
|
||||
if dtype == tl.float8e5:
|
||||
offsetv_y = offset_y * HEAD_DIM + lo
|
||||
else:
|
||||
offsetv_y = offset_y + lo
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in tl.range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = desc_k.load([offsetk_y, 0]).T
|
||||
qk = tl.dot(q, k)
|
||||
if STAGE == 2:
|
||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk -= m_ij[:, None]
|
||||
else:
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
|
||||
qk = qk * qk_scale - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
# -- compute correction factor
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
l_ij = tl.sum(p, 1)
|
||||
acc = acc * alpha[:, None]
|
||||
# prepare p and v for the dot
|
||||
if dtype == tl.float8e5:
|
||||
v = desc_v.load([0, offsetv_y]).T
|
||||
else:
|
||||
v = desc_v.load([offsetv_y, 0])
|
||||
p = p.to(dtype)
|
||||
# note that this non transposed v for FP8 is only supported on Blackwell
|
||||
acc = tl.dot(p, v, acc)
|
||||
# update m_i and l_i
|
||||
# place this at the end of the loop to reduce register pressure
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
offsetk_y += BLOCK_N
|
||||
offsetv_y += BLOCK_N
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
def _host_descriptor_pre_hook(nargs):
|
||||
BLOCK_M = nargs["BLOCK_M"]
|
||||
BLOCK_N = nargs["BLOCK_N"]
|
||||
HEAD_DIM = nargs["HEAD_DIM"]
|
||||
if not isinstance(nargs["desc_q"], TensorDescriptor):
|
||||
return
|
||||
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
|
||||
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
|
||||
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
|
||||
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
|
||||
|
||||
|
||||
NUM_STAGES_OPTIONS = [2, 3, 4]
|
||||
|
||||
configs = [
|
||||
triton.Config(
|
||||
{"BLOCK_M": BM, "BLOCK_N": BN},
|
||||
num_stages=s,
|
||||
num_warps=w,
|
||||
pre_hook=_host_descriptor_pre_hook,
|
||||
)
|
||||
for BM in [64, 128]
|
||||
for BN in [32, 64, 128]
|
||||
for s in NUM_STAGES_OPTIONS
|
||||
for w in [4, 8]
|
||||
]
|
||||
if "PYTEST_VERSION" in os.environ:
|
||||
# Use a single config in testing for reproducibility
|
||||
configs = [
|
||||
triton.Config(
|
||||
dict(BLOCK_M=128, BLOCK_N=64),
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
pre_hook=_host_descriptor_pre_hook,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def keep(conf):
|
||||
BLOCK_M = conf.kwargs["BLOCK_M"]
|
||||
BLOCK_N = conf.kwargs["BLOCK_N"]
|
||||
return not (
|
||||
torch.cuda.get_device_capability()[0] == 9
|
||||
and BLOCK_M * BLOCK_N < 128 * 128
|
||||
and conf.num_warps == 8
|
||||
)
|
||||
|
||||
|
||||
def prune_invalid_configs(configs, named_args, **kwargs):
|
||||
N_CTX = kwargs["N_CTX"]
|
||||
|
||||
# Filter out configs where BLOCK_M > N_CTX
|
||||
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
|
||||
if isinstance(desc_or_ptr, tl.tensor_descriptor):
|
||||
return desc_or_ptr
|
||||
else:
|
||||
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)
|
||||
|
||||
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.autotune(
|
||||
configs=list(filter(keep, configs)),
|
||||
key=["N_CTX", "HEAD_DIM"],
|
||||
prune_configs_by={"early_config_prune": prune_invalid_configs},
|
||||
)
|
||||
@triton.jit
|
||||
def _attn_fwd(
|
||||
sm_scale,
|
||||
M, #
|
||||
Z,
|
||||
H,
|
||||
desc_q,
|
||||
desc_k,
|
||||
desc_v,
|
||||
desc_o,
|
||||
N_CTX, #
|
||||
HEAD_DIM: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
STAGE: tl.constexpr, #
|
||||
):
|
||||
dtype = tl.float16
|
||||
tl.static_assert(BLOCK_N <= HEAD_DIM)
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
|
||||
y_dim = Z * H * N_CTX
|
||||
desc_q = _maybe_make_tensor_desc(
|
||||
desc_q,
|
||||
shape=[y_dim, HEAD_DIM],
|
||||
strides=[HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_M, HEAD_DIM],
|
||||
)
|
||||
desc_v = _maybe_make_tensor_desc(
|
||||
desc_v,
|
||||
shape=[y_dim, HEAD_DIM],
|
||||
strides=[HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_N, HEAD_DIM],
|
||||
)
|
||||
desc_k = _maybe_make_tensor_desc(
|
||||
desc_k,
|
||||
shape=[y_dim, HEAD_DIM],
|
||||
strides=[HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_N, HEAD_DIM],
|
||||
)
|
||||
desc_o = _maybe_make_tensor_desc(
|
||||
desc_o,
|
||||
shape=[y_dim, HEAD_DIM],
|
||||
strides=[HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_M, HEAD_DIM],
|
||||
)
|
||||
|
||||
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
|
||||
qo_offset_y = offset_y + start_m * BLOCK_M
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
# load scales
|
||||
qk_scale = sm_scale
|
||||
qk_scale *= 1.44269504 # 1/log(2)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = desc_q.load([qo_offset_y, 0])
|
||||
# stage 1: off-band
|
||||
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
||||
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
||||
if STAGE & 1:
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q, #
|
||||
desc_k,
|
||||
desc_v, #
|
||||
offset_y,
|
||||
dtype,
|
||||
start_m,
|
||||
qk_scale, #
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N, #
|
||||
4 - STAGE,
|
||||
offs_m,
|
||||
offs_n,
|
||||
N_CTX, #
|
||||
)
|
||||
# stage 2: on-band
|
||||
if STAGE & 2:
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q, #
|
||||
desc_k,
|
||||
desc_v, #
|
||||
offset_y,
|
||||
dtype,
|
||||
start_m,
|
||||
qk_scale, #
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N, #
|
||||
2,
|
||||
offs_m,
|
||||
offs_n,
|
||||
N_CTX, #
|
||||
)
|
||||
# epilogue
|
||||
m_i += tl.math.log2(l_i)
|
||||
acc = acc / l_i[:, None]
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(m_ptrs, m_i)
|
||||
desc_o.store([qo_offset_y, 0], acc.to(dtype))
|
||||
|
||||
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.jit
|
||||
def _attn_bwd_preprocess(
|
||||
O,
|
||||
DO, #
|
||||
Delta, #
|
||||
Z,
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr, #
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_hz = tl.program_id(1)
|
||||
off_n = tl.arange(0, HEAD_DIM)
|
||||
# load
|
||||
o = tl.load(
|
||||
O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
|
||||
)
|
||||
do = tl.load(
|
||||
DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
|
||||
).to(tl.float32)
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(Delta + off_hz * N_CTX + off_m, delta)
|
||||
|
||||
|
||||
# The main inner-loop logic for computing dK and dV.
|
||||
@triton.jit
|
||||
def _attn_bwd_dkdv(
|
||||
dk,
|
||||
dv, #
|
||||
Q,
|
||||
k,
|
||||
v,
|
||||
sm_scale, #
|
||||
DO, #
|
||||
M,
|
||||
D, #
|
||||
# shared by Q/K/V/DO.
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX,
|
||||
BLOCK_M1: tl.constexpr, #
|
||||
BLOCK_N1: tl.constexpr, #
|
||||
HEAD_DIM: tl.constexpr, #
|
||||
# Filled in by the wrapper.
|
||||
start_n,
|
||||
start_m,
|
||||
num_steps, #
|
||||
MASK: tl.constexpr,
|
||||
):
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M1)
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
|
||||
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
||||
curr_m = start_m
|
||||
step_m = BLOCK_M1
|
||||
for blk_idx in range(num_steps):
|
||||
qT = tl.load(qT_ptrs)
|
||||
# Load m before computing qk to reduce pipeline stall.
|
||||
offs_m = curr_m + tl.arange(0, BLOCK_M1)
|
||||
m = tl.load(M + offs_m)
|
||||
qkT = tl.dot(k, qT)
|
||||
pT = tl.math.exp2(qkT - m[None, :])
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
mask = offs_m[None, :] >= offs_n[:, None]
|
||||
pT = tl.where(mask, pT, 0.0)
|
||||
do = tl.load(do_ptrs)
|
||||
# Compute dV.
|
||||
ppT = pT
|
||||
ppT = ppT.to(tl.float16)
|
||||
dv += tl.dot(ppT, do)
|
||||
# D (= delta) is pre-divided by ds_scale.
|
||||
Di = tl.load(D + offs_m)
|
||||
# Compute dP and dS.
|
||||
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
|
||||
dsT = pT * (dpT - Di[None, :])
|
||||
dsT = dsT.to(tl.float16)
|
||||
dk += tl.dot(dsT, tl.trans(qT))
|
||||
# Increment pointers.
|
||||
curr_m += step_m
|
||||
qT_ptrs += step_m * stride_tok
|
||||
do_ptrs += step_m * stride_tok
|
||||
return dk, dv
|
||||
|
||||
|
||||
# the main inner-loop logic for computing dQ
|
||||
@triton.jit
|
||||
def _attn_bwd_dq(
|
||||
dq,
|
||||
q,
|
||||
K,
|
||||
V, #
|
||||
do,
|
||||
m,
|
||||
D,
|
||||
# shared by Q/K/V/DO.
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M2: tl.constexpr, #
|
||||
BLOCK_N2: tl.constexpr, #
|
||||
HEAD_DIM: tl.constexpr,
|
||||
# Filled in by the wrapper.
|
||||
start_m,
|
||||
start_n,
|
||||
num_steps, #
|
||||
MASK: tl.constexpr,
|
||||
):
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N2)
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
||||
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
||||
# D (= delta) is pre-divided by ds_scale.
|
||||
Di = tl.load(D + offs_m)
|
||||
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
||||
curr_n = start_n
|
||||
step_n = BLOCK_N2
|
||||
for blk_idx in range(num_steps):
|
||||
kT = tl.load(kT_ptrs)
|
||||
vT = tl.load(vT_ptrs)
|
||||
qk = tl.dot(q, kT)
|
||||
p = tl.math.exp2(qk - m)
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
offs_n = curr_n + tl.arange(0, BLOCK_N2)
|
||||
mask = offs_m[:, None] >= offs_n[None, :]
|
||||
p = tl.where(mask, p, 0.0)
|
||||
# Compute dP and dS.
|
||||
dp = tl.dot(do, vT).to(tl.float32)
|
||||
ds = p * (dp - Di[:, None])
|
||||
ds = ds.to(tl.float16)
|
||||
# Compute dQ.
|
||||
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
||||
dq += tl.dot(ds, tl.trans(kT))
|
||||
# Increment pointers.
|
||||
curr_n += step_n
|
||||
kT_ptrs += step_n * stride_tok
|
||||
vT_ptrs += step_n * stride_tok
|
||||
return dq
|
||||
|
||||
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.jit
|
||||
def _attn_bwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale, #
|
||||
DO, #
|
||||
DQ,
|
||||
DK,
|
||||
DV, #
|
||||
M,
|
||||
D,
|
||||
# shared by Q/K/V/DO.
|
||||
stride_z,
|
||||
stride_h,
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M1: tl.constexpr, #
|
||||
BLOCK_N1: tl.constexpr, #
|
||||
BLOCK_M2: tl.constexpr, #
|
||||
BLOCK_N2: tl.constexpr, #
|
||||
BLK_SLICE_FACTOR: tl.constexpr, #
|
||||
HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
|
||||
|
||||
bhid = tl.program_id(2)
|
||||
off_chz = (bhid * N_CTX).to(tl.int64)
|
||||
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# offset pointers for batch/head
|
||||
Q += adj
|
||||
K += adj
|
||||
V += adj
|
||||
DO += adj
|
||||
DQ += adj
|
||||
DK += adj
|
||||
DV += adj
|
||||
M += off_chz
|
||||
D += off_chz
|
||||
|
||||
# load scales
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
|
||||
start_n = pid * BLOCK_N1
|
||||
start_m = start_n
|
||||
|
||||
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
||||
|
||||
dk, dv = _attn_bwd_dkdv(
|
||||
dk,
|
||||
dv, #
|
||||
Q,
|
||||
k,
|
||||
v,
|
||||
sm_scale, #
|
||||
DO, #
|
||||
M,
|
||||
D, #
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
MASK_BLOCK_M1,
|
||||
BLOCK_N1,
|
||||
HEAD_DIM, #
|
||||
start_n,
|
||||
start_m,
|
||||
num_steps, #
|
||||
MASK=True, #
|
||||
)
|
||||
|
||||
start_m += num_steps * MASK_BLOCK_M1
|
||||
num_steps = (N_CTX - start_m) // BLOCK_M1
|
||||
|
||||
# Compute dK and dV for non-masked blocks.
|
||||
dk, dv = _attn_bwd_dkdv( #
|
||||
dk,
|
||||
dv, #
|
||||
Q,
|
||||
k,
|
||||
v,
|
||||
sm_scale, #
|
||||
DO, #
|
||||
M,
|
||||
D, #
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M1,
|
||||
BLOCK_N1,
|
||||
HEAD_DIM, #
|
||||
start_n,
|
||||
start_m,
|
||||
num_steps, #
|
||||
MASK=False, #
|
||||
)
|
||||
|
||||
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dv_ptrs, dv)
|
||||
|
||||
# Write back dK.
|
||||
dk *= sm_scale
|
||||
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
# THIS BLOCK DOES DQ:
|
||||
start_m = pid * BLOCK_M2
|
||||
end_n = start_m + BLOCK_M2
|
||||
|
||||
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
|
||||
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
|
||||
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
m = tl.load(M + offs_m)
|
||||
m = m[:, None]
|
||||
|
||||
# Compute dQ for masked (diagonal) blocks.
|
||||
# NOTE: This code scans each row of QK^T backward (from right to left,
|
||||
# but inside each call to _attn_bwd_dq, from left to right), but that's
|
||||
# not due to anything important. I just wanted to reuse the loop
|
||||
# structure for dK & dV above as much as possible.
|
||||
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq,
|
||||
q,
|
||||
K,
|
||||
V, #
|
||||
do,
|
||||
m,
|
||||
D, #
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M2,
|
||||
MASK_BLOCK_N2,
|
||||
HEAD_DIM, #
|
||||
start_m,
|
||||
end_n - num_steps * MASK_BLOCK_N2,
|
||||
num_steps, #
|
||||
MASK=True, #
|
||||
)
|
||||
end_n -= num_steps * MASK_BLOCK_N2
|
||||
# stage 2
|
||||
num_steps = end_n // BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq,
|
||||
q,
|
||||
K,
|
||||
V, #
|
||||
do,
|
||||
m,
|
||||
D, #
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M2,
|
||||
BLOCK_N2,
|
||||
HEAD_DIM, #
|
||||
start_m,
|
||||
end_n - num_steps * BLOCK_N2,
|
||||
num_steps, #
|
||||
MASK=False, #
|
||||
)
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
dq *= LN2
|
||||
tl.store(dq_ptrs, dq)
|
||||
|
||||
|
||||
class _attention_triton(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
# shape constraints
|
||||
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
|
||||
# when v is in float8_e5m2 it is transposed.
|
||||
HEAD_DIM_V = v.shape[-1]
|
||||
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
|
||||
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
|
||||
o = torch.empty_like(q)
|
||||
stage = 3 if causal else 1
|
||||
|
||||
M = torch.empty(
|
||||
(q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
|
||||
)
|
||||
desc_q = q
|
||||
desc_v = v
|
||||
desc_k = k
|
||||
desc_o = o
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(q.shape[2], META["BLOCK_M"]),
|
||||
q.shape[0] * q.shape[1],
|
||||
1,
|
||||
)
|
||||
|
||||
_attn_fwd[grid](
|
||||
sm_scale,
|
||||
M, #
|
||||
q.shape[0],
|
||||
q.shape[1], #
|
||||
desc_q,
|
||||
desc_k,
|
||||
desc_v,
|
||||
desc_o, #
|
||||
N_CTX=q.shape[2], #
|
||||
HEAD_DIM=HEAD_DIM_K, #
|
||||
STAGE=stage, #
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, M)
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.HEAD_DIM = HEAD_DIM_K
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, M = ctx.saved_tensors
|
||||
assert do.is_contiguous()
|
||||
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
BATCH, N_HEAD, N_CTX = q.shape[:3]
|
||||
PRE_BLOCK = 128
|
||||
NUM_WARPS, NUM_STAGES = 4, 5
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
|
||||
BLK_SLICE_FACTOR = 2
|
||||
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
||||
arg_k = k
|
||||
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
|
||||
PRE_BLOCK = 128
|
||||
assert N_CTX % PRE_BLOCK == 0
|
||||
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
|
||||
delta = torch.empty_like(M)
|
||||
_attn_bwd_preprocess[pre_grid](
|
||||
o,
|
||||
do, #
|
||||
delta, #
|
||||
BATCH,
|
||||
N_HEAD,
|
||||
N_CTX, #
|
||||
BLOCK_M=PRE_BLOCK,
|
||||
HEAD_DIM=ctx.HEAD_DIM, #
|
||||
)
|
||||
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
|
||||
_attn_bwd[grid](
|
||||
q,
|
||||
arg_k,
|
||||
v,
|
||||
ctx.sm_scale,
|
||||
do,
|
||||
dq,
|
||||
dk,
|
||||
dv, #
|
||||
M,
|
||||
delta, #
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3), #
|
||||
N_HEAD,
|
||||
N_CTX, #
|
||||
BLOCK_M1=BLOCK_M1,
|
||||
BLOCK_N1=BLOCK_N1, #
|
||||
BLOCK_M2=BLOCK_M2,
|
||||
BLOCK_N2=BLOCK_N2, #
|
||||
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
|
||||
HEAD_DIM=ctx.HEAD_DIM, #
|
||||
num_warps=NUM_WARPS, #
|
||||
num_stages=NUM_STAGES, #
|
||||
)
|
||||
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[_attn_fwd],
|
||||
Path(__file__).parent / "attention.cc",
|
||||
)
|
||||
def _attn_fwd_tvm_ffi(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, sm_scale: float
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class _attention_tvm_ffi(_attention_triton):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
# shape constraints
|
||||
HEAD_DIM_K = k.shape[-1]
|
||||
M, o = _attn_fwd_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
M = torch.from_dlpack(M)
|
||||
o = torch.from_dlpack(o)
|
||||
ctx.save_for_backward(q, k, v, o, M)
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.HEAD_DIM = HEAD_DIM_K
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
|
||||
def attn_torch(q, k, v, causal=False, sm_scale=1.0):
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
if causal:
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1)
|
||||
p = p.to(q.dtype)
|
||||
out = torch.matmul(p, v)
|
||||
return out
|
||||
|
||||
|
||||
def attn_triton(q, k, v, causal=False, sm_scale=1.0):
|
||||
return _attention_triton.apply(q, k, v, causal, sm_scale)
|
||||
|
||||
|
||||
def attn_tvm_ffi(q, k, v, causal=False, sm_scale=1.0):
|
||||
return _attention_tvm_ffi.apply(q, k, v, causal, sm_scale)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Z = 1
|
||||
H = 2
|
||||
N_CTX = 128
|
||||
HEAD_DIM = 64
|
||||
causal = True
|
||||
mode = "fwd"
|
||||
dtype = torch.float16
|
||||
torch.manual_seed(20)
|
||||
q = (
|
||||
torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
k = (
|
||||
torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
v = (
|
||||
torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
sm_scale = 0.5
|
||||
# reference implementation
|
||||
ref_dtype = dtype
|
||||
q = q.to(ref_dtype)
|
||||
k = k.to(ref_dtype)
|
||||
v = v.to(ref_dtype)
|
||||
ref_out = attn_torch(q, k, v, causal, sm_scale).half()
|
||||
if mode == "bwd":
|
||||
dout = torch.randn_like(q)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
tri_out = attn_triton(q, k, v, causal, sm_scale).half()
|
||||
tvm_ffi_out = attn_tvm_ffi(q, k, v, causal, sm_scale).half()
|
||||
round = 1000
|
||||
if mode == "fwd":
|
||||
atol = 1e-2
|
||||
torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(tvm_ffi_out, ref_out, atol=atol, rtol=0)
|
||||
for _ in range(5):
|
||||
attn_torch(q, k, v, causal, sm_scale)
|
||||
attn_triton(q, k, v, causal, sm_scale)
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_torch(q, k, v, causal, sm_scale)
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_triton(q, k, v, causal, sm_scale)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
elif mode == "bwd":
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=0)
|
||||
rtol = 0.0
|
||||
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
|
||||
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
|
||||
if (
|
||||
torch.version.hip is not None
|
||||
and triton.runtime.driver.active.get_current_target().arch == "gfx90a"
|
||||
):
|
||||
rtol = 1e-2
|
||||
torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol)
|
||||
@@ -303,5 +303,5 @@ if __name__ == "__main__":
|
||||
matmul(a, b, "")
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch matmul: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton matmul: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI matmul: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user