mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
include header files by c/cpp instead of jinja
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
8
CMakeLists.txt
Normal file
8
CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.18)
|
||||||
|
|
||||||
|
project(${SKBUILD_PROJECT_NAME})
|
||||||
|
|
||||||
|
install(
|
||||||
|
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||||
|
DESTINATION ${CMAKE_INSTALL_PREFIX}/triton_tvm_ffi
|
||||||
|
)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
#ifndef TRITON_TVM_FFI_GRID_H
|
#ifndef TRITON_TVM_FFI_GRID_H_
|
||||||
#define TRITON_TVM_FFI_GRID_H
|
#define TRITON_TVM_FFI_GRID_H_
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <tvm/ffi/extra/cuda/base.h>
|
#include <tvm/ffi/extra/cuda/base.h>
|
||||||
6
main.py
6
main.py
@@ -1,6 +0,0 @@
|
|||||||
def main():
|
|
||||||
print("Hello from triton-tvm-ffi!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "triton-tvm-ffi"
|
name = "triton-tvm-ffi"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "A Python package for the FFI bindings of Triton TVM."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"apache-tvm-ffi",
|
"apache-tvm-ffi",
|
||||||
@@ -9,8 +9,8 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools"]
|
requires = ["scikit-build-core"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "scikit_build_core.build"
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
packages = ["triton_tvm_ffi"]
|
packages = ["triton_tvm_ffi"]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .jit import jit
|
from .jit import jit
|
||||||
|
from .utils import include_paths
|
||||||
from .wrap import torch_wrap, wrap
|
from .wrap import torch_wrap, wrap
|
||||||
|
|
||||||
__all__ = ["jit", "torch_wrap", "wrap"]
|
__all__ = ["include_paths", "jit", "torch_wrap", "wrap"]
|
||||||
|
|||||||
@@ -3,8 +3,7 @@
|
|||||||
#include <optional>
|
#include <optional>
|
||||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
#include <tvm/ffi/function.h>
|
#include <tvm/ffi/function.h>
|
||||||
|
#include "triton_tvm_ffi/grid.h"
|
||||||
{% include "grid.h" %}
|
|
||||||
|
|
||||||
#define {{ name | upper }}_NAME "{{ uniquename }}"
|
#define {{ name | upper }}_NAME "{{ uniquename }}"
|
||||||
{% for fn in fns %}
|
{% for fn in fns %}
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
from typing import Optional
|
import importlib.resources
|
||||||
|
from importlib.resources.abc import Traversable
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from triton.backends.nvidia.driver import ty_to_cpp
|
from triton.backends.nvidia.driver import ty_to_cpp
|
||||||
|
|
||||||
|
|
||||||
|
def include_paths() -> List[str]:
|
||||||
|
pkg_path: Traversable = importlib.resources.files("triton_tvm_ffi")
|
||||||
|
return [str(pkg_path / "include"), str(pkg_path / "../../include")]
|
||||||
|
|
||||||
|
|
||||||
def type_canonicalize(ty: str) -> Optional[str]:
|
def type_canonicalize(ty: str) -> Optional[str]:
|
||||||
if ty == "constexpr":
|
if ty == "constexpr":
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import torch.utils.cpp_extension
|
|||||||
import tvm_ffi
|
import tvm_ffi
|
||||||
|
|
||||||
from .jit import TVMFFIJITFunction
|
from .jit import TVMFFIJITFunction
|
||||||
|
from .utils import include_paths
|
||||||
|
|
||||||
|
|
||||||
class TVMFFIWrapperFunction(object):
|
class TVMFFIWrapperFunction(object):
|
||||||
@@ -103,7 +104,7 @@ def wrap(
|
|||||||
extra_cflags,
|
extra_cflags,
|
||||||
extra_cuda_cflags,
|
extra_cuda_cflags,
|
||||||
extra_ldflags,
|
extra_ldflags,
|
||||||
extra_include_paths,
|
include_paths() + (extra_include_paths or []),
|
||||||
)
|
)
|
||||||
|
|
||||||
return decorate
|
return decorate
|
||||||
|
|||||||
Reference in New Issue
Block a user