mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
enable lambda function for grid descriptor
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, Final, List, Optional, Tuple
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from triton.compiler import CompiledKernel
|
||||
@@ -15,13 +16,16 @@ class TVMFFIJITFunction(object):
|
||||
def __init__(self, fn: JITFunction, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fn: Final[JITFunction] = fn
|
||||
self.signature: Optional[List[str]] = None
|
||||
self.ctypes: Optional[List[Optional[str]]] = None
|
||||
self.kernel: Optional[bytes] = None
|
||||
self.num_warps: Optional[int] = None
|
||||
|
||||
@tvm_ffi.register_global_func(self.fullname)
|
||||
def _(
|
||||
grid: Tuple[int, int, int],
|
||||
grid: Union[
|
||||
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||
],
|
||||
_: int,
|
||||
num_warps: Optional[int],
|
||||
num_stages: Optional[int],
|
||||
@@ -38,11 +42,17 @@ class TVMFFIJITFunction(object):
|
||||
kwargs["num_stages"] = num_stages
|
||||
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
|
||||
self.num_warps, _, _ = kernel.packed_metadata
|
||||
self.signature = [*inspect.signature(self.fn.fn).parameters.keys()]
|
||||
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
||||
self.kernel = kernel.kernel
|
||||
return kernel
|
||||
|
||||
def __getitem__(self, grid: Tuple[int, int, int]):
|
||||
def __getitem__(
|
||||
self,
|
||||
grid: Union[
|
||||
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||
],
|
||||
):
|
||||
return self.fn[grid]
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user