diff --git a/python/triton_tvm_ffi/wrap.py b/python/triton_tvm_ffi/wrap.py index c5cda5c..f0e8f6c 100644 --- a/python/triton_tvm_ffi/wrap.py +++ b/python/triton_tvm_ffi/wrap.py @@ -155,22 +155,28 @@ def torch_wrap( extra_ldflags: Optional[Sequence[str]] = None, extra_include_paths: Optional[Sequence[Union[str, Path]]] = None, ) -> TVMFFIWrapperFunction: + cuda_home: str = tvm_ffi.cpp.extension._find_cuda_home() return wrap( name, fns, code, extra_ldflags=[ "-Wl,--no-as-needed", + f"-L{cuda_home}/lib64", *map( lambda path: f"-L{path}", torch.utils.cpp_extension.library_paths(), ), + "-lcuda", "-lc10", "-ltorch", ] + (extra_ldflags or []), extra_cflags=extra_cflags, extra_cuda_cflags=extra_cuda_cflags, - extra_include_paths=[*torch.utils.cpp_extension.include_paths()] + extra_include_paths=[ + f"{cuda_home}/include", + *torch.utils.cpp_extension.include_paths(), + ] + (extra_include_paths or []), )