include header files by c/cpp instead of jinja

Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
jinjieliu
2026-02-07 17:16:49 +08:00
parent 6a19a6b06d
commit 24237a6313
8 changed files with 26 additions and 16 deletions

8
CMakeLists.txt Normal file
View File

@@ -0,0 +1,8 @@
cmake_minimum_required(VERSION 3.18)
project(${SKBUILD_PROJECT_NAME})
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include
DESTINATION ${CMAKE_INSTALL_PREFIX}/triton_tvm_ffi
)

View File

@@ -1,5 +1,5 @@
#ifndef TRITON_TVM_FFI_GRID_H
#define TRITON_TVM_FFI_GRID_H
#ifndef TRITON_TVM_FFI_GRID_H_
#define TRITON_TVM_FFI_GRID_H_
#include <cstdint>
#include <tvm/ffi/extra/cuda/base.h>

View File

@@ -1,6 +0,0 @@
def main():
print("Hello from triton-tvm-ffi!")
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,7 @@
[project]
name = "triton-tvm-ffi"
version = "0.1.0"
description = "Add your description here"
description = "A Python package for the FFI bindings of Triton TVM."
readme = "README.md"
dependencies = [
"apache-tvm-ffi",
@@ -9,8 +9,8 @@ dependencies = [
]
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
requires = ["scikit-build-core"]
build-backend = "scikit_build_core.build"
[tool.setuptools]
packages = ["triton_tvm_ffi"]

View File

@@ -1,4 +1,5 @@
from .jit import jit
from .utils import include_paths
from .wrap import torch_wrap, wrap
__all__ = ["jit", "torch_wrap", "wrap"]
__all__ = ["include_paths", "jit", "torch_wrap", "wrap"]

View File

@@ -3,8 +3,7 @@
#include <optional>
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
#include <tvm/ffi/function.h>
{% include "grid.h" %}
#include "triton_tvm_ffi/grid.h"
#define {{ name | upper }}_NAME "{{ uniquename }}"
{% for fn in fns %}

View File

@@ -1,8 +1,15 @@
from typing import Optional
import importlib.resources
from importlib.resources.abc import Traversable
from typing import List, Optional
from triton.backends.nvidia.driver import ty_to_cpp
def include_paths() -> List[str]:
pkg_path: Traversable = importlib.resources.files("triton_tvm_ffi")
return [str(pkg_path / "include"), str(pkg_path / "../../include")]
def type_canonicalize(ty: str) -> Optional[str]:
if ty == "constexpr":
return None

View File

@@ -8,6 +8,7 @@ import torch.utils.cpp_extension
import tvm_ffi
from .jit import TVMFFIJITFunction
from .utils import include_paths
class TVMFFIWrapperFunction(object):
@@ -103,7 +104,7 @@ def wrap(
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
include_paths() + (extra_include_paths or []),
)
return decorate