enable lambda function for grid descriptor

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-05 15:59:22 +08:00
parent 8b8aa6cb84
commit f6c7a48c1b
6 changed files with 104 additions and 57 deletions

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
from functools import cached_property
from typing import Any, Dict, Final, List, Optional, Tuple
import inspect
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union
import torch
from triton.compiler import CompiledKernel
@@ -15,13 +16,16 @@ class TVMFFIJITFunction(object):
def __init__(self, fn: JITFunction, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fn: Final[JITFunction] = fn
self.signature: Optional[List[str]] = None
self.ctypes: Optional[List[Optional[str]]] = None
self.kernel: Optional[bytes] = None
self.num_warps: Optional[int] = None
@tvm_ffi.register_global_func(self.fullname)
def _(
grid: Tuple[int, int, int],
grid: Union[
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
],
_: int,
num_warps: Optional[int],
num_stages: Optional[int],
@@ -38,11 +42,17 @@ class TVMFFIJITFunction(object):
kwargs["num_stages"] = num_stages
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
self.num_warps, _, _ = kernel.packed_metadata
self.signature = [*inspect.signature(self.fn.fn).parameters.keys()]
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
self.kernel = kernel.kernel
return kernel
def __getitem__(self, grid: Tuple[int, int, int]):
def __getitem__(
self,
grid: Union[
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
],
):
return self.fn[grid]
@property