summaryrefslogtreecommitdiffstats
path: root/pkgs/development/python-modules/tensorflow/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/tensorflow/default.nix')
-rw-r--r--pkgs/development/python-modules/tensorflow/default.nix7
1 files changed, 6 insertions, 1 deletions
diff --git a/pkgs/development/python-modules/tensorflow/default.nix b/pkgs/development/python-modules/tensorflow/default.nix
index 2d5a302521b4..3a8eba3ba97f 100644
--- a/pkgs/development/python-modules/tensorflow/default.nix
+++ b/pkgs/development/python-modules/tensorflow/default.nix
@@ -17,7 +17,7 @@
# that in nix as well. It would make some things easier and less confusing, but
# it would also make the default tensorflow package unfree. See
# https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0
-, cudaSupport ? false, cudatoolkit ? null, cudnn ? null, nccl ? null
+, cudaSupport ? false, cudaPackages ? {}
, mklSupport ? false, mkl ? null
, tensorboardSupport ? true
# XLA without CUDA is broken
@@ -31,6 +31,10 @@
, Foundation, Security, cctools, llvmPackages_11
}:
+let
+ inherit (cudaPackages) cudatoolkit cudnn nccl;
+in
+
assert cudaSupport -> cudatoolkit != null
&& cudnn != null;
@@ -514,6 +518,7 @@ in buildPythonPackage {
# Regression test for #77626 removed because not more `tensorflow.contrib`.
passthru = {
+ inherit cudaPackages;
deps = bazel-build.deps;
libtensorflow = bazel-build.out;
};