Files
triton-tvm-ffi/examples/add/add.py
Jinjie Liu b7bf598fde support softmax
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
2026-02-04 16:38:33 +08:00

75 lines
2.0 KiB
Python

from pathlib import Path
import time
import torch
import triton
import triton.language as tl
import triton_tvm_ffi
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@triton_tvm_ffi.jit
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output: torch.Tensor = torch.empty_like(x)
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements: int = output.numel()
BLOCK_SIZE: int = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE)
return output
@triton_tvm_ffi.torch_wrap(
[add_kernel],
Path(__file__).parent / "add.cc",
)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
if __name__ == "__main__":
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x + y
output_triton = add_triton(x, y)
output_tvm_ffi = add(x, y)
assert torch.allclose(output_torch, output_triton)
assert torch.allclose(output_torch, output_tvm_ffi)
output_tvm_ffi = add(x, y)
assert torch.allclose(output_torch, output_tvm_ffi)
round = 1000
cp0 = time.perf_counter_ns()
for _ in range(round):
x + y
cp1 = time.perf_counter_ns()
for _ in range(round):
add_triton(x, y)
cp2 = time.perf_counter_ns()
for _ in range(round):
add(x, y)
cp3 = time.perf_counter_ns()
print(
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
)