support softmax

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-04 16:38:33 +08:00
parent 192dc95ac0
commit b7bf598fde
4 changed files with 125 additions and 5 deletions

1
.gitignore vendored
View File

@@ -11,6 +11,7 @@ wheels/
.vscode/
.cache
.clangd
.python-version
uv.lock

View File

@@ -4,7 +4,6 @@ import time
import torch
import triton
import triton.language as tl
import triton_tvm_ffi
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@@ -29,9 +28,6 @@ def add_kernel(
tl.store(output_ptr + offsets, output, mask=mask)
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output: torch.Tensor = torch.empty_like(x)
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
@@ -43,7 +39,7 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@triton_tvm_ffi.torch_wrap(
[add_kernel_tvm_ffi],
[add_kernel],
Path(__file__).parent / "add.cc",
)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...

View File

@@ -0,0 +1,35 @@
#include <ATen/DLConvertor.h>
#include <ATen/dlpack.h>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/tvm_ffi.h>
#ifndef SOFTMAX_KERNEL_STUB
#define SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, output, input, \
inputStride, outputStride, nRows, nCols, \
BLOCK_SIZE)
#endif
#ifndef SOFTMAX_NAME
#define SOFTMAX_NAME ""
#endif
tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) {
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
at::Tensor ytorch = at::empty_like(xtorch);
uint32_t nRows = xtorch.size(0), nCols = xtorch.size(1), numWarps = 8,
numStages = 4, xStride = xtorch.stride(0),
yStride = ytorch.stride(0),
BLOCK_SIZE = 1u << (32 - __builtin_clz(nCols - 1));
tvm::ffi::Tensor y = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ytorch));
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{nRows / 1024, 1, 1};
DLDevice device = x.device();
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, y, x, xStride, yStride,
nRows, nCols, BLOCK_SIZE);
return y;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(SOFTMAX_NAME, Softmax);
}

View File

@@ -0,0 +1,88 @@
from pathlib import Path
import time
import torch
import triton
import triton.language as tl
import triton_tvm_ffi
@triton_tvm_ffi.jit
@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step):
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
def softmax_triton(x):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = 8
num_stages = 4
y = torch.empty_like(x)
softmax_kernel[(n_rows, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
BLOCK_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)
return y
@triton_tvm_ffi.torch_wrap(
[softmax_kernel],
Path(__file__).parent / "softmax.cc",
)
def softmax(x: torch.Tensor) -> torch.Tensor: ...
if __name__ == "__main__":
x = torch.randn(1823, 781, device="cuda")
y_torch = torch.softmax(x, axis=1)
y_triton = softmax_triton(x)
y_tvm_ffi = softmax(x)
assert torch.allclose(y_torch, y_triton), (y_torch, y_triton)
assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi)
y_tvm_ffi = softmax(x)
assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi)
round = 1000
cp0 = time.perf_counter_ns()
for _ in range(round):
torch.softmax(x, axis=1)
cp1 = time.perf_counter_ns()
for _ in range(round):
softmax_triton(x)
cp2 = time.perf_counter_ns()
for _ in range(round):
softmax(x)
cp3 = time.perf_counter_ns()
print(
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
)