From 6a19a6b06d3fc92fea2a4085bff5f5ea174b962c Mon Sep 17 00:00:00 2001 From: jinjieliu Date: Sat, 7 Feb 2026 14:25:10 +0800 Subject: [PATCH] put num_warps and num_stages in kwargs Signed-off-by: jinjieliu --- examples/add/add.cc | 5 ++--- examples/mm/mm.cc | 5 ++--- examples/softmax/softmax.cc | 12 +++++------- python/triton_tvm_ffi/jit.py | 6 ------ python/triton_tvm_ffi/templates/gendef.cc.j2 | 2 +- 5 files changed, 10 insertions(+), 20 deletions(-) diff --git a/examples/add/add.cc b/examples/add/add.cc index ed67885..bddda09 100644 --- a/examples/add/add.cc +++ b/examples/add/add.cc @@ -5,7 +5,7 @@ #include #ifndef ADD_KERNEL_STUB -#define ADD_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs) +#define ADD_KERNEL_STUB(grid, stream, args, kwargs) #endif #ifndef ADD_NAME @@ -23,12 +23,11 @@ tvm::ffi::Tensor Add(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { const int32_t BLOCK_SIZE = meta["BLOCK_SIZE"].cast(); return tvm::ffi::Tuple((numel + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1); }); - tvm::ffi::Optional numWarps = std::nullopt, numStages = std::nullopt; DLDevice device = x.device(); void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); tvm::ffi::Array args = {x, y, output, numel, 1024}; tvm::ffi::Map kwargs = {}; - ADD_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs); + ADD_KERNEL_STUB(grid, stream, args, kwargs); return output; } diff --git a/examples/mm/mm.cc b/examples/mm/mm.cc index dbb8543..c4b84a9 100644 --- a/examples/mm/mm.cc +++ b/examples/mm/mm.cc @@ -4,7 +4,7 @@ #include #ifndef MATMUL_KERNEL_STUB -#define MATMUL_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs) +#define MATMUL_KERNEL_STUB(grid, stream, args, kwargs) #endif #ifndef MATMUL_NAME @@ -27,7 +27,6 @@ tvm::ffi::Tensor Matmul(tvm::ffi::Tensor a, tvm::ffi::Tensor b, ((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N), 1, 1}; }); - tvm::ffi::Optional numWarps = std::nullopt, numStages = std::nullopt; DLDevice device = a.device(); void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); tvm::ffi::Tensor c = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ctorch)); @@ -46,7 +45,7 @@ tvm::ffi::Tensor Matmul(tvm::ffi::Tensor a, tvm::ffi::Tensor b, tvm::ffi::Map kwargs = { {"ACTIVATION", activation}, }; - MATMUL_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs); + MATMUL_KERNEL_STUB(grid, stream, args, kwargs); return c; } diff --git a/examples/softmax/softmax.cc b/examples/softmax/softmax.cc index 4892252..a3c5709 100644 --- a/examples/softmax/softmax.cc +++ b/examples/softmax/softmax.cc @@ -4,7 +4,7 @@ #include #ifndef SOFTMAX_KERNEL_STUB -#define SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs) +#define SOFTMAX_KERNEL_STUB(grid, stream, args, kwargs) #endif #ifndef SOFTMAX_NAME @@ -14,19 +14,17 @@ tvm::ffi::Tensor Softmax(tvm::ffi::Tensor x) { at::Tensor xtorch = at::fromDLPack(x.ToDLPack()); at::Tensor ytorch = at::empty_like(xtorch); - uint32_t nRows = xtorch.size(0), nCols = xtorch.size(1), numWarps = 8, - numStages = 4, xStride = xtorch.stride(0), - yStride = ytorch.stride(0), + uint32_t nRows = xtorch.size(0), nCols = xtorch.size(1), + xStride = xtorch.stride(0), yStride = ytorch.stride(0), BLOCK_SIZE = 1u << (32 - __builtin_clz(nCols - 1)); tvm::ffi::Tensor y = tvm::ffi::Tensor::FromDLPack(at::toDLPack(ytorch)); tvm::ffi::Tuple grid{nRows / 1024, 1, 1}; DLDevice device = x.device(); - void* stream = - TVMFFIEnvGetStream(device.device_type, device.device_id); + void *stream = TVMFFIEnvGetStream(device.device_type, device.device_id); tvm::ffi::Array args = {y, x, xStride, yStride, nRows, nCols, BLOCK_SIZE}; tvm::ffi::Map kwargs = {}; - SOFTMAX_KERNEL_STUB(grid, stream, numWarps, numStages, args, kwargs); + SOFTMAX_KERNEL_STUB(grid, stream, args, kwargs); return y; } diff --git a/python/triton_tvm_ffi/jit.py b/python/triton_tvm_ffi/jit.py index ae95af0..cd6b115 100644 --- a/python/triton_tvm_ffi/jit.py +++ b/python/triton_tvm_ffi/jit.py @@ -41,8 +41,6 @@ class TVMFFIJITFunction(object): Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int] ], _: int, - num_warps: Optional[int], - num_stages: Optional[int], args: Sequence[Any], kwargs: Mapping[str, Any], ): @@ -50,10 +48,6 @@ class TVMFFIJITFunction(object): kwargs: Dict[str, Any] = { k: v for k, v in zip(self.signature, args) if v is not None } | {k: self.canonicalize(v) for k, v in kwargs.items()} - if num_warps is not None: - kwargs["num_warps"] = num_warps - if num_stages is not None: - kwargs["num_stages"] = num_stages kernel: CompiledKernel = self.fn[grid](*args, **kwargs) self.num_warps, _, self.shmem = kernel.packed_metadata self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()] diff --git a/python/triton_tvm_ffi/templates/gendef.cc.j2 b/python/triton_tvm_ffi/templates/gendef.cc.j2 index ffa09eb..cedda4b 100644 --- a/python/triton_tvm_ffi/templates/gendef.cc.j2 +++ b/python/triton_tvm_ffi/templates/gendef.cc.j2 @@ -35,7 +35,7 @@ static CUfunction __Get{{ fn.fnname }}Kernel() { return *function; } -#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __numWarps, __numStages, __args, __kwargs) do { \ +#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __args, __kwargs) do { \ const char *__signature[] = { "{{ fn.signature | join("\", \"") }}" }; \ tvm::ffi::Map __meta = { \ {% if fn.best_config != none %}