Compare commits

..

14 Commits

Author SHA1 Message Date
JinjieLiu a727526794 add length check to unwrap arguments
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
2026-02-12 15:31:50 +08:00
JinjieLiu 599957e156 support attention bwd
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
2026-02-10 17:01:28 +08:00
JinjieLiu e41ec26329 support attention fwd
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
2026-02-09 17:15:51 +08:00
jinjieliu 213e4fc060 use templates to substitute parts of macros
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
2026-02-08 22:24:12 +08:00
jinjieliu 1c4f13c8f0 find packages by sysconfig instead of importlib
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
2026-02-08 16:15:56 +08:00
jinjieliu 24237a6313 include header files by c/cpp instead of jinja
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
2026-02-07 17:16:49 +08:00
jinjieliu 6a19a6b06d put num_warps and num_stages in kwargs
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
2026-02-07 14:25:10 +08:00
jinjieliu 2298b6f8c8 support mm and autotune
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
2026-02-07 00:41:23 +08:00
JinjieLiu f6c7a48c1b enable lambda function for grid descriptor
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
2026-02-05 15:59:22 +08:00
jinjieliu 8b8aa6cb84 enable optional for numwarps and numstages
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
2026-02-05 01:01:49 +08:00
JinjieLiu b7bf598fde support softmax
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
2026-02-04 16:38:33 +08:00
JinjieLiu 192dc95ac0 supports decorator for jit and wrapper
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
2026-02-04 10:46:14 +08:00
JinjieLiu 6e4c2d4a43 fix bugs on lacking cudatoolkit
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
2026-02-04 10:27:44 +08:00
jinjieliu dc8c2c17e0 verify tvm-ffi cpp wrapper on vector-add.py
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
2026-02-04 02:36:06 +08:00
36 changed files with 2168 additions and 948 deletions
+2 -1
View File
@@ -9,8 +9,9 @@ wheels/
# Virtual environments
.venv
.vscode/
.cache
.clangd
.ruff_cache
.python-version
uv.lock
-28
View File
@@ -1,28 +0,0 @@
cmake_minimum_required(VERSION 3.18)
if(DEFINED SKBUILD_PROJECT_NAME)
project(${SKBUILD_PROJECT_NAME})
else()
project(triton-tvm-ffi)
endif()
string(REPLACE "-" "_" TARGET_NAME "${PROJECT_NAME}")
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
else(CMAKE_BUILD_TYPE STREQUAL "Release")
endif()
find_package(CUDAToolkit REQUIRED)
find_package(Python COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir
OUTPUT_VARIABLE TVM_FFI_CMAKEDIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)
list(APPEND CMAKE_PREFIX_PATH "${TVM_FFI_CMAKEDIR}")
find_package(tvm_ffi CONFIG REQUIRED)
add_subdirectory(${PROJECT_SOURCE_DIR}/src)
-18
View File
@@ -1,18 +0,0 @@
# Triton-TVM-FFI
## Instructions
### Debug Install
```bash
SKBUILD_BUILD_DIR="build" SKBUILD_CMAKE_BUILD_TYPE=Debug uv pip install --no-build-isolation -ve .
```
### Release Install
```bash
SKBUILD_CMAKE_BUILD_TYPE=Release uv pip install -v .
```
### Format
```bash
find python -name "*.py" | xargs ruff format && find include src -name "*.h" -o -name "*.cc" | xargs clang-format -i
```
+37
View File
@@ -0,0 +1,37 @@
#include <ATen/DLConvertor.h>
#include <ATen/dlpack.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef ADD_KERNEL_STUB
#define ADD_KERNEL_STUB(grid, device, stream, args, kwargs)
#endif
#ifndef ADD_NAME
#define ADD_NAME ""
#endif
tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
at::Tensor otorch = at::empty_like(xtorch);
int32_t numel = otorch.numel();
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
[numel](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
const int32_t BLOCK_SIZE = meta["BLOCK_SIZE"].cast<int32_t>();
return tvm::ffi::Tuple((numel + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1);
});
DLDevice device = x.device();
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
tvm::ffi::Array<tvm::ffi::Any> args = {x, y, output, numel, 1024};
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {};
ADD_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
return output;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(ADD_NAME, Add);
}
+75
View File
@@ -0,0 +1,75 @@
from pathlib import Path
import time
import torch
import triton
import triton.language as tl
import triton_tvm_ffi
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@triton_tvm_ffi.jit
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output: torch.Tensor = torch.empty_like(x)
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements: int = output.numel()
BLOCK_SIZE: int = 1024
add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1, 1)](
x, y, output, n_elements, BLOCK_SIZE
)
return output
@triton_tvm_ffi.torch_wrap(
[add_kernel],
Path(__file__).parent / "add.cc",
)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
if __name__ == "__main__":
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x + y
output_triton = add_triton(x, y)
output_tvm_ffi = add(x, y)
assert torch.allclose(output_torch, output_triton)
assert torch.allclose(output_torch, output_tvm_ffi)
output_tvm_ffi = add(x, y)
assert torch.allclose(output_torch, output_tvm_ffi)
round = 1000
cp0 = time.perf_counter_ns()
for _ in range(round):
x + y
cp1 = time.perf_counter_ns()
for _ in range(round):
add_triton(x, y)
cp2 = time.perf_counter_ns()
for _ in range(round):
add(x, y)
cp3 = time.perf_counter_ns()
print(
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
)
+937
View File
@@ -0,0 +1,937 @@
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team
Extra Credits:
* Original flash attention paper (https://arxiv.org/abs/2205.14135)
* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import os
from pathlib import Path
import time
from typing import Sequence
import torch
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
import triton_tvm_ffi
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q, #
desc_k,
desc_v, #
offset_y,
dtype: tl.constexpr,
start_m,
qk_scale, #
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr, #
N_CTX: tl.constexpr,
):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
offsetk_y = offset_y + lo
if dtype == tl.float8e5:
offsetv_y = offset_y * HEAD_DIM + lo
else:
offsetv_y = offset_y + lo
# loop over k, v and update accumulator
for start_n in tl.range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = desc_k.load([offsetk_y, 0]).T
qk = tl.dot(q, k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk)
# -- compute correction factor
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
# prepare p and v for the dot
if dtype == tl.float8e5:
v = desc_v.load([0, offsetv_y]).T
else:
v = desc_v.load([offsetv_y, 0])
p = p.to(dtype)
# note that this non transposed v for FP8 is only supported on Blackwell
acc = tl.dot(p, v, acc)
# update m_i and l_i
# place this at the end of the loop to reduce register pressure
l_i = l_i * alpha + l_ij
m_i = m_ij
offsetk_y += BLOCK_N
offsetv_y += BLOCK_N
return acc, l_i, m_i
def _host_descriptor_pre_hook(nargs):
BLOCK_M = nargs["BLOCK_M"]
BLOCK_N = nargs["BLOCK_N"]
HEAD_DIM = nargs["HEAD_DIM"]
if not isinstance(nargs["desc_q"], TensorDescriptor):
return
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
NUM_STAGES_OPTIONS = [2, 3, 4]
configs = [
triton.Config(
{"BLOCK_M": BM, "BLOCK_N": BN},
num_stages=s,
num_warps=w,
pre_hook=_host_descriptor_pre_hook,
)
for BM in [64, 128]
for BN in [32, 64, 128]
for s in NUM_STAGES_OPTIONS
for w in [4, 8]
]
if "PYTEST_VERSION" in os.environ:
# Use a single config in testing for reproducibility
configs = [
triton.Config(
dict(BLOCK_M=128, BLOCK_N=64),
num_stages=2,
num_warps=4,
pre_hook=_host_descriptor_pre_hook,
),
]
def keep(conf):
BLOCK_M = conf.kwargs["BLOCK_M"]
BLOCK_N = conf.kwargs["BLOCK_N"]
return not (
torch.cuda.get_device_capability()[0] == 9
and BLOCK_M * BLOCK_N < 128 * 128
and conf.num_warps == 8
)
def prune_invalid_configs(configs, named_args, **kwargs):
N_CTX = kwargs["N_CTX"]
# Filter out configs where BLOCK_M > N_CTX
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX]
@triton.jit
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
if isinstance(desc_or_ptr, tl.tensor_descriptor):
return desc_or_ptr
else:
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)
@triton_tvm_ffi.jit
@triton.autotune(
configs=list(filter(keep, configs)),
key=["N_CTX", "HEAD_DIM"],
prune_configs_by={"early_config_prune": prune_invalid_configs},
)
@triton.jit
def _attn_fwd(
sm_scale,
M, #
Z,
H,
desc_q,
desc_k,
desc_v,
desc_o,
N_CTX, #
HEAD_DIM: tl.constexpr, #
BLOCK_M: tl.constexpr, #
BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, #
):
dtype = tl.float16
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
y_dim = Z * H * N_CTX
desc_q = _maybe_make_tensor_desc(
desc_q,
shape=[y_dim, HEAD_DIM],
strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM],
)
desc_v = _maybe_make_tensor_desc(
desc_v,
shape=[y_dim, HEAD_DIM],
strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM],
)
desc_k = _maybe_make_tensor_desc(
desc_k,
shape=[y_dim, HEAD_DIM],
strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM],
)
desc_o = _maybe_make_tensor_desc(
desc_o,
shape=[y_dim, HEAD_DIM],
strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM],
)
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
qo_offset_y = offset_y + start_m * BLOCK_M
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
q = desc_q.load([qo_offset_y, 0])
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q, #
desc_k,
desc_v, #
offset_y,
dtype,
start_m,
qk_scale, #
BLOCK_M,
HEAD_DIM,
BLOCK_N, #
4 - STAGE,
offs_m,
offs_n,
N_CTX, #
)
# stage 2: on-band
if STAGE & 2:
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q, #
desc_k,
desc_v, #
offset_y,
dtype,
start_m,
qk_scale, #
BLOCK_M,
HEAD_DIM,
BLOCK_N, #
2,
offs_m,
offs_n,
N_CTX, #
)
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
desc_o.store([qo_offset_y, 0], acc.to(dtype))
@triton_tvm_ffi.jit
@triton.jit
def _attn_bwd_preprocess(
O,
DO, #
Delta, #
Z,
H,
N_CTX, #
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr, #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
off_n = tl.arange(0, HEAD_DIM)
# load
o = tl.load(
O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
)
do = tl.load(
DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_hz * N_CTX + off_m, delta)
# The main inner-loop logic for computing dK and dV.
@triton.jit
def _attn_bwd_dkdv(
dk,
dv, #
Q,
k,
v,
sm_scale, #
DO, #
M,
D, #
# shared by Q/K/V/DO.
stride_tok,
stride_d, #
H,
N_CTX,
BLOCK_M1: tl.constexpr, #
BLOCK_N1: tl.constexpr, #
HEAD_DIM: tl.constexpr, #
# Filled in by the wrapper.
start_n,
start_m,
num_steps, #
MASK: tl.constexpr,
):
offs_m = start_m + tl.arange(0, BLOCK_M1)
offs_n = start_n + tl.arange(0, BLOCK_N1)
offs_k = tl.arange(0, HEAD_DIM)
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
curr_m = start_m
step_m = BLOCK_M1
for blk_idx in range(num_steps):
qT = tl.load(qT_ptrs)
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
qkT = tl.dot(k, qT)
pT = tl.math.exp2(qkT - m[None, :])
# Autoregressive masking.
if MASK:
mask = offs_m[None, :] >= offs_n[:, None]
pT = tl.where(mask, pT, 0.0)
do = tl.load(do_ptrs)
# Compute dV.
ppT = pT
ppT = ppT.to(tl.float16)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.float16)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
qT_ptrs += step_m * stride_tok
do_ptrs += step_m * stride_tok
return dk, dv
# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq(
dq,
q,
K,
V, #
do,
m,
D,
# shared by Q/K/V/DO.
stride_tok,
stride_d, #
H,
N_CTX, #
BLOCK_M2: tl.constexpr, #
BLOCK_N2: tl.constexpr, #
HEAD_DIM: tl.constexpr,
# Filled in by the wrapper.
start_m,
start_n,
num_steps, #
MASK: tl.constexpr,
):
offs_m = start_m + tl.arange(0, BLOCK_M2)
offs_n = start_n + tl.arange(0, BLOCK_N2)
offs_k = tl.arange(0, HEAD_DIM)
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
curr_n = start_n
step_n = BLOCK_N2
for blk_idx in range(num_steps):
kT = tl.load(kT_ptrs)
vT = tl.load(vT_ptrs)
qk = tl.dot(q, kT)
p = tl.math.exp2(qk - m)
# Autoregressive masking.
if MASK:
offs_n = curr_n + tl.arange(0, BLOCK_N2)
mask = offs_m[:, None] >= offs_n[None, :]
p = tl.where(mask, p, 0.0)
# Compute dP and dS.
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - Di[:, None])
ds = ds.to(tl.float16)
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
curr_n += step_n
kT_ptrs += step_n * stride_tok
vT_ptrs += step_n * stride_tok
return dq
@triton_tvm_ffi.jit
@triton.jit
def _attn_bwd(
Q,
K,
V,
sm_scale, #
DO, #
DQ,
DK,
DV, #
M,
D,
# shared by Q/K/V/DO.
stride_z,
stride_h,
stride_tok,
stride_d, #
H,
N_CTX, #
BLOCK_M1: tl.constexpr, #
BLOCK_N1: tl.constexpr, #
BLOCK_M2: tl.constexpr, #
BLOCK_N2: tl.constexpr, #
BLK_SLICE_FACTOR: tl.constexpr, #
HEAD_DIM: tl.constexpr,
):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
bhid = tl.program_id(2)
off_chz = (bhid * N_CTX).to(tl.int64)
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
pid = tl.program_id(0)
# offset pointers for batch/head
Q += adj
K += adj
V += adj
DO += adj
DQ += adj
DK += adj
DV += adj
M += off_chz
D += off_chz
# load scales
offs_k = tl.arange(0, HEAD_DIM)
start_n = pid * BLOCK_N1
start_m = start_n
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
# load K and V: they stay in SRAM throughout the inner loop.
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(
dk,
dv, #
Q,
k,
v,
sm_scale, #
DO, #
M,
D, #
stride_tok,
stride_d, #
H,
N_CTX, #
MASK_BLOCK_M1,
BLOCK_N1,
HEAD_DIM, #
start_n,
start_m,
num_steps, #
MASK=True, #
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv( #
dk,
dv, #
Q,
k,
v,
sm_scale, #
DO, #
M,
D, #
stride_tok,
stride_d, #
H,
N_CTX, #
BLOCK_M1,
BLOCK_N1,
HEAD_DIM, #
start_n,
start_m,
num_steps, #
MASK=False, #
)
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dv_ptrs, dv)
# Write back dK.
dk *= sm_scale
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dk_ptrs, dk)
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
m = tl.load(M + offs_m)
m = m[:, None]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(
dq,
q,
K,
V, #
do,
m,
D, #
stride_tok,
stride_d, #
H,
N_CTX, #
BLOCK_M2,
MASK_BLOCK_N2,
HEAD_DIM, #
start_m,
end_n - num_steps * MASK_BLOCK_N2,
num_steps, #
MASK=True, #
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(
dq,
q,
K,
V, #
do,
m,
D, #
stride_tok,
stride_d, #
H,
N_CTX, #
BLOCK_M2,
BLOCK_N2,
HEAD_DIM, #
start_m,
end_n - num_steps * BLOCK_N2,
num_steps, #
MASK=False, #
)
# Write back dQ.
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
dq *= LN2
tl.store(dq_ptrs, dq)
class _attention_triton(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
o = torch.empty_like(q)
stage = 3 if causal else 1
M = torch.empty(
(q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
)
desc_q = q
desc_v = v
desc_k = k
desc_o = o
def grid(META):
return (
triton.cdiv(q.shape[2], META["BLOCK_M"]),
q.shape[0] * q.shape[1],
1,
)
_attn_fwd[grid](
sm_scale,
M, #
q.shape[0],
q.shape[1], #
desc_q,
desc_k,
desc_v,
desc_o, #
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
)
ctx.save_for_backward(q, k, v, o, M)
ctx.sm_scale = sm_scale
ctx.HEAD_DIM = HEAD_DIM_K
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
PRE_BLOCK = 128
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o,
do, #
delta, #
BATCH,
N_HEAD,
N_CTX, #
BLOCK_M=PRE_BLOCK,
HEAD_DIM=ctx.HEAD_DIM, #
)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q,
arg_k,
v,
ctx.sm_scale,
do,
dq,
dk,
dv, #
M,
delta, #
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3), #
N_HEAD,
N_CTX, #
BLOCK_M1=BLOCK_M1,
BLOCK_N1=BLOCK_N1, #
BLOCK_M2=BLOCK_M2,
BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM, #
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES, #
)
return dq, dk, dv, None, None, None, None
@triton_tvm_ffi.torch_wrap(
[_attn_fwd],
Path(__file__).parent / "attnfwd.cc",
)
def _attn_fwd_tvm_ffi(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, sm_scale: float
) -> torch.Tensor: ...
@triton_tvm_ffi.torch_wrap(
[_attn_bwd_preprocess],
Path(__file__).parent / "attnbwdpre.cc",
)
def _attn_bwd_preprocess_tvm_ffi(
o: torch.Tensor,
do: torch.Tensor,
mshape: Sequence[int],
head_dim: int,
): ...
@triton_tvm_ffi.torch_wrap(
[_attn_bwd],
Path(__file__).parent / "attnbwd.cc",
)
def _attn_bwd_tvm_ffi(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sm_scale: float,
do: torch.Tensor,
m: torch.Tensor,
delta: torch.Tensor,
HEAD_DIM: int,
): ...
class _attention_tvm_ffi(_attention_triton):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
HEAD_DIM_K = k.shape[-1]
M, o = _attn_fwd_tvm_ffi(q, k, v, causal, sm_scale)
M = torch.from_dlpack(M)
o = torch.from_dlpack(o)
ctx.save_for_backward(q, k, v, o, M)
ctx.sm_scale = sm_scale
ctx.HEAD_DIM = HEAD_DIM_K
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, M = ctx.saved_tensors
delta = _attn_bwd_preprocess_tvm_ffi(
o,
do,
M.shape,
ctx.HEAD_DIM,
)
dq, dk, dv = _attn_bwd_tvm_ffi(
q,
k,
v,
ctx.sm_scale,
do,
M,
delta,
ctx.HEAD_DIM,
)
dq = torch.from_dlpack(dq)
dk = torch.from_dlpack(dk)
dv = torch.from_dlpack(dv)
return dq, dk, dv, None, None, None, None
def attn_torch(q, k, v, causal=False, sm_scale=1.0):
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1)
p = p.to(q.dtype)
out = torch.matmul(p, v)
return out
def attn_triton(q, k, v, causal=False, sm_scale=1.0):
return _attention_triton.apply(q, k, v, causal, sm_scale)
def attn_tvm_ffi(q, k, v, causal=False, sm_scale=1.0):
return _attention_tvm_ffi.apply(q, k, v, causal, sm_scale)
if __name__ == "__main__":
Z = 1
H = 2
N_CTX = 128
HEAD_DIM = 64
causal = True
mode = "bwd"
dtype = torch.float16
torch.manual_seed(20)
q = (
torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
k = (
torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
v = (
torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
sm_scale = 0.5
# reference implementation
ref_dtype = dtype
q = q.to(ref_dtype)
k = k.to(ref_dtype)
v = v.to(ref_dtype)
ref_out = attn_torch(q, k, v, causal, sm_scale).half()
if mode == "bwd":
dout = torch.randn_like(q)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
tri_out = attn_triton(q, k, v, causal, sm_scale).half()
tvm_ffi_out = attn_tvm_ffi(q, k, v, causal, sm_scale).half()
warmup = 5
round = 1000
if mode == "fwd":
atol = 1e-2
torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0)
torch.testing.assert_close(tvm_ffi_out, ref_out, atol=atol, rtol=0)
with torch.no_grad():
for _ in range(warmup):
attn_torch(q, k, v, causal, sm_scale)
attn_triton(q, k, v, causal, sm_scale)
attn_tvm_ffi(q, k, v, causal, sm_scale)
cp0 = time.perf_counter_ns()
for _ in range(round):
attn_torch(q, k, v, causal, sm_scale)
cp1 = time.perf_counter_ns()
for _ in range(round):
attn_triton(q, k, v, causal, sm_scale)
cp2 = time.perf_counter_ns()
for _ in range(round):
attn_tvm_ffi(q, k, v, causal, sm_scale)
cp3 = time.perf_counter_ns()
print(
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
)
elif mode == "bwd":
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=0)
rtol = 0.0
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
if (
torch.version.hip is not None
and triton.runtime.driver.active.get_current_target().arch == "gfx90a"
):
rtol = 1e-2
torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol)
torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol)
torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol)
tvm_ffi_out.backward(dout)
tvm_ffi_dv, v.grad = v.grad.clone(), None
tvm_ffi_dk, k.grad = k.grad.clone(), None
tvm_ffi_dq, q.grad = q.grad.clone(), None
# compare
torch.testing.assert_close(tvm_ffi_out, ref_out, atol=1e-2, rtol=0)
rtol = 0.0
torch.testing.assert_close(tvm_ffi_dv, ref_dv, atol=1e-2, rtol=rtol)
torch.testing.assert_close(tvm_ffi_dk, ref_dk, atol=1e-2, rtol=rtol)
torch.testing.assert_close(tvm_ffi_dq, ref_dq, atol=1e-2, rtol=rtol)
for _ in range(warmup):
attn_torch(q, k, v, causal, sm_scale).backward(dout)
attn_triton(q, k, v, causal, sm_scale).backward(dout)
attn_tvm_ffi(q, k, v, causal, sm_scale).backward(dout)
cp0 = time.perf_counter_ns()
for _ in range(round):
attn_torch(q, k, v, causal, sm_scale).backward(dout)
cp1 = time.perf_counter_ns()
for _ in range(round):
attn_triton(q, k, v, causal, sm_scale).backward(dout)
cp2 = time.perf_counter_ns()
for _ in range(round):
attn_tvm_ffi(q, k, v, causal, sm_scale).backward(dout)
cp3 = time.perf_counter_ns()
print(
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
)
+56
View File
@@ -0,0 +1,56 @@
#include <ATen/DLConvertor.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/dlpack.h>
#include <ATen/ops/empty.h>
#include <torch/headeronly/core/DeviceType.h>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef _ATTN_BWD_STUB
#define _ATTN_BWD_STUB(grid, device, stream, args, kwargs)
#endif
#ifndef _ATTN_BWD_TVM_FFI_NAME
#define _ATTN_BWD_TVM_FFI_NAME ""
#endif
tvm::ffi::Tuple<tvm::ffi::Tensor, tvm::ffi::Tensor, tvm::ffi::Tensor>
AttnBwd(tvm::ffi::Tensor q, tvm::ffi::Tensor k, tvm::ffi::Tensor v,
const double smScale, tvm::ffi::Tensor do_, tvm::ffi::Tensor m,
tvm::ffi::Tensor delta, const int32_t kHeadDim) {
tvm::ffi::ShapeView qshape = q.shape(), qstride = q.strides();
const int32_t kBatch = qshape[0], kNHead = qshape[1], kNCtx = qshape[2],
kBlockN1 = 128;
const double kArgKScale = smScale / log(2);
at::Tensor qTorch = at::fromDLPack(q.ToDLPack()),
kTorch = at::fromDLPack(k.ToDLPack()),
vTorch = at::fromDLPack(v.ToDLPack()),
dqTorch = at::empty_like(qTorch), dkTorch = at::empty_like(kTorch),
dvTorch = at::empty_like(vTorch),
argKTorch = at::mul(kTorch, kArgKScale);
tvm::ffi::Tensor dq = tvm::ffi::Tensor::FromDLPack(at::toDLPack(dqTorch)),
dk = tvm::ffi::Tensor::FromDLPack(at::toDLPack(dkTorch)),
dv = tvm::ffi::Tensor::FromDLPack(at::toDLPack(dvTorch)),
argK = tvm::ffi::Tensor::FromDLPack(at::toDLPack(argKTorch));
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid(kNCtx / kBlockN1, 1,
kBatch * kNHead);
tvm::ffi::Array<tvm::ffi::Any> args = {
q, argK, v, smScale, do_, dq, dk, dv,
m, delta, qstride[0], qstride[1], qstride[2], qstride[3], kNHead, kNCtx};
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
{"BLOCK_M1", 32}, {"BLOCK_N1", kBlockN1}, {"BLOCK_M2", 128},
{"BLOCK_N2", 32}, {"BLK_SLICE_FACTOR", 2}, {"HEAD_DIM", kHeadDim},
{"num_warps", 4}, {"num_stages", 5},
};
DLDevice device = q.device();
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
_ATTN_BWD_STUB(grid, device.device_id, stream, args, kwargs);
return tvm::ffi::Tuple{dq, dk, dv};
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(_ATTN_BWD_TVM_FFI_NAME, AttnBwd);
}
+46
View File
@@ -0,0 +1,46 @@
#include <ATen/DLConvertor.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/dlpack.h>
#include <ATen/ops/empty.h>
#include <torch/headeronly/core/DeviceType.h>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef _ATTN_BWD_PREPROCESS_STUB
#define _ATTN_BWD_PREPROCESS_STUB(grid, device, stream, args, kwargs)
#endif
#ifndef _ATTN_BWD_PREPROCESS_TVM_FFI_NAME
#define _ATTN_BWD_PREPROCESS_TVM_FFI_NAME ""
#endif
tvm::ffi::Tensor AttnBwdPreprocess(tvm::ffi::Tensor o, tvm::ffi::Tensor do_,
tvm::ffi::Shape mshape,
const int32_t kHeadDim) {
const int32_t kBatch = mshape[0], kNHead = mshape[1], kNCtx = mshape[2],
kPreBlock = 128;
at::Tensor deltaTorch = at::empty(mshape, at::kFloat, std::nullopt,
at::Device(at::kCUDA, o.device().device_id),
std::nullopt, std::nullopt);
tvm::ffi::Tensor delta =
tvm::ffi::Tensor::FromDLPack(at::toDLPack(deltaTorch));
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid(kNCtx / kPreBlock,
kBatch * kNHead, 1);
tvm::ffi::Array<tvm::ffi::Any> args = {o, do_, delta, kBatch, kNHead};
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
{"N_CTX", kNCtx},
{"BLOCK_M", kPreBlock},
{"HEAD_DIM", kHeadDim},
};
DLDevice device = o.device();
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
_ATTN_BWD_PREPROCESS_STUB(grid, device.device_id, stream, args, kwargs);
return delta;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(_ATTN_BWD_PREPROCESS_TVM_FFI_NAME, AttnBwdPreprocess);
}
+52
View File
@@ -0,0 +1,52 @@
#include <ATen/DLConvertor.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/dlpack.h>
#include <ATen/ops/empty.h>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef _ATTN_FWD_STUB
#define _ATTN_FWD_STUB(grid, device, stream, args, kwargs)
#endif
#ifndef _ATTN_FWD_TVM_FFI_NAME
#define _ATTN_FWD_TVM_FFI_NAME ""
#endif
tvm::ffi::Tuple<tvm::ffi::Tensor, tvm::ffi::Tensor>
AttnFwd(tvm::ffi::Tensor q, tvm::ffi::Tensor k, tvm::ffi::Tensor v, bool casual,
float smScale) {
const tvm::ffi::ShapeView &qshape = q.shape(), &kshape = k.shape(),
&vshape = v.shape();
const int32_t kB = qshape[0], kH = qshape[1], kN = qshape[2], kQ = qshape[3],
kK = kshape[3], kV = vshape[3], stage = casual ? 3 : 1;
at::Tensor qTorch = at::fromDLPack(q.ToDLPack()),
oTorch = at::empty_like(qTorch),
mTorch =
at::empty({kB, kH, kN}, qTorch.options().dtype(at::kFloat));
tvm::ffi::Tensor o = tvm::ffi::Tensor::FromDLPack(at::toDLPack(oTorch)),
m = tvm::ffi::Tensor::FromDLPack(at::toDLPack(mTorch));
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
[kB, kH, kN](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
const int32_t kBlockM = meta["BLOCK_M"].cast<int32_t>();
return tvm::ffi::Tuple<int32_t, int32_t, int32_t>(
(kN + kBlockM - 1) / kBlockM, kB * kH, 1);
});
tvm::ffi::Array<tvm::ffi::Any> args = {smScale, m, kB, kH, q, k, v, o, kN};
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
{"HEAD_DIM", kK},
{"STAGE", stage},
};
DLDevice device = q.device();
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
_ATTN_FWD_STUB(grid, device.device_id, stream, args, kwargs);
return tvm::ffi::Tuple{m, o};
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(_ATTN_FWD_TVM_FFI_NAME, AttnFwd);
}
+55
View File
@@ -0,0 +1,55 @@
#include <ATen/DLConvertor.h>
#include <ATen/dlpack.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef MATMUL_KERNEL_STUB
#define MATMUL_KERNEL_STUB(grid, device, stream, args, kwargs)
#endif
#ifndef MATMUL_NAME
#define MATMUL_NAME ""
#endif
tvm::ffi::Tensor Matmul(tvm::ffi::Tensor a, tvm::ffi::Tensor b,
tvm::ffi::String activation) {
at::Tensor atorch = at::fromDLPack(a.ToDLPack()),
btorch = at::fromDLPack(b.ToDLPack());
const int32_t M = atorch.size(0), K = atorch.size(1), N = btorch.size(1);
at::Tensor ctorch = at::empty({M, N}, atorch.options());
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
[M, N](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
const int32_t BLOCK_SIZE_M = meta["BLOCK_SIZE_M"].cast<int32_t>(),
BLOCK_SIZE_N = meta["BLOCK_SIZE_N"].cast<int32_t>();
return tvm::ffi::Tuple<int32_t, int32_t, int32_t>{
(M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M *
((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N),
1, 1};
});
DLDevice device = a.device();
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
tvm::ffi::Tensor c = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ctorch));
tvm::ffi::Array<tvm::ffi::Any> args = {a,
b,
c,
M,
N,
K,
atorch.stride(0),
atorch.stride(1),
btorch.stride(0),
btorch.stride(1),
ctorch.stride(0),
ctorch.stride(1)};
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
{"ACTIVATION", activation},
};
MATMUL_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
return c;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(MATMUL_NAME, Matmul);
}
+307
View File
@@ -0,0 +1,307 @@
from pathlib import Path
import time
import torch
import triton
import triton.language as tl
import triton_tvm_ffi
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def get_autotune_config():
return [
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
]
@triton_tvm_ffi.jit
@triton.autotune(
configs=get_autotune_config(),
key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
c = accumulator.to(tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01 * x)
def matmul_triton(a, b, activation=""):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
matmul_kernel[
lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
ACTIVATION=activation,
)
return c
@triton_tvm_ffi.torch_wrap(
[matmul_kernel],
Path(__file__).parent / "mm.cc",
)
def matmul(a: torch.Tensor, b: torch.Tensor, activation: str = "") -> torch.Tensor: ...
if __name__ == "__main__":
torch.manual_seed(0)
a = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
b = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
torch_output = torch.matmul(a, b)
triton_output = matmul_triton(a, b, "")
tvm_ffi_output = matmul(a, b, "")
assert torch.allclose(torch_output, triton_output, atol=1e-2, rtol=1e-2)
assert torch.allclose(torch_output, tvm_ffi_output, atol=1e-2, rtol=1e-2)
tvm_ffi_output = matmul(a, b, "")
assert torch.allclose(torch_output, tvm_ffi_output, atol=1e-2, rtol=1e-2)
round = 1000
cp0 = time.perf_counter_ns()
for _ in range(round):
a @ b
cp1 = time.perf_counter_ns()
for _ in range(round):
matmul_triton(a, b, "")
cp2 = time.perf_counter_ns()
for _ in range(round):
matmul(a, b, "")
cp3 = time.perf_counter_ns()
print(
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
)
+34
View File
@@ -0,0 +1,34 @@
#include <ATen/DLConvertor.h>
#include <ATen/dlpack.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef SOFTMAX_KERNEL_STUB
#define SOFTMAX_KERNEL_STUB(grid, device, stream, args, kwargs)
#endif
#ifndef SOFTMAX_NAME
#define SOFTMAX_NAME ""
#endif
tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) {
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
at::Tensor ytorch = at::empty_like(xtorch);
uint32_t nRows = xtorch.size(0), nCols = xtorch.size(1),
xStride = xtorch.stride(0), yStride = ytorch.stride(0),
BLOCK_SIZE = 1u << (32 - __builtin_clz(nCols - 1));
tvm::ffi::Tensor y = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ytorch));
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{nRows / 1024, 1, 1};
DLDevice device = x.device();
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
tvm::ffi::Array<tvm::ffi::Any> args = {y, x, xStride, yStride,
nRows, nCols, BLOCK_SIZE};
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {};
SOFTMAX_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
return y;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(SOFTMAX_NAME, Softmax);
}
+88
View File
@@ -0,0 +1,88 @@
from pathlib import Path
import time
import torch
import triton
import triton.language as tl
import triton_tvm_ffi
@triton_tvm_ffi.jit
@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step):
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
def softmax_triton(x):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = 8
num_stages = 4
y = torch.empty_like(x)
softmax_kernel[(n_rows, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
BLOCK_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)
return y
@triton_tvm_ffi.torch_wrap(
[softmax_kernel],
Path(__file__).parent / "softmax.cc",
)
def softmax(x: torch.Tensor) -> torch.Tensor: ...
if __name__ == "__main__":
x = torch.randn(1823, 781, device="cuda")
y_torch = torch.softmax(x, axis=1)
y_triton = softmax_triton(x)
y_tvm_ffi = softmax(x)
assert torch.allclose(y_torch, y_triton), (y_torch, y_triton)
assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi)
y_tvm_ffi = softmax(x)
assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi)
round = 1000
cp0 = time.perf_counter_ns()
for _ in range(round):
torch.softmax(x, axis=1)
cp1 = time.perf_counter_ns()
for _ in range(round):
softmax_triton(x)
cp2 = time.perf_counter_ns()
for _ in range(round):
softmax(x)
cp3 = time.perf_counter_ns()
print(
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
)
-49
View File
@@ -1,49 +0,0 @@
#ifndef TRITON_TVM_FFI_EXCEPTION_H_
#define TRITON_TVM_FFI_EXCEPTION_H_
#include "type.h"
#include <cuda.h>
#include <exception>
namespace triton_tvm_ffi {
class CUDAException : public std::exception {
public:
CUDAException(CUresult code);
const char *what() const noexcept override;
private:
const CUresult code_;
};
class NotImplementedException : public std::exception {
public:
NotImplementedException(std::string_view name);
const char *what() const noexcept override;
private:
const std::string message_;
};
class UnmatchedArgumentException : public std::exception {
public:
UnmatchedArgumentException(std::string_view name, size_t len, size_t expect);
const char *what() const noexcept override;
private:
const std::string message_;
};
class UnknownTypeException : public std::exception {
public:
UnknownTypeException(Type type);
UnknownTypeException(std::string_view type);
const char *what() const noexcept override;
private:
const std::string message_;
};
} // namespace triton_tvm_ffi
#endif
-48
View File
@@ -1,48 +0,0 @@
#ifndef TRITON_TVM_FFI_LAUNCH_H_
#define TRITON_TVM_FFI_LAUNCH_H_
#include "type.h"
#include <cuda.h>
#include <tvm/ffi/object.h>
namespace triton_tvm_ffi {
class TVMFFILauncherImplObj : public tvm::ffi::Object {
public:
TVMFFILauncherImplObj(const tvm::ffi::Array<Type> &signature,
bool launchCooperativeGrid, bool launchAsync);
TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default;
TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default;
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function, int32_t numWarps, int32_t numCtas,
int32_t sharedMemory, uint64_t globalScratch,
uint64_t profileScratch,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl",
TVMFFILauncherImplObj, tvm::ffi::Object);
private:
tvm::ffi::Array<Type> signature_;
const bool launchCooperativeGrid_;
const bool launchAsync_;
};
class TVMFFILauncherImpl : public tvm::ffi::ObjectRef {
public:
TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
bool launchCooperativeGrid, bool launchAsync);
using tvm::ffi::ObjectRef::ObjectRef;
using tvm::ffi::ObjectRef::operator=;
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function, int32_t numWarps, int32_t numCtas,
int32_t sharedMemory, uint64_t globalScratch,
uint64_t profileScratch,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl,
tvm::ffi::ObjectRef,
TVMFFILauncherImplObj);
};
} // namespace triton_tvm_ffi
#endif
-19
View File
@@ -1,19 +0,0 @@
#ifndef TRITON_TVM_FFI_MACRO_H_
#define TRITON_TVM_FFI_MACRO_H_
#include "exception.h"
#if defined(__GNUC__) || defined(__clang__)
#define TRITON_TVM_FFI_INLINE __attribute__((always_inline)) inline
#endif
#define UNLIKELY(cond) __builtin_expect((cond), 0)
#define CUDA_CHECK(code) \
do { \
if (UNLIKELY((code) != CUDA_SUCCESS)) { \
throw triton_tvm_ffi::CUDAException(code); \
} \
} while (false)
#endif
+34
View File
@@ -0,0 +1,34 @@
#ifndef TRITON_TVM_FFI_GRID_H_
#define TRITON_TVM_FFI_GRID_H_
#include <cstdint>
#include <tvm/ffi/tvm_ffi.h>
namespace triton_tvm_ffi {
template <typename T>
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
MakeGridDim(const T &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta);
template <>
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
const tvm::ffi::Tuple<int32_t, int32_t, int32_t> &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &) {
return grid;
}
template <>
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
MakeGridDim<tvm::ffi::Function>(
const tvm::ffi::Function &grid,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta) {
tvm::ffi::Tuple<int32_t, int32_t, int32_t> tuple =
grid(meta).cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
return MakeGridDim(tuple, meta);
}
} // namespace triton_tvm_ffi
#endif
+39
View File
@@ -0,0 +1,39 @@
#ifndef TRITON_TVM_FFI_KERNEL_H_
#define TRITON_TVM_FFI_KERNEL_H_
#include "macro.h"
#include <cstdint>
#include <cuda.h>
#include <unordered_map>
namespace triton_tvm_ffi {
template <const char kFnName[], const char kCubin[], size_t kSMem>
inline CUfunction GetKernel(int32_t device) {
static std::unordered_map<int32_t, CUfunction> functions = {};
if (functions.find(device) == functions.end()) {
CUmodule module;
CUfunction func;
__CUDA_CHECK(cuModuleLoadData(&module, kCubin));
__CUDA_CHECK(cuModuleGetFunction(&func, module, kFnName));
if (kSMem > 49152) {
int32_t shared_optin, shared_static;
__CUDA_CHECK(cuDeviceGetAttribute(
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
if (shared_optin >= kSMem) {
__CUDA_CHECK(cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
__CUDA_CHECK(cuFuncSetAttribute(
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
}
functions[device] = func;
}
return functions[device];
};
} // namespace triton_tvm_ffi
#endif
+22
View File
@@ -0,0 +1,22 @@
#ifndef TRITON_TVM_FFI_MACRO_H_
#define TRITON_TVM_FFI_MACRO_H_
#include <cuda.h>
#include <sstream>
#include <stdexcept>
#include <string>
#define __CUDA_CHECK(__code) \
do { \
if ((__code) != CUDA_SUCCESS) { \
const char *errorName = nullptr, *errorStr = nullptr; \
cuGetErrorName((__code), &errorName); \
cuGetErrorString((__code), &errorStr); \
std::ostringstream __oss; \
__oss << "[" << errorName << "] " << errorStr << ", at " << __FILE__ \
<< ":" << __LINE__; \
throw std::runtime_error(__oss.str()); \
} \
} while (false)
#endif
+52
View File
@@ -0,0 +1,52 @@
#ifndef TRITON_TVM_FFI_META_H_
#define TRITON_TVM_FFI_META_H_
#include <tvm/ffi/tvm_ffi.h>
namespace triton_tvm_ffi {
template <const char... Ks[]> struct FillMetaImpl {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs);
};
template <> struct FillMetaImpl<> {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {}
};
template <const char K[], const char... Ks[]> struct FillMetaImpl<K, Ks...> {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
if (argsBegin != argsEnd) {
meta.Set(K, *argsBegin++);
} else if (auto val = kwargs.Get(K)) {
meta.Set(K, *val);
}
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
}
};
template <const char... Ks[]> struct FillMeta {
static inline void
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
const tvm::ffi::Array<tvm::ffi::Any> &args,
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
tvm::ffi::Array<tvm::ffi::Any>::iterator argsBegin = args.begin();
tvm::ffi::Array<tvm::ffi::Any>::iterator argsEnd = args.end();
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
}
};
} // namespace triton_tvm_ffi
#endif
-54
View File
@@ -1,54 +0,0 @@
#ifndef TRITON_TVM_FFI_TYPE_H_
#define TRITON_TVM_FFI_TYPE_H_
#include <cstdint>
#include <tvm/ffi/tvm_ffi.h>
#include <type_traits>
namespace triton_tvm_ffi {
// --------------- Definitions --------------- //
#define TYPE_TABLE_NATIVE(V) \
V(I1, "i1", int8_t) \
V(I8, "i8", int8_t) \
V(I16, "i16", int16_t) \
V(I32, "i32", int32_t) \
V(I64, "i64", int64_t) \
V(U1, "u1", uint8_t) \
V(U8, "u8", uint8_t) \
V(U16, "u16", uint16_t) \
V(U32, "u32", uint32_t) \
V(U64, "u64", uint64_t) \
V(FP16, "fp16", double) \
V(BF16, "bf16", double) \
V(FP32, "f32", double) \
V(FP64, "fp64", double)
#define TYPE_TABLE(V) \
TYPE_TABLE_NATIVE(V) \
V(PTR, "*?", void *) \
V(CONSTEXPR, "constexpr", void)
enum class Type : int64_t {
#define DEFINE_ENUM(type, str, ctype) type,
TYPE_TABLE(DEFINE_ENUM)
#undef DEFINE_ENUM
};
const char *TypeToString(Type type);
tvm::ffi::Optional<Type> StringToType(const tvm::ffi::String &name);
const char *TypeToCType(Type type);
template <Type T> struct type_to_ctype;
#define DEFINE_TYPE_TO_CTYPE(type, str, ctype) \
template <> struct type_to_ctype<Type::type> { using t = ctype; };
TYPE_TABLE(DEFINE_TYPE_TO_CTYPE)
#undef DEFINE_TYPE_TO_CTYPE
template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t;
// --------------- Implementations --------------- //
} // namespace triton_tvm_ffi
#endif
+5 -8
View File
@@ -1,7 +1,7 @@
[project]
name = "triton-tvm-ffi"
version = "0.1.0"
description = "Add your description here"
description = "A Python package for the FFI bindings of Triton TVM."
readme = "README.md"
dependencies = [
"apache-tvm-ffi",
@@ -9,12 +9,9 @@ dependencies = [
]
[build-system]
requires = ["apache-tvm-ffi", "scikit-build-core"]
requires = ["scikit-build-core"]
build-backend = "scikit_build_core.build"
[project.entry-points."triton.backends"]
nvidia = "triton_tvm_ffi"
[tool.scikit-build]
wheel.install-dir = "triton_tvm_ffi"
wheel.packages = ["python/triton_tvm_ffi"]
[tool.setuptools]
packages = ["triton_tvm_ffi"]
package-dir = {"" = "python"}
+5 -11
View File
@@ -1,11 +1,5 @@
# tvm-ffi-stubgen(begin): export/_ffi_api
# fmt: off
# isort: off
from ._ffi_api import * # noqa: F403
from ._ffi_api import __all__ as _ffi_api__all__
if "__all__" not in globals():
__all__ = []
__all__.extend(_ffi_api__all__)
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
from .jit import jit
from .utils import include_paths
from .wrap import torch_wrap, wrap
__all__ = ["include_paths", "jit", "torch_wrap", "wrap"]
-53
View File
@@ -1,53 +0,0 @@
# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
from tvm_ffi import Object as _ffi_Object, init_ffi_api as _FFI_INIT_FUNC, register_object as _FFI_REG_OBJ
from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
from tvm_ffi import Object
from typing import Any
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
# tvm-ffi-stubgen(import-object): tvm_ffi.libinfo.load_lib_module;False;_FFI_LOAD_LIB
LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "triton_tvm_ffi")
# tvm-ffi-stubgen(begin): global/triton_tvm_ffi
# fmt: off
_FFI_INIT_FUNC("triton_tvm_ffi", __name__)
if TYPE_CHECKING:
def string_to_type(_0: str, /) -> int | None: ...
def type_to_ctype(_0: int, /) -> str: ...
def type_to_string(_0: int, /) -> str: ...
# fmt: on
# tvm-ffi-stubgen(end)
# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ
# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
@_FFI_REG_OBJ("triton_tvm_ffi.TVMFFILauncherImpl")
class TVMFFILauncherImpl(_ffi_Object):
"""FFI binding for `triton_tvm_ffi.TVMFFILauncherImpl`."""
# tvm-ffi-stubgen(begin): object/triton_tvm_ffi.TVMFFILauncherImpl
# fmt: off
if TYPE_CHECKING:
@staticmethod
def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ...
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: Sequence[Any], /) -> None: ...
# fmt: on
# tvm-ffi-stubgen(end)
__all__ = [
# tvm-ffi-stubgen(begin): __all__
"LIB",
"TVMFFILauncherImpl",
"string_to_type",
"type_to_ctype",
"type_to_string",
# tvm-ffi-stubgen(end)
]
-7
View File
@@ -1,7 +0,0 @@
from triton.backends.nvidia.compiler import CUDABackend
class TVMFFIBackend(CUDABackend): ...
del CUDABackend
-176
View File
@@ -1,176 +0,0 @@
from __future__ import annotations
from functools import cached_property
import os
from typing import Any, Final, List, Type
import jinja2
from triton.backends.nvidia.driver import CudaDriver
from triton.runtime import _allocation
import tvm_ffi
from . import TVMFFILauncherImpl, utils, string_to_type, type_to_ctype
class TVMLauncher(object):
def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher:
super().__init__(*args, **kwargs)
self.signature: List[str] = [*src.signature.values()]
self.num_ctas: Final[int] = getattr(metadata, "num_ctas", 1)
self.global_scratch_size: Final[int] = metadata.global_scratch_size
self.global_scratch_align: Final[int] = metadata.global_scratch_align
self.profile_scratch_size: Final[int] = metadata.profile_scratch_size
self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
self.launch_pdl: Final[bool] = metadata.launch_pdl
if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}:
mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline(
"launch",
cpp_sources=[self.codegen],
extra_ldflags=["-Wl,--no-as-needed", "-lcuda"],
extra_include_paths=[
f"{tvm_ffi.cpp.extension._find_cuda_home()}/include"
],
)
launch: tvm_ffi.Function = mod.get_function("launch")
self.launch = (
lambda grid_x,
grid_y,
grid_z,
stream,
function,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
*args: launch(
grid_x,
grid_y,
grid_z,
stream,
function,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
*args,
)
)
else:
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
[string_to_type(t) for t in self.signature],
self.launch_cooperative_grid,
self.launch_pdl,
)
self.launch = (
lambda grid_x,
grid_y,
grid_z,
stream,
function,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
*args: self.impl.launch(
grid_x,
grid_y,
grid_z,
stream,
function,
num_warps,
num_ctas,
shared_memory,
global_scratch,
profile_scratch,
args,
)
)
def __call__(
self,
gridX,
gridY,
gridZ,
stream,
function,
kernel_metadata,
launch_metadata,
launch_enter_hook,
launch_exit_hook,
*args,
):
def allocate_scratch(size, align, allocator):
if size > 0:
grid_size = gridX * gridY * gridZ
alloc_size = grid_size * self.num_ctas * size
alloc_fn = allocator.get()
return alloc_fn(alloc_size, align, stream)
return None
global_scratch = allocate_scratch(
self.global_scratch_size, self.global_scratch_align, _allocation._allocator
)
profile_scratch = allocate_scratch(
self.profile_scratch_size,
self.profile_scratch_align,
_allocation._profile_allocator,
)
def canonicalize(obj: Any) -> int:
if obj is None:
return 0
elif isinstance(obj, int):
return obj
elif get_ptr := getattr(obj, "data_ptr", None):
return get_ptr()
else:
raise TypeError(f"cannot canonicalize object of type {type(obj)}")
(num_warps, num_ctas, shared_memory) = kernel_metadata
if launch_enter_hook:
launch_enter_hook(launch_metadata)
ret = self.launch(
gridX,
gridY,
gridZ,
stream,
function,
num_warps,
num_ctas,
shared_memory,
canonicalize(global_scratch),
canonicalize(profile_scratch),
*args,
)
if launch_exit_hook:
launch_exit_hook(launch_metadata)
return ret
@cached_property
def codegen(self) -> str:
env: jinja2.Environment = jinja2.Environment(
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
trim_blocks=True,
lstrip_blocks=True,
)
template: jinja2.Template = env.get_template("launch.c.j2")
signature: List[int] = list(
map(lambda t: type_to_ctype(string_to_type(t)), self.signature),
)
html: str = template.render(signature=signature)
return html
class TVMFFIDriver(CudaDriver):
def __init__(self, *args, **kwargs) -> TVMFFIDriver:
super().__init__(*args, **kwargs)
self.utils = utils
self.launcher_cls: Type[TVMLauncher] = TVMLauncher
del CudaDriver
+119
View File
@@ -0,0 +1,119 @@
from __future__ import annotations
from functools import cached_property
import inspect
from typing import (
Any,
Callable,
Dict,
Final,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
import torch
from triton.compiler import CompiledKernel
from triton.runtime import Autotuner, JITFunction
import tvm_ffi
from .utils import type_canonicalize
class TVMFFIJITFunction(object):
def __init__(self, fn: Union[Autotuner, JITFunction], *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fn: Final[Union[Autotuner, JITFunction]] = fn
self.signature: List[str] = [*inspect.signature(self.basefn).parameters.keys()]
self.best_config: Optional[Dict[str, Any]] = None
self.ctypes: Optional[List[Optional[str]]] = None
self.kernel: Optional[bytes] = None
self.num_warps: Optional[int] = None
self.shmem: int = 0
@tvm_ffi.register_global_func(self.fullname)
def _(
grid: Union[
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
],
_device: int,
_stream: int,
args: Sequence[Any],
kwargs: Mapping[str, Any],
):
args: Iterator[Any] = map(self.canonicalize, args)
kwargs: Dict[str, Any] = {
k: self.canonicalize(v) for k, v in kwargs.items()
}
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
self.num_warps, _, self.shmem = kernel.packed_metadata
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
self.kernel = kernel.kernel
if isinstance(self.fn, Autotuner):
self.best_config = self.fn.best_config.all_kwargs()
return kernel
def __getitem__(
self,
grid: Union[
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
],
):
return self.fn[grid]
@cached_property
def basefn(self) -> Callable:
return self.jitfn.fn
@property
def cache_hash(self) -> int:
return self.ctypes_hash ^ self.kernel_hash
@property
def ctypes_hash(self) -> int:
return hash(tuple(self.ctypes) if self.ctypes is not None else None)
@cached_property
def fnname(self) -> str:
return self.basefn.__name__
@cached_property
def fullname(self) -> str:
return f"triton.{self.name}"
@cached_property
def jitfn(self) -> JITFunction:
fn: Union[Autotuner, JITFunction] = self.fn
while not isinstance(fn, JITFunction):
fn = fn.fn
return fn
@property
def kernel_hash(self) -> int:
return hash(self.kernel)
@property
def kernel_cstr(self) -> Optional[str]:
if self.kernel is not None:
return "".join(f"\\x{byte:02x}" for byte in self.kernel)
else:
return None
@cached_property
def name(self) -> str:
return f"{self.fnname}_{hash(self.basefn)}"
@staticmethod
def canonicalize(val: Any) -> Any:
if hasattr(val, "__dlpack__"):
return torch.from_dlpack(val)
else:
return val
def jit(fn: JITFunction) -> TVMFFIJITFunction:
return TVMFFIJITFunction(fn)
@@ -0,0 +1,45 @@
#include <cuda.h>
#include <tvm/ffi/function.h>
#include "triton_tvm_ffi/grid.h"
#include "triton_tvm_ffi/kernel.h"
#include "triton_tvm_ffi/macro.h"
#include "triton_tvm_ffi/meta.h"
#define {{ name | upper }}_NAME "{{ uniquename }}"
{% for fn in fns %}
{% if fn.ctypes is none %}
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
{% else %}
static constexpr char __fnname_{{ fn.fnname }}[] = "{{ fn.fnname }}";
static constexpr char __cubin_{{ fn.fnname }}[] = "{{ fn.kernel_cstr }}";
#define {{ fn.fnname | upper }}_STUB(__grid, __device, __stream, __args, __kwargs) do { \
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
{% if fn.best_config != none %}
{% for k, v in fn.best_config.items() %}
{ "{{ k }}", {{ v }} }, \
{% endfor %}
{% endif %}
}; \
{% for type in fn.signature %}
static constexpr char __varname{{ loop.index0 }}[] = "{{ type }}"; \
{% endfor %}
triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }}{% if not loop.last %}, {% endif %}{% endfor %}>::apply(__meta, __args, __kwargs); \
CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
void *dummy = nullptr; \
const size_t __args_len = __args.size(); \
{% for ctype in fn.ctypes %}
{% if ctype == "CUdeviceptr" %}
void *__arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr() : __kwargs[__varname{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
{% elif ctype != none %}
{{ ctype }} __arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<{{ ctype }}>() : __kwargs[__varname{{ loop.index0 }}].cast<{{ ctype }}>(); \
{% endif %}
{% endfor %}
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
__CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast<CUstream>(__stream), __params, nullptr)); \
} while (false)
{% endif %}
{% endfor %}
{{ code }}
@@ -1,64 +0,0 @@
#include <assert.h>
#include <cuda.h>
#include <tvm/ffi/tvm_ffi.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_FFI_DLL_EXPORT void
__tvm_ffi_launch(void *handle, const TVMFFIAny *args, int32_t num_args,
TVMFFIAny *result) {
int32_t gridX = args[0].v_int64;
int32_t gridY = args[1].v_int64;
int32_t gridZ = args[2].v_int64;
CUstream stream = (CUstream)args[3].v_uint64;
CUfunction function = (CUfunction)args[4].v_uint64;
int32_t numWarps = args[5].v_int64;
int32_t numCtas = args[6].v_int64;
int32_t sharedMemory = args[7].v_int64;
uint64_t globalScratch = args[8].v_uint64;
uint64_t profileScratch = args[9].v_uint64;
if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4];
CUlaunchConfig config;
config.gridDimX = gridX * numCtas;
config.gridDimY = gridY;
config.gridDimZ = gridZ;
config.blockDimX = 32 * numWarps;
config.blockDimY = 1;
config.blockDimZ = 1;
config.sharedMemBytes = sharedMemory;
config.hStream = stream;
config.attrs = launchAttr;
int32_t numAttrs = 0;
// TODO: check `launchPdl`
// TODO: check `launchCooperativeGrid`
if (numCtas != 1) {
CUlaunchAttribute clusterAttr;
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
clusterAttr.value.clusterDim.x = numCtas;
clusterAttr.value.clusterDim.y = 1;
clusterAttr.value.clusterDim.z = 1;
launchAttr[numAttrs++] = clusterAttr;
CUlaunchAttribute clusterSchedulingAttr;
clusterSchedulingAttr.id =
CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference =
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
launchAttr[numAttrs++] = clusterSchedulingAttr;
}
config.numAttrs = numAttrs;
{% for type in signature %}
{% if type == "void *" %}
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
{% elif type == "int32_t" %}
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
{% elif type == "void" %}
{% else %}
assert(false, "unsupported type yet {{ type }}");
{% endif %}
{% endfor %}
void *params[] = { {% for type in signature %} {% if type != "void" %} &arg{{ loop.index0 }}, {% endif %} {% endfor %}&globalScratch, &profileScratch };
cuLaunchKernelEx(&config, function, params, NULL);
}
}
+14 -34
View File
@@ -1,36 +1,16 @@
# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
import sysconfig
from typing import List, Optional
# tvm-ffi-stubgen(begin): global/triton_tvm_ffi.utils
# fmt: off
_FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__)
if TYPE_CHECKING:
def build_signature_metadata(*args: Any) -> Any: ...
def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ...
def fill_tma_descriptor(*args: Any) -> Any: ...
def get_device_properties(_0: int, /) -> Mapping[str, int]: ...
def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ...
def set_printf_fifo_size(*args: Any) -> Any: ...
# fmt: on
# tvm-ffi-stubgen(end)
from triton.backends.nvidia.driver import ty_to_cpp
__all__ = [
# tvm-ffi-stubgen(begin): __all__
"build_signature_metadata",
"cuOccupancyMaxActiveClusters",
"fill_tma_descriptor",
"get_device_properties",
"load_binary",
"set_printf_fifo_size",
# tvm-ffi-stubgen(end)
]
def include_paths() -> List[str]:
pkg_path: str = sysconfig.get_path("purelib")
return [f"{pkg_path}/triton_tvm_ffi/include"]
def type_canonicalize(ty: str) -> Optional[str]:
if ty == "constexpr":
return None
else:
return ty_to_cpp(ty)
+144
View File
@@ -0,0 +1,144 @@
from functools import cached_property
from io import TextIOWrapper
from pathlib import Path
from typing import Any, Callable, Final, List, Optional, Sequence, Union
import jinja2
import torch.utils.cpp_extension
import tvm_ffi
from .jit import TVMFFIJITFunction
from .utils import include_paths
class TVMFFIWrapperFunction(object):
def __init__(
self,
name: str,
fns: List[TVMFFIJITFunction],
code: Union[str, Path, TextIOWrapper],
extra_cflags: Optional[Sequence[str]] = None,
extra_cuda_cflags: Optional[Sequence[str]] = None,
extra_ldflags: Optional[Sequence[str]] = None,
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.name: Final[str] = name
self.fns: List[TVMFFIJITFunction] = [*fns]
if isinstance(code, Path):
with open(code, "r") as f:
self.code: Final[str] = f.read()
elif isinstance(code, TextIOWrapper):
self.code: Final[str] = code.read()
else:
self.code: Final[str] = f"{code}"
self.extra_cflags: Optional[Sequence[str]] = extra_cflags
self.extra_cuda_cflags: Optional[Sequence[str]] = extra_cuda_cflags
self.extra_ldflags: Optional[Sequence[str]] = extra_ldflags
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
extra_include_paths
)
self.env: Final[jinja2.Environment] = jinja2.Environment(
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
trim_blocks=True,
)
self.tpl: Final[jinja2.Template] = self.env.get_template("gendef.cc.j2")
def __call__(self, *args, **kwargs) -> None:
func: tvm_ffi.Function = self.compile()
return func(*args, **kwargs)
@property
def fns_hash(self) -> int:
return hash(tuple(fn.cache_hash for fn in self.fns))
@cached_property
def fullname(self) -> str:
return f"triton.{self.name}"
@property
def emit(self) -> str:
return self.tpl.render(
code=self.code, fns=self.fns, name=self.name, uniquename=self.uniquename
)
@property
def uniquename(self) -> str:
return f"{self.name}_{self.fns_hash}"
def compile(self) -> tvm_ffi.Function:
if func := tvm_ffi.get_global_func(self.uniquename, allow_missing=True):
return func
else:
tvm_ffi.cpp.load_inline(
self.name,
cpp_sources=[self.emit],
extra_cflags=self.extra_cflags,
extra_cuda_cflags=self.extra_cuda_cflags,
extra_ldflags=self.extra_ldflags,
extra_include_paths=self.extra_include_paths,
embed_cubin={
f"triton_{fn.fnname}": fn.kernel
for fn in self.fns
if fn.kernel is not None
},
)
return tvm_ffi.get_global_func(self.uniquename)
def wrap(
fns: List[TVMFFIJITFunction],
code: Union[str, Path, TextIOWrapper],
extra_cflags: Optional[Sequence[str]] = None,
extra_cuda_cflags: Optional[Sequence[str]] = None,
extra_ldflags: Optional[Sequence[str]] = None,
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
) -> TVMFFIWrapperFunction:
def decorate(fn: Union[str, Callable[..., Any]]) -> TVMFFIWrapperFunction:
return TVMFFIWrapperFunction(
fn if isinstance(fn, str) else fn.__name__,
fns,
code,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
include_paths() + (extra_include_paths or []),
)
return decorate
def torch_wrap(
fns: List[TVMFFIJITFunction],
code: Union[str, Path, TextIOWrapper],
extra_cflags: Optional[Sequence[str]] = None,
extra_cuda_cflags: Optional[Sequence[str]] = None,
extra_ldflags: Optional[Sequence[str]] = None,
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
) -> TVMFFIWrapperFunction:
cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home()
return wrap(
fns,
code,
extra_ldflags=[
"-Wl,--no-as-needed",
f"-L{cuda_home}/lib64",
*map(
lambda path: f"-L{path}",
torch.utils.cpp_extension.library_paths(),
),
"-lcuda",
"-lc10",
"-ltorch",
]
+ (extra_ldflags or []),
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=[
f"{cuda_home}/include",
*torch.utils.cpp_extension.include_paths(),
]
+ (extra_include_paths or []),
)
-36
View File
@@ -1,36 +0,0 @@
add_library(
${TARGET_NAME}
SHARED
${CMAKE_CURRENT_SOURCE_DIR}/exception.cc
${CMAKE_CURRENT_SOURCE_DIR}/launch.cc
${CMAKE_CURRENT_SOURCE_DIR}/type.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils.cc
)
target_include_directories(
${TARGET_NAME}
PRIVATE ${PROJECT_SOURCE_DIR}/include
)
target_compile_options(
${TARGET_NAME}
PRIVATE
$<$<CONFIG:Debug>:-O0 -g>
$<$<CONFIG:Release>:-O3 -DNDEBUG>
)
target_link_libraries(
${TARGET_NAME}
PRIVATE
CUDA::cudart
CUDA::cuda_driver
)
tvm_ffi_configure_target(
${TARGET_NAME}
STUB_DIR "${CMAKE_SOURCE_DIR}/python"
STUB_INIT ON
)
install(
TARGETS ${TARGET_NAME}
LIBRARY DESTINATION .
)
tvm_ffi_install(${TARGET_NAME} DESTINATION .)
-42
View File
@@ -1,42 +0,0 @@
#include "exception.h"
namespace triton_tvm_ffi {
CUDAException::CUDAException(CUresult code) : code_(code) {}
const char *CUDAException::what() const noexcept {
const char *p = nullptr;
cuGetErrorString(code_, &p);
return p;
}
NotImplementedException::NotImplementedException(std::string_view name)
: message_("[NotImplementedException]: \"" + std::string(name) + "\"") {}
const char *NotImplementedException::what() const noexcept {
return message_.c_str();
}
UnmatchedArgumentException::UnmatchedArgumentException(std::string_view name,
size_t len,
size_t expect)
: message_("[UnmatchedArgumentException]: argument \"" + std::string(name) +
"\" has length " + std::to_string(len) + ", but expected " +
std::to_string(expect)) {}
const char *UnmatchedArgumentException::what() const noexcept {
return message_.c_str();
}
UnknownTypeException::UnknownTypeException(Type type)
: message_("[UnknownTypeException]: unknown type: \"" +
std::string(TypeToString(type)) + "\"") {}
UnknownTypeException::UnknownTypeException(std::string_view type)
: message_("[UnknownTypeException]: unknown type: \"" + std::string(type) +
"\"") {}
const char *UnknownTypeException::what() const noexcept {
return message_.c_str();
}
} // namespace triton_tvm_ffi
-126
View File
@@ -1,126 +0,0 @@
#include "launch.h"
#include "macro.h"
#include <cstdint>
#include <tvm/ffi/base_details.h>
namespace triton_tvm_ffi {
TVMFFILauncherImplObj::TVMFFILauncherImplObj(
const tvm::ffi::Array<Type> &signature, bool launchCooperativeGrid,
bool launchAsync)
: signature_(std::move(signature)),
launchCooperativeGrid_(launchCooperativeGrid), launchAsync_(launchAsync) {
}
void TVMFFILauncherImplObj::Launch(
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
uint64_t globalScratch, uint64_t profileScratch,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
CUstream cStream = reinterpret_cast<CUstream>(stream);
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
if (gridX * gridY * gridZ > 0) {
CUlaunchAttribute launchAttr[4];
CUlaunchConfig config;
config.gridDimX = gridX * numCtas;
config.gridDimY = gridY;
config.gridDimZ = gridZ;
static constexpr int32_t kThreadsPerWarp = 32;
config.blockDimX = kThreadsPerWarp * numWarps;
config.blockDimY = 1;
config.blockDimZ = 1;
config.sharedMemBytes = sharedMemory;
config.hStream = cStream;
config.attrs = launchAttr;
int32_t numAttrs = 0;
if (numCtas != 1) {
CUlaunchAttribute clusterAttr;
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
clusterAttr.value.clusterDim.x = numCtas;
clusterAttr.value.clusterDim.y = 1;
clusterAttr.value.clusterDim.z = 1;
launchAttr[numAttrs++] = clusterAttr;
CUlaunchAttribute clusterSchedulingAttr;
clusterSchedulingAttr.id =
CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference =
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
launchAttr[numAttrs++] = clusterSchedulingAttr;
}
config.numAttrs = numAttrs;
if (numCtas == 16) {
CUDA_CHECK(cuFuncSetAttribute(
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
}
const int32_t kernelArgNum = kernelArgs.size();
void **params =
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
size_t j = 0;
#ifndef NDEBUG
if (kernelArgNum != signature_.size()) {
throw UnmatchedArgumentException("kernelArgs", kernelArgNum,
signature_.size());
}
#endif
for (size_t i = 0; i < kernelArgNum; ++i) {
tvm::ffi::Any value = kernelArgs[i];
switch (signature_[i]) {
#define CASE_STMT(type, str, ctype) \
case Type::type: { \
using cpptype = type_to_ctype_t<Type::type>; \
params[j] = reinterpret_cast<void *>(alloca(sizeof(cpptype))); \
*reinterpret_cast<cpptype *>(params[j]) = value.cast<cpptype>(); \
++j; \
break; \
}
TYPE_TABLE_NATIVE(CASE_STMT)
#undef CASE_STMT
case Type::PTR: {
params[j] = reinterpret_cast<void *>(alloca(sizeof(void *)));
*reinterpret_cast<void **>(params[j]) =
value.cast<tvm::ffi::TensorView>().data_ptr();
++j;
break;
}
case Type::CONSTEXPR: {
break;
}
default: {
#ifdef NDEBUG
__builtin_unreachable();
#else
throw NotImplementedException("CONSTEXPR for value casting");
#endif
}
}
}
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
params[j] = &globalScratch;
params[j + 1] = &profileScratch;
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
}
}
TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
bool launchCooperativeGrid,
bool launchAsync)
: tvm::ffi::ObjectRef(tvm::ffi::make_object<TVMFFILauncherImplObj>(
std::move(signature), launchCooperativeGrid, launchAsync)) {}
void TVMFFILauncherImpl::Launch(
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
uint64_t globalScratch, uint64_t profileScratch,
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
get()->Launch(gridX, gridY, gridZ, stream, function, numWarps, numCtas,
sharedMemory, globalScratch, profileScratch, kernelArgs);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TVMFFILauncherImplObj>()
.def(refl::init<const tvm::ffi::Array<Type> &, bool, bool>())
.def("launch", &TVMFFILauncherImplObj::Launch);
}
} // namespace triton_tvm_ffi
-61
View File
@@ -1,61 +0,0 @@
#include "type.h"
#include "exception.h"
#include <cassert>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/tvm_ffi.h>
namespace triton_tvm_ffi {
const char *TypeToString(Type type) {
switch (type) {
#define CASE_ENUM(type, str, ctype) \
case Type::type: \
return str;
TYPE_TABLE(CASE_ENUM)
#undef CASE_ENUM
default:
throw UnknownTypeException(type);
}
}
tvm::ffi::Optional<Type> StringToType(const tvm::ffi::String &name) {
if (name.starts_with("*")) {
return Type::PTR;
}
if (name == "constexpr") {
return Type::CONSTEXPR;
}
#define IF_ENUM(type, str, ctype) \
if (name == str) { \
return Type::type; \
}
TYPE_TABLE(IF_ENUM)
#undef IF_ENUM
if (name.starts_with("tensordesc") || name == "nvTmaDesc") {
throw NotImplementedException(
"tensordesc and nvTmaDesc are not supported in triton-tvm-ffi yet.");
}
return std::nullopt;
}
const char *TypeToCType(Type type) {
switch (type) {
#define CASE_ENUM(type, str, ctype) \
case Type::type: \
return #ctype;
TYPE_TABLE(CASE_ENUM)
#undef CASE_ENUM
default:
throw UnknownTypeException(type);
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("triton_tvm_ffi.type_to_string", TypeToString)
.def("triton_tvm_ffi.string_to_type", StringToType)
.def("triton_tvm_ffi.type_to_ctype", TypeToCType);
}
} // namespace triton_tvm_ffi
-113
View File
@@ -1,113 +0,0 @@
#include "macro.h"
#include <cuda.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef NDEBUG
#include <cassert>
#endif
namespace triton_tvm_ffi {
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
tvm::ffi::cuda_api::DeviceHandle device;
CUDA_CHECK(cuDeviceGet(&device, device_id));
int maxSharedMem = 0;
int maxNumRegs = 0;
int multiprocessorCount = 0;
int warpSize = 0;
int smClockRate = 0;
int memClockRate = 0;
int memBusWidth = 0;
CUDA_CHECK(cuDeviceGetAttribute(
&maxSharedMem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
CUDA_CHECK(cuDeviceGetAttribute(
&maxNumRegs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
CUDA_CHECK(cuDeviceGetAttribute(
&multiprocessorCount, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(
cuDeviceGetAttribute(&warpSize, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
CUDA_CHECK(cuDeviceGetAttribute(&smClockRate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE,
device));
CUDA_CHECK(cuDeviceGetAttribute(
&memClockRate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(
&memBusWidth, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
return {{"max_shared_mem", maxSharedMem},
{"max_num_regs", maxNumRegs},
{"multiprocessor_count", multiprocessorCount},
{"warpSize", warpSize},
{"sm_clock_rate", smClockRate},
{"mem_clock_rate", memClockRate},
{"mem_bus_width", memBusWidth}};
}
tvm::ffi::Tuple<uint64_t, uint64_t, int32_t, int32_t, int32_t>
LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data,
int32_t shared, CUdevice device) {
CUcontext pctx;
CUfunction fun;
CUmodule mod;
int32_t nRegs = 0;
int32_t nSpills = 0;
int32_t nMaxThreads = 0;
int32_t sharedOptin = 0;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}
CUDA_CHECK(cuModuleLoadData(&mod, data.data()));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name.data()));
CUDA_CHECK(cuFuncGetAttribute(&nRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(
cuFuncGetAttribute(&nSpills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
CUDA_CHECK(cuFuncGetAttribute(&nMaxThreads,
CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
CUDA_CHECK(cuDeviceGetAttribute(
&sharedOptin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
static constexpr int64_t kExpectedMaxDynamicSharedMemory = 49152;
if (shared > kExpectedMaxDynamicSharedMemory &&
sharedOptin > kExpectedMaxDynamicSharedMemory) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int32_t sharedTotal = 0, sharedStatic = 0;
CUDA_CHECK(cuDeviceGetAttribute(
&sharedTotal, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device));
CUDA_CHECK(cuFuncGetAttribute(&sharedStatic,
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
sharedOptin - sharedStatic));
}
return tvm::ffi::Tuple<uint64_t, uint64_t, int32_t, int32_t, int32_t>{
mod, fun, nRegs, nSpills, nMaxThreads};
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("triton_tvm_ffi.utils.build_signature_metadata",
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
throw NotImplementedException("build_signature_metadata");
})
.def_packed("triton_tvm_ffi.utils.cuOccupancyMaxActiveClusters",
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
throw NotImplementedException(
"cuOccupancyMaxActiveClusters");
})
.def_packed("triton_tvm_ffi.utils.fill_tma_descriptor",
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
throw NotImplementedException("fill_tma_descriptor");
})
.def_packed("triton_tvm_ffi.utils.set_printf_fifo_size",
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
throw NotImplementedException("set_printf_fifo_size");
})
.def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties)
.def("triton_tvm_ffi.utils.load_binary", LoadBinary);
}
} // namespace triton_tvm_ffi