From 8b8aa6cb84b7366d3082d3678cfbeef23665461e Mon Sep 17 00:00:00 2001 From: jinjieliu Date: Thu, 5 Feb 2026 01:01:49 +0800 Subject: [PATCH] enable optional for numwarps and numstages Signed-off-by: jinjieliu --- examples/add/add.cc | 3 ++- python/triton_tvm_ffi/jit.py | 14 +++++++++----- python/triton_tvm_ffi/wrap.py | 3 ++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/add/add.cc b/examples/add/add.cc index 4e4aadd..cbfb364 100644 --- a/examples/add/add.cc +++ b/examples/add/add.cc @@ -18,7 +18,8 @@ tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { int64_t numel = otorch.numel(); tvm::ffi::Tensor output = tvm::ffi::Tensor::FromDLPack(at::toDLPack(otorch)); tvm::ffi::Tuple grid{(numel + 1023) / 1024, 1, 1}; - size_t numWarps = 4, numStages = 3; + // TODO: check the performance loss after enabling `Optional` + tvm::ffi::Optional numWarps = std::nullopt, numStages = std::nullopt; DLDevice device = x.device(); void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); ADD_KERNEL_STUB(grid, stream, numWarps, numStages, x, y, output, numel, 1024); diff --git a/python/triton_tvm_ffi/jit.py b/python/triton_tvm_ffi/jit.py index 1611a0f..ac2c0a5 100644 --- a/python/triton_tvm_ffi/jit.py +++ b/python/triton_tvm_ffi/jit.py @@ -17,13 +17,14 @@ class TVMFFIJITFunction(object): self.fn: Final[JITFunction] = fn self.ctypes: Optional[List[Optional[str]]] = None self.kernel: Optional[bytes] = None + self.num_warps: Optional[int] = None @tvm_ffi.register_global_func(self.fullname) def _( grid: Tuple[int, int, int], _: int, - num_warps: int, - num_stages: int, + num_warps: Optional[int], + num_stages: Optional[int], *args, **kwargs, ): @@ -31,9 +32,12 @@ class TVMFFIJITFunction(object): kwargs: Dict[str, Any] = { k: self.canonicalize(v) for k, v in kwargs.items() } - kernel: CompiledKernel = self.fn[grid]( - *args, **kwargs, num_warps=num_warps, num_stages=num_stages - ) + if num_warps is not None: + kwargs["num_warps"] = num_warps + if num_stages is not None: + kwargs["num_stages"] = num_stages + kernel: CompiledKernel = self.fn[grid](*args, **kwargs) + self.num_warps, _, _ = kernel.packed_metadata self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()] self.kernel = kernel.kernel return kernel diff --git a/python/triton_tvm_ffi/wrap.py b/python/triton_tvm_ffi/wrap.py index 7340e2b..6b96b3d 100644 --- a/python/triton_tvm_ffi/wrap.py +++ b/python/triton_tvm_ffi/wrap.py @@ -104,7 +104,7 @@ TVM_FFI_EMBED_CUBIN(triton_{fnname}); #define {}_STUB(__gtuple, __stream, __numWarps, __numStages, {}) do {{ \\ static auto __kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(triton_{fnname}, "{fnname}"); \\ tvm::ffi::dim3 __grid(__gtuple.get<0>(), __gtuple.get<1>(), __gtuple.get<2>()); \\ -tvm::ffi::dim3 __block(__numWarps * 32, 1, 1); \\ +tvm::ffi::dim3 __block({} * 32, 1, 1); \\ void *dummy = nullptr, {}; \\ void *__params[] = {{{}, &dummy, &dummy}}; \\ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __block, static_cast(__stream))); \\ @@ -112,6 +112,7 @@ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(__kernel.Launch(__params, __grid, __bloc """.format( fn.fnname.upper(), ", ".join(arg for _, arg in ctype_arg_list), + fn.num_warps if fn.num_warps is not None else "__numWarps", ", ".join( f"*{arg}_ptr = {arg}.data_ptr()" for ctype, arg in ctype_arg_list