mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
enable lambda function for grid descriptor
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -33,8 +33,9 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||
n_elements: int = output.numel()
|
||||
BLOCK_SIZE: int = 1024
|
||||
grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1)
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE)
|
||||
add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1, 1)](
|
||||
x, y, output, n_elements, BLOCK_SIZE
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user