mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
enable optional for numwarps and numstages
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
@@ -104,7 +104,7 @@ TVM_FFI_EMBED_CUBIN(triton_{fnname});
|
||||
#define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\
|
||||
static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\
|
||||
tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\
|
||||
tvm::ffi::dim3 __block(__numWarps * 32, 1, 1); \\
|
||||
tvm::ffi::dim3 __block({} * 32, 1, 1); \\
|
||||
void *dummy = nullptr, {}; \\
|
||||
void *__params[] = {{{}, &dummy, &dummy}}; \\
|
||||
TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast<tvm::ffi::cuda_api::StreamHandle>(__stream))); \\
|
||||
@@ -112,6 +112,7 @@ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __bloc
|
||||
""".format(
|
||||
fn.fnname.upper(),
|
||||
", ".join(arg for _, arg in ctype_arg_list),
|
||||
fn.num_warps if fn.num_warps is not None else "__numWarps",
|
||||
", ".join(
|
||||
f"*{arg}_ptr = {arg}.data_ptr()"
|
||||
for ctype, arg in ctype_arg_list
|
||||
|
||||
Reference in New Issue
Block a user