summaryrefslogtreecommitdiffstats
path: root/pkgs/top-level
diff options
context:
space:
mode:
authorAlexander Tsvyashchenko <ndl@endl.ch>2021-12-28 01:19:10 +0100
committerGitHub <noreply@github.com>2021-12-27 16:19:10 -0800
commitbe5272250926e352427b3c62c6066a95c6592375 (patch)
tree826d9be930dc2c701209d84eb6abbda59cff853c /pkgs/top-level
parent8efd318b108e44673cfcb0643ddd1fd224e25dc1 (diff)
python3Packages.jaxlib: refactor to support Nix-based builds (#151909)
* python3Packages.jaxlib: rename to `jaxlib-bin` Refactoring `jaxlib` to have a similar structure to `tensorflow` with the 'bin' and 'build' options. * python3Packages.jaxlib: init the 'build' variant at 0.1.75 Similar to `tensorflow-build`, now there's an option to build `jaxlib` using Nix-provided environment and dependencies. * python3Packages.jax: 0.2.24 -> 0.2.26 * Addressed review comments. * Fixed `cudaSupport` missing property on some arches. * Unified the versions of CUDA-related packages with TF. Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Diffstat (limited to 'pkgs/top-level')
-rw-r--r--pkgs/top-level/python-packages.nix28
1 files changed, 22 insertions, 6 deletions
diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix
index 1faf86289bc3..2383658c76a0 100644
--- a/pkgs/top-level/python-packages.nix
+++ b/pkgs/top-level/python-packages.nix
@@ -100,6 +100,12 @@ let
disabledIf = x: drv: if x then disabled drv else drv;
+ # CUDA-related packages that are compatible with the currently packaged version
+ # of TensorFlow, used to keep these versions in sync in related packages like `jaxlib`.
+ tensorflow_compat_cudatoolkit = pkgs.cudatoolkit_11_2;
+ tensorflow_compat_cudnn = pkgs.cudnn_cudatoolkit_11_2;
+ tensorflow_compat_nccl = pkgs.nccl_cudatoolkit_11;
+
in {
inherit pkgs stdenv;
@@ -4053,7 +4059,17 @@ in {
jax = callPackage ../development/python-modules/jax { };
- jaxlib = callPackage ../development/python-modules/jaxlib { };
+ jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { };
+
+ jaxlib-build = callPackage ../development/python-modules/jaxlib {
+ # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
+ cudaSupport = pkgs.config.cudaSupport or false;
+ cudatoolkit = tensorflow_compat_cudatoolkit;
+ cudnn = tensorflow_compat_cudnn;
+ nccl = tensorflow_compat_nccl;
+ };
+
+ jaxlib = self.jaxlib-build;
JayDeBeApi = callPackage ../development/python-modules/JayDeBeApi { };
@@ -9453,16 +9469,16 @@ in {
tensorflow-bin = callPackage ../development/python-modules/tensorflow/bin.nix {
cudaSupport = pkgs.config.cudaSupport or false;
- cudatoolkit = pkgs.cudatoolkit_11_2;
- cudnn = pkgs.cudnn_cudatoolkit_11_2;
+ cudatoolkit = tensorflow_compat_cudatoolkit;
+ cudnn = tensorflow_compat_cudnn;
};
tensorflow-build = callPackage ../development/python-modules/tensorflow {
inherit (pkgs.darwin) cctools;
cudaSupport = pkgs.config.cudaSupport or false;
- cudatoolkit = pkgs.cudatoolkit_11_2;
- cudnn = pkgs.cudnn_cudatoolkit_11_2;
- nccl = pkgs.nccl_cudatoolkit_11;
+ cudatoolkit = tensorflow_compat_cudatoolkit;
+ cudnn = tensorflow_compat_cudnn;
+ nccl = tensorflow_compat_nccl;
inherit (pkgs.darwin.apple_sdk.frameworks) Foundation Security;
flatbuffers-core = pkgs.flatbuffers;
flatbuffers-python = self.flatbuffers;