diff --git a/.gitignore b/.gitignore index 4e2888b..5a9d978 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ wheels/ .vscode/ +.cache .clangd .python-version uv.lock diff --git a/examples/add/add.py b/examples/add/add.py index 8977432..050ee5d 100644 --- a/examples/add/add.py +++ b/examples/add/add.py @@ -4,7 +4,6 @@ import time import torch import triton import triton.language as tl - import triton_tvm_ffi DEVICE = triton.runtime.driver.active.get_active_torch_device() @@ -29,9 +28,6 @@ def add_kernel( tl.store(output_ptr + offsets, output, mask=mask) -add_kernel_tvm_ffi = triton_tvm_ffi.jit(add_kernel) - - def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output: torch.Tensor = torch.empty_like(x) assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE @@ -43,7 +39,7 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @triton_tvm_ffi.torch_wrap( - [add_kernel_tvm_ffi], + [add_kernel], Path(__file__).parent / "add.cc", ) def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... diff --git a/examples/softmax/softmax.cc b/examples/softmax/softmax.cc new file mode 100644 index 0000000..a060572 --- /dev/null +++ b/examples/softmax/softmax.cc @@ -0,0 +1,35 @@ +#include +#include +#include +#include + +#ifndef SOFTMAX_KERNEL_STUB +#define SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, output, input, \ + inputStride, outputStride, nRows, nCols, \ + BLOCK_SIZE) +#endif + +#ifndef SOFTMAX_NAME +#define SOFTMAX_NAME "" +#endif + +tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) { + at::Tensor xtorch = at::fromDLPack(x.ToDLPack()); + at::Tensor ytorch = at::empty_like(xtorch); + uint32_t nRows = xtorch.size(0), nCols = xtorch.size(1), numWarps = 8, + numStages = 4, xStride = xtorch.stride(0), + yStride = ytorch.stride(0), + BLOCK_SIZE = 1u << (32 - __builtin_clz(nCols - 1)); + tvm::ffi::Tensor y = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ytorch)); + tvm::ffi::Tuple grid{nRows / 1024, 1, 1}; + DLDevice device = x.device(); + void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); + SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, y, x, xStride, yStride, + nRows, nCols, BLOCK_SIZE); + return y; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def(SOFTMAX_NAME, Softmax); +} diff --git a/examples/softmax/softmax.py b/examples/softmax/softmax.py new file mode 100644 index 0000000..f03d146 --- /dev/null +++ b/examples/softmax/softmax.py @@ -0,0 +1,88 @@ +from pathlib import Path +import time + +import torch +import triton +import triton.language as tl +import triton_tvm_ffi + + +@triton_tvm_ffi.jit +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step): + row_start_ptr = input_ptr + row_idx * input_row_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + + +def softmax_triton(x): + n_rows, n_cols = x.shape + BLOCK_SIZE = triton.next_power_of_2(n_cols) + num_warps = 8 + num_stages = 4 + y = torch.empty_like(x) + softmax_kernel[(n_rows, 1, 1)]( + y, + x, + x.stride(0), + y.stride(0), + n_rows, + n_cols, + BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + return y + + +@triton_tvm_ffi.torch_wrap( + [softmax_kernel], + Path(__file__).parent / "softmax.cc", +) +def softmax(x: torch.Tensor) -> torch.Tensor: ... + + +if __name__ == "__main__": + x = torch.randn(1823, 781, device="cuda") + y_torch = torch.softmax(x, axis=1) + y_triton = softmax_triton(x) + y_tvm_ffi = softmax(x) + assert torch.allclose(y_torch, y_triton), (y_torch, y_triton) + assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi) + y_tvm_ffi = softmax(x) + assert torch.allclose(y_torch, y_tvm_ffi), (y_torch, y_tvm_ffi) + + round = 1000 + cp0 = time.perf_counter_ns() + for _ in range(round): + torch.softmax(x, axis=1) + cp1 = time.perf_counter_ns() + for _ in range(round): + softmax_triton(x) + cp2 = time.perf_counter_ns() + for _ in range(round): + softmax(x) + cp3 = time.perf_counter_ns() + print( + f"PyTorch: {(cp1 - cp0) / round * 1e-6:.3f} ms\nTriton: {(cp2 - cp1) / round * 1e-6:.3f} ms\nTVM FFI: {(cp3 - cp2) / round * 1e-6:.3f} ms" + )