mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
enable lambda function for grid descriptor
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include "tvm/ffi/function.h"
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
@@ -15,10 +16,14 @@
|
||||
tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
|
||||
at::Tensor xtorch = at::fromDLPack(x.ToDLPack());
|
||||
at::Tensor otorch = at::empty_like(xtorch);
|
||||
int64_t numel = otorch.numel();
|
||||
int32_t numel = otorch.numel();
|
||||
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{(numel + 1023) / 1024, 1, 1};
|
||||
// TODO: check the performance loss after enabling `Optional`
|
||||
tvm::ffi::Function grid = tvm::ffi::Function::FromTyped(
|
||||
[numel](const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta)
|
||||
-> tvm::ffi::Tuple<int32_t, int32_t, int32_t> {
|
||||
const int32_t BLOCK_SIZE = meta["BLOCK_SIZE"].cast<int32_t>();
|
||||
return tvm::ffi::Tuple((numel + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1);
|
||||
});
|
||||
tvm::ffi::Optional<int32_t> numWarps = std::nullopt, numStages = std::nullopt;
|
||||
DLDevice device = x.device();
|
||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||
|
||||
@@ -33,8 +33,9 @@ def add_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
|
||||
n_elements: int = output.numel()
|
||||
BLOCK_SIZE: int = 1024
|
||||
grid = (triton.cdiv(n_elements, BLOCK_SIZE), 1, 1)
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE)
|
||||
add_kernel[lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), 1, 1)](
|
||||
x, y, output, n_elements, BLOCK_SIZE
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, Final, List, Optional, Tuple
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from triton.compiler import CompiledKernel
|
||||
@@ -15,13 +16,16 @@ class TVMFFIJITFunction(object):
|
||||
def __init__(self, fn: JITFunction, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fn: Final[JITFunction] = fn
|
||||
self.signature: Optional[List[str]] = None
|
||||
self.ctypes: Optional[List[Optional[str]]] = None
|
||||
self.kernel: Optional[bytes] = None
|
||||
self.num_warps: Optional[int] = None
|
||||
|
||||
@tvm_ffi.register_global_func(self.fullname)
|
||||
def _(
|
||||
grid: Tuple[int, int, int],
|
||||
grid: Union[
|
||||
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||
],
|
||||
_: int,
|
||||
num_warps: Optional[int],
|
||||
num_stages: Optional[int],
|
||||
@@ -38,11 +42,17 @@ class TVMFFIJITFunction(object):
|
||||
kwargs["num_stages"] = num_stages
|
||||
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
|
||||
self.num_warps, _, _ = kernel.packed_metadata
|
||||
self.signature = [*inspect.signature(self.fn.fn).parameters.keys()]
|
||||
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
||||
self.kernel = kernel.kernel
|
||||
return kernel
|
||||
|
||||
def __getitem__(self, grid: Tuple[int, int, int]):
|
||||
def __getitem__(
|
||||
self,
|
||||
grid: Union[
|
||||
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||
],
|
||||
):
|
||||
return self.fn[grid]
|
||||
|
||||
@property
|
||||
|
||||
42
python/triton_tvm_ffi/templates/gendef.cc.j2
Normal file
42
python/triton_tvm_ffi/templates/gendef.cc.j2
Normal file
@@ -0,0 +1,42 @@
|
||||
#include <cuda.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
|
||||
{% include "grid.h" %}
|
||||
|
||||
#define {{ name | upper }}_NAME "{{ uniquename }}"
|
||||
{% for fn in fns %}
|
||||
{% if fn.ctypes is none %}
|
||||
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
|
||||
{% else %}
|
||||
TVM_FFI_EMBED_CUBIN(triton_{{ fn.fnname }});
|
||||
#define {{ fn.fnname | upper}}_STUB(__grid, __stream, __numWarps, __numStages{% for ctype in fn.ctypes %}, {{ "__arg" ~ loop.index0 }}{% endfor %}) do { \
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||
{% for name in fn.signature %}
|
||||
{ "{{ name }}", __arg{{ loop.index0 }} }, \
|
||||
{% endfor %}
|
||||
}; \
|
||||
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{{ fn.fnname }}, "{{ fn.fnname }}"); \
|
||||
tvm::ffi::dim3 __gridDim = MakeGridDim(__grid, __meta); \
|
||||
tvm::ffi::dim3 __block({% if fn.num_warps != none %}{{ fn.num_warps }}{% else %}__numWarps{% endif %} * 32, 1, 1); \
|
||||
void *dummy = nullptr
|
||||
{%- for ctype in fn.ctypes -%}
|
||||
{%- if ctype == "CUdeviceptr" -%}
|
||||
, *__arg{{ loop.index0 }}_ptr=__arg{{ loop.index0 }}.data_ptr()
|
||||
{%- endif -%}
|
||||
{%- endfor -%}; \
|
||||
void *__params[] = {
|
||||
{%- for ctype in fn.ctypes -%}
|
||||
{%- if ctype != none -%}
|
||||
&__arg{{ loop.index0 }}
|
||||
{%- if ctype == "CUdeviceptr" -%}
|
||||
_ptr
|
||||
{%- endif -%},
|
||||
{%- endif -%}
|
||||
{%- endfor -%}&dummy, &dummy }; \
|
||||
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __gridDim, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \
|
||||
} while (false)
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
{{ code }}
|
||||
29
python/triton_tvm_ffi/templates/grid.h
Normal file
29
python/triton_tvm_ffi/templates/grid.h
Normal file
@@ -0,0 +1,29 @@
|
||||
#ifndef TRITON_TVM_FFI_GRID_H
|
||||
#define TRITON_TVM_FFI_GRID_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/extra/cuda/base.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
template <typename T>
|
||||
inline tvm::ffi::dim3
|
||||
MakeGridDim(const T &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta);
|
||||
|
||||
template <>
|
||||
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
|
||||
const tvm::ffi::Tuple<int32_t, int32_t, int32_t> &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &) {
|
||||
return tvm::ffi::dim3(grid.get<0>(), grid.get<1>(), grid.get<2>());
|
||||
}
|
||||
|
||||
template <>
|
||||
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Function>(
|
||||
const tvm::ffi::Function &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta) {
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> tuple =
|
||||
grid(meta).cast<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>();
|
||||
return MakeGridDim(tuple, meta);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,8 +1,9 @@
|
||||
from functools import cached_property
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Final, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Final, List, Optional, Sequence, Union
|
||||
|
||||
import jinja2
|
||||
import torch.utils.cpp_extension
|
||||
import tvm_ffi
|
||||
|
||||
@@ -38,6 +39,11 @@ class TVMFFIWrapperFunction(object):
|
||||
self.extra_include_paths: Optional[Sequence[Union[str, Path]]] = (
|
||||
extra_include_paths
|
||||
)
|
||||
self.env: Final[jinja2.Environment] = jinja2.Environment(
|
||||
loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"),
|
||||
trim_blocks=True,
|
||||
)
|
||||
self.tpl: Final[jinja2.Template] = self.env.get_template("gendef.cc.j2")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> None:
|
||||
func: tvm_ffi.Function = self.compile()
|
||||
@@ -53,19 +59,9 @@ class TVMFFIWrapperFunction(object):
|
||||
|
||||
@property
|
||||
def emit(self) -> str:
|
||||
defs: str = "\n".join(
|
||||
[
|
||||
"#include <cuda.h>",
|
||||
"#include <tvm/ffi/extra/cuda/cubin_launcher.h>",
|
||||
"#include <tvm/ffi/function.h>",
|
||||
f'#define {self.name.upper()}_NAME "{self.uniquename}"',
|
||||
*map(
|
||||
self.gendef,
|
||||
self.fns,
|
||||
),
|
||||
]
|
||||
return self.tpl.render(
|
||||
code=self.code, fns=self.fns, name=self.name, uniquename=self.uniquename
|
||||
)
|
||||
return f"{defs}\n{self.code}"
|
||||
|
||||
@property
|
||||
def uniquename(self) -> str:
|
||||
@@ -90,42 +86,6 @@ class TVMFFIWrapperFunction(object):
|
||||
)
|
||||
return tvm_ffi.get_global_func(self.uniquename)
|
||||
|
||||
@staticmethod
|
||||
def gendef(fn: TVMFFIJITFunction) -> str:
|
||||
if fn.ctypes is None:
|
||||
return f'#define {fn.fnname.upper()}_STUB tvm::ffi::Function::GetGlobalRequired("{fn.fullname}")'
|
||||
else:
|
||||
ctype_arg_list: List[Tuple[str, str]] = [
|
||||
(ctype, f"__arg{idx}") for idx, ctype in enumerate(fn.ctypes)
|
||||
]
|
||||
|
||||
return """
|
||||
TVM_FFI_EMBED_CUBIN(triton_{fnname});
|
||||
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
|
||||
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
|
||||
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
|
||||
tvm::ffi::dim3 __block({} * 32, 1, 1); \\
|
||||
void *dummy = nullptr, {}; \\
|
||||
void *__params[] = {{{}, &dummy, &dummy}}; \\
|
||||
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
|
||||
}} while (false)
|
||||
""".format(
|
||||
fn.fnname.upper(),
|
||||
", ".join(arg for _, arg in ctype_arg_list),
|
||||
fn.num_warps if fn.num_warps is not None else "__numWarps",
|
||||
", ".join(
|
||||
f"*{arg}_ptr = {arg}.data_ptr()"
|
||||
for ctype, arg in ctype_arg_list
|
||||
if ctype == "CUdeviceptr"
|
||||
),
|
||||
", ".join(
|
||||
f"&{arg}" if ctype != "CUdeviceptr" else f"&{arg}_ptr"
|
||||
for ctype, arg in ctype_arg_list
|
||||
if ctype is not None
|
||||
),
|
||||
fnname=fn.fnname,
|
||||
).strip()
|
||||
|
||||
|
||||
def wrap(
|
||||
fns: List[TVMFFIJITFunction],
|
||||
|
||||
Reference in New Issue
Block a user