diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index 26c011d..eea20b6 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -12,9 +12,36 @@ class TVMFFIUtils(object): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + from triton.backends.nvidia.driver import CudaUtils + + self._utils: CudaUtils = CudaUtils() + + def load_binary(self, *args, **kwargs): + return self._utils.load_binary(*args, **kwargs) + + def get_device_properties(self, *args, **kwargs): + return self._utils.get_device_properties(*args, **kwargs) + + def cuOccupancyMaxActiveClusters(self, *args, **kwargs): + return self._utils.cuOccupancyMaxActiveClusters(*args, **kwargs) + + def set_printf_fifo_size(self, *args, **kwargs): + return self._utils.set_printf_fifo_size(*args, **kwargs) + + def fill_tma_descriptor(self, *args, **kwargs): + return self._utils.fill_tma_descriptor(*args, **kwargs) + + def launch(self, *args, **kwargs): + return self._utils.launch(*args, **kwargs) + + def build_signature_metadata(self, *args, **kwargs): + return self._utils.build_signature_metadata(*args, **kwargs) -class TVMFFIDriver(CudaDriver): ... +class TVMFFIDriver(CudaDriver): + def __init__(self, *args, **kwargs) -> TVMFFIDriver: + super().__init__(*args, **kwargs) + self.utils: TVMFFIUtils = TVMFFIUtils() del CudaDriver