Skip to content

Commit b86c4bd

Browse files
committed
use a symlinked toolkit instead to make cmake working
1 parent 92e4d29 commit b86c4bd

File tree

1 file changed

+18
-25
lines changed

1 file changed

+18
-25
lines changed

flake.nix

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@
3535
);
3636
pkgs = import nixpkgs { inherit system; };
3737
nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
38+
cudatoolkit_joined = with pkgs; symlinkJoin {
39+
# HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
40+
# see https://github.com/NixOS/nixpkgs/issues/224291
41+
# copied from jaxlib
42+
name = "${cudaPackages.cudatoolkit.name}-merged";
43+
paths = [
44+
cudaPackages.cudatoolkit.lib
45+
cudaPackages.cudatoolkit.out
46+
] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
47+
# for some reason some of the required libs are in the targets/x86_64-linux
48+
# directory; not sure why but this works around it
49+
"${cudaPackages.cudatoolkit}/targets/${system}"
50+
];
51+
};
3852
llama-python =
3953
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
4054
postPatch = ''
@@ -71,31 +85,10 @@
7185
};
7286
packages.cuda = pkgs.stdenv.mkDerivation {
7387
inherit name src meta postPatch nativeBuildInputs postInstall;
74-
buildInputs = with pkgs; buildInputs ++ [ cudaPackages.cudatoolkit ];
75-
76-
# HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit, so we force gnu make
77-
# see https://github.com/NixOS/nixpkgs/issues/224291
78-
dontUseCmakeConfigure = true;
79-
dontUseNinjaBuild = true;
80-
81-
buildFlags = [ "LLAMA_CUBLAS=1" ];
82-
installPhase = ''
83-
runHook preInstall
84-
85-
mkdir -p $out/bin
86-
87-
# TODO(Green-Sky): add install target to Makefile, or wait for cmake support
88-
mv main $out/bin/
89-
mv server $out/bin/
90-
mv speculative $out/bin/
91-
mv perplexity $out/bin/
92-
mv embedding $out/bin/
93-
mv quantize $out/bin/
94-
mv llama-bench $out/bin/
95-
mv train-text-from-scratch $out/bin/
96-
97-
runHook postInstall
98-
'';
88+
buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
89+
cmakeFlags = cmakeFlags ++ [
90+
"-DLLAMA_CUBLAS=ON"
91+
];
9992
};
10093
packages.rocm = pkgs.stdenv.mkDerivation {
10194
inherit name src meta postPatch nativeBuildInputs postInstall;

0 commit comments

Comments
 (0)