diff options
Diffstat (limited to 'pkgs/development/cuda-modules/nccl/default.nix')
-rw-r--r-- | pkgs/development/cuda-modules/nccl/default.nix | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/pkgs/development/cuda-modules/nccl/default.nix b/pkgs/development/cuda-modules/nccl/default.nix index 8043adae4d1e..dd767d2781f0 100644 --- a/pkgs/development/cuda-modules/nccl/default.nix +++ b/pkgs/development/cuda-modules/nccl/default.nix @@ -17,9 +17,10 @@ let cuda_cccl cuda_cudart cuda_nvcc + cudaAtLeast cudaFlags + cudaOlder cudatoolkit - cudaVersion ; in backendStdenv.mkDerivation (finalAttrs: { @@ -33,6 +34,7 @@ backendStdenv.mkDerivation (finalAttrs: { hash = "sha256-IF2tILwW8XnzSmfn7N1CO7jXL95gUp02guIW5n1eaig="; }; + __structuredAttrs = true; strictDeps = true; outputs = [ @@ -46,12 +48,12 @@ backendStdenv.mkDerivation (finalAttrs: { autoAddDriverRunpath python3 ] - ++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ] - ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ cuda_nvcc ]; + ++ lib.optionals (cudaOlder "11.4") [ cudatoolkit ] + ++ lib.optionals (cudaAtLeast "11.4") [ cuda_nvcc ]; buildInputs = - lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ] - ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ + lib.optionals (cudaOlder "11.4") [ cudatoolkit ] + ++ lib.optionals (cudaAtLeast "11.4") [ cuda_nvcc.dev # crt/host_config.h cuda_cudart ] @@ -59,25 +61,25 @@ backendStdenv.mkDerivation (finalAttrs: { # against other version, like below, it's important that we use the same format. Otherwise, # we'll get incorrect results. # For example, lib.versionAtLeast "12.0" "12.0.0" == false. - ++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [ cuda_cccl ]; + ++ lib.optionals (cudaAtLeast "12.0") [ cuda_cccl ]; env.NIX_CFLAGS_COMPILE = toString [ "-Wno-unused-function" ]; - preConfigure = '' + postPatch = '' patchShebangs ./src/device/generate.py - makeFlagsArray+=( - "NVCC_GENCODE=${lib.concatStringsSep " " cudaFlags.gencode}" - ) ''; - makeFlags = - [ "PREFIX=$(out)" ] - ++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ + makeFlagsArray = + [ + "PREFIX=$(out)" + "NVCC_GENCODE=${cudaFlags.gencodeString}" + ] + ++ lib.optionals (cudaOlder "11.4") [ "CUDA_HOME=${cudatoolkit}" "CUDA_LIB=${lib.getLib cudatoolkit}/lib" "CUDA_INC=${lib.getDev cudatoolkit}/include" ] - ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ + ++ lib.optionals (cudaAtLeast "11.4") [ "CUDA_HOME=${cuda_nvcc}" "CUDA_LIB=${lib.getLib cuda_cudart}/lib" "CUDA_INC=${lib.getDev cuda_cudart}/include" |