mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
verify tvm-ffi cpp wrapper on vector-add.py
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
86
examples/add/add.py
Normal file
86
examples/add/add.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import triton_tvm_ffi
|
||||
|
||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||
|
||||
|
||||
# Support decorators here like
|
||||
# @triton_tvm_ffi.jit
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
|
||||
|
||||
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output: torch.Tensor = torch.empty_like(x)
|
||||
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||
n_elements: int = output.numel()
|
||||
BLOCK_SIZE: int = 1024
|
||||
grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1)
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE)
|
||||
return output
|
||||
|
||||
|
||||
# TODO: it woule be more user-friendly to define wrapper functions like below
|
||||
# @triton_tvm_ffi.torch_wrap(
|
||||
# "add",
|
||||
# [add_kernel_tvm_ffi],
|
||||
# Path(__file__).parent / "add.cc",
|
||||
# )
|
||||
# def add_tvm_ffi(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
# ...
|
||||
add_tvm_ffi = triton_tvm_ffi.torch_wrap(
|
||||
"add",
|
||||
[add_kernel_tvm_ffi],
|
||||
Path(__file__).parent / "add.cc",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device=DEVICE)
|
||||
y = torch.rand(size, device=DEVICE)
|
||||
output_torch = x + y
|
||||
output_triton = add(x, y)
|
||||
output_tvm_ffi = add_tvm_ffi(x, y)
|
||||
assert torch.allclose(output_torch, output_triton)
|
||||
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||
output_tvm_ffi = add_tvm_ffi(x, y)
|
||||
assert torch.allclose(output_torch, output_tvm_ffi)
|
||||
|
||||
round = 1000
|
||||
cp0 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
x + y
|
||||
cp1 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
add(x, y)
|
||||
cp2 = time.perf_counter_ns()
|
||||
for _ in range(round):
|
||||
add_tvm_ffi(x, y)
|
||||
cp3 = time.perf_counter_ns()
|
||||
print(
|
||||
f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms"
|
||||
)
|
||||
Reference in New Issue
Block a user