mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
1
.gitignore
vendored
1
.gitignore
vendored
@@ -11,6 +11,7 @@ wheels/
|
|||||||
|
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
||||||
|
.cache
|
||||||
.clangd
|
.clangd
|
||||||
.python-version
|
.python-version
|
||||||
uv.lock
|
uv.lock
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
import triton_tvm_ffi
|
import triton_tvm_ffi
|
||||||
|
|
||||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||||
@@ -29,9 +28,6 @@ def add_kernel(
|
|||||||
tl.store(output_ptr + offsets, output, mask=mask)
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
|
|
||||||
|
|
||||||
|
|
||||||
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
output: torch.Tensor = torch.empty_like(x)
|
output: torch.Tensor = torch.empty_like(x)
|
||||||
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||||
@@ -43,7 +39,7 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
@triton_tvm_ffi.torch_wrap(
|
@triton_tvm_ffi.torch_wrap(
|
||||||
[add_kernel_tvm_ffi],
|
[add_kernel],
|
||||||
Path(__file__).parent / "add.cc",
|
Path(__file__).parent / "add.cc",
|
||||||
)
|
)
|
||||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|||||||
35
examples/softmax/softmax.cc
Normal file
35
examples/softmax/softmax.cc
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#include <ATen/DLConvertor.h>
|
||||||
|
#include <ATen/dlpack.h>
|
||||||
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
|
#include <tvm/ffi/tvm_ffi.h>
|
||||||
|
|
||||||
|
#ifndef SOFTMAX_KERNEL_STUB
|
||||||
|
#define SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, output, input, \
|
||||||
|
inputStride, outputStride, nRows, nCols, \
|
||||||
|
BLOCK_SIZE)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef SOFTMAX_NAME
|
||||||
|
#define SOFTMAX_NAME ""
|
||||||
|
#endif
|
||||||
|
|
||||||
|
tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) {
|
||||||
|
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
|
||||||
|
at::Tensor ytorch = at::empty_like(xtorch);
|
||||||
|
uint32_t nRows = xtorch.size(0), nCols = xtorch.size(1), numWarps = 8,
|
||||||
|
numStages = 4, xStride = xtorch.stride(0),
|
||||||
|
yStride = ytorch.stride(0),
|
||||||
|
BLOCK_SIZE = 1u << (32 - __builtin_clz(nCols - 1));
|
||||||
|
tvm::ffi::Tensor y = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ytorch));
|
||||||
|
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{nRows / 1024, 1, 1};
|
||||||
|
DLDevice device = x.device();
|
||||||
|
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||||
|
SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, y, x, xStride, yStride,
|
||||||
|
nRows, nCols, BLOCK_SIZE);
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
TVM_FFI_STATIC_INIT_BLOCK() {
|
||||||
|
namespace refl = tvm::ffi::reflection;
|
||||||
|
refl::GlobalDef().def(SOFTMAX_NAME, Softmax);
|
||||||
|
}
|
||||||
88
examples/softmax/softmax.py
Normal file
88
examples/softmax/softmax.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
import triton_tvm_ffi
|
||||||
|
|
||||||
|
|
||||||
|
@triton_tvm_ffi.jit
|
||||||
|
@triton.jit
|
||||||
|
def softmax_kernel(
|
||||||
|
output_ptr,
|
||||||
|
input_ptr,
|
||||||
|
input_row_stride,
|
||||||
|
output_row_stride,
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
row_start = tl.program_id(0)
|
||||||
|
row_step = tl.num_programs(0)
|
||||||
|
for row_idx in tl.range(row_start, n_rows, row_step):
|
||||||
|
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
input_ptrs = row_start_ptr + col_offsets
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
|
||||||
|
row_minus_max = row - tl.max(row, axis=0)
|
||||||
|
numerator = tl.exp(row_minus_max)
|
||||||
|
denominator = tl.sum(numerator, axis=0)
|
||||||
|
softmax_output = numerator / denominator
|
||||||
|
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||||
|
output_ptrs = output_row_start_ptr + col_offsets
|
||||||
|
tl.store(output_ptrs, softmax_output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_triton(x):
|
||||||
|
n_rows, n_cols = x.shape
|
||||||
|
BLOCK_SIZE = triton.next_power_of_2(n_cols)
|
||||||
|
num_warps = 8
|
||||||
|
num_stages = 4
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
softmax_kernel[(n_rows, 1, 1)](
|
||||||
|
y,
|
||||||
|
x,
|
||||||
|
x.stride(0),
|
||||||
|
y.stride(0),
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
@triton_tvm_ffi.torch_wrap(
|
||||||
|
[softmax_kernel],
|
||||||
|
Path(__file__).parent / "softmax.cc",
|
||||||
|
)
|
||||||
|
def softmax(x: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
x = torch.randn(1823, 781, device="cuda")
|
||||||
|
y_torch = torch.softmax(x, axis=1)
|
||||||
|
y_triton = softmax_triton(x)
|
||||||
|
y_tvm_ffi = softmax(x)
|
||||||
|
assert torch.allclose(y_torch, y_triton), (y_torch, y_triton)
|
||||||
|
assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi)
|
||||||
|
y_tvm_ffi = softmax(x)
|
||||||
|
assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi)
|
||||||
|
|
||||||
|
round = 1000
|
||||||
|
cp0 = time.perf_counter_ns()
|
||||||
|
for _ in range(round):
|
||||||
|
torch.softmax(x, axis=1)
|
||||||
|
cp1 = time.perf_counter_ns()
|
||||||
|
for _ in range(round):
|
||||||
|
softmax_triton(x)
|
||||||
|
cp2 = time.perf_counter_ns()
|
||||||
|
for _ in range(round):
|
||||||
|
softmax(x)
|
||||||
|
cp3 = time.perf_counter_ns()
|
||||||
|
print(
|
||||||
|
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user