mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-05-02 03:52:11 +08:00
put num_warps and num_stages in kwargs
Signed-off-by: jinjieliu <jinjie.liu@usc.edu>
This commit is contained in:
@@ -41,8 +41,6 @@ class TVMFFIJITFunction(object):
|
||||
Callable[[Dict[str, Any]], Tuple[int, int, int]], Tuple[int, int, int]
|
||||
],
|
||||
_: int,
|
||||
num_warps: Optional[int],
|
||||
num_stages: Optional[int],
|
||||
args: Sequence[Any],
|
||||
kwargs: Mapping[str, Any],
|
||||
):
|
||||
@@ -50,10 +48,6 @@ class TVMFFIJITFunction(object):
|
||||
kwargs: Dict[str, Any] = {
|
||||
k: v for k, v in zip(self.signature, args) if v is not None
|
||||
} | {k: self.canonicalize(v) for k, v in kwargs.items()}
|
||||
if num_warps is not None:
|
||||
kwargs["num_warps"] = num_warps
|
||||
if num_stages is not None:
|
||||
kwargs["num_stages"] = num_stages
|
||||
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
|
||||
self.num_warps, _, self.shmem = kernel.packed_metadata
|
||||
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]
|
||||
|
||||
@@ -35,7 +35,7 @@ static CUfunction __Get{{ fn.fnname }}Kernel() {
|
||||
return *function;
|
||||
}
|
||||
|
||||
#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __numWarps, __numStages, __args, __kwargs) do { \
|
||||
#define {{ fn.fnname | upper }}_STUB(__grid, __stream, __args, __kwargs) do { \
|
||||
const char *__signature[] = { "{{ fn.signature | join("\", \"") }}" }; \
|
||||
tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any> __meta = { \
|
||||
{% if fn.best_config != none %}
|
||||
|
||||
Reference in New Issue
Block a user