diff --git a/include/launch.h b/include/launch.h index 04b6e77..083f004 100644 --- a/include/launch.h +++ b/include/launch.h @@ -2,6 +2,7 @@ #define TRITON_TVM_FFI_LAUNCH_H_ #include "type.h" +#include #include namespace triton_tvm_ffi { @@ -13,13 +14,9 @@ public: TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default; TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default; void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, - uint64_t function, - tvm::ffi::Tuple kernelMetadata, - tvm::ffi::ObjectRef launchMetadata, - tvm::ffi::ObjectRef launchEnterHook, - tvm::ffi::ObjectRef launchExitHook, - tvm::ffi::ObjectRef globalScratchObject, - tvm::ffi::ObjectRef profileScratchObject, + uint64_t function, int32_t numWarps, int32_t numCtas, + int32_t sharedMemory, uint64_t globalScratch, + uint64_t profileScratch, const tvm::ffi::Array &kernelArgs) const; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl", TVMFFILauncherImplObj, tvm::ffi::Object); @@ -37,13 +34,9 @@ public: using tvm::ffi::ObjectRef::ObjectRef; using tvm::ffi::ObjectRef::operator=; void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, - uint64_t function, - tvm::ffi::Tuple kernelMetadata, - tvm::ffi::ObjectRef launchMetadata, - tvm::ffi::ObjectRef launchEnterHook, - tvm::ffi::ObjectRef launchExitHook, - tvm::ffi::ObjectRef globalScratchObject, - tvm::ffi::ObjectRef profileScratchObject, + uint64_t function, int32_t numWarps, int32_t numCtas, + int32_t sharedMemory, uint64_t globalScratch, + uint64_t profileScratch, const tvm::ffi::Array &kernelArgs) const; TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl, tvm::ffi::ObjectRef, diff --git a/python/triton_tvm_ffi/_ffi_api.py b/python/triton_tvm_ffi/_ffi_api.py index a872302..da4ff6b 100644 --- a/python/triton_tvm_ffi/_ffi_api.py +++ b/python/triton_tvm_ffi/_ffi_api.py @@ -37,7 +37,7 @@ class TVMFFILauncherImpl(_ffi_Object): if TYPE_CHECKING: @staticmethod def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ... - def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: tuple[int, int, int], _7: Object, _8: Object, _9: Object, _10: Object, _11: Object, _12: Sequence[Any], /) -> None: ... + def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: Sequence[Any], /) -> None: ... # fmt: on # tvm-ffi-stubgen(end) diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index c34eaab..1e224ae 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -2,7 +2,7 @@ from __future__ import annotations from functools import cached_property import os -from typing import Any, Callable, Final, List, Sequence, Type +from typing import Any, Final, List, Type import jinja2 from triton.backends.nvidia.driver import CudaDriver @@ -24,19 +24,16 @@ class TVMLauncher(object): self.profile_scratch_align: Final[int] = metadata.profile_scratch_align self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid self.launch_pdl: Final[bool] = metadata.launch_pdl - self.enable_jit: Final[bool] = ( - os.getenv("TRITON_TVM_FFI_ENABLE_JIT", None) is not None - ) - if self.enable_jit: - mod = tvm_ffi.cpp.load_inline( + if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}: + mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline( "launch", - cpp_sources=self.codegen, + cpp_sources=[self.codegen], extra_ldflags=["-Wl,--no-as-needed", "-lcuda"], extra_include_paths=[ f"{tvm_ffi.cpp.extension._find_cuda_home()}/include" ], ) - launch = mod.get_function("launch") + launch: tvm_ffi.Function = mod.get_function("launch") self.launch = launch else: self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl( @@ -50,10 +47,9 @@ class TVMLauncher(object): grid_z, stream, function, - kernel_metadata, - launch_metadata, - launch_enter_hook, - launch_exit_hook, + num_warps, + num_ctas, + shared_memory, global_scratch, profile_scratch, *args: self.impl.launch( @@ -62,10 +58,9 @@ class TVMLauncher(object): grid_z, stream, function, - kernel_metadata, - launch_metadata, - launch_enter_hook, - launch_exit_hook, + num_warps, + num_ctas, + shared_memory, global_scratch, profile_scratch, args, @@ -101,37 +96,36 @@ class TVMLauncher(object): self.profile_scratch_align, _allocation._profile_allocator, ) - assert not self.launch_cooperative_grid - assert not self.launch_pdl - if self.enable_jit: - (num_warps, num_ctas, shared_memory) = kernel_metadata - return self.launch( - gridX, - gridY, - gridZ, - stream, - function, - num_warps, - num_ctas, - shared_memory, - *args, - ) - else: - return self.launch( - gridX, - gridY, - gridZ, - stream, - function, - kernel_metadata, - launch_metadata, - launch_enter_hook, - launch_exit_hook, - global_scratch, - profile_scratch, - *args, - ) + def canonicalize(obj: Any) -> int: + if obj is None: + return 0 + elif isinstance(obj, int): + return obj + elif get_ptr := getattr(obj, "data_ptr", None): + return get_ptr() + else: + raise TypeError(f"cannot canonicalize object of type {type(obj)}") + + (num_warps, num_ctas, shared_memory) = kernel_metadata + if launch_enter_hook: + launch_enter_hook(launch_metadata) + ret = self.launch( + gridX, + gridY, + gridZ, + stream, + function, + num_warps, + num_ctas, + shared_memory, + canonicalize(global_scratch), + canonicalize(profile_scratch), + *args, + ) + if launch_exit_hook: + launch_exit_hook(launch_metadata) + return ret @cached_property def codegen(self) -> str: diff --git a/python/triton_tvm_ffi/templates/launch.c.j2 b/python/triton_tvm_ffi/templates/launch.c.j2 index 94209ca..c318749 100644 --- a/python/triton_tvm_ffi/templates/launch.c.j2 +++ b/python/triton_tvm_ffi/templates/launch.c.j2 @@ -14,10 +14,8 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in int32_t numWarps = args[5].v_int64; int32_t numCtas = args[6].v_int64; int32_t sharedMemory = args[7].v_int64; - // TODO: Implement the launch logic - CUdeviceptr globalScratch = 0; - // TODO: check `profileScratchObject` - CUdeviceptr profileScratch = 0; + uint64_t globalScratch = args[8].v_uint64; + uint64_t profileScratch = args[9].v_uint64; if (gridX * gridY * gridZ > 0) { CUlaunchAttribute launchAttr[4]; CUlaunchConfig config; @@ -51,9 +49,9 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in config.numAttrs = numAttrs; {% for type in signature %} {% if type == "void *" %} - {{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data; + {{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data; {% elif type == "int32_t" %} - {{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64; + {{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64; {% else %} assert(false, "unsupported type yet {{ type }}"); {% endif %} diff --git a/src/launch.cc b/src/launch.cc index b54015b..a525116 100644 --- a/src/launch.cc +++ b/src/launch.cc @@ -1,6 +1,6 @@ #include "launch.h" #include "macro.h" -#include +#include #include namespace triton_tvm_ffi { @@ -14,19 +14,11 @@ TVMFFILauncherImplObj::TVMFFILauncherImplObj( void TVMFFILauncherImplObj::Launch( int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, - uint64_t function, - tvm::ffi::Tuple kernelMetadata, - tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook, - tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject, - tvm::ffi::ObjectRef profileScratchObject, + uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory, + uint64_t globalScratch, uint64_t profileScratch, const tvm::ffi::Array &kernelArgs) const { CUstream cStream = reinterpret_cast(stream); CUfunction cFunction = reinterpret_cast(function); - auto [numWarps, numCtas, sharedMemory] = kernelMetadata; - // TODO: Implement the launch logic - CUdeviceptr globalScratch = 0; - // TODO: check `profileScratchObject` - CUdeviceptr profileScratch = 0; if (gridX * gridY * gridZ > 0) { CUlaunchAttribute launchAttr[4]; CUlaunchConfig config; @@ -41,8 +33,6 @@ void TVMFFILauncherImplObj::Launch( config.hStream = cStream; config.attrs = launchAttr; int32_t numAttrs = 0; - // TODO: check `launchPdl` - // TODO: check `launchCooperativeGrid` if (numCtas != 1) { CUlaunchAttribute clusterAttr; clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; @@ -109,7 +99,6 @@ void TVMFFILauncherImplObj::Launch( params[j + 1] = &profileScratch; CUDA_CHECK(cuLaunchKernelEx(&config, cFunction, params, nullptr)); } - // TODO: call `launchExitHook` } TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array signature, @@ -120,15 +109,11 @@ TVMFFILauncherImpl::TVMFFILauncherImpl(tvm::ffi::Array signature, void TVMFFILauncherImpl::Launch( int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, - uint64_t function, - tvm::ffi::Tuple kernelMetadata, - tvm::ffi::ObjectRef launchMetadata, tvm::ffi::ObjectRef launchEnterHook, - tvm::ffi::ObjectRef launchExitHook, tvm::ffi::ObjectRef globalScratchObject, - tvm::ffi::ObjectRef profileScratchObject, + uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory, + uint64_t globalScratch, uint64_t profileScratch, const tvm::ffi::Array &kernelArgs) const { - get()->Launch(gridX, gridY, gridZ, stream, function, kernelMetadata, - launchMetadata, launchEnterHook, launchExitHook, - globalScratchObject, profileScratchObject, kernelArgs); + get()->Launch(gridX, gridY, gridZ, stream, function, numWarps, numCtas, + sharedMemory, globalScratch, profileScratch, kernelArgs); } TVM_FFI_STATIC_INIT_BLOCK() {