support attention bwd

Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
This commit is contained in:
2026-02-10 17:01:28 +08:00
parent e41ec26329
commit 599957e156
6 changed files with 206 additions and 26 deletions

View File

@@ -47,8 +47,8 @@ class TVMFFIJITFunction(object):
):
args: Iterator[Any] = map(self.canonicalize, args)
kwargs: Dict[str, Any] = {
k: v for k, v in zip(self.signature, args) if v is not None
} | {k: self.canonicalize(v) for k, v in kwargs.items()}
k: self.canonicalize(v) for k, v in kwargs.items()
}
kernel: CompiledKernel = self.fn[grid](*args, **kwargs)
self.num_warps, _, self.shmem = kernel.packed_metadata
self.ctypes = [type_canonicalize(v) for v in kernel.src.signature.values()]