mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-01 19:42:10 +08:00
use templates to substitute parts of macros
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
#include "tvm/ffi/function.h"
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef ADD_KERNEL_STUB
|
||||
#define ADD_KERNEL_STUB(grid, stream, args, kwargs)
|
||||
#define ADD_KERNEL_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef ADD_NAME
|
||||
@@ -27,7 +27,7 @@ tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
|
||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||
tvm::ffi::Array<tvm::ffi::Any> args = {x, y, output, numel, 1024};
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {};
|
||||
ADD_KERNEL_STUB(grid, stream, args, kwargs);
|
||||
ADD_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef MATMUL_KERNEL_STUB
|
||||
#define MATMUL_KERNEL_STUB(grid, stream, args, kwargs)
|
||||
#define MATMUL_KERNEL_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef MATMUL_NAME
|
||||
@@ -45,7 +45,7 @@ tvm::ffi::Tensor Matmul(tvm::ffi::Tensor a, tvm::ffi::Tensor b,
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {
|
||||
{"ACTIVATION", activation},
|
||||
};
|
||||
MATMUL_KERNEL_STUB(grid, stream, args, kwargs);
|
||||
MATMUL_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return c;
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
#ifndef SOFTMAX_KERNEL_STUB
|
||||
#define SOFTMAX_KERNEL_STUB(grid, stream, args, kwargs)
|
||||
#define SOFTMAX_KERNEL_STUB(grid, device, stream, args, kwargs)
|
||||
#endif
|
||||
|
||||
#ifndef SOFTMAX_NAME
|
||||
@@ -24,7 +24,7 @@ tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) {
|
||||
tvm::ffi::Array<tvm::ffi::Any> args = {y, x, xStride, yStride,
|
||||
nRows, nCols, BLOCK_SIZE};
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> kwargs = {};
|
||||
SOFTMAX_KERNEL_STUB(grid, stream, args, kwargs);
|
||||
SOFTMAX_KERNEL_STUB(grid, device.device_id, stream, args, kwargs);
|
||||
return y;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
#define TRITON_TVM_FFI_GRID_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <tvm/ffi/extra/cuda/base.h>
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <typename T>
|
||||
inline tvm::ffi::Tuple<int32_t, int32_t, int32_t>
|
||||
MakeGridDim(const T &grid,
|
||||
@@ -28,4 +29,6 @@ MakeGridDim<tvm::ffi::Function>(
|
||||
return MakeGridDim(tuple, meta);
|
||||
}
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
|
||||
39
include/triton_tvm_ffi/kernel.h
Normal file
39
include/triton_tvm_ffi/kernel.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#ifndef TRITON_TVM_FFI_KERNEL_H_
|
||||
#define TRITON_TVM_FFI_KERNEL_H_
|
||||
|
||||
#include "macro.h"
|
||||
#include <cstdint>
|
||||
#include <cuda.h>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <const char kFnName[], const char kCubin[], size_t kSMem>
|
||||
inline CUfunction GetKernel(int32_t device) {
|
||||
static std::unordered_map<int32_t, CUfunction> functions = {};
|
||||
if (functions.find(device) == functions.end()) {
|
||||
CUmodule module;
|
||||
CUfunction func;
|
||||
__CUDA_CHECK(cuModuleLoadData(&module, kCubin));
|
||||
__CUDA_CHECK(cuModuleGetFunction(&func, module, kFnName));
|
||||
if (kSMem > 49152) {
|
||||
int32_t shared_optin, shared_static;
|
||||
__CUDA_CHECK(cuDeviceGetAttribute(
|
||||
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
device));
|
||||
if (shared_optin >= kSMem) {
|
||||
__CUDA_CHECK(cuFuncGetAttribute(
|
||||
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
|
||||
__CUDA_CHECK(cuFuncSetAttribute(
|
||||
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_optin - shared_static));
|
||||
}
|
||||
}
|
||||
functions[device] = func;
|
||||
}
|
||||
return functions[device];
|
||||
};
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
22
include/triton_tvm_ffi/macro.h
Normal file
22
include/triton_tvm_ffi/macro.h
Normal file
@@ -0,0 +1,22 @@
|
||||
#ifndef TRITON_TVM_FFI_MACRO_H_
|
||||
#define TRITON_TVM_FFI_MACRO_H_
|
||||
|
||||
#include <cuda.h>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#define __CUDA_CHECK(__code) \
|
||||
do { \
|
||||
if ((__code) != CUDA_SUCCESS) { \
|
||||
const char *errorName = nullptr, *errorStr = nullptr; \
|
||||
cuGetErrorName((__code), &errorName); \
|
||||
cuGetErrorString((__code), &errorStr); \
|
||||
std::ostringstream __oss; \
|
||||
__oss << "[" << errorName << "] " << errorStr << ", at " << __FILE__ \
|
||||
<< ":" << __LINE__; \
|
||||
throw std::runtime_error(__oss.str()); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#endif
|
||||
52
include/triton_tvm_ffi/meta.h
Normal file
52
include/triton_tvm_ffi/meta.h
Normal file
@@ -0,0 +1,52 @@
|
||||
#ifndef TRITON_TVM_FFI_META_H_
|
||||
#define TRITON_TVM_FFI_META_H_
|
||||
|
||||
#include <tvm/ffi/tvm_ffi.h>
|
||||
|
||||
namespace triton_tvm_ffi {
|
||||
|
||||
template <const char... Ks[]> struct FillMetaImpl {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
|
||||
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs);
|
||||
};
|
||||
|
||||
template <> struct FillMetaImpl<> {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
|
||||
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {}
|
||||
};
|
||||
|
||||
template <const char K[], const char... Ks[]> struct FillMetaImpl<K, Ks...> {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator &argsBegin,
|
||||
const tvm::ffi::Array<tvm::ffi::Any>::iterator &argsEnd,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
|
||||
if (argsBegin != argsEnd) {
|
||||
meta.Set(K, *argsBegin++);
|
||||
} else if (auto val = kwargs.Get(K)) {
|
||||
meta.Set(K, *val);
|
||||
}
|
||||
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
|
||||
}
|
||||
};
|
||||
|
||||
template <const char... Ks[]> struct FillMeta {
|
||||
static inline void
|
||||
apply(tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &meta,
|
||||
const tvm::ffi::Array<tvm::ffi::Any> &args,
|
||||
const tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> &kwargs) {
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator argsBegin = args.begin();
|
||||
tvm::ffi::Array<tvm::ffi::Any>::iterator argsEnd = args.end();
|
||||
FillMetaImpl<Ks...>::apply(meta, argsBegin, argsEnd, kwargs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace triton_tvm_ffi
|
||||
|
||||
#endif
|
||||
@@ -40,7 +40,8 @@ class TVMFFIJITFunction(object):
|
||||
grid: Union[
|
||||
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||
],
|
||||
_: int,
|
||||
_device: int,
|
||||
_stream: int,
|
||||
args: Sequence[Any],
|
||||
kwargs: Mapping[str, Any],
|
||||
):
|
||||
|
||||
@@ -1,41 +1,19 @@
|
||||
#include <cassert>
|
||||
#include <cuda.h>
|
||||
#include <optional>
|
||||
#include <tvm/ffi/extra/cuda/cubin_launcher.h>
|
||||
#include <tvm/ffi/function.h>
|
||||
#include "triton_tvm_ffi/grid.h"
|
||||
#include "triton_tvm_ffi/kernel.h"
|
||||
#include "triton_tvm_ffi/macro.h"
|
||||
#include "triton_tvm_ffi/meta.h"
|
||||
|
||||
#define {{ name | upper }}_NAME "{{ uniquename }}"
|
||||
{% for fn in fns %}
|
||||
{% if fn.ctypes is none %}
|
||||
#define {{ fn.fnname | upper }}_STUB tvm::ffi::Function::GetGlobalRequired("{{ fn.fullname }}")
|
||||
{% else %}
|
||||
static const char __cubin[] = "{{ fn.kernel_cstr }}";
|
||||
static constexpr char __fnname_{{ fn.fnname }}[] = "{{ fn.fnname }}";
|
||||
static constexpr char __cubin_{{ fn.fnname }}[] = "{{ fn.kernel_cstr }}";
|
||||
|
||||
#define __CUDA_CHECK(code) assert((code) == CUDA_SUCCESS)
|
||||
|
||||
static CUfunction __Get{{ fn.fnname }}Kernel() {
|
||||
static std::optional<CUfunction> function = std::nullopt;
|
||||
if (!function) {
|
||||
CUmodule module;
|
||||
CUfunction func;
|
||||
__CUDA_CHECK(cuModuleLoadData(&module, __cubin));
|
||||
__CUDA_CHECK(cuModuleGetFunction(&func, module, "{{ fn.fnname }}"));
|
||||
{% if fn.shmem > 49152 %}
|
||||
int shared_optin, shared_static;
|
||||
__CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, /* TODO: we assume the device id is 0 here, but this may not work on devices with more than one gpu */0));
|
||||
if (shared_optin >= 49152) {
|
||||
__CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
|
||||
__CUDA_CHECK(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
|
||||
}
|
||||
{% endif %}
|
||||
function = func;
|
||||
}
|
||||
return *function;
|
||||
}
|
||||
|
||||
#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __args, __kwargs) do { \
|
||||
const char *__signature[] = { "{{ fn.signature | join("\", \"") }}" }; \
|
||||
#define {{ fn.fnname | upper }}_STUB(__grid, __device, __stream, __args, __kwargs) do { \
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||
{% if fn.best_config != none %}
|
||||
{% for k, v in fn.best_config.items() %}
|
||||
@@ -43,24 +21,19 @@ tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
}; \
|
||||
for (size_t __i = 0, __size_args = __args.size(); __i < sizeof(__signature) / sizeof(const char *); ++__i) { \
|
||||
if (__i < __size_args) { \
|
||||
__meta.Set(__signature[__i], __args[__i]); \
|
||||
} else if (auto __val = __kwargs.Get(__signature[__i])) { \
|
||||
__meta.Set(__signature[__i], *__val); \
|
||||
} \
|
||||
} \
|
||||
CUfunction __function = __Get{{ fn.fnname }}Kernel(); \
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = MakeGridDim(__grid, __meta); \
|
||||
{% for type in fn.signature %}
|
||||
static constexpr char __varname{{ loop.index0 }}[] = "{{ type }}"; \
|
||||
{% endfor %}
|
||||
triton_tvm_ffi::FillMeta<{% for type in fn.signature %}__varname{{ loop.index0 }}{% if not loop.last %}, {% endif %}{% endfor %}>::apply(__meta, __args, __kwargs); \
|
||||
CUfunction __function = triton_tvm_ffi::GetKernel<__fnname_{{ fn.fnname }}, __cubin_{{ fn.fnname }}, {{ fn.shmem }}>(__device); \
|
||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> __gridDim = triton_tvm_ffi::MakeGridDim(__grid, __meta); \
|
||||
void *dummy = nullptr; \
|
||||
{% for ctype in fn.ctypes %}
|
||||
{% if ctype != none %}
|
||||
{% if ctype == "CUdeviceptr" %}
|
||||
void *__arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<tvm::ffi::TensorView>().data_ptr(); \
|
||||
{% else %}
|
||||
{% elif ctype != none %}
|
||||
{{ ctype }} __arg{{ loop.index0 }} = __args[{{ loop.index0 }}].cast<{{ ctype }}>(); \
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
void *__params[] = { {% for ctype in fn.ctypes %}{% if ctype != none %}&__arg{{ loop.index0 }}, {% endif %}{% endfor %}&dummy, &dummy }; \
|
||||
__CUDA_CHECK(cuLaunchKernel(__function, __gridDim.get<0>(), __gridDim.get<1>(), __gridDim.get<2>(), 32 * {{ fn.num_warps }}, 1, 1, {{ fn.shmem }}, reinterpret_cast<CUstream>(__stream), __params, nullptr)); \
|
||||
|
||||
Reference in New Issue
Block a user