support softmax

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-04 16:38:33 +08:00
parent 192dc95ac0
commit b7bf598fde
4 changed files with 125 additions and 5 deletions

View File

@@ -4,7 +4,6 @@ import time
import torch
import triton
import triton.language as tl
import triton_tvm_ffi
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@@ -29,9 +28,6 @@ def add_kernel(
tl.store(output_ptr + offsets, output, mask=mask)
add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel)
def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output: torch.Tensor = torch.empty_like(x)
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
@@ -43,7 +39,7 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@triton_tvm_ffi.torch_wrap(
[add_kernel_tvm_ffi],
[add_kernel],
Path(__file__).parent / "add.cc",
)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...