mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
enable lambda function for grid descriptor
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
#include "tvm/ffi/function.h"
|
||||||
#include <ATen/DLConvertor.h>
|
#include <ATen/DLConvertor.h>
|
||||||
#include <ATen/dlpack.h>
|
#include <ATen/dlpack.h>
|
||||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
@@ -15,10 +16,14 @@
|
|||||||
tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
|
tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
|
||||||
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
|
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
|
||||||
at::Tensor otorch = at::empty_like(xtorch);
|
at::Tensor otorch = at::empty_like(xtorch);
|
||||||
int64_t numel = otorch.numel();
|
int32_t numel = otorch.numel();
|
||||||
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
|
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
|
||||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{(numel + 1023) / 1024, 1, 1};
|
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
|
||||||
// TODO: check the performance loss after enabling `Optional`
|
[numel](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
|
||||||
|
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
|
||||||
|
const int32_t BLOCK_SIZE = meta["BLOCK_SIZE"].cast<int32_t>();
|
||||||
|
return tvm::ffi::Tuple((numel + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1);
|
||||||
|
});
|
||||||
tvm::ffi::Optional<int32_t> numWarps = std::nullopt, numStages = std::nullopt;
|
tvm::ffi::Optional<int32_t> numWarps = std::nullopt, numStages = std::nullopt;
|
||||||
DLDevice device = x.device();
|
DLDevice device = x.device();
|
||||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||||
|
|||||||
@@ -33,8 +33,9 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|||||||
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||||
n_elements: int = output.numel()
|
n_elements: int = output.numel()
|
||||||
BLOCK_SIZE: int = 1024
|
BLOCK_SIZE: int = 1024
|
||||||
grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1)
|
add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1, 1)](
|
||||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE)
|
x, y, output, n_elements, BLOCK_SIZE
|
||||||
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Dict, Final, List, Optional, Tuple
|
import inspect
|
||||||
|
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from triton.compiler import CompiledKernel
|
from triton.compiler import CompiledKernel
|
||||||
@@ -15,13 +16,16 @@ class TVMFFIJITFunction(object):
|
|||||||
def __init__(self, fn: JITFunction, *args, **kwargs) -> None:
|
def __init__(self, fn: JITFunction, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.fn: Final[JITFunction] = fn
|
self.fn: Final[JITFunction] = fn
|
||||||
|
self.signature: Optional[List[str]] = None
|
||||||
self.ctypes: Optional[List[Optional[str]]] = None
|
self.ctypes: Optional[List[Optional[str]]] = None
|
||||||
self.kernel: Optional[bytes] = None
|
self.kernel: Optional[bytes] = None
|
||||||
self.num_warps: Optional[int] = None
|
self.num_warps: Optional[int] = None
|
||||||
|
|
||||||
@tvm_ffi.register_global_func(self.fullname)
|
@tvm_ffi.register_global_func(self.fullname)
|
||||||
def _(
|
def _(
|
||||||
grid: Tuple[int, int, int],
|
grid: Union[
|
||||||
|
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||||
|
],
|
||||||
_: int,
|
_: int,
|
||||||
num_warps: Optional[int],
|
num_warps: Optional[int],
|
||||||
num_stages: Optional[int],
|
num_stages: Optional[int],
|
||||||
@@ -38,11 +42,17 @@ class TVMFFIJITFunction(object):
|
|||||||
kwargs["num_stages"] = num_stages
|
kwargs["num_stages"] = num_stages
|
||||||
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
|
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
|
||||||
self.num_warps, _, _ = kernel.packed_metadata
|
self.num_warps, _, _ = kernel.packed_metadata
|
||||||
|
self.signature = [*inspect.signature(self.fn.fn).parameters.keys()]
|
||||||
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
||||||
self.kernel = kernel.kernel
|
self.kernel = kernel.kernel
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
def __getitem__(self, grid: Tuple[int, int, int]):
|
def __getitem__(
|
||||||
|
self,
|
||||||
|
grid: Union[
|
||||||
|
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||||
|
],
|
||||||
|
):
|
||||||
return self.fn[grid]
|
return self.fn[grid]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
42
python/triton_tvm_ffi/templates/gendef.cc.j2
Normal file
42
python/triton_tvm_ffi/templates/gendef.cc.j2
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
#include <cuda.h>
|
||||||
|
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||||
|
#include <tvm/ffi/function.h>
|
||||||
|
|
||||||
|
{% include "grid.h" %}
|
||||||
|
|
||||||
|
#define {{ name | upper }}_NAME "{{ uniquename }}"
|
||||||
|
{% for fn in fns %}
|
||||||
|
{% if fn.ctypes is none %}
|
||||||
|
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
|
||||||
|
{% else %}
|
||||||
|
TVM_FFI_EMBED_CUBIN(triton_{{ fn.fnname }});
|
||||||
|
#define {{ fn.fnname | upper}}_STUB(__grid, __stream, __numWarps, __numStages{% for ctype in fn.ctypes %}, {{ "__arg" ~ loop.index0 }}{% endfor %}) do { \
|
||||||
|
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||||
|
{% for name in fn.signature %}
|
||||||
|
{ "{{ name }}", __arg{{ loop.index0 }} }, \
|
||||||
|
{% endfor %}
|
||||||
|
}; \
|
||||||
|
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{{ fn.fnname }}, "{{ fn.fnname }}"); \
|
||||||
|
tvm::ffi::dim3 __gridDim = MakeGridDim(__grid, __meta); \
|
||||||
|
tvm::ffi::dim3 __block({% if fn.num_warps != none %}{{ fn.num_warps }}{% else %}__numWarps{% endif %} * 32, 1, 1); \
|
||||||
|
void *dummy = nullptr
|
||||||
|
{%- for ctype in fn.ctypes -%}
|
||||||
|
{%- if ctype == "CUdeviceptr" -%}
|
||||||
|
, *__arg{{ loop.index0 }}_ptr=__arg{{ loop.index0 }}.data_ptr()
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}; \
|
||||||
|
void *__params[] = {
|
||||||
|
{%- for ctype in fn.ctypes -%}
|
||||||
|
{%- if ctype != none -%}
|
||||||
|
&__arg{{ loop.index0 }}
|
||||||
|
{%- if ctype == "CUdeviceptr" -%}
|
||||||
|
_ptr
|
||||||
|
{%- endif -%},
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}&dummy, &dummy }; \
|
||||||
|
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __gridDim, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \
|
||||||
|
} while (false)
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
{{ code }}
|
||||||
29
python/triton_tvm_ffi/templates/grid.h
Normal file
29
python/triton_tvm_ffi/templates/grid.h
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#ifndef TRITON_TVM_FFI_GRID_H
|
||||||
|
#define TRITON_TVM_FFI_GRID_H
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <tvm/ffi/extra/cuda/base.h>
|
||||||
|
#include <tvm/ffi/tvm_ffi.h>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline tvm::ffi::dim3
|
||||||
|
MakeGridDim(const T &grid,
|
||||||
|
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
|
||||||
|
const tvm::ffi::Tuple<int32_t, int32_t, int32_t> &grid,
|
||||||
|
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &) {
|
||||||
|
return tvm::ffi::dim3(grid.get<0>(), grid.get<1>(), grid.get<2>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Function>(
|
||||||
|
const tvm::ffi::Function &grid,
|
||||||
|
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta) {
|
||||||
|
tvm::ffi::Tuple<int32_t, int32_t, int32_t> tuple =
|
||||||
|
grid(meta).cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
|
||||||
|
return MakeGridDim(tuple, meta);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from io import TextIOWrapper
|
from io import TextIOWrapper
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
|
from typing import Any, Callable, Final, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import jinja2
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
import tvm_ffi
|
import tvm_ffi
|
||||||
|
|
||||||
@@ -38,6 +39,11 @@ class TVMFFIWrapperFunction(object):
|
|||||||
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
|
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
|
||||||
extra_include_paths
|
extra_include_paths
|
||||||
)
|
)
|
||||||
|
self.env: Final[jinja2.Environment] = jinja2.Environment(
|
||||||
|
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
|
||||||
|
trim_blocks=True,
|
||||||
|
)
|
||||||
|
self.tpl: Final[jinja2.Template] = self.env.get_template("gendef.cc.j2")
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs) -> None:
|
def __call__(self, *args, **kwargs) -> None:
|
||||||
func: tvm_ffi.Function = self.compile()
|
func: tvm_ffi.Function = self.compile()
|
||||||
@@ -53,19 +59,9 @@ class TVMFFIWrapperFunction(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def emit(self) -> str:
|
def emit(self) -> str:
|
||||||
defs: str = "\n".join(
|
return self.tpl.render(
|
||||||
[
|
code=self.code, fns=self.fns, name=self.name, uniquename=self.uniquename
|
||||||
"#include <cuda.h>",
|
|
||||||
"#include <tvm/ffi/extra/cuda/cubin_launcher.h>",
|
|
||||||
"#include <tvm/ffi/function.h>",
|
|
||||||
f'#define {self.name.upper()}_NAME "{self.uniquename}"',
|
|
||||||
*map(
|
|
||||||
self.gendef,
|
|
||||||
self.fns,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return f"{defs}\n{self.code}"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def uniquename(self) -> str:
|
def uniquename(self) -> str:
|
||||||
@@ -90,42 +86,6 @@ class TVMFFIWrapperFunction(object):
|
|||||||
)
|
)
|
||||||
return tvm_ffi.get_global_func(self.uniquename)
|
return tvm_ffi.get_global_func(self.uniquename)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gendef(fn: TVMFFIJITFunction) -> str:
|
|
||||||
if fn.ctypes is None:
|
|
||||||
return f'#define {fn.fnname.upper()}_STUB tvm::ffi::Function::GetGlobalRequired("{fn.fullname}")'
|
|
||||||
else:
|
|
||||||
ctype_arg_list: List[Tuple[str, str]] = [
|
|
||||||
(ctype, f"__arg{idx}") for idx, ctype in enumerate(fn.ctypes)
|
|
||||||
]
|
|
||||||
|
|
||||||
return """
|
|
||||||
TVM_FFI_EMBED_CUBIN(triton_{fnname});
|
|
||||||
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
|
|
||||||
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
|
|
||||||
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
|
|
||||||
tvm::ffi::dim3 __block({} * 32, 1, 1); \\
|
|
||||||
void *dummy = nullptr, {}; \\
|
|
||||||
void *__params[] = {{{}, &dummy, &dummy}}; \\
|
|
||||||
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
|
|
||||||
}} while (false)
|
|
||||||
""".format(
|
|
||||||
fn.fnname.upper(),
|
|
||||||
", ".join(arg for _, arg in ctype_arg_list),
|
|
||||||
fn.num_warps if fn.num_warps is not None else "__numWarps",
|
|
||||||
", ".join(
|
|
||||||
f"*{arg}_ptr = {arg}.data_ptr()"
|
|
||||||
for ctype, arg in ctype_arg_list
|
|
||||||
if ctype == "CUdeviceptr"
|
|
||||||
),
|
|
||||||
", ".join(
|
|
||||||
f"&{arg}" if ctype != "CUdeviceptr" else f"&{arg}_ptr"
|
|
||||||
for ctype, arg in ctype_arg_list
|
|
||||||
if ctype is not None
|
|
||||||
),
|
|
||||||
fnname=fn.fnname,
|
|
||||||
).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def wrap(
|
def wrap(
|
||||||
fns: List[TVMFFIJITFunction],
|
fns: List[TVMFFIJITFunction],
|
||||||
|
|||||||
Reference in New Issue
Block a user