mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
enable optional for numwarps and numstages
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
@@ -17,13 +17,14 @@ class TVMFFIJITFunction(object):
|
||||
self.fn: Final[JITFunction] = fn
|
||||
self.ctypes: Optional[List[Optional[str]]] = None
|
||||
self.kernel: Optional[bytes] = None
|
||||
self.num_warps: Optional[int] = None
|
||||
|
||||
@tvm_ffi.register_global_func(self.fullname)
|
||||
def _(
|
||||
grid: Tuple[int, int, int],
|
||||
_: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
num_warps: Optional[int],
|
||||
num_stages: Optional[int],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -31,9 +32,12 @@ class TVMFFIJITFunction(object):
|
||||
kwargs: Dict[str, Any] = {
|
||||
k: self.canonicalize(v) for k, v in kwargs.items()
|
||||
}
|
||||
kernel: CompiledKernel = self.fn[grid](
|
||||
*args, **kwargs, num_warps=num_warps, num_stages=num_stages
|
||||
)
|
||||
if num_warps is not None:
|
||||
kwargs["num_warps"] = num_warps
|
||||
if num_stages is not None:
|
||||
kwargs["num_stages"] = num_stages
|
||||
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
|
||||
self.num_warps, _, _ = kernel.packed_metadata
|
||||
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
||||
self.kernel = kernel.kernel
|
||||
return kernel
|
||||
|
||||
Reference in New Issue
Block a user