mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
supports decorator for jit and wrapper
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from functools import cached_property
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import Final, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch.utils.cpp_extension
|
||||
import tvm_ffi
|
||||
@@ -88,7 +88,7 @@ class TVMFFIWrapperFunction(object):
|
||||
if fn.kernel is not None
|
||||
},
|
||||
)
|
||||
return tvm_ffi.get_global_func(self.uniquename, allow_missing=True)
|
||||
return tvm_ffi.get_global_func(self.uniquename)
|
||||
|
||||
@staticmethod
|
||||
def gendef(fn: TVMFFIJITFunction) -> str:
|
||||
@@ -127,7 +127,6 @@ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __bloc
|
||||
|
||||
|
||||
def wrap(
|
||||
name: str,
|
||||
fns: List[TVMFFIJITFunction],
|
||||
code: Union[str, Path, TextIOWrapper],
|
||||
extra_cflags: Optional[Sequence[str]] = None,
|
||||
@@ -135,19 +134,21 @@ def wrap(
|
||||
extra_ldflags: Optional[Sequence[str]] = None,
|
||||
extra_include_paths: Optional[Sequence[Union[str, Path]]] = None,
|
||||
) -> TVMFFIWrapperFunction:
|
||||
return TVMFFIWrapperFunction(
|
||||
name,
|
||||
fns,
|
||||
code,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
)
|
||||
def decorate(fn: Union[str, Callable[..., Any]]) -> TVMFFIWrapperFunction:
|
||||
return TVMFFIWrapperFunction(
|
||||
fn if isinstance(fn, str) else fn.__name__,
|
||||
fns,
|
||||
code,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def torch_wrap(
|
||||
name: str,
|
||||
fns: List[TVMFFIJITFunction],
|
||||
code: Union[str, Path, TextIOWrapper],
|
||||
extra_cflags: Optional[Sequence[str]] = None,
|
||||
@@ -157,7 +158,6 @@ def torch_wrap(
|
||||
) -> TVMFFIWrapperFunction:
|
||||
cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home()
|
||||
return wrap(
|
||||
name,
|
||||
fns,
|
||||
code,
|
||||
extra_ldflags=[
|
||||
|
||||
Reference in New Issue
Block a user