mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 00:42:05 +08:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a727526794 | |||
| 599957e156 | |||
| e41ec26329 | |||
| 213e4fc060 | |||
| 1c4f13c8f0 | |||
| 24237a6313 | |||
| 6a19a6b06d | |||
| 2298b6f8c8 | |||
| f6c7a48c1b | |||
| 8b8aa6cb84 | |||
| b7bf598fde | |||
| 192dc95ac0 | |||
| 6e4c2d4a43 | |||
| dc8c2c17e0 |
+2
-1
@@ -9,8 +9,9 @@ wheels/
|
||||
# Virtual environments
|
||||
.venv
|
||||
|
||||
.vscode/
|
||||
|
||||
.cache
|
||||
.clangd
|
||||
.ruff_cache
|
||||
.python-version
|
||||
uv.lock
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.18)
|
||||
|
||||
if(DEFINED SKBUILD_PROJECT_NAME)
|
||||
project(${SKBUILD_PROJECT_NAME})
|
||||
else()
|
||||
project(triton-tvm-ffi)
|
||||
endif()
|
||||
|
||||
string(REPLACE "-" "_" TARGET_NAME "${PROJECT_NAME}")
|
||||
|
||||
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
else(CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
endif()
|
||||
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
find_package(Python COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir
|
||||
OUTPUT_VARIABLE TVM_FFI_CMAKEDIR
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${TVM_FFI_CMAKEDIR}")
|
||||
|
||||
find_package(tvm_ffi CONFIG REQUIRED)
|
||||
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/src)
|
||||
@@ -1,18 +0,0 @@
|
||||
# Triton-TVM-FFI
|
||||
|
||||
## Instructions
|
||||
|
||||
### Debug Install
|
||||
```bash
|
||||
SKBUILD_BUILD_DIR="build" SKBUILD_CMAKE_BUILD_TYPE=Debug uv pip install --no-build-isolation -ve .
|
||||
```
|
||||
|
||||
### Release Install
|
||||
```bash
|
||||
SKBUILD_CMAKE_BUILD_TYPE=Release uv pip install -v .
|
||||
```
|
||||
|
||||
### Format
|
||||
```bash
|
||||
find python -name "*.py" | xargs ruff format && find include src -name "*.h" -o -name "*.cc" | xargs clang-format -i
|
||||
```
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef ADD_KERNEL_STUB
|
||||
#define ADD_KERNEL_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef ADD_NAME
|
||||
#define ADD_NAME ""
|
||||
#endif
|
||||
|
||||
tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
|
||||
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
|
||||
at::Tensor otorch = at::empty_like(xtorch);
|
||||
int32_t numel = otorch.numel();
|
||||
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
|
||||
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
|
||||
[numel](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
|
||||
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
|
||||
const int32_t BLOCK_SIZE = meta["BLOCK_SIZE"].cast<int32_t>();
|
||||
return tvm::ffi::Tuple((numel + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1);
|
||||
});
|
||||
DLDevice device = x.device();
|
||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||
tvm::ffi::Array<tvm::ffi::Any> args = {x, y, output, numel, 1024};
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {};
|
||||
ADD_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return output;
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef().def(ADD_NAME, Add);
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton_tvm_ffi
|
||||
|
||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||
|
||||
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output: torch.Tensor = torch.empty_like(x)
|
||||
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||
n_elements: int = output.numel()
|
||||
BLOCK_SIZE: int = 1024
|
||||
add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1, 1)](
|
||||
x, y, output, n_elements, BLOCK_SIZE
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[add_kernel],
|
||||
Path(__file__).parent / "add.cc",
|
||||
)
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device=DEVICE)
|
||||
y = torch.rand(size, device=DEVICE)
|
||||
output_torch = x + y
|
||||
output_triton = add_triton(x, y)
|
||||
output_tvm_ffi = add(x, y)
|
||||
assert torch.allclose(output_torch, output_triton)
|
||||
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||
output_tvm_ffi = add(x, y)
|
||||
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||
|
||||
round = 1000
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
x + y
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
add_triton(x, y)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
add(x, y)
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
@@ -0,0 +1,937 @@
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
|
||||
|
||||
Credits: OpenAI kernel team
|
||||
|
||||
Extra Credits:
|
||||
|
||||
* Original flash attention paper (https://arxiv.org/abs/2205.14135)
|
||||
* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.tools.tensor_descriptor import TensorDescriptor
|
||||
import triton_tvm_ffi
|
||||
|
||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q, #
|
||||
desc_k,
|
||||
desc_v, #
|
||||
offset_y,
|
||||
dtype: tl.constexpr,
|
||||
start_m,
|
||||
qk_scale, #
|
||||
BLOCK_M: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, #
|
||||
STAGE: tl.constexpr,
|
||||
offs_m: tl.constexpr,
|
||||
offs_n: tl.constexpr, #
|
||||
N_CTX: tl.constexpr,
|
||||
):
|
||||
# range of values handled by this stage
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
# causal = False
|
||||
else:
|
||||
lo, hi = 0, N_CTX
|
||||
offsetk_y = offset_y + lo
|
||||
if dtype == tl.float8e5:
|
||||
offsetv_y = offset_y * HEAD_DIM + lo
|
||||
else:
|
||||
offsetv_y = offset_y + lo
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in tl.range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = desc_k.load([offsetk_y, 0]).T
|
||||
qk = tl.dot(q, k)
|
||||
if STAGE == 2:
|
||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk -= m_ij[:, None]
|
||||
else:
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
|
||||
qk = qk * qk_scale - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
# -- compute correction factor
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
l_ij = tl.sum(p, 1)
|
||||
acc = acc * alpha[:, None]
|
||||
# prepare p and v for the dot
|
||||
if dtype == tl.float8e5:
|
||||
v = desc_v.load([0, offsetv_y]).T
|
||||
else:
|
||||
v = desc_v.load([offsetv_y, 0])
|
||||
p = p.to(dtype)
|
||||
# note that this non transposed v for FP8 is only supported on Blackwell
|
||||
acc = tl.dot(p, v, acc)
|
||||
# update m_i and l_i
|
||||
# place this at the end of the loop to reduce register pressure
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
offsetk_y += BLOCK_N
|
||||
offsetv_y += BLOCK_N
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
def _host_descriptor_pre_hook(nargs):
|
||||
BLOCK_M = nargs["BLOCK_M"]
|
||||
BLOCK_N = nargs["BLOCK_N"]
|
||||
HEAD_DIM = nargs["HEAD_DIM"]
|
||||
if not isinstance(nargs["desc_q"], TensorDescriptor):
|
||||
return
|
||||
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
|
||||
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
|
||||
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
|
||||
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
|
||||
|
||||
|
||||
NUM_STAGES_OPTIONS = [2, 3, 4]
|
||||
|
||||
configs = [
|
||||
triton.Config(
|
||||
{"BLOCK_M": BM, "BLOCK_N": BN},
|
||||
num_stages=s,
|
||||
num_warps=w,
|
||||
pre_hook=_host_descriptor_pre_hook,
|
||||
)
|
||||
for BM in [64, 128]
|
||||
for BN in [32, 64, 128]
|
||||
for s in NUM_STAGES_OPTIONS
|
||||
for w in [4, 8]
|
||||
]
|
||||
if "PYTEST_VERSION" in os.environ:
|
||||
# Use a single config in testing for reproducibility
|
||||
configs = [
|
||||
triton.Config(
|
||||
dict(BLOCK_M=128, BLOCK_N=64),
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
pre_hook=_host_descriptor_pre_hook,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def keep(conf):
|
||||
BLOCK_M = conf.kwargs["BLOCK_M"]
|
||||
BLOCK_N = conf.kwargs["BLOCK_N"]
|
||||
return not (
|
||||
torch.cuda.get_device_capability()[0] == 9
|
||||
and BLOCK_M * BLOCK_N < 128 * 128
|
||||
and conf.num_warps == 8
|
||||
)
|
||||
|
||||
|
||||
def prune_invalid_configs(configs, named_args, **kwargs):
|
||||
N_CTX = kwargs["N_CTX"]
|
||||
|
||||
# Filter out configs where BLOCK_M > N_CTX
|
||||
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
|
||||
if isinstance(desc_or_ptr, tl.tensor_descriptor):
|
||||
return desc_or_ptr
|
||||
else:
|
||||
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)
|
||||
|
||||
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.autotune(
|
||||
configs=list(filter(keep, configs)),
|
||||
key=["N_CTX", "HEAD_DIM"],
|
||||
prune_configs_by={"early_config_prune": prune_invalid_configs},
|
||||
)
|
||||
@triton.jit
|
||||
def _attn_fwd(
|
||||
sm_scale,
|
||||
M, #
|
||||
Z,
|
||||
H,
|
||||
desc_q,
|
||||
desc_k,
|
||||
desc_v,
|
||||
desc_o,
|
||||
N_CTX, #
|
||||
HEAD_DIM: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
STAGE: tl.constexpr, #
|
||||
):
|
||||
dtype = tl.float16
|
||||
tl.static_assert(BLOCK_N <= HEAD_DIM)
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
|
||||
y_dim = Z * H * N_CTX
|
||||
desc_q = _maybe_make_tensor_desc(
|
||||
desc_q,
|
||||
shape=[y_dim, HEAD_DIM],
|
||||
strides=[HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_M, HEAD_DIM],
|
||||
)
|
||||
desc_v = _maybe_make_tensor_desc(
|
||||
desc_v,
|
||||
shape=[y_dim, HEAD_DIM],
|
||||
strides=[HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_N, HEAD_DIM],
|
||||
)
|
||||
desc_k = _maybe_make_tensor_desc(
|
||||
desc_k,
|
||||
shape=[y_dim, HEAD_DIM],
|
||||
strides=[HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_N, HEAD_DIM],
|
||||
)
|
||||
desc_o = _maybe_make_tensor_desc(
|
||||
desc_o,
|
||||
shape=[y_dim, HEAD_DIM],
|
||||
strides=[HEAD_DIM, 1],
|
||||
block_shape=[BLOCK_M, HEAD_DIM],
|
||||
)
|
||||
|
||||
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
|
||||
qo_offset_y = offset_y + start_m * BLOCK_M
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
# load scales
|
||||
qk_scale = sm_scale
|
||||
qk_scale *= 1.44269504 # 1/log(2)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = desc_q.load([qo_offset_y, 0])
|
||||
# stage 1: off-band
|
||||
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
||||
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
||||
if STAGE & 1:
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q, #
|
||||
desc_k,
|
||||
desc_v, #
|
||||
offset_y,
|
||||
dtype,
|
||||
start_m,
|
||||
qk_scale, #
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N, #
|
||||
4 - STAGE,
|
||||
offs_m,
|
||||
offs_n,
|
||||
N_CTX, #
|
||||
)
|
||||
# stage 2: on-band
|
||||
if STAGE & 2:
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q, #
|
||||
desc_k,
|
||||
desc_v, #
|
||||
offset_y,
|
||||
dtype,
|
||||
start_m,
|
||||
qk_scale, #
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N, #
|
||||
2,
|
||||
offs_m,
|
||||
offs_n,
|
||||
N_CTX, #
|
||||
)
|
||||
# epilogue
|
||||
m_i += tl.math.log2(l_i)
|
||||
acc = acc / l_i[:, None]
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(m_ptrs, m_i)
|
||||
desc_o.store([qo_offset_y, 0], acc.to(dtype))
|
||||
|
||||
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.jit
|
||||
def _attn_bwd_preprocess(
|
||||
O,
|
||||
DO, #
|
||||
Delta, #
|
||||
Z,
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr, #
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_hz = tl.program_id(1)
|
||||
off_n = tl.arange(0, HEAD_DIM)
|
||||
# load
|
||||
o = tl.load(
|
||||
O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
|
||||
)
|
||||
do = tl.load(
|
||||
DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
|
||||
).to(tl.float32)
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(Delta + off_hz * N_CTX + off_m, delta)
|
||||
|
||||
|
||||
# The main inner-loop logic for computing dK and dV.
|
||||
@triton.jit
|
||||
def _attn_bwd_dkdv(
|
||||
dk,
|
||||
dv, #
|
||||
Q,
|
||||
k,
|
||||
v,
|
||||
sm_scale, #
|
||||
DO, #
|
||||
M,
|
||||
D, #
|
||||
# shared by Q/K/V/DO.
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX,
|
||||
BLOCK_M1: tl.constexpr, #
|
||||
BLOCK_N1: tl.constexpr, #
|
||||
HEAD_DIM: tl.constexpr, #
|
||||
# Filled in by the wrapper.
|
||||
start_n,
|
||||
start_m,
|
||||
num_steps, #
|
||||
MASK: tl.constexpr,
|
||||
):
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M1)
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
|
||||
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
||||
curr_m = start_m
|
||||
step_m = BLOCK_M1
|
||||
for blk_idx in range(num_steps):
|
||||
qT = tl.load(qT_ptrs)
|
||||
# Load m before computing qk to reduce pipeline stall.
|
||||
offs_m = curr_m + tl.arange(0, BLOCK_M1)
|
||||
m = tl.load(M + offs_m)
|
||||
qkT = tl.dot(k, qT)
|
||||
pT = tl.math.exp2(qkT - m[None, :])
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
mask = offs_m[None, :] >= offs_n[:, None]
|
||||
pT = tl.where(mask, pT, 0.0)
|
||||
do = tl.load(do_ptrs)
|
||||
# Compute dV.
|
||||
ppT = pT
|
||||
ppT = ppT.to(tl.float16)
|
||||
dv += tl.dot(ppT, do)
|
||||
# D (= delta) is pre-divided by ds_scale.
|
||||
Di = tl.load(D + offs_m)
|
||||
# Compute dP and dS.
|
||||
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
|
||||
dsT = pT * (dpT - Di[None, :])
|
||||
dsT = dsT.to(tl.float16)
|
||||
dk += tl.dot(dsT, tl.trans(qT))
|
||||
# Increment pointers.
|
||||
curr_m += step_m
|
||||
qT_ptrs += step_m * stride_tok
|
||||
do_ptrs += step_m * stride_tok
|
||||
return dk, dv
|
||||
|
||||
|
||||
# the main inner-loop logic for computing dQ
|
||||
@triton.jit
|
||||
def _attn_bwd_dq(
|
||||
dq,
|
||||
q,
|
||||
K,
|
||||
V, #
|
||||
do,
|
||||
m,
|
||||
D,
|
||||
# shared by Q/K/V/DO.
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M2: tl.constexpr, #
|
||||
BLOCK_N2: tl.constexpr, #
|
||||
HEAD_DIM: tl.constexpr,
|
||||
# Filled in by the wrapper.
|
||||
start_m,
|
||||
start_n,
|
||||
num_steps, #
|
||||
MASK: tl.constexpr,
|
||||
):
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N2)
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
||||
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
||||
# D (= delta) is pre-divided by ds_scale.
|
||||
Di = tl.load(D + offs_m)
|
||||
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
||||
curr_n = start_n
|
||||
step_n = BLOCK_N2
|
||||
for blk_idx in range(num_steps):
|
||||
kT = tl.load(kT_ptrs)
|
||||
vT = tl.load(vT_ptrs)
|
||||
qk = tl.dot(q, kT)
|
||||
p = tl.math.exp2(qk - m)
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
offs_n = curr_n + tl.arange(0, BLOCK_N2)
|
||||
mask = offs_m[:, None] >= offs_n[None, :]
|
||||
p = tl.where(mask, p, 0.0)
|
||||
# Compute dP and dS.
|
||||
dp = tl.dot(do, vT).to(tl.float32)
|
||||
ds = p * (dp - Di[:, None])
|
||||
ds = ds.to(tl.float16)
|
||||
# Compute dQ.
|
||||
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
||||
dq += tl.dot(ds, tl.trans(kT))
|
||||
# Increment pointers.
|
||||
curr_n += step_n
|
||||
kT_ptrs += step_n * stride_tok
|
||||
vT_ptrs += step_n * stride_tok
|
||||
return dq
|
||||
|
||||
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.jit
|
||||
def _attn_bwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale, #
|
||||
DO, #
|
||||
DQ,
|
||||
DK,
|
||||
DV, #
|
||||
M,
|
||||
D,
|
||||
# shared by Q/K/V/DO.
|
||||
stride_z,
|
||||
stride_h,
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M1: tl.constexpr, #
|
||||
BLOCK_N1: tl.constexpr, #
|
||||
BLOCK_M2: tl.constexpr, #
|
||||
BLOCK_N2: tl.constexpr, #
|
||||
BLK_SLICE_FACTOR: tl.constexpr, #
|
||||
HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
|
||||
|
||||
bhid = tl.program_id(2)
|
||||
off_chz = (bhid * N_CTX).to(tl.int64)
|
||||
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# offset pointers for batch/head
|
||||
Q += adj
|
||||
K += adj
|
||||
V += adj
|
||||
DO += adj
|
||||
DQ += adj
|
||||
DK += adj
|
||||
DV += adj
|
||||
M += off_chz
|
||||
D += off_chz
|
||||
|
||||
# load scales
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
|
||||
start_n = pid * BLOCK_N1
|
||||
start_m = start_n
|
||||
|
||||
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
||||
|
||||
dk, dv = _attn_bwd_dkdv(
|
||||
dk,
|
||||
dv, #
|
||||
Q,
|
||||
k,
|
||||
v,
|
||||
sm_scale, #
|
||||
DO, #
|
||||
M,
|
||||
D, #
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
MASK_BLOCK_M1,
|
||||
BLOCK_N1,
|
||||
HEAD_DIM, #
|
||||
start_n,
|
||||
start_m,
|
||||
num_steps, #
|
||||
MASK=True, #
|
||||
)
|
||||
|
||||
start_m += num_steps * MASK_BLOCK_M1
|
||||
num_steps = (N_CTX - start_m) // BLOCK_M1
|
||||
|
||||
# Compute dK and dV for non-masked blocks.
|
||||
dk, dv = _attn_bwd_dkdv( #
|
||||
dk,
|
||||
dv, #
|
||||
Q,
|
||||
k,
|
||||
v,
|
||||
sm_scale, #
|
||||
DO, #
|
||||
M,
|
||||
D, #
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M1,
|
||||
BLOCK_N1,
|
||||
HEAD_DIM, #
|
||||
start_n,
|
||||
start_m,
|
||||
num_steps, #
|
||||
MASK=False, #
|
||||
)
|
||||
|
||||
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dv_ptrs, dv)
|
||||
|
||||
# Write back dK.
|
||||
dk *= sm_scale
|
||||
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
# THIS BLOCK DOES DQ:
|
||||
start_m = pid * BLOCK_M2
|
||||
end_n = start_m + BLOCK_M2
|
||||
|
||||
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
|
||||
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
|
||||
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
m = tl.load(M + offs_m)
|
||||
m = m[:, None]
|
||||
|
||||
# Compute dQ for masked (diagonal) blocks.
|
||||
# NOTE: This code scans each row of QK^T backward (from right to left,
|
||||
# but inside each call to _attn_bwd_dq, from left to right), but that's
|
||||
# not due to anything important. I just wanted to reuse the loop
|
||||
# structure for dK & dV above as much as possible.
|
||||
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq,
|
||||
q,
|
||||
K,
|
||||
V, #
|
||||
do,
|
||||
m,
|
||||
D, #
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M2,
|
||||
MASK_BLOCK_N2,
|
||||
HEAD_DIM, #
|
||||
start_m,
|
||||
end_n - num_steps * MASK_BLOCK_N2,
|
||||
num_steps, #
|
||||
MASK=True, #
|
||||
)
|
||||
end_n -= num_steps * MASK_BLOCK_N2
|
||||
# stage 2
|
||||
num_steps = end_n // BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq,
|
||||
q,
|
||||
K,
|
||||
V, #
|
||||
do,
|
||||
m,
|
||||
D, #
|
||||
stride_tok,
|
||||
stride_d, #
|
||||
H,
|
||||
N_CTX, #
|
||||
BLOCK_M2,
|
||||
BLOCK_N2,
|
||||
HEAD_DIM, #
|
||||
start_m,
|
||||
end_n - num_steps * BLOCK_N2,
|
||||
num_steps, #
|
||||
MASK=False, #
|
||||
)
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
dq *= LN2
|
||||
tl.store(dq_ptrs, dq)
|
||||
|
||||
|
||||
class _attention_triton(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
# shape constraints
|
||||
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
|
||||
# when v is in float8_e5m2 it is transposed.
|
||||
HEAD_DIM_V = v.shape[-1]
|
||||
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
|
||||
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
|
||||
o = torch.empty_like(q)
|
||||
stage = 3 if causal else 1
|
||||
|
||||
M = torch.empty(
|
||||
(q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
|
||||
)
|
||||
desc_q = q
|
||||
desc_v = v
|
||||
desc_k = k
|
||||
desc_o = o
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(q.shape[2], META["BLOCK_M"]),
|
||||
q.shape[0] * q.shape[1],
|
||||
1,
|
||||
)
|
||||
|
||||
_attn_fwd[grid](
|
||||
sm_scale,
|
||||
M, #
|
||||
q.shape[0],
|
||||
q.shape[1], #
|
||||
desc_q,
|
||||
desc_k,
|
||||
desc_v,
|
||||
desc_o, #
|
||||
N_CTX=q.shape[2], #
|
||||
HEAD_DIM=HEAD_DIM_K, #
|
||||
STAGE=stage, #
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, M)
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.HEAD_DIM = HEAD_DIM_K
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, M = ctx.saved_tensors
|
||||
assert do.is_contiguous()
|
||||
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
BATCH, N_HEAD, N_CTX = q.shape[:3]
|
||||
PRE_BLOCK = 128
|
||||
NUM_WARPS, NUM_STAGES = 4, 5
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
|
||||
BLK_SLICE_FACTOR = 2
|
||||
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
||||
arg_k = k
|
||||
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
|
||||
PRE_BLOCK = 128
|
||||
assert N_CTX % PRE_BLOCK == 0
|
||||
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
|
||||
delta = torch.empty_like(M)
|
||||
_attn_bwd_preprocess[pre_grid](
|
||||
o,
|
||||
do, #
|
||||
delta, #
|
||||
BATCH,
|
||||
N_HEAD,
|
||||
N_CTX, #
|
||||
BLOCK_M=PRE_BLOCK,
|
||||
HEAD_DIM=ctx.HEAD_DIM, #
|
||||
)
|
||||
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
|
||||
_attn_bwd[grid](
|
||||
q,
|
||||
arg_k,
|
||||
v,
|
||||
ctx.sm_scale,
|
||||
do,
|
||||
dq,
|
||||
dk,
|
||||
dv, #
|
||||
M,
|
||||
delta, #
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3), #
|
||||
N_HEAD,
|
||||
N_CTX, #
|
||||
BLOCK_M1=BLOCK_M1,
|
||||
BLOCK_N1=BLOCK_N1, #
|
||||
BLOCK_M2=BLOCK_M2,
|
||||
BLOCK_N2=BLOCK_N2, #
|
||||
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
|
||||
HEAD_DIM=ctx.HEAD_DIM, #
|
||||
num_warps=NUM_WARPS, #
|
||||
num_stages=NUM_STAGES, #
|
||||
)
|
||||
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[_attn_fwd],
|
||||
Path(__file__).parent / "attnfwd.cc",
|
||||
)
|
||||
def _attn_fwd_tvm_ffi(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, sm_scale: float
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[_attn_bwd_preprocess],
|
||||
Path(__file__).parent / "attnbwdpre.cc",
|
||||
)
|
||||
def _attn_bwd_preprocess_tvm_ffi(
|
||||
o: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
mshape: Sequence[int],
|
||||
head_dim: int,
|
||||
): ...
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[_attn_bwd],
|
||||
Path(__file__).parent / "attnbwd.cc",
|
||||
)
|
||||
def _attn_bwd_tvm_ffi(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
sm_scale: float,
|
||||
do: torch.Tensor,
|
||||
m: torch.Tensor,
|
||||
delta: torch.Tensor,
|
||||
HEAD_DIM: int,
|
||||
): ...
|
||||
|
||||
|
||||
class _attention_tvm_ffi(_attention_triton):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
# shape constraints
|
||||
HEAD_DIM_K = k.shape[-1]
|
||||
M, o = _attn_fwd_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
M = torch.from_dlpack(M)
|
||||
o = torch.from_dlpack(o)
|
||||
ctx.save_for_backward(q, k, v, o, M)
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.HEAD_DIM = HEAD_DIM_K
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, M = ctx.saved_tensors
|
||||
delta = _attn_bwd_preprocess_tvm_ffi(
|
||||
o,
|
||||
do,
|
||||
M.shape,
|
||||
ctx.HEAD_DIM,
|
||||
)
|
||||
dq, dk, dv = _attn_bwd_tvm_ffi(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
ctx.sm_scale,
|
||||
do,
|
||||
M,
|
||||
delta,
|
||||
ctx.HEAD_DIM,
|
||||
)
|
||||
dq = torch.from_dlpack(dq)
|
||||
dk = torch.from_dlpack(dk)
|
||||
dv = torch.from_dlpack(dv)
|
||||
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
def attn_torch(q, k, v, causal=False, sm_scale=1.0):
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
if causal:
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1)
|
||||
p = p.to(q.dtype)
|
||||
out = torch.matmul(p, v)
|
||||
return out
|
||||
|
||||
|
||||
def attn_triton(q, k, v, causal=False, sm_scale=1.0):
|
||||
return _attention_triton.apply(q, k, v, causal, sm_scale)
|
||||
|
||||
|
||||
def attn_tvm_ffi(q, k, v, causal=False, sm_scale=1.0):
|
||||
return _attention_tvm_ffi.apply(q, k, v, causal, sm_scale)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Z = 1
|
||||
H = 2
|
||||
N_CTX = 128
|
||||
HEAD_DIM = 64
|
||||
causal = True
|
||||
mode = "bwd"
|
||||
dtype = torch.float16
|
||||
torch.manual_seed(20)
|
||||
q = (
|
||||
torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
k = (
|
||||
torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
v = (
|
||||
torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
sm_scale = 0.5
|
||||
# reference implementation
|
||||
ref_dtype = dtype
|
||||
q = q.to(ref_dtype)
|
||||
k = k.to(ref_dtype)
|
||||
v = v.to(ref_dtype)
|
||||
ref_out = attn_torch(q, k, v, causal, sm_scale).half()
|
||||
if mode == "bwd":
|
||||
dout = torch.randn_like(q)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
tri_out = attn_triton(q, k, v, causal, sm_scale).half()
|
||||
tvm_ffi_out = attn_tvm_ffi(q, k, v, causal, sm_scale).half()
|
||||
warmup = 5
|
||||
round = 1000
|
||||
if mode == "fwd":
|
||||
atol = 1e-2
|
||||
torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(tvm_ffi_out, ref_out, atol=atol, rtol=0)
|
||||
with torch.no_grad():
|
||||
for _ in range(warmup):
|
||||
attn_torch(q, k, v, causal, sm_scale)
|
||||
attn_triton(q, k, v, causal, sm_scale)
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_torch(q, k, v, causal, sm_scale)
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_triton(q, k, v, causal, sm_scale)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale)
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
elif mode == "bwd":
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=0)
|
||||
rtol = 0.0
|
||||
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
|
||||
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
|
||||
if (
|
||||
torch.version.hip is not None
|
||||
and triton.runtime.driver.active.get_current_target().arch == "gfx90a"
|
||||
):
|
||||
rtol = 1e-2
|
||||
torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol)
|
||||
tvm_ffi_out.backward(dout)
|
||||
tvm_ffi_dv, v.grad = v.grad.clone(), None
|
||||
tvm_ffi_dk, k.grad = k.grad.clone(), None
|
||||
tvm_ffi_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
torch.testing.assert_close(tvm_ffi_out, ref_out, atol=1e-2, rtol=0)
|
||||
rtol = 0.0
|
||||
torch.testing.assert_close(tvm_ffi_dv, ref_dv, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tvm_ffi_dk, ref_dk, atol=1e-2, rtol=rtol)
|
||||
torch.testing.assert_close(tvm_ffi_dq, ref_dq, atol=1e-2, rtol=rtol)
|
||||
|
||||
for _ in range(warmup):
|
||||
attn_torch(q, k, v, causal, sm_scale).backward(dout)
|
||||
attn_triton(q, k, v, causal, sm_scale).backward(dout)
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale).backward(dout)
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_torch(q, k, v, causal, sm_scale).backward(dout)
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_triton(q, k, v, causal, sm_scale).backward(dout)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
attn_tvm_ffi(q, k, v, causal, sm_scale).backward(dout)
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <torch/headeronly/core/DeviceType.h>
|
||||
#include <tvm/ffi/container/tensor.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef _ATTN_BWD_STUB
|
||||
#define _ATTN_BWD_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef _ATTN_BWD_TVM_FFI_NAME
|
||||
#define _ATTN_BWD_TVM_FFI_NAME ""
|
||||
#endif
|
||||
|
||||
tvm::ffi::Tuple<tvm::ffi::Tensor, tvm::ffi::Tensor, tvm::ffi::Tensor>
|
||||
AttnBwd(tvm::ffi::Tensor q, tvm::ffi::Tensor k, tvm::ffi::Tensor v,
|
||||
const double smScale, tvm::ffi::Tensor do_, tvm::ffi::Tensor m,
|
||||
tvm::ffi::Tensor delta, const int32_t kHeadDim) {
|
||||
tvm::ffi::ShapeView qshape = q.shape(), qstride = q.strides();
|
||||
const int32_t kBatch = qshape[0], kNHead = qshape[1], kNCtx = qshape[2],
|
||||
kBlockN1 = 128;
|
||||
const double kArgKScale = smScale / log(2);
|
||||
at::Tensor qTorch = at::fromDLPack(q.ToDLPack()),
|
||||
kTorch = at::fromDLPack(k.ToDLPack()),
|
||||
vTorch = at::fromDLPack(v.ToDLPack()),
|
||||
dqTorch = at::empty_like(qTorch), dkTorch = at::empty_like(kTorch),
|
||||
dvTorch = at::empty_like(vTorch),
|
||||
argKTorch = at::mul(kTorch, kArgKScale);
|
||||
tvm::ffi::Tensor dq = tvm::ffi::Tensor::FromDLPack(at::toDLPack(dqTorch)),
|
||||
dk = tvm::ffi::Tensor::FromDLPack(at::toDLPack(dkTorch)),
|
||||
dv = tvm::ffi::Tensor::FromDLPack(at::toDLPack(dvTorch)),
|
||||
argK = tvm::ffi::Tensor::FromDLPack(at::toDLPack(argKTorch));
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid(kNCtx / kBlockN1, 1,
|
||||
kBatch * kNHead);
|
||||
tvm::ffi::Array<tvm::ffi::Any> args = {
|
||||
q, argK, v, smScale, do_, dq, dk, dv,
|
||||
m, delta, qstride[0], qstride[1], qstride[2], qstride[3], kNHead, kNCtx};
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
|
||||
{"BLOCK_M1", 32}, {"BLOCK_N1", kBlockN1}, {"BLOCK_M2", 128},
|
||||
{"BLOCK_N2", 32}, {"BLK_SLICE_FACTOR", 2}, {"HEAD_DIM", kHeadDim},
|
||||
{"num_warps", 4}, {"num_stages", 5},
|
||||
};
|
||||
DLDevice device = q.device();
|
||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||
_ATTN_BWD_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return tvm::ffi::Tuple{dq, dk, dv};
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef().def(_ATTN_BWD_TVM_FFI_NAME, AttnBwd);
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <torch/headeronly/core/DeviceType.h>
|
||||
#include <tvm/ffi/container/tensor.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef _ATTN_BWD_PREPROCESS_STUB
|
||||
#define _ATTN_BWD_PREPROCESS_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef _ATTN_BWD_PREPROCESS_TVM_FFI_NAME
|
||||
#define _ATTN_BWD_PREPROCESS_TVM_FFI_NAME ""
|
||||
#endif
|
||||
|
||||
tvm::ffi::Tensor AttnBwdPreprocess(tvm::ffi::Tensor o, tvm::ffi::Tensor do_,
|
||||
tvm::ffi::Shape mshape,
|
||||
const int32_t kHeadDim) {
|
||||
const int32_t kBatch = mshape[0], kNHead = mshape[1], kNCtx = mshape[2],
|
||||
kPreBlock = 128;
|
||||
at::Tensor deltaTorch = at::empty(mshape, at::kFloat, std::nullopt,
|
||||
at::Device(at::kCUDA, o.device().device_id),
|
||||
std::nullopt, std::nullopt);
|
||||
tvm::ffi::Tensor delta =
|
||||
tvm::ffi::Tensor::FromDLPack(at::toDLPack(deltaTorch));
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid(kNCtx / kPreBlock,
|
||||
kBatch * kNHead, 1);
|
||||
tvm::ffi::Array<tvm::ffi::Any> args = {o, do_, delta, kBatch, kNHead};
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
|
||||
{"N_CTX", kNCtx},
|
||||
{"BLOCK_M", kPreBlock},
|
||||
{"HEAD_DIM", kHeadDim},
|
||||
};
|
||||
DLDevice device = o.device();
|
||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||
_ATTN_BWD_PREPROCESS_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return delta;
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef().def(_ATTN_BWD_PREPROCESS_TVM_FFI_NAME, AttnBwdPreprocess);
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <tvm/ffi/container/tensor.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef _ATTN_FWD_STUB
|
||||
#define _ATTN_FWD_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef _ATTN_FWD_TVM_FFI_NAME
|
||||
#define _ATTN_FWD_TVM_FFI_NAME ""
|
||||
#endif
|
||||
|
||||
tvm::ffi::Tuple<tvm::ffi::Tensor, tvm::ffi::Tensor>
|
||||
AttnFwd(tvm::ffi::Tensor q, tvm::ffi::Tensor k, tvm::ffi::Tensor v, bool casual,
|
||||
float smScale) {
|
||||
const tvm::ffi::ShapeView &qshape = q.shape(), &kshape = k.shape(),
|
||||
&vshape = v.shape();
|
||||
const int32_t kB = qshape[0], kH = qshape[1], kN = qshape[2], kQ = qshape[3],
|
||||
kK = kshape[3], kV = vshape[3], stage = casual ? 3 : 1;
|
||||
at::Tensor qTorch = at::fromDLPack(q.ToDLPack()),
|
||||
oTorch = at::empty_like(qTorch),
|
||||
mTorch =
|
||||
at::empty({kB, kH, kN}, qTorch.options().dtype(at::kFloat));
|
||||
tvm::ffi::Tensor o = tvm::ffi::Tensor::FromDLPack(at::toDLPack(oTorch)),
|
||||
m = tvm::ffi::Tensor::FromDLPack(at::toDLPack(mTorch));
|
||||
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
|
||||
[kB, kH, kN](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
|
||||
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
|
||||
const int32_t kBlockM = meta["BLOCK_M"].cast<int32_t>();
|
||||
return tvm::ffi::Tuple<int32_t, int32_t, int32_t>(
|
||||
(kN + kBlockM - 1) / kBlockM, kB * kH, 1);
|
||||
});
|
||||
tvm::ffi::Array<tvm::ffi::Any> args = {smScale, m, kB, kH, q, k, v, o, kN};
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
|
||||
{"HEAD_DIM", kK},
|
||||
{"STAGE", stage},
|
||||
};
|
||||
DLDevice device = q.device();
|
||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||
_ATTN_FWD_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return tvm::ffi::Tuple{m, o};
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef().def(_ATTN_FWD_TVM_FFI_NAME, AttnFwd);
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef MATMUL_KERNEL_STUB
|
||||
#define MATMUL_KERNEL_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef MATMUL_NAME
|
||||
#define MATMUL_NAME ""
|
||||
#endif
|
||||
|
||||
tvm::ffi::Tensor Matmul(tvm::ffi::Tensor a, tvm::ffi::Tensor b,
|
||||
tvm::ffi::String activation) {
|
||||
at::Tensor atorch = at::fromDLPack(a.ToDLPack()),
|
||||
btorch = at::fromDLPack(b.ToDLPack());
|
||||
const int32_t M = atorch.size(0), K = atorch.size(1), N = btorch.size(1);
|
||||
at::Tensor ctorch = at::empty({M, N}, atorch.options());
|
||||
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
|
||||
[M, N](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
|
||||
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
|
||||
const int32_t BLOCK_SIZE_M = meta["BLOCK_SIZE_M"].cast<int32_t>(),
|
||||
BLOCK_SIZE_N = meta["BLOCK_SIZE_N"].cast<int32_t>();
|
||||
return tvm::ffi::Tuple<int32_t, int32_t, int32_t>{
|
||||
(M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M *
|
||||
((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N),
|
||||
1, 1};
|
||||
});
|
||||
DLDevice device = a.device();
|
||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||
tvm::ffi::Tensor c = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ctorch));
|
||||
tvm::ffi::Array<tvm::ffi::Any> args = {a,
|
||||
b,
|
||||
c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
atorch.stride(0),
|
||||
atorch.stride(1),
|
||||
btorch.stride(0),
|
||||
btorch.stride(1),
|
||||
ctorch.stride(0),
|
||||
ctorch.stride(1)};
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
|
||||
{"ACTIVATION", activation},
|
||||
};
|
||||
MATMUL_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return c;
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef().def(MATMUL_NAME, Matmul);
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
from pathlib import Path
|
||||
import time
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton_tvm_ffi
|
||||
|
||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||
|
||||
|
||||
def get_autotune_config():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.autotune(
|
||||
configs=get_autotune_config(),
|
||||
key=["M", "N", "K"],
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
tl.assume(pid_m >= 0)
|
||||
tl.assume(pid_n >= 0)
|
||||
tl.assume(stride_am > 0)
|
||||
tl.assume(stride_ak > 0)
|
||||
tl.assume(stride_bn > 0)
|
||||
tl.assume(stride_bk > 0)
|
||||
tl.assume(stride_cm > 0)
|
||||
tl.assume(stride_cn > 0)
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
if ACTIVATION == "leaky_relu":
|
||||
accumulator = leaky_relu(accumulator)
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
return tl.where(x >= 0, x, 0.01 * x)
|
||||
|
||||
|
||||
def matmul_triton(a, b, activation=""):
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.is_contiguous(), "Matrix A must be contiguous"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
||||
matmul_kernel[
|
||||
lambda META: (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
ACTIVATION=activation,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[matmul_kernel],
|
||||
Path(__file__).parent / "mm.cc",
|
||||
)
|
||||
def matmul(a: torch.Tensor, b: torch.Tensor, activation: str = "") -> torch.Tensor: ...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(0)
|
||||
a = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
|
||||
b = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
|
||||
torch_output = torch.matmul(a, b)
|
||||
triton_output = matmul_triton(a, b, "")
|
||||
tvm_ffi_output = matmul(a, b, "")
|
||||
assert torch.allclose(torch_output, triton_output, atol=1e-2, rtol=1e-2)
|
||||
assert torch.allclose(torch_output, tvm_ffi_output, atol=1e-2, rtol=1e-2)
|
||||
tvm_ffi_output = matmul(a, b, "")
|
||||
assert torch.allclose(torch_output, tvm_ffi_output, atol=1e-2, rtol=1e-2)
|
||||
|
||||
round = 1000
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
a @ b
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
matmul_triton(a, b, "")
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
matmul(a, b, "")
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
@@ -0,0 +1,34 @@
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef SOFTMAX_KERNEL_STUB
|
||||
#define SOFTMAX_KERNEL_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef SOFTMAX_NAME
|
||||
#define SOFTMAX_NAME ""
|
||||
#endif
|
||||
|
||||
tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) {
|
||||
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
|
||||
at::Tensor ytorch = at::empty_like(xtorch);
|
||||
uint32_t nRows = xtorch.size(0), nCols = xtorch.size(1),
|
||||
xStride = xtorch.stride(0), yStride = ytorch.stride(0),
|
||||
BLOCK_SIZE = 1u << (32 - __builtin_clz(nCols - 1));
|
||||
tvm::ffi::Tensor y = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ytorch));
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{nRows / 1024, 1, 1};
|
||||
DLDevice device = x.device();
|
||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||
tvm::ffi::Array<tvm::ffi::Any> args = {y, x, xStride, yStride,
|
||||
nRows, nCols, BLOCK_SIZE};
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {};
|
||||
SOFTMAX_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return y;
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef().def(SOFTMAX_NAME, Softmax);
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton_tvm_ffi
|
||||
|
||||
|
||||
@triton_tvm_ffi.jit
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
output_ptr,
|
||||
input_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_rows,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
row_start = tl.program_id(0)
|
||||
row_step = tl.num_programs(0)
|
||||
for row_idx in tl.range(row_start, n_rows, row_step):
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
mask = col_offsets < n_cols
|
||||
row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
tl.store(output_ptrs, softmax_output, mask=mask)
|
||||
|
||||
|
||||
def softmax_triton(x):
|
||||
n_rows, n_cols = x.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(n_cols)
|
||||
num_warps = 8
|
||||
num_stages = 4
|
||||
y = torch.empty_like(x)
|
||||
softmax_kernel[(n_rows, 1, 1)](
|
||||
y,
|
||||
x,
|
||||
x.stride(0),
|
||||
y.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[softmax_kernel],
|
||||
Path(__file__).parent / "softmax.cc",
|
||||
)
|
||||
def softmax(x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.randn(1823, 781, device="cuda")
|
||||
y_torch = torch.softmax(x, axis=1)
|
||||
y_triton = softmax_triton(x)
|
||||
y_tvm_ffi = softmax(x)
|
||||
assert torch.allclose(y_torch, y_triton), (y_torch, y_triton)
|
||||
assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi)
|
||||
y_tvm_ffi = softmax(x)
|
||||
assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi)
|
||||
|
||||
round = 1000
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
torch.softmax(x, axis=1)
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
softmax_triton(x)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
softmax(x)
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
@@ -1,49 +0,0 @@
|
||||
#ifndef TRITON_TVM_FFI_EXCEPTION_H_
|
||||
#define TRITON_TVM_FFI_EXCEPTION_H_
|
||||
|
||||
#include "type.h"
|
||||
#include <cuda.h>
|
||||
#include <exception>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
class CUDAException : public std::exception {
|
||||
public:
|
||||
CUDAException(CUresult code);
|
||||
const char *what() const noexcept override;
|
||||
|
||||
private:
|
||||
const CUresult code_;
|
||||
};
|
||||
|
||||
class NotImplementedException : public std::exception {
|
||||
public:
|
||||
NotImplementedException(std::string_view name);
|
||||
const char *what() const noexcept override;
|
||||
|
||||
private:
|
||||
const std::string message_;
|
||||
};
|
||||
|
||||
class UnmatchedArgumentException : public std::exception {
|
||||
public:
|
||||
UnmatchedArgumentException(std::string_view name, size_t len, size_t expect);
|
||||
const char *what() const noexcept override;
|
||||
|
||||
private:
|
||||
const std::string message_;
|
||||
};
|
||||
|
||||
class UnknownTypeException : public std::exception {
|
||||
public:
|
||||
UnknownTypeException(Type type);
|
||||
UnknownTypeException(std::string_view type);
|
||||
const char *what() const noexcept override;
|
||||
|
||||
private:
|
||||
const std::string message_;
|
||||
};
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
@@ -1,48 +0,0 @@
|
||||
#ifndef TRITON_TVM_FFI_LAUNCH_H_
|
||||
#define TRITON_TVM_FFI_LAUNCH_H_
|
||||
|
||||
#include "type.h"
|
||||
#include <cuda.h>
|
||||
#include <tvm/ffi/object.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
class TVMFFILauncherImplObj : public tvm::ffi::Object {
|
||||
public:
|
||||
TVMFFILauncherImplObj(const tvm::ffi::Array<Type> &signature,
|
||||
bool launchCooperativeGrid, bool launchAsync);
|
||||
TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default;
|
||||
TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default;
|
||||
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function, int32_t numWarps, int32_t numCtas,
|
||||
int32_t sharedMemory, uint64_t globalScratch,
|
||||
uint64_t profileScratch,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
|
||||
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl",
|
||||
TVMFFILauncherImplObj, tvm::ffi::Object);
|
||||
|
||||
private:
|
||||
tvm::ffi::Array<Type> signature_;
|
||||
const bool launchCooperativeGrid_;
|
||||
const bool launchAsync_;
|
||||
};
|
||||
|
||||
class TVMFFILauncherImpl : public tvm::ffi::ObjectRef {
|
||||
public:
|
||||
TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
|
||||
bool launchCooperativeGrid, bool launchAsync);
|
||||
using tvm::ffi::ObjectRef::ObjectRef;
|
||||
using tvm::ffi::ObjectRef::operator=;
|
||||
void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function, int32_t numWarps, int32_t numCtas,
|
||||
int32_t sharedMemory, uint64_t globalScratch,
|
||||
uint64_t profileScratch,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const;
|
||||
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl,
|
||||
tvm::ffi::ObjectRef,
|
||||
TVMFFILauncherImplObj);
|
||||
};
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
@@ -1,19 +0,0 @@
|
||||
#ifndef TRITON_TVM_FFI_MACRO_H_
|
||||
#define TRITON_TVM_FFI_MACRO_H_
|
||||
|
||||
#include "exception.h"
|
||||
|
||||
#if defined(__GNUC__) || defined(__clang__)
|
||||
#define TRITON_TVM_FFI_INLINE __attribute__((always_inline)) inline
|
||||
#endif
|
||||
|
||||
#define UNLIKELY(cond) __builtin_expect((cond), 0)
|
||||
|
||||
#define CUDA_CHECK(code) \
|
||||
do { \
|
||||
if (UNLIKELY((code) != CUDA_SUCCESS)) { \
|
||||
throw triton_tvm_ffi::CUDAException(code); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,34 @@
|
||||
#ifndef TRITON_TVM_FFI_GRID_H_
|
||||
#define TRITON_TVM_FFI_GRID_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <typename T>
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim(const T &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta);
|
||||
|
||||
template <>
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
|
||||
const tvm::ffi::Tuple<int32_t, int32_t, int32_t> &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &) {
|
||||
return grid;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim<tvm::ffi::Function>(
|
||||
const tvm::ffi::Function &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta) {
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> tuple =
|
||||
grid(meta).cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
|
||||
return MakeGridDim(tuple, meta);
|
||||
}
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,39 @@
|
||||
#ifndef TRITON_TVM_FFI_KERNEL_H_
|
||||
#define TRITON_TVM_FFI_KERNEL_H_
|
||||
|
||||
#include "macro.h"
|
||||
#include <cstdint>
|
||||
#include <cuda.h>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <const char kFnName[], const char kCubin[], size_t kSMem>
|
||||
inline CUfunction GetKernel(int32_t device) {
|
||||
static std::unordered_map<int32_t, CUfunction> functions = {};
|
||||
if (functions.find(device) == functions.end()) {
|
||||
CUmodule module;
|
||||
CUfunction func;
|
||||
__CUDA_CHECK(cuModuleLoadData(&module, kCubin));
|
||||
__CUDA_CHECK(cuModuleGetFunction(&func, module, kFnName));
|
||||
if (kSMem > 49152) {
|
||||
int32_t shared_optin, shared_static;
|
||||
__CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
if (shared_optin >= kSMem) {
|
||||
__CUDA_CHECK(cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
|
||||
__CUDA_CHECK(cuFuncSetAttribute(
|
||||
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
}
|
||||
functions[device] = func;
|
||||
}
|
||||
return functions[device];
|
||||
};
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,22 @@
|
||||
#ifndef TRITON_TVM_FFI_MACRO_H_
|
||||
#define TRITON_TVM_FFI_MACRO_H_
|
||||
|
||||
#include <cuda.h>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#define __CUDA_CHECK(__code) \
|
||||
do { \
|
||||
if ((__code) != CUDA_SUCCESS) { \
|
||||
const char *errorName = nullptr, *errorStr = nullptr; \
|
||||
cuGetErrorName((__code), &errorName); \
|
||||
cuGetErrorString((__code), &errorStr); \
|
||||
std::ostringstream __oss; \
|
||||
__oss << "[" << errorName << "] " << errorStr << ", at " << __FILE__ \
|
||||
<< ":" << __LINE__; \
|
||||
throw std::runtime_error(__oss.str()); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,52 @@
|
||||
#ifndef TRITON_TVM_FFI_META_H_
|
||||
#define TRITON_TVM_FFI_META_H_
|
||||
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <const char... Ks[]> struct FillMetaImpl {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
|
||||
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs);
|
||||
};
|
||||
|
||||
template <> struct FillMetaImpl<> {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
|
||||
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {}
|
||||
};
|
||||
|
||||
template <const char K[], const char... Ks[]> struct FillMetaImpl<K, Ks...> {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
|
||||
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
|
||||
if (argsBegin != argsEnd) {
|
||||
meta.Set(K, *argsBegin++);
|
||||
} else if (auto val = kwargs.Get(K)) {
|
||||
meta.Set(K, *val);
|
||||
}
|
||||
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
|
||||
}
|
||||
};
|
||||
|
||||
template <const char... Ks[]> struct FillMeta {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &args,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator argsBegin = args.begin();
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator argsEnd = args.end();
|
||||
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
@@ -1,54 +0,0 @@
|
||||
#ifndef TRITON_TVM_FFI_TYPE_H_
|
||||
#define TRITON_TVM_FFI_TYPE_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
// --------------- Definitions --------------- //
|
||||
|
||||
#define TYPE_TABLE_NATIVE(V) \
|
||||
V(I1, "i1", int8_t) \
|
||||
V(I8, "i8", int8_t) \
|
||||
V(I16, "i16", int16_t) \
|
||||
V(I32, "i32", int32_t) \
|
||||
V(I64, "i64", int64_t) \
|
||||
V(U1, "u1", uint8_t) \
|
||||
V(U8, "u8", uint8_t) \
|
||||
V(U16, "u16", uint16_t) \
|
||||
V(U32, "u32", uint32_t) \
|
||||
V(U64, "u64", uint64_t) \
|
||||
V(FP16, "fp16", double) \
|
||||
V(BF16, "bf16", double) \
|
||||
V(FP32, "f32", double) \
|
||||
V(FP64, "fp64", double)
|
||||
|
||||
#define TYPE_TABLE(V) \
|
||||
TYPE_TABLE_NATIVE(V) \
|
||||
V(PTR, "*?", void *) \
|
||||
V(CONSTEXPR, "constexpr", void)
|
||||
|
||||
enum class Type : int64_t {
|
||||
#define DEFINE_ENUM(type, str, ctype) type,
|
||||
TYPE_TABLE(DEFINE_ENUM)
|
||||
#undef DEFINE_ENUM
|
||||
};
|
||||
|
||||
const char *TypeToString(Type type);
|
||||
tvm::ffi::Optional<Type> StringToType(const tvm::ffi::String &name);
|
||||
const char *TypeToCType(Type type);
|
||||
|
||||
template <Type T> struct type_to_ctype;
|
||||
#define DEFINE_TYPE_TO_CTYPE(type, str, ctype) \
|
||||
template <> struct type_to_ctype<Type::type> { using t = ctype; };
|
||||
TYPE_TABLE(DEFINE_TYPE_TO_CTYPE)
|
||||
#undef DEFINE_TYPE_TO_CTYPE
|
||||
template <Type T> using type_to_ctype_t = typename type_to_ctype<T>::t;
|
||||
|
||||
// --------------- Implementations --------------- //
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
+5
-8
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "triton-tvm-ffi"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
description = "A Python package for the FFI bindings of Triton TVM."
|
||||
readme = "README.md"
|
||||
dependencies = [
|
||||
"apache-tvm-ffi",
|
||||
@@ -9,12 +9,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["apache-tvm-ffi", "scikit-build-core"]
|
||||
requires = ["scikit-build-core"]
|
||||
build-backend = "scikit_build_core.build"
|
||||
|
||||
[project.entry-points."triton.backends"]
|
||||
nvidia = "triton_tvm_ffi"
|
||||
|
||||
[tool.scikit-build]
|
||||
wheel.install-dir = "triton_tvm_ffi"
|
||||
wheel.packages = ["python/triton_tvm_ffi"]
|
||||
[tool.setuptools]
|
||||
packages = ["triton_tvm_ffi"]
|
||||
package-dir = {"" = "python"}
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
# tvm-ffi-stubgen(begin): export/_ffi_api
|
||||
# fmt: off
|
||||
# isort: off
|
||||
from ._ffi_api import * # noqa: F403
|
||||
from ._ffi_api import __all__ as _ffi_api__all__
|
||||
if "__all__" not in globals():
|
||||
__all__ = []
|
||||
__all__.extend(_ffi_api__all__)
|
||||
# isort: on
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
from .jit import jit
|
||||
from .utils import include_paths
|
||||
from .wrap import torch_wrap, wrap
|
||||
|
||||
__all__ = ["include_paths", "jit", "torch_wrap", "wrap"]
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
# tvm-ffi-stubgen(begin): import-section
|
||||
# fmt: off
|
||||
# isort: off
|
||||
from __future__ import annotations
|
||||
from tvm_ffi import Object as _ffi_Object, init_ffi_api as _FFI_INIT_FUNC, register_object as _FFI_REG_OBJ
|
||||
from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from tvm_ffi import Object
|
||||
from typing import Any
|
||||
# isort: on
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
|
||||
# tvm-ffi-stubgen(import-object): tvm_ffi.libinfo.load_lib_module;False;_FFI_LOAD_LIB
|
||||
LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "triton_tvm_ffi")
|
||||
# tvm-ffi-stubgen(begin): global/triton_tvm_ffi
|
||||
# fmt: off
|
||||
_FFI_INIT_FUNC("triton_tvm_ffi", __name__)
|
||||
if TYPE_CHECKING:
|
||||
def string_to_type(_0: str, /) -> int | None: ...
|
||||
def type_to_ctype(_0: int, /) -> str: ...
|
||||
def type_to_string(_0: int, /) -> str: ...
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
|
||||
|
||||
# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ
|
||||
# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
|
||||
@_FFI_REG_OBJ("triton_tvm_ffi.TVMFFILauncherImpl")
|
||||
class TVMFFILauncherImpl(_ffi_Object):
|
||||
"""FFI binding for `triton_tvm_ffi.TVMFFILauncherImpl`."""
|
||||
|
||||
# tvm-ffi-stubgen(begin): object/triton_tvm_ffi.TVMFFILauncherImpl
|
||||
# fmt: off
|
||||
if TYPE_CHECKING:
|
||||
@staticmethod
|
||||
def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ...
|
||||
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: Sequence[Any], /) -> None: ...
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# tvm-ffi-stubgen(begin): __all__
|
||||
"LIB",
|
||||
"TVMFFILauncherImpl",
|
||||
"string_to_type",
|
||||
"type_to_ctype",
|
||||
"type_to_string",
|
||||
# tvm-ffi-stubgen(end)
|
||||
]
|
||||
@@ -1,7 +0,0 @@
|
||||
from triton.backends.nvidia.compiler import CUDABackend
|
||||
|
||||
|
||||
class TVMFFIBackend(CUDABackend): ...
|
||||
|
||||
|
||||
del CUDABackend
|
||||
@@ -1,176 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import os
|
||||
from typing import Any, Final, List, Type
|
||||
|
||||
import jinja2
|
||||
from triton.backends.nvidia.driver import CudaDriver
|
||||
from triton.runtime import _allocation
|
||||
import tvm_ffi
|
||||
|
||||
from . import TVMFFILauncherImpl, utils, string_to_type, type_to_ctype
|
||||
|
||||
|
||||
class TVMLauncher(object):
|
||||
def __init__(self, src, metadata, *args, **kwargs) -> TVMLauncher:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.signature: List[str] = [*src.signature.values()]
|
||||
self.num_ctas: Final[int] = getattr(metadata, "num_ctas", 1)
|
||||
self.global_scratch_size: Final[int] = metadata.global_scratch_size
|
||||
self.global_scratch_align: Final[int] = metadata.global_scratch_align
|
||||
self.profile_scratch_size: Final[int] = metadata.profile_scratch_size
|
||||
self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
|
||||
self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
|
||||
self.launch_pdl: Final[bool] = metadata.launch_pdl
|
||||
if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}:
|
||||
mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline(
|
||||
"launch",
|
||||
cpp_sources=[self.codegen],
|
||||
extra_ldflags=["-Wl,--no-as-needed", "-lcuda"],
|
||||
extra_include_paths=[
|
||||
f"{tvm_ffi.cpp.extension._find_cuda_home()}/include"
|
||||
],
|
||||
)
|
||||
launch: tvm_ffi.Function = mod.get_function("launch")
|
||||
self.launch = (
|
||||
lambda grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*args: launch(
|
||||
grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*args,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
|
||||
[string_to_type(t) for t in self.signature],
|
||||
self.launch_cooperative_grid,
|
||||
self.launch_pdl,
|
||||
)
|
||||
self.launch = (
|
||||
lambda grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*args: self.impl.launch(
|
||||
grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
stream,
|
||||
function,
|
||||
kernel_metadata,
|
||||
launch_metadata,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
*args,
|
||||
):
|
||||
def allocate_scratch(size, align, allocator):
|
||||
if size > 0:
|
||||
grid_size = gridX * gridY * gridZ
|
||||
alloc_size = grid_size * self.num_ctas * size
|
||||
alloc_fn = allocator.get()
|
||||
return alloc_fn(alloc_size, align, stream)
|
||||
return None
|
||||
|
||||
global_scratch = allocate_scratch(
|
||||
self.global_scratch_size, self.global_scratch_align, _allocation._allocator
|
||||
)
|
||||
profile_scratch = allocate_scratch(
|
||||
self.profile_scratch_size,
|
||||
self.profile_scratch_align,
|
||||
_allocation._profile_allocator,
|
||||
)
|
||||
|
||||
def canonicalize(obj: Any) -> int:
|
||||
if obj is None:
|
||||
return 0
|
||||
elif isinstance(obj, int):
|
||||
return obj
|
||||
elif get_ptr := getattr(obj, "data_ptr", None):
|
||||
return get_ptr()
|
||||
else:
|
||||
raise TypeError(f"cannot canonicalize object of type {type(obj)}")
|
||||
|
||||
(num_warps, num_ctas, shared_memory) = kernel_metadata
|
||||
if launch_enter_hook:
|
||||
launch_enter_hook(launch_metadata)
|
||||
ret = self.launch(
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
canonicalize(global_scratch),
|
||||
canonicalize(profile_scratch),
|
||||
*args,
|
||||
)
|
||||
if launch_exit_hook:
|
||||
launch_exit_hook(launch_metadata)
|
||||
return ret
|
||||
|
||||
@cached_property
|
||||
def codegen(self) -> str:
|
||||
env: jinja2.Environment = jinja2.Environment(
|
||||
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
template: jinja2.Template = env.get_template("launch.c.j2")
|
||||
signature: List[int] = list(
|
||||
map(lambda t: type_to_ctype(string_to_type(t)), self.signature),
|
||||
)
|
||||
html: str = template.render(signature=signature)
|
||||
return html
|
||||
|
||||
|
||||
class TVMFFIDriver(CudaDriver):
|
||||
def __init__(self, *args, **kwargs) -> TVMFFIDriver:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.utils = utils
|
||||
self.launcher_cls: Type[TVMLauncher] = TVMLauncher
|
||||
|
||||
|
||||
del CudaDriver
|
||||
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import inspect
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Final,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
from triton.compiler import CompiledKernel
|
||||
from triton.runtime import Autotuner, JITFunction
|
||||
import tvm_ffi
|
||||
|
||||
from .utils import type_canonicalize
|
||||
|
||||
|
||||
class TVMFFIJITFunction(object):
|
||||
def __init__(self, fn: Union[Autotuner, JITFunction], *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fn: Final[Union[Autotuner, JITFunction]] = fn
|
||||
self.signature: List[str] = [*inspect.signature(self.basefn).parameters.keys()]
|
||||
self.best_config: Optional[Dict[str, Any]] = None
|
||||
self.ctypes: Optional[List[Optional[str]]] = None
|
||||
self.kernel: Optional[bytes] = None
|
||||
self.num_warps: Optional[int] = None
|
||||
self.shmem: int = 0
|
||||
|
||||
@tvm_ffi.register_global_func(self.fullname)
|
||||
def _(
|
||||
grid: Union[
|
||||
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||
],
|
||||
_device: int,
|
||||
_stream: int,
|
||||
args: Sequence[Any],
|
||||
kwargs: Mapping[str, Any],
|
||||
):
|
||||
args: Iterator[Any] = map(self.canonicalize, args)
|
||||
kwargs: Dict[str, Any] = {
|
||||
k: self.canonicalize(v) for k, v in kwargs.items()
|
||||
}
|
||||
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
|
||||
self.num_warps, _, self.shmem = kernel.packed_metadata
|
||||
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
||||
self.kernel = kernel.kernel
|
||||
if isinstance(self.fn, Autotuner):
|
||||
self.best_config = self.fn.best_config.all_kwargs()
|
||||
return kernel
|
||||
|
||||
def __getitem__(
|
||||
self,
|
||||
grid: Union[
|
||||
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||
],
|
||||
):
|
||||
return self.fn[grid]
|
||||
|
||||
@cached_property
|
||||
def basefn(self) -> Callable:
|
||||
return self.jitfn.fn
|
||||
|
||||
@property
|
||||
def cache_hash(self) -> int:
|
||||
return self.ctypes_hash ^ self.kernel_hash
|
||||
|
||||
@property
|
||||
def ctypes_hash(self) -> int:
|
||||
return hash(tuple(self.ctypes) if self.ctypes is not None else None)
|
||||
|
||||
@cached_property
|
||||
def fnname(self) -> str:
|
||||
return self.basefn.__name__
|
||||
|
||||
@cached_property
|
||||
def fullname(self) -> str:
|
||||
return f"triton.{self.name}"
|
||||
|
||||
@cached_property
|
||||
def jitfn(self) -> JITFunction:
|
||||
fn: Union[Autotuner, JITFunction] = self.fn
|
||||
while not isinstance(fn, JITFunction):
|
||||
fn = fn.fn
|
||||
return fn
|
||||
|
||||
@property
|
||||
def kernel_hash(self) -> int:
|
||||
return hash(self.kernel)
|
||||
|
||||
@property
|
||||
def kernel_cstr(self) -> Optional[str]:
|
||||
if self.kernel is not None:
|
||||
return "".join(f"\\x{byte:02x}" for byte in self.kernel)
|
||||
else:
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.fnname}_{hash(self.basefn)}"
|
||||
|
||||
@staticmethod
|
||||
def canonicalize(val: Any) -> Any:
|
||||
if hasattr(val, "__dlpack__"):
|
||||
return torch.from_dlpack(val)
|
||||
else:
|
||||
return val
|
||||
|
||||
|
||||
def jit(fn: JITFunction) -> TVMFFIJITFunction:
|
||||
return TVMFFIJITFunction(fn)
|
||||
@@ -0,0 +1,45 @@
|
||||
#include <cuda.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
#include "triton_tvm_ffi/grid.h"
|
||||
#include "triton_tvm_ffi/kernel.h"
|
||||
#include "triton_tvm_ffi/macro.h"
|
||||
#include "triton_tvm_ffi/meta.h"
|
||||
|
||||
#define {{ name | upper }}_NAME "{{ uniquename }}"
|
||||
{% for fn in fns %}
|
||||
{% if fn.ctypes is none %}
|
||||
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
|
||||
{% else %}
|
||||
static constexpr char __fnname_{{ fn.fnname }}[] = "{{ fn.fnname }}";
|
||||
static constexpr char __cubin_{{ fn.fnname }}[] = "{{ fn.kernel_cstr }}";
|
||||
|
||||
#define {{ fn.fnname | upper }}_STUB(__grid, __device, __stream, __args, __kwargs) do { \
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||
{% if fn.best_config != none %}
|
||||
{% for k, v in fn.best_config.items() %}
|
||||
{ "{{ k }}", {{ v }} }, \
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
}; \
|
||||
{% for type in fn.signature %}
|
||||
static constexpr char __varname{{ loop.index0 }}[] = "{{ type }}"; \
|
||||
{% endfor %}
|
||||
triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }}{% if not loop.last %}, {% endif %}{% endfor %}>::apply(__meta, __args, __kwargs); \
|
||||
CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
|
||||
void *dummy = nullptr; \
|
||||
const size_t __args_len = __args.size(); \
|
||||
{% for ctype in fn.ctypes %}
|
||||
{% if ctype == "CUdeviceptr" %}
|
||||
void *__arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr() : __kwargs[__varname{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
|
||||
{% elif ctype != none %}
|
||||
{{ ctype }} __arg{{ loop.index0 }} = {{ loop.index0 }} < __args_len ? __args[{{ loop.index0 }}].cast<{{ ctype }}>() : __kwargs[__varname{{ loop.index0 }}].cast<{{ ctype }}>(); \
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
|
||||
__CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast<CUstream>(__stream), __params, nullptr)); \
|
||||
} while (false)
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
{{ code }}
|
||||
@@ -1,64 +0,0 @@
|
||||
#include <assert.h>
|
||||
#include <cuda.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C"
|
||||
#endif
|
||||
TVM_FFI_DLL_EXPORT void
|
||||
__tvm_ffi_launch(void *handle, const TVMFFIAny *args, int32_t num_args,
|
||||
TVMFFIAny *result) {
|
||||
int32_t gridX = args[0].v_int64;
|
||||
int32_t gridY = args[1].v_int64;
|
||||
int32_t gridZ = args[2].v_int64;
|
||||
CUstream stream = (CUstream)args[3].v_uint64;
|
||||
CUfunction function = (CUfunction)args[4].v_uint64;
|
||||
int32_t numWarps = args[5].v_int64;
|
||||
int32_t numCtas = args[6].v_int64;
|
||||
int32_t sharedMemory = args[7].v_int64;
|
||||
uint64_t globalScratch = args[8].v_uint64;
|
||||
uint64_t profileScratch = args[9].v_uint64;
|
||||
if (gridX * gridY * gridZ > 0) {
|
||||
CUlaunchAttribute launchAttr[4];
|
||||
CUlaunchConfig config;
|
||||
config.gridDimX = gridX * numCtas;
|
||||
config.gridDimY = gridY;
|
||||
config.gridDimZ = gridZ;
|
||||
config.blockDimX = 32 * numWarps;
|
||||
config.blockDimY = 1;
|
||||
config.blockDimZ = 1;
|
||||
config.sharedMemBytes = sharedMemory;
|
||||
config.hStream = stream;
|
||||
config.attrs = launchAttr;
|
||||
int32_t numAttrs = 0;
|
||||
// TODO: check `launchPdl`
|
||||
// TODO: check `launchCooperativeGrid`
|
||||
if (numCtas != 1) {
|
||||
CUlaunchAttribute clusterAttr;
|
||||
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
||||
clusterAttr.value.clusterDim.x = numCtas;
|
||||
clusterAttr.value.clusterDim.y = 1;
|
||||
clusterAttr.value.clusterDim.z = 1;
|
||||
launchAttr[numAttrs++] = clusterAttr;
|
||||
CUlaunchAttribute clusterSchedulingAttr;
|
||||
clusterSchedulingAttr.id =
|
||||
CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
||||
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference =
|
||||
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
||||
launchAttr[numAttrs++] = clusterSchedulingAttr;
|
||||
}
|
||||
config.numAttrs = numAttrs;
|
||||
{% for type in signature %}
|
||||
{% if type == "void *" %}
|
||||
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
||||
{% elif type == "int32_t" %}
|
||||
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
|
||||
{% elif type == "void" %}
|
||||
{% else %}
|
||||
assert(false, "unsupported type yet {{ type }}");
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
void *params[] = { {% for type in signature %} {% if type != "void" %} &arg{{ loop.index0 }}, {% endif %} {% endfor %}&globalScratch, &profileScratch };
|
||||
cuLaunchKernelEx(&config, function, params, NULL);
|
||||
}
|
||||
}
|
||||
@@ -1,36 +1,16 @@
|
||||
# tvm-ffi-stubgen(begin): import-section
|
||||
# fmt: off
|
||||
# isort: off
|
||||
from __future__ import annotations
|
||||
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
# isort: on
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
import sysconfig
|
||||
from typing import List, Optional
|
||||
|
||||
# tvm-ffi-stubgen(begin): global/triton_tvm_ffi.utils
|
||||
# fmt: off
|
||||
_FFI_INIT_FUNC("triton_tvm_ffi.utils", __name__)
|
||||
if TYPE_CHECKING:
|
||||
def build_signature_metadata(*args: Any) -> Any: ...
|
||||
def cuOccupancyMaxActiveClusters(*args: Any) -> Any: ...
|
||||
def fill_tma_descriptor(*args: Any) -> Any: ...
|
||||
def get_device_properties(_0: int, /) -> Mapping[str, int]: ...
|
||||
def load_binary(_0: str, _1: bytes, _2: int, _3: int, /) -> tuple[int, int, int, int, int]: ...
|
||||
def set_printf_fifo_size(*args: Any) -> Any: ...
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
from triton.backends.nvidia.driver import ty_to_cpp
|
||||
|
||||
__all__ = [
|
||||
# tvm-ffi-stubgen(begin): __all__
|
||||
"build_signature_metadata",
|
||||
"cuOccupancyMaxActiveClusters",
|
||||
"fill_tma_descriptor",
|
||||
"get_device_properties",
|
||||
"load_binary",
|
||||
"set_printf_fifo_size",
|
||||
# tvm-ffi-stubgen(end)
|
||||
]
|
||||
|
||||
def include_paths() -> List[str]:
|
||||
pkg_path: str = sysconfig.get_path("purelib")
|
||||
return [f"{pkg_path}/triton_tvm_ffi/include"]
|
||||
|
||||
|
||||
def type_canonicalize(ty: str) -> Optional[str]:
|
||||
if ty == "constexpr":
|
||||
return None
|
||||
else:
|
||||
return ty_to_cpp(ty)
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
from functools import cached_property
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Final, List, Optional, Sequence, Union
|
||||
|
||||
import jinja2
|
||||
import torch.utils.cpp_extension
|
||||
import tvm_ffi
|
||||
|
||||
from .jit import TVMFFIJITFunction
|
||||
from .utils import include_paths
|
||||
|
||||
|
||||
class TVMFFIWrapperFunction(object):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
fns: List[TVMFFIJITFunction],
|
||||
code: Union[str, Path, TextIOWrapper],
|
||||
extra_cflags: Optional[Sequence[str]] = None,
|
||||
extra_cuda_cflags: Optional[Sequence[str]] = None,
|
||||
extra_ldflags: Optional[Sequence[str]] = None,
|
||||
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.name: Final[str] = name
|
||||
self.fns: List[TVMFFIJITFunction] = [*fns]
|
||||
if isinstance(code, Path):
|
||||
with open(code, "r") as f:
|
||||
self.code: Final[str] = f.read()
|
||||
elif isinstance(code, TextIOWrapper):
|
||||
self.code: Final[str] = code.read()
|
||||
else:
|
||||
self.code: Final[str] = f"{code}"
|
||||
self.extra_cflags: Optional[Sequence[str]] = extra_cflags
|
||||
self.extra_cuda_cflags: Optional[Sequence[str]] = extra_cuda_cflags
|
||||
self.extra_ldflags: Optional[Sequence[str]] = extra_ldflags
|
||||
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
|
||||
extra_include_paths
|
||||
)
|
||||
self.env: Final[jinja2.Environment] = jinja2.Environment(
|
||||
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
|
||||
trim_blocks=True,
|
||||
)
|
||||
self.tpl: Final[jinja2.Template] = self.env.get_template("gendef.cc.j2")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> None:
|
||||
func: tvm_ffi.Function = self.compile()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def fns_hash(self) -> int:
|
||||
return hash(tuple(fn.cache_hash for fn in self.fns))
|
||||
|
||||
@cached_property
|
||||
def fullname(self) -> str:
|
||||
return f"triton.{self.name}"
|
||||
|
||||
@property
|
||||
def emit(self) -> str:
|
||||
return self.tpl.render(
|
||||
code=self.code, fns=self.fns, name=self.name, uniquename=self.uniquename
|
||||
)
|
||||
|
||||
@property
|
||||
def uniquename(self) -> str:
|
||||
return f"{self.name}_{self.fns_hash}"
|
||||
|
||||
def compile(self) -> tvm_ffi.Function:
|
||||
if func := tvm_ffi.get_global_func(self.uniquename, allow_missing=True):
|
||||
return func
|
||||
else:
|
||||
tvm_ffi.cpp.load_inline(
|
||||
self.name,
|
||||
cpp_sources=[self.emit],
|
||||
extra_cflags=self.extra_cflags,
|
||||
extra_cuda_cflags=self.extra_cuda_cflags,
|
||||
extra_ldflags=self.extra_ldflags,
|
||||
extra_include_paths=self.extra_include_paths,
|
||||
embed_cubin={
|
||||
f"triton_{fn.fnname}": fn.kernel
|
||||
for fn in self.fns
|
||||
if fn.kernel is not None
|
||||
},
|
||||
)
|
||||
return tvm_ffi.get_global_func(self.uniquename)
|
||||
|
||||
|
||||
def wrap(
|
||||
fns: List[TVMFFIJITFunction],
|
||||
code: Union[str, Path, TextIOWrapper],
|
||||
extra_cflags: Optional[Sequence[str]] = None,
|
||||
extra_cuda_cflags: Optional[Sequence[str]] = None,
|
||||
extra_ldflags: Optional[Sequence[str]] = None,
|
||||
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
||||
) -> TVMFFIWrapperFunction:
|
||||
def decorate(fn: Union[str, Callable[..., Any]]) -> TVMFFIWrapperFunction:
|
||||
return TVMFFIWrapperFunction(
|
||||
fn if isinstance(fn, str) else fn.__name__,
|
||||
fns,
|
||||
code,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
include_paths() + (extra_include_paths or []),
|
||||
)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def torch_wrap(
|
||||
fns: List[TVMFFIJITFunction],
|
||||
code: Union[str, Path, TextIOWrapper],
|
||||
extra_cflags: Optional[Sequence[str]] = None,
|
||||
extra_cuda_cflags: Optional[Sequence[str]] = None,
|
||||
extra_ldflags: Optional[Sequence[str]] = None,
|
||||
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
||||
) -> TVMFFIWrapperFunction:
|
||||
cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home()
|
||||
return wrap(
|
||||
fns,
|
||||
code,
|
||||
extra_ldflags=[
|
||||
"-Wl,--no-as-needed",
|
||||
f"-L{cuda_home}/lib64",
|
||||
*map(
|
||||
lambda path: f"-L{path}",
|
||||
torch.utils.cpp_extension.library_paths(),
|
||||
),
|
||||
"-lcuda",
|
||||
"-lc10",
|
||||
"-ltorch",
|
||||
]
|
||||
+ (extra_ldflags or []),
|
||||
extra_cflags=extra_cflags,
|
||||
extra_cuda_cflags=extra_cuda_cflags,
|
||||
extra_include_paths=[
|
||||
f"{cuda_home}/include",
|
||||
*torch.utils.cpp_extension.include_paths(),
|
||||
]
|
||||
+ (extra_include_paths or []),
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
add_library(
|
||||
${TARGET_NAME}
|
||||
SHARED
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/exception.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/launch.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/type.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cc
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
${TARGET_NAME}
|
||||
PRIVATE ${PROJECT_SOURCE_DIR}/include
|
||||
)
|
||||
target_compile_options(
|
||||
${TARGET_NAME}
|
||||
PRIVATE
|
||||
$<$<CONFIG:Debug>:-O0 -g>
|
||||
$<$<CONFIG:Release>:-O3 -DNDEBUG>
|
||||
)
|
||||
target_link_libraries(
|
||||
${TARGET_NAME}
|
||||
PRIVATE
|
||||
CUDA::cudart
|
||||
CUDA::cuda_driver
|
||||
)
|
||||
tvm_ffi_configure_target(
|
||||
${TARGET_NAME}
|
||||
STUB_DIR "${CMAKE_SOURCE_DIR}/python"
|
||||
STUB_INIT ON
|
||||
)
|
||||
|
||||
install(
|
||||
TARGETS ${TARGET_NAME}
|
||||
LIBRARY DESTINATION .
|
||||
)
|
||||
tvm_ffi_install(${TARGET_NAME} DESTINATION .)
|
||||
@@ -1,42 +0,0 @@
|
||||
#include "exception.h"
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
CUDAException::CUDAException(CUresult code) : code_(code) {}
|
||||
|
||||
const char *CUDAException::what() const noexcept {
|
||||
const char *p = nullptr;
|
||||
cuGetErrorString(code_, &p);
|
||||
return p;
|
||||
}
|
||||
|
||||
NotImplementedException::NotImplementedException(std::string_view name)
|
||||
: message_("[NotImplementedException]: \"" + std::string(name) + "\"") {}
|
||||
|
||||
const char *NotImplementedException::what() const noexcept {
|
||||
return message_.c_str();
|
||||
}
|
||||
|
||||
UnmatchedArgumentException::UnmatchedArgumentException(std::string_view name,
|
||||
size_t len,
|
||||
size_t expect)
|
||||
: message_("[UnmatchedArgumentException]: argument \"" + std::string(name) +
|
||||
"\" has length " + std::to_string(len) + ", but expected " +
|
||||
std::to_string(expect)) {}
|
||||
|
||||
const char *UnmatchedArgumentException::what() const noexcept {
|
||||
return message_.c_str();
|
||||
}
|
||||
|
||||
UnknownTypeException::UnknownTypeException(Type type)
|
||||
: message_("[UnknownTypeException]: unknown type: \"" +
|
||||
std::string(TypeToString(type)) + "\"") {}
|
||||
|
||||
UnknownTypeException::UnknownTypeException(std::string_view type)
|
||||
: message_("[UnknownTypeException]: unknown type: \"" + std::string(type) +
|
||||
"\"") {}
|
||||
const char *UnknownTypeException::what() const noexcept {
|
||||
return message_.c_str();
|
||||
}
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
-126
@@ -1,126 +0,0 @@
|
||||
#include "launch.h"
|
||||
#include "macro.h"
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/base_details.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
TVMFFILauncherImplObj::TVMFFILauncherImplObj(
|
||||
const tvm::ffi::Array<Type> &signature, bool launchCooperativeGrid,
|
||||
bool launchAsync)
|
||||
: signature_(std::move(signature)),
|
||||
launchCooperativeGrid_(launchCooperativeGrid), launchAsync_(launchAsync) {
|
||||
}
|
||||
|
||||
void TVMFFILauncherImplObj::Launch(
|
||||
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
|
||||
uint64_t globalScratch, uint64_t profileScratch,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
|
||||
CUstream cStream = reinterpret_cast<CUstream>(stream);
|
||||
CUfunction cFunction = reinterpret_cast<CUfunction>(function);
|
||||
if (gridX * gridY * gridZ > 0) {
|
||||
CUlaunchAttribute launchAttr[4];
|
||||
CUlaunchConfig config;
|
||||
config.gridDimX = gridX * numCtas;
|
||||
config.gridDimY = gridY;
|
||||
config.gridDimZ = gridZ;
|
||||
static constexpr int32_t kThreadsPerWarp = 32;
|
||||
config.blockDimX = kThreadsPerWarp * numWarps;
|
||||
config.blockDimY = 1;
|
||||
config.blockDimZ = 1;
|
||||
config.sharedMemBytes = sharedMemory;
|
||||
config.hStream = cStream;
|
||||
config.attrs = launchAttr;
|
||||
int32_t numAttrs = 0;
|
||||
if (numCtas != 1) {
|
||||
CUlaunchAttribute clusterAttr;
|
||||
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
||||
clusterAttr.value.clusterDim.x = numCtas;
|
||||
clusterAttr.value.clusterDim.y = 1;
|
||||
clusterAttr.value.clusterDim.z = 1;
|
||||
launchAttr[numAttrs++] = clusterAttr;
|
||||
CUlaunchAttribute clusterSchedulingAttr;
|
||||
clusterSchedulingAttr.id =
|
||||
CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
||||
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference =
|
||||
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
||||
launchAttr[numAttrs++] = clusterSchedulingAttr;
|
||||
}
|
||||
config.numAttrs = numAttrs;
|
||||
if (numCtas == 16) {
|
||||
CUDA_CHECK(cuFuncSetAttribute(
|
||||
cFunction, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
||||
}
|
||||
const int32_t kernelArgNum = kernelArgs.size();
|
||||
void **params =
|
||||
reinterpret_cast<void **>(alloca(sizeof(void *) * (kernelArgNum + 2)));
|
||||
size_t j = 0;
|
||||
#ifndef NDEBUG
|
||||
if (kernelArgNum != signature_.size()) {
|
||||
throw UnmatchedArgumentException("kernelArgs", kernelArgNum,
|
||||
signature_.size());
|
||||
}
|
||||
#endif
|
||||
for (size_t i = 0; i < kernelArgNum; ++i) {
|
||||
tvm::ffi::Any value = kernelArgs[i];
|
||||
switch (signature_[i]) {
|
||||
#define CASE_STMT(type, str, ctype) \
|
||||
case Type::type: { \
|
||||
using cpptype = type_to_ctype_t<Type::type>; \
|
||||
params[j] = reinterpret_cast<void *>(alloca(sizeof(cpptype))); \
|
||||
*reinterpret_cast<cpptype *>(params[j]) = value.cast<cpptype>(); \
|
||||
++j; \
|
||||
break; \
|
||||
}
|
||||
TYPE_TABLE_NATIVE(CASE_STMT)
|
||||
#undef CASE_STMT
|
||||
case Type::PTR: {
|
||||
params[j] = reinterpret_cast<void *>(alloca(sizeof(void *)));
|
||||
*reinterpret_cast<void **>(params[j]) =
|
||||
value.cast<tvm::ffi::TensorView>().data_ptr();
|
||||
++j;
|
||||
break;
|
||||
}
|
||||
case Type::CONSTEXPR: {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
#ifdef NDEBUG
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw NotImplementedException("CONSTEXPR for value casting");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: unwrap PyObject* from scratch pointers and assign to kernel args
|
||||
params[j] = &globalScratch;
|
||||
params[j + 1] = &profileScratch;
|
||||
CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array<Type> signature,
|
||||
bool launchCooperativeGrid,
|
||||
bool launchAsync)
|
||||
: tvm::ffi::ObjectRef(tvm::ffi::make_object<TVMFFILauncherImplObj>(
|
||||
std::move(signature), launchCooperativeGrid, launchAsync)) {}
|
||||
|
||||
void TVMFFILauncherImpl::Launch(
|
||||
int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream,
|
||||
uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory,
|
||||
uint64_t globalScratch, uint64_t profileScratch,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &kernelArgs) const {
|
||||
get()->Launch(gridX, gridY, gridZ, stream, function, numWarps, numCtas,
|
||||
sharedMemory, globalScratch, profileScratch, kernelArgs);
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::ObjectDef<TVMFFILauncherImplObj>()
|
||||
.def(refl::init<const tvm::ffi::Array<Type> &, bool, bool>())
|
||||
.def("launch", &TVMFFILauncherImplObj::Launch);
|
||||
}
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
-61
@@ -1,61 +0,0 @@
|
||||
#include "type.h"
|
||||
#include "exception.h"
|
||||
#include <cassert>
|
||||
#include <tvm/ffi/optional.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
const char *TypeToString(Type type) {
|
||||
switch (type) {
|
||||
#define CASE_ENUM(type, str, ctype) \
|
||||
case Type::type: \
|
||||
return str;
|
||||
TYPE_TABLE(CASE_ENUM)
|
||||
#undef CASE_ENUM
|
||||
default:
|
||||
throw UnknownTypeException(type);
|
||||
}
|
||||
}
|
||||
|
||||
tvm::ffi::Optional<Type> StringToType(const tvm::ffi::String &name) {
|
||||
if (name.starts_with("*")) {
|
||||
return Type::PTR;
|
||||
}
|
||||
if (name == "constexpr") {
|
||||
return Type::CONSTEXPR;
|
||||
}
|
||||
#define IF_ENUM(type, str, ctype) \
|
||||
if (name == str) { \
|
||||
return Type::type; \
|
||||
}
|
||||
TYPE_TABLE(IF_ENUM)
|
||||
#undef IF_ENUM
|
||||
if (name.starts_with("tensordesc") || name == "nvTmaDesc") {
|
||||
throw NotImplementedException(
|
||||
"tensordesc and nvTmaDesc are not supported in triton-tvm-ffi yet.");
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const char *TypeToCType(Type type) {
|
||||
switch (type) {
|
||||
#define CASE_ENUM(type, str, ctype) \
|
||||
case Type::type: \
|
||||
return #ctype;
|
||||
TYPE_TABLE(CASE_ENUM)
|
||||
#undef CASE_ENUM
|
||||
default:
|
||||
throw UnknownTypeException(type);
|
||||
}
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef()
|
||||
.def("triton_tvm_ffi.type_to_string", TypeToString)
|
||||
.def("triton_tvm_ffi.string_to_type", StringToType)
|
||||
.def("triton_tvm_ffi.type_to_ctype", TypeToCType);
|
||||
}
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
-113
@@ -1,113 +0,0 @@
|
||||
#include "macro.h"
|
||||
#include <cuda.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/reflection/registry.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
#ifndef NDEBUG
|
||||
#include <cassert>
|
||||
#endif
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
tvm::ffi::Map<tvm::ffi::String, int32_t> GetDeviceProperties(int device_id) {
|
||||
tvm::ffi::cuda_api::DeviceHandle device;
|
||||
CUDA_CHECK(cuDeviceGet(&device, device_id));
|
||||
int maxSharedMem = 0;
|
||||
int maxNumRegs = 0;
|
||||
int multiprocessorCount = 0;
|
||||
int warpSize = 0;
|
||||
int smClockRate = 0;
|
||||
int memClockRate = 0;
|
||||
int memBusWidth = 0;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&maxSharedMem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&maxNumRegs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&multiprocessorCount, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
||||
CUDA_CHECK(
|
||||
cuDeviceGetAttribute(&warpSize, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&smClockRate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE,
|
||||
device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&memClockRate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&memBusWidth, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
||||
return {{"max_shared_mem", maxSharedMem},
|
||||
{"max_num_regs", maxNumRegs},
|
||||
{"multiprocessor_count", multiprocessorCount},
|
||||
{"warpSize", warpSize},
|
||||
{"sm_clock_rate", smClockRate},
|
||||
{"mem_clock_rate", memClockRate},
|
||||
{"mem_bus_width", memBusWidth}};
|
||||
}
|
||||
|
||||
tvm::ffi::Tuple<uint64_t, uint64_t, int32_t, int32_t, int32_t>
|
||||
LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data,
|
||||
int32_t shared, CUdevice device) {
|
||||
CUcontext pctx;
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
int32_t nRegs = 0;
|
||||
int32_t nSpills = 0;
|
||||
int32_t nMaxThreads = 0;
|
||||
int32_t sharedOptin = 0;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
CUDA_CHECK(cuModuleLoadData(&mod, data.data()));
|
||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name.data()));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&nRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK(
|
||||
cuFuncGetAttribute(&nSpills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&nMaxThreads,
|
||||
CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&sharedOptin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
static constexpr int64_t kExpectedMaxDynamicSharedMemory = 49152;
|
||||
if (shared > kExpectedMaxDynamicSharedMemory &&
|
||||
sharedOptin > kExpectedMaxDynamicSharedMemory) {
|
||||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int32_t sharedTotal = 0, sharedStatic = 0;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&sharedTotal, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
||||
device));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&sharedStatic,
|
||||
CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(
|
||||
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
sharedOptin - sharedStatic));
|
||||
}
|
||||
return tvm::ffi::Tuple<uint64_t, uint64_t, int32_t, int32_t, int32_t>{
|
||||
mod, fun, nRegs, nSpills, nMaxThreads};
|
||||
}
|
||||
|
||||
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||
namespace refl = tvm::ffi::reflection;
|
||||
refl::GlobalDef()
|
||||
.def_packed("triton_tvm_ffi.utils.build_signature_metadata",
|
||||
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
throw NotImplementedException("build_signature_metadata");
|
||||
})
|
||||
.def_packed("triton_tvm_ffi.utils.cuOccupancyMaxActiveClusters",
|
||||
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
throw NotImplementedException(
|
||||
"cuOccupancyMaxActiveClusters");
|
||||
})
|
||||
.def_packed("triton_tvm_ffi.utils.fill_tma_descriptor",
|
||||
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
throw NotImplementedException("fill_tma_descriptor");
|
||||
})
|
||||
.def_packed("triton_tvm_ffi.utils.set_printf_fifo_size",
|
||||
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
|
||||
throw NotImplementedException("set_printf_fifo_size");
|
||||
})
|
||||
.def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties)
|
||||
.def("triton_tvm_ffi.utils.load_binary", LoadBinary);
|
||||
}
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
Reference in New Issue
Block a user