summaryrefslogtreecommitdiffstats
path: root/pkgs/development/cuda-modules/nccl/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/cuda-modules/nccl/default.nix')
-rw-r--r--pkgs/development/cuda-modules/nccl/default.nix30
1 files changed, 16 insertions, 14 deletions
diff --git a/pkgs/development/cuda-modules/nccl/default.nix b/pkgs/development/cuda-modules/nccl/default.nix
index 8043adae4d1e..dd767d2781f0 100644
--- a/pkgs/development/cuda-modules/nccl/default.nix
+++ b/pkgs/development/cuda-modules/nccl/default.nix
@@ -17,9 +17,10 @@ let
cuda_cccl
cuda_cudart
cuda_nvcc
+ cudaAtLeast
cudaFlags
+ cudaOlder
cudatoolkit
- cudaVersion
;
in
backendStdenv.mkDerivation (finalAttrs: {
@@ -33,6 +34,7 @@ backendStdenv.mkDerivation (finalAttrs: {
hash = "sha256-IF2tILwW8XnzSmfn7N1CO7jXL95gUp02guIW5n1eaig=";
};
+ __structuredAttrs = true;
strictDeps = true;
outputs = [
@@ -46,12 +48,12 @@ backendStdenv.mkDerivation (finalAttrs: {
autoAddDriverRunpath
python3
]
- ++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ]
- ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ cuda_nvcc ];
+ ++ lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
+ ++ lib.optionals (cudaAtLeast "11.4") [ cuda_nvcc ];
buildInputs =
- lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ]
- ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [
+ lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
+ ++ lib.optionals (cudaAtLeast "11.4") [
cuda_nvcc.dev # crt/host_config.h
cuda_cudart
]
@@ -59,25 +61,25 @@ backendStdenv.mkDerivation (finalAttrs: {
# against other version, like below, it's important that we use the same format. Otherwise,
# we'll get incorrect results.
# For example, lib.versionAtLeast "12.0" "12.0.0" == false.
- ++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [ cuda_cccl ];
+ ++ lib.optionals (cudaAtLeast "12.0") [ cuda_cccl ];
env.NIX_CFLAGS_COMPILE = toString [ "-Wno-unused-function" ];
- preConfigure = ''
+ postPatch = ''
patchShebangs ./src/device/generate.py
- makeFlagsArray+=(
- "NVCC_GENCODE=${lib.concatStringsSep " " cudaFlags.gencode}"
- )
'';
- makeFlags =
- [ "PREFIX=$(out)" ]
- ++ lib.optionals (lib.versionOlder cudaVersion "11.4") [
+ makeFlagsArray =
+ [
+ "PREFIX=$(out)"
+ "NVCC_GENCODE=${cudaFlags.gencodeString}"
+ ]
+ ++ lib.optionals (cudaOlder "11.4") [
"CUDA_HOME=${cudatoolkit}"
"CUDA_LIB=${lib.getLib cudatoolkit}/lib"
"CUDA_INC=${lib.getDev cudatoolkit}/include"
]
- ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [
+ ++ lib.optionals (cudaAtLeast "11.4") [
"CUDA_HOME=${cuda_nvcc}"
"CUDA_LIB=${lib.getLib cuda_cudart}/lib"
"CUDA_INC=${lib.getDev cuda_cudart}/include"