mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
enable optional for numwarps and numstages
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
@@ -18,7 +18,8 @@ tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
|
|||||||
int64_t numel = otorch.numel();
|
int64_t numel = otorch.numel();
|
||||||
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
|
tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch));
|
||||||
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{(numel + 1023) / 1024, 1, 1};
|
tvm::ffi::Tuple<int32_t, int32_t, int32_t> grid{(numel + 1023) / 1024, 1, 1};
|
||||||
size_t numWarps = 4, numStages = 3;
|
// TODO: check the performance loss after enabling `Optional`
|
||||||
|
tvm::ffi::Optional<int32_t> numWarps = std::nullopt, numStages = std::nullopt;
|
||||||
DLDevice device = x.device();
|
DLDevice device = x.device();
|
||||||
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id);
|
||||||
ADD_KERNEL_STUB(grid, stream, numWarps, numStages, x, y, output, numel, 1024);
|
ADD_KERNEL_STUB(grid, stream, numWarps, numStages, x, y, output, numel, 1024);
|
||||||
|
|||||||
@@ -17,13 +17,14 @@ class TVMFFIJITFunction(object):
|
|||||||
self.fn: Final[JITFunction] = fn
|
self.fn: Final[JITFunction] = fn
|
||||||
self.ctypes: Optional[List[Optional[str]]] = None
|
self.ctypes: Optional[List[Optional[str]]] = None
|
||||||
self.kernel: Optional[bytes] = None
|
self.kernel: Optional[bytes] = None
|
||||||
|
self.num_warps: Optional[int] = None
|
||||||
|
|
||||||
@tvm_ffi.register_global_func(self.fullname)
|
@tvm_ffi.register_global_func(self.fullname)
|
||||||
def _(
|
def _(
|
||||||
grid: Tuple[int, int, int],
|
grid: Tuple[int, int, int],
|
||||||
_: int,
|
_: int,
|
||||||
num_warps: int,
|
num_warps: Optional[int],
|
||||||
num_stages: int,
|
num_stages: Optional[int],
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -31,9 +32,12 @@ class TVMFFIJITFunction(object):
|
|||||||
kwargs: Dict[str, Any] = {
|
kwargs: Dict[str, Any] = {
|
||||||
k: self.canonicalize(v) for k, v in kwargs.items()
|
k: self.canonicalize(v) for k, v in kwargs.items()
|
||||||
}
|
}
|
||||||
kernel: CompiledKernel = self.fn[grid](
|
if num_warps is not None:
|
||||||
*args, **kwargs, num_warps=num_warps, num_stages=num_stages
|
kwargs["num_warps"] = num_warps
|
||||||
)
|
if num_stages is not None:
|
||||||
|
kwargs["num_stages"] = num_stages
|
||||||
|
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
|
||||||
|
self.num_warps, _, _ = kernel.packed_metadata
|
||||||
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
||||||
self.kernel = kernel.kernel
|
self.kernel = kernel.kernel
|
||||||
return kernel
|
return kernel
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ TVM_FFI_EMBED_CUBIN(triton_{fnname});
|
|||||||
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
|
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
|
||||||
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
|
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
|
||||||
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
|
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
|
||||||
tvm::ffi::dim3 __block(__numWarps * 32, 1, 1); \\
|
tvm::ffi::dim3 __block({} * 32, 1, 1); \\
|
||||||
void *dummy = nullptr, {}; \\
|
void *dummy = nullptr, {}; \\
|
||||||
void *__params[] = {{{}, &dummy, &dummy}}; \\
|
void *__params[] = {{{}, &dummy, &dummy}}; \\
|
||||||
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
|
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
|
||||||
@@ -112,6 +112,7 @@ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __bloc
|
|||||||
""".format(
|
""".format(
|
||||||
fn.fnname.upper(),
|
fn.fnname.upper(),
|
||||||
", ".join(arg for _, arg in ctype_arg_list),
|
", ".join(arg for _, arg in ctype_arg_list),
|
||||||
|
fn.num_warps if fn.num_warps is not None else "__numWarps",
|
||||||
", ".join(
|
", ".join(
|
||||||
f"*{arg}_ptr = {arg}.data_ptr()"
|
f"*{arg}_ptr = {arg}.data_ptr()"
|
||||||
for ctype, arg in ctype_arg_list
|
for ctype, arg in ctype_arg_list
|
||||||
|
|||||||
Reference in New Issue
Block a user