nixpkgs/pkgs/development/python-modules/tinygrad/patch-cuda-paths.patch

39 lines
2 KiB
Diff

diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py
index 061c8d73f..65ec3ffd7 100644
--- a/tinygrad/runtime/autogen/cuda.py
+++ b/tinygrad/runtime/autogen/cuda.py
@@ -1,7 +1,7 @@
# mypy: ignore-errors
import ctypes
from tinygrad.runtime.support.c import DLL, Struct, CEnum, _IO, _IOW, _IOR, _IOWR
-dll = DLL('cuda', 'cuda')
+dll = DLL('cuda', '@driverLink@/lib/libcuda.so')
cuuint32_t = ctypes.c_uint32
cuuint64_t = ctypes.c_uint64
CUdeviceptr_v2 = ctypes.c_uint64
diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py
index 88085c45b..90518d403 100644
--- a/tinygrad/runtime/autogen/nvrtc.py
+++ b/tinygrad/runtime/autogen/nvrtc.py
@@ -2,7 +2,7 @@
import ctypes
from tinygrad.runtime.support.c import DLL, Struct, CEnum, _IO, _IOW, _IOR, _IOWR
import sysconfig
-dll = DLL('nvrtc', 'nvrtc', f'/usr/local/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib')
+dll = DLL('nvrtc','@cuda_nvrtc@/lib/libnvrtc.so')
nvrtcResult = CEnum(ctypes.c_uint32)
NVRTC_SUCCESS = nvrtcResult.define('NVRTC_SUCCESS', 0)
NVRTC_ERROR_OUT_OF_MEMORY = nvrtcResult.define('NVRTC_ERROR_OUT_OF_MEMORY', 1)
diff --git a/tinygrad/runtime/support/compiler_cuda.py b/tinygrad/runtime/support/compiler_cuda.py
index 8f71a9255..fdbf01bad 100644
--- a/tinygrad/runtime/support/compiler_cuda.py
+++ b/tinygrad/runtime/support/compiler_cuda.py
@@ -43,7 +43,7 @@ def cuda_disassemble(lib:bytes, arch:str):
class CUDACompiler(Compiler):
def __init__(self, arch:str, cache_key:str="cuda"):
self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}']
- self.compile_options += [f"-I{CUDA_PATH}/include"] if CUDA_PATH else ["-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include"]
+ self.compile_options += ["-I@cuda_cudart@/include"]
nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
super().__init__(f"compile_{cache_key}_{self.arch}")