mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
include header files by c/cpp instead of jinja
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .jit import jit
|
||||
from .utils import include_paths
|
||||
from .wrap import torch_wrap, wrap
|
||||
|
||||
__all__ = ["jit", "torch_wrap", "wrap"]
|
||||
__all__ = ["include_paths", "jit", "torch_wrap", "wrap"]
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
#include <optional>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
|
||||
{% include "grid.h" %}
|
||||
#include "triton_tvm_ffi/grid.h"
|
||||
|
||||
#define {{ name | upper }}_NAME "{{ uniquename }}"
|
||||
{% for fn in fns %}
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
#ifndef TRITON_TVM_FFI_GRID_H
|
||||
#define TRITON_TVM_FFI_GRID_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/extra/cuda/base.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
template <typename T>
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim(const T &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta);
|
||||
|
||||
template <>
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
|
||||
const tvm::ffi::Tuple<int32_t, int32_t, int32_t> &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &) {
|
||||
return grid;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim<tvm::ffi::Function>(
|
||||
const tvm::ffi::Function &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta) {
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> tuple =
|
||||
grid(meta).cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
|
||||
return MakeGridDim(tuple, meta);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,8 +1,15 @@
|
||||
from typing import Optional
|
||||
import importlib.resources
|
||||
from importlib.resources.abc import Traversable
|
||||
from typing import List, Optional
|
||||
|
||||
from triton.backends.nvidia.driver import ty_to_cpp
|
||||
|
||||
|
||||
def include_paths() -> List[str]:
|
||||
pkg_path: Traversable = importlib.resources.files("triton_tvm_ffi")
|
||||
return [str(pkg_path / "include"), str(pkg_path / "../../include")]
|
||||
|
||||
|
||||
def type_canonicalize(ty: str) -> Optional[str]:
|
||||
if ty == "constexpr":
|
||||
return None
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch.utils.cpp_extension
|
||||
import tvm_ffi
|
||||
|
||||
from .jit import TVMFFIJITFunction
|
||||
from .utils import include_paths
|
||||
|
||||
|
||||
class TVMFFIWrapperFunction(object):
|
||||
@@ -103,7 +104,7 @@ def wrap(
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
include_paths() + (extra_include_paths or []),
|
||||
)
|
||||
|
||||
return decorate
|
||||
|
||||
Reference in New Issue
Block a user