#ifndef TRITON_TVM_FFI_LAUNCH_H_ #define TRITON_TVM_FFI_LAUNCH_H_ #include "type.h" #include #include namespace triton_tvm_ffi { class TVMFFILauncherImplObj : public tvm::ffi::Object { public: TVMFFILauncherImplObj(const tvm::ffi::Array &signature, bool launchCooperativeGrid, bool launchAsync); TVMFFILauncherImplObj(const TVMFFILauncherImplObj &other) = default; TVMFFILauncherImplObj(TVMFFILauncherImplObj &&other) = default; void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory, uint64_t globalScratch, uint64_t profileScratch, const tvm::ffi::Array &kernelArgs) const; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("triton_tvm_ffi.TVMFFILauncherImpl", TVMFFILauncherImplObj, tvm::ffi::Object); private: tvm::ffi::Array signature_; const bool launchCooperativeGrid_; const bool launchAsync_; }; class TVMFFILauncherImpl : public tvm::ffi::ObjectRef { public: TVMFFILauncherImpl(tvm::ffi::Array signature, bool launchCooperativeGrid, bool launchAsync); using tvm::ffi::ObjectRef::ObjectRef; using tvm::ffi::ObjectRef::operator=; void Launch(int32_t gridX, int32_t gridY, int32_t gridZ, uint64_t stream, uint64_t function, int32_t numWarps, int32_t numCtas, int32_t sharedMemory, uint64_t globalScratch, uint64_t profileScratch, const tvm::ffi::Array &kernelArgs) const; TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TVMFFILauncherImpl, tvm::ffi::ObjectRef, TVMFFILauncherImplObj); }; } // namespace triton_tvm_ffi #endif