summaryrefslogtreecommitdiffstats
path: root/pkgs/development/cuda-modules/cutensor/extension.nix
blob: 5fdf356df916e47ebfa5be2b26514caa945e463b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Support matrix can be found at
# https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-880/support-matrix/index.html
#
# TODO(@connorbaker):
# This is a very similar strategy to CUDA/CUDNN:
#
# - Get all versions supported by the current release of CUDA
# - Build all of them
# - Make the newest the default
#
# Unique twists:
#
# - Instead of providing different releases for each version of CUDA, CuTensor has multiple subdirectories in `lib`
#   -- one for each version of CUDA.
{
  cudaVersion,
  flags,
  lib,
  mkVersionedPackageName,
  stdenv,
}:
let
  inherit (lib)
    attrsets
    lists
    modules
    versions
    strings
    trivial
    ;

  inherit (stdenv) hostPlatform;

  redistName = "cutensor";
  pname = "libcutensor";

  cutensorVersions = [
    "1.3.3"
    "1.4.0"
    "1.5.0"
    "1.6.2"
    "1.7.0"
  ];

  # Manifests :: { redistrib, feature }

  # Each release of cutensor gets mapped to an evaluated module for that release.
  # From there, we can get the min/max CUDA versions supported by that release.
  # listOfManifests :: List Manifests
  listOfManifests =
    let
      configEvaluator =
        fullCutensorVersion:
        modules.evalModules {
          modules = [
            ../modules
            # We need to nest the manifests in a config.cutensor.manifests attribute so the
            # module system can evaluate them.
            {
              cutensor.manifests = {
                redistrib = trivial.importJSON (./manifests + "/redistrib_${fullCutensorVersion}.json");
                feature = trivial.importJSON (./manifests + "/feature_${fullCutensorVersion}.json");
              };
            }
          ];
        };
      # Un-nest the manifests attribute set.
      releaseGrabber = evaluatedModules: evaluatedModules.config.cutensor.manifests;
    in
    lists.map (trivial.flip trivial.pipe [
      configEvaluator
      releaseGrabber
    ]) cutensorVersions;

  # Our cudaVersion tells us which version of CUDA we're building against.
  # The subdirectories in lib/ tell us which versions of CUDA are supported.
  # Typically the names will look like this:
  #
  # - 10.2
  # - 11
  # - 11.0
  # - 12

  # libPath :: String
  libPath =
    let
      cudaMajorMinor = versions.majorMinor cudaVersion;
      cudaMajor = versions.major cudaVersion;
    in
    if cudaMajorMinor == "10.2" then cudaMajorMinor else cudaMajor;

  # A release is supported if it has a libPath that matches our CUDA version for our platform.
  # LibPath are not constant across the same release -- one platform may support fewer
  # CUDA versions than another.
  # redistArch :: String
  redistArch = flags.getRedistArch hostPlatform.system;
  # platformIsSupported :: Manifests -> Boolean
  platformIsSupported =
    { feature, ... }:
    (attrsets.attrByPath [
      pname
      redistArch
    ] null feature) != null;

  # TODO(@connorbaker): With an auxilliary file keeping track of the CUDA versions each release supports,
  # we could filter out releases that don't support our CUDA version.
  # However, we don't have that currently, so we make a best-effort to try to build TensorRT with whatever
  # libPath corresponds to our CUDA version.
  # supportedManifests :: List Manifests
  supportedManifests = builtins.filter platformIsSupported listOfManifests;

  # Compute versioned attribute name to be used in this package set
  # Patch version changes should not break the build, so we only use major and minor
  # computeName :: RedistribRelease -> String
  computeName = { version, ... }: mkVersionedPackageName redistName version;
in
final: _:
let
  # buildCutensorPackage :: Manifests -> AttrSet Derivation
  buildCutensorPackage =
    { redistrib, feature }:
    let
      drv = final.callPackage ../generic-builders/manifest.nix {
        inherit pname redistName libPath;
        redistribRelease = redistrib.${pname};
        featureRelease = feature.${pname};
      };
      fixedDrv = drv.overrideAttrs (prevAttrs: {
        buildInputs =
          prevAttrs.buildInputs
          ++ lists.optionals (strings.versionOlder cudaVersion "11.4") [ final.cudatoolkit ]
          ++ lists.optionals (strings.versionAtLeast cudaVersion "11.4") (
            [ final.libcublas.lib ]
            # For some reason, the 1.4.x release of cuTENSOR requires the cudart library.
            ++ lists.optionals (strings.hasPrefix "1.4" redistrib.${pname}.version) [ final.cuda_cudart.lib ]
          );
        meta = prevAttrs.meta // {
          description = "cuTENSOR: A High-Performance CUDA Library For Tensor Primitives";
          homepage = "https://developer.nvidia.com/cutensor";
          maintainers = prevAttrs.meta.maintainers ++ [ lib.maintainers.obsidian-systems-maintenance ];
          license = lib.licenses.unfreeRedistributable // {
            shortName = "cuTENSOR EULA";
            name = "cuTENSOR SUPPLEMENT TO SOFTWARE LICENSE AGREEMENT FOR NVIDIA SOFTWARE DEVELOPMENT KITS";
            url = "https://docs.nvidia.com/cuda/cutensor/license.html";
          };
        };
      });
    in
    attrsets.nameValuePair (computeName redistrib.${pname}) fixedDrv;

  extension =
    let
      nameOfNewest = computeName (lists.last supportedManifests).redistrib.${pname};
      drvs = builtins.listToAttrs (lists.map buildCutensorPackage supportedManifests);
      containsDefault = attrsets.optionalAttrs (drvs != { }) { cutensor = drvs.${nameOfNewest}; };
    in
    drvs // containsDefault;
in
extension