mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
@@ -2,24 +2,38 @@ from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Final,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
from triton.compiler import CompiledKernel
|
||||
from triton.runtime import JITFunction
|
||||
from triton.runtime import Autotuner, JITFunction
|
||||
import tvm_ffi
|
||||
|
||||
from .utils import type_canonicalize
|
||||
|
||||
|
||||
class TVMFFIJITFunction(object):
|
||||
def __init__(self, fn: JITFunction, *args, **kwargs) -> None:
|
||||
def __init__(self, fn: Union[Autotuner, JITFunction], *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fn: Final[JITFunction] = fn
|
||||
self.signature: Optional[List[str]] = None
|
||||
self.fn: Final[Union[Autotuner, JITFunction]] = fn
|
||||
self.signature: List[str] = [*inspect.signature(self.basefn).parameters.keys()]
|
||||
self.best_config: Optional[Dict[str, Any]] = None
|
||||
self.ctypes: Optional[List[Optional[str]]] = None
|
||||
self.kernel: Optional[bytes] = None
|
||||
self.num_warps: Optional[int] = None
|
||||
self.shmem: int = 0
|
||||
|
||||
@tvm_ffi.register_global_func(self.fullname)
|
||||
def _(
|
||||
@@ -29,22 +43,23 @@ class TVMFFIJITFunction(object):
|
||||
_: int,
|
||||
num_warps: Optional[int],
|
||||
num_stages: Optional[int],
|
||||
*args,
|
||||
**kwargs,
|
||||
args: Sequence[Any],
|
||||
kwargs: Mapping[str, Any],
|
||||
):
|
||||
args: List[Any] = map(self.canonicalize, args)
|
||||
args: Iterator[Any] = map(self.canonicalize, args)
|
||||
kwargs: Dict[str, Any] = {
|
||||
k: self.canonicalize(v) for k, v in kwargs.items()
|
||||
}
|
||||
k: v for k, v in zip(self.signature, args) if v is not None
|
||||
} | {k: self.canonicalize(v) for k, v in kwargs.items()}
|
||||
if num_warps is not None:
|
||||
kwargs["num_warps"] = num_warps
|
||||
if num_stages is not None:
|
||||
kwargs["num_stages"] = num_stages
|
||||
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
|
||||
self.num_warps, _, _ = kernel.packed_metadata
|
||||
self.signature = [*inspect.signature(self.fn.fn).parameters.keys()]
|
||||
self.num_warps, _, self.shmem = kernel.packed_metadata
|
||||
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
||||
self.kernel = kernel.kernel
|
||||
if isinstance(self.fn, Autotuner):
|
||||
self.best_config = self.fn.best_config.all_kwargs()
|
||||
return kernel
|
||||
|
||||
def __getitem__(
|
||||
@@ -55,6 +70,10 @@ class TVMFFIJITFunction(object):
|
||||
):
|
||||
return self.fn[grid]
|
||||
|
||||
@cached_property
|
||||
def basefn(self) -> Callable:
|
||||
return self.jitfn.fn
|
||||
|
||||
@property
|
||||
def cache_hash(self) -> int:
|
||||
return self.ctypes_hash ^ self.kernel_hash
|
||||
@@ -63,21 +82,35 @@ class TVMFFIJITFunction(object):
|
||||
def ctypes_hash(self) -> int:
|
||||
return hash(tuple(self.ctypes) if self.ctypes is not None else None)
|
||||
|
||||
@property
|
||||
def kernel_hash(self) -> int:
|
||||
return hash(self.kernel)
|
||||
|
||||
@cached_property
|
||||
def fnname(self) -> str:
|
||||
return self.fn.fn.__name__
|
||||
return self.basefn.__name__
|
||||
|
||||
@cached_property
|
||||
def fullname(self) -> str:
|
||||
return f"triton.{self.name}"
|
||||
|
||||
@cached_property
|
||||
def jitfn(self) -> JITFunction:
|
||||
fn: Union[Autotuner, JITFunction] = self.fn
|
||||
while not isinstance(fn, JITFunction):
|
||||
fn = fn.fn
|
||||
return fn
|
||||
|
||||
@property
|
||||
def kernel_hash(self) -> int:
|
||||
return hash(self.kernel)
|
||||
|
||||
@property
|
||||
def kernel_cstr(self) -> Optional[str]:
|
||||
if self.kernel is not None:
|
||||
return "".join(f"\\x{byte:02x}" for byte in self.kernel)
|
||||
else:
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def name(self) -> str:
|
||||
return f"{self.fnname}_{hash(self.fn.fn)}"
|
||||
return f"{self.fnname}_{hash(self.basefn)}"
|
||||
|
||||
@staticmethod
|
||||
def canonicalize(val: Any) -> Any:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <optional>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
|
||||
@@ -9,32 +11,60 @@
|
||||
{% if fn.ctypes is none %}
|
||||
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
|
||||
{% else %}
|
||||
TVM_FFI_EMBED_CUBIN(triton_{{ fn.fnname }});
|
||||
#define {{ fn.fnname | upper}}_STUB(__grid, __stream, __numWarps, __numStages{% for ctype in fn.ctypes %}, {{ "__arg" ~ loop.index0 }}{% endfor %}) do { \
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||
{% for name in fn.signature %}
|
||||
{ "{{ name }}", __arg{{ loop.index0 }} }, \
|
||||
static const char __cubin[] = "{{ fn.kernel_cstr }}";
|
||||
|
||||
#define __CUDA_CHECK(code) assert((code) == CUDA_SUCCESS)
|
||||
|
||||
static CUfunction __Get{{ fn.fnname }}Kernel() {
|
||||
static std::optional<CUfunction> function = std::nullopt;
|
||||
if (!function) {
|
||||
CUmodule module;
|
||||
CUfunction func;
|
||||
__CUDA_CHECK(cuModuleLoadData(&module, __cubin));
|
||||
__CUDA_CHECK(cuModuleGetFunction(&func, module, "{{ fn.fnname }}"));
|
||||
{% if fn.shmem > 49152 %}
|
||||
int shared_optin, shared_static;
|
||||
__CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, /* TODO: we assume the device id is 0 here, but this may not work on devices with more than one gpu */0));
|
||||
if (shared_optin >= 49152) {
|
||||
__CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
|
||||
__CUDA_CHECK(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
|
||||
}
|
||||
{% endif %}
|
||||
function = func;
|
||||
}
|
||||
return *function;
|
||||
}
|
||||
|
||||
#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __numWarps, __numStages, __args, __kwargs) do { \
|
||||
const char *__signature[] = { "{{ fn.signature | join("\", \"") }}" }; \
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||
{% if fn.best_config != none %}
|
||||
{% for k, v in fn.best_config.items() %}
|
||||
{ "{{ k }}", {{ v }} }, \
|
||||
{% endfor %}
|
||||
}; \
|
||||
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{{ fn.fnname }}, "{{ fn.fnname }}"); \
|
||||
tvm::ffi::dim3 __gridDim = MakeGridDim(__grid, __meta); \
|
||||
tvm::ffi::dim3 __block({% if fn.num_warps != none %}{{ fn.num_warps }}{% else %}__numWarps{% endif %} * 32, 1, 1); \
|
||||
void *dummy = nullptr
|
||||
{%- for ctype in fn.ctypes -%}
|
||||
{%- if ctype == "CUdeviceptr" -%}
|
||||
, *__arg{{ loop.index0 }}_ptr=__arg{{ loop.index0 }}.data_ptr()
|
||||
{%- endif -%}
|
||||
{%- endfor -%}; \
|
||||
void *__params[] = {
|
||||
{%- for ctype in fn.ctypes -%}
|
||||
{%- if ctype != none -%}
|
||||
&__arg{{ loop.index0 }}
|
||||
{%- if ctype == "CUdeviceptr" -%}
|
||||
_ptr
|
||||
{%- endif -%},
|
||||
{%- endif -%}
|
||||
{%- endfor -%}&dummy, &dummy }; \
|
||||
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __gridDim, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \
|
||||
{% endif %}
|
||||
}; \
|
||||
for (size_t __i = 0, __size_args = __args.size(); __i < sizeof(__signature) / sizeof(const char *); ++__i) { \
|
||||
if (__i < __size_args) { \
|
||||
__meta.Set(__signature[__i], __args[__i]); \
|
||||
} else if (auto __val = __kwargs.Get(__signature[__i])) { \
|
||||
__meta.Set(__signature[__i], *__val); \
|
||||
} \
|
||||
} \
|
||||
CUfunction __function = __Get{{ fn.fnname }}Kernel(); \
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = MakeGridDim(__grid, __meta); \
|
||||
void *dummy = nullptr; \
|
||||
{% for ctype in fn.ctypes %}
|
||||
{% if ctype != none %}
|
||||
{% if ctype == "CUdeviceptr" %}
|
||||
void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
|
||||
{% else %}
|
||||
{{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
|
||||
__CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast<CUstream>(__stream), __params, nullptr)); \
|
||||
} while (false)
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
@@ -6,19 +6,21 @@
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
template <typename T>
|
||||
inline tvm::ffi::dim3
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim(const T &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta);
|
||||
|
||||
template <>
|
||||
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim<tvm::ffi::Tuple<int32_t, int32_t, int32_t>>(
|
||||
const tvm::ffi::Tuple<int32_t, int32_t, int32_t> &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &) {
|
||||
return tvm::ffi::dim3(grid.get<0>(), grid.get<1>(), grid.get<2>());
|
||||
return grid;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline tvm::ffi::dim3 MakeGridDim<tvm::ffi::Function>(
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim<tvm::ffi::Function>(
|
||||
const tvm::ffi::Function &grid,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta) {
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> tuple =
|
||||
|
||||
Reference in New Issue
Block a user