diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..536b718 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.18) + +project(${SKBUILD_PROJECT_NAME}) + +install( + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include + DESTINATION ${CMAKE_INSTALL_PREFIX}/triton_tvm_ffi +) diff --git a/python/triton_tvm_ffi/templates/grid.h b/include/triton_tvm_ffi/grid.h similarity index 93% rename from python/triton_tvm_ffi/templates/grid.h rename to include/triton_tvm_ffi/grid.h index ea3ad13..3480158 100644 --- a/python/triton_tvm_ffi/templates/grid.h +++ b/include/triton_tvm_ffi/grid.h @@ -1,5 +1,5 @@ -#ifndef TRITON_TVM_FFI_GRID_H -#define TRITON_TVM_FFI_GRID_H +#ifndef TRITON_TVM_FFI_GRID_H_ +#define TRITON_TVM_FFI_GRID_H_ #include #include diff --git a/main.py b/main.py deleted file mode 100644 index d20d5d4..0000000 --- a/main.py +++ /dev/null @@ -1,6 +0,0 @@ -def main(): - print("Hello from triton-tvm-ffi!") - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index b0fb141..e7c48e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "triton-tvm-ffi" version = "0.1.0" -description = "Add your description here" +description = "A Python package for the FFI bindings of Triton TVM." readme = "README.md" dependencies = [ "apache-tvm-ffi", @@ -9,8 +9,8 @@ dependencies = [ ] [build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" +requires = ["scikit-build-core"] +build-backend = "scikit_build_core.build" [tool.setuptools] packages = ["triton_tvm_ffi"] diff --git a/python/triton_tvm_ffi/__init__.py b/python/triton_tvm_ffi/__init__.py index 2200bf7..060caef 100644 --- a/python/triton_tvm_ffi/__init__.py +++ b/python/triton_tvm_ffi/__init__.py @@ -1,4 +1,5 @@ from .jit import jit +from .utils import include_paths from .wrap import torch_wrap, wrap -__all__ = ["jit", "torch_wrap", "wrap"] +__all__ = ["include_paths", "jit", "torch_wrap", "wrap"] diff --git a/python/triton_tvm_ffi/templates/gendef.cc.j2 b/python/triton_tvm_ffi/templates/gendef.cc.j2 index cedda4b..09c4e6e 100644 --- a/python/triton_tvm_ffi/templates/gendef.cc.j2 +++ b/python/triton_tvm_ffi/templates/gendef.cc.j2 @@ -3,8 +3,7 @@ #include #include #include - -{% include "grid.h" %} +#include "triton_tvm_ffi/grid.h" #define {{ name | upper }}_NAME "{{ uniquename }}" {% for fn in fns %} diff --git a/python/triton_tvm_ffi/utils.py b/python/triton_tvm_ffi/utils.py index 3874f2a..694b438 100644 --- a/python/triton_tvm_ffi/utils.py +++ b/python/triton_tvm_ffi/utils.py @@ -1,8 +1,15 @@ -from typing import Optional +import importlib.resources +from importlib.resources.abc import Traversable +from typing import List, Optional from triton.backends.nvidia.driver import ty_to_cpp +def include_paths() -> List[str]: + pkg_path: Traversable = importlib.resources.files("triton_tvm_ffi") + return [str(pkg_path / "include"), str(pkg_path / "../../include")] + + def type_canonicalize(ty: str) -> Optional[str]: if ty == "constexpr": return None diff --git a/python/triton_tvm_ffi/wrap.py b/python/triton_tvm_ffi/wrap.py index 2908991..2e3c68c 100644 --- a/python/triton_tvm_ffi/wrap.py +++ b/python/triton_tvm_ffi/wrap.py @@ -8,6 +8,7 @@ import torch.utils.cpp_extension import tvm_ffi from .jit import TVMFFIJITFunction +from .utils import include_paths class TVMFFIWrapperFunction(object): @@ -103,7 +104,7 @@ def wrap( extra_cflags, extra_cuda_cflags, extra_ldflags, - extra_include_paths, + include_paths() + (extra_include_paths or []), ) return decorate