supports decorator for jit and wrapper

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-04 10:46:14 +08:00
parent 6e4c2d4a43
commit 192dc95ac0
2 changed files with 24 additions and 32 deletions

View File

@@ -1,7 +1,7 @@
from functools import cached_property
from io import TextIOWrapper
from pathlib import Path
from typing import Final, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
import torch.utils.cpp_extension
import tvm_ffi
@@ -88,7 +88,7 @@ class TVMFFIWrapperFunction(object):
if fn.kernel is not None
},
)
return tvm_ffi.get_global_func(self.uniquename, allow_missing=True)
return tvm_ffi.get_global_func(self.uniquename)
@staticmethod
def gendef(fn: TVMFFIJITFunction) -> str:
@@ -127,7 +127,6 @@ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __bloc
def wrap(
name: str,
fns: List[TVMFFIJITFunction],
code: Union[str, Path, TextIOWrapper],
extra_cflags: Optional[Sequence[str]] = None,
@@ -135,19 +134,21 @@ def wrap(
extra_ldflags: Optional[Sequence[str]] = None,
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
) -> TVMFFIWrapperFunction:
return TVMFFIWrapperFunction(
name,
fns,
code,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
)
def decorate(fn: Union[str, Callable[..., Any]]) -> TVMFFIWrapperFunction:
return TVMFFIWrapperFunction(
fn if isinstance(fn, str) else fn.__name__,
fns,
code,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
)
return decorate
def torch_wrap(
name: str,
fns: List[TVMFFIJITFunction],
code: Union[str, Path, TextIOWrapper],
extra_cflags: Optional[Sequence[str]] = None,
@@ -157,7 +158,6 @@ def torch_wrap(
) -> TVMFFIWrapperFunction:
cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home()
return wrap(
name,
fns,
code,
extra_ldflags=[