mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
@@ -4,7 +4,6 @@ import time
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import triton_tvm_ffi
|
||||
|
||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||
@@ -29,9 +28,6 @@ def add_kernel(
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
|
||||
|
||||
|
||||
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output: torch.Tensor = torch.empty_like(x)
|
||||
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||
@@ -43,7 +39,7 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@triton_tvm_ffi.torch_wrap(
|
||||
[add_kernel_tvm_ffi],
|
||||
[add_kernel],
|
||||
Path(__file__).parent / "add.cc",
|
||||
)
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
Reference in New Issue
Block a user