diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 7635a7167..4beb11569 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -36,7 +36,7 @@ The specs on the benchmarking machines are: > **Prerequisites:** > > - Python minimum requirement >= 3.8 -> - CUDA 11.8 +> - CUDA 12.1 > - NVIDIA Driver version 535.104.05 To set up a virtual enviornment and install this repository @@ -68,12 +68,19 @@ To set up a virtual enviornment and install this repository For **PyTorch** + Note: the below command assumes you have CUDA 12.1 installed locally. + This is the default in the provided Docker image. + We recommend you match this CUDA version but if you decide to run + with a different local CUDA version, please find the appropriate wheel + url to pass to the `pip install` command for `pytorch`. + ```bash pip3 install -e '.[jax_cpu]' - pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html' + pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121' pip3 install -e '.[full]' ``` +
Per workload installations diff --git a/README.md b/README.md index 65bae4d54..6bd9f7acc 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ pip3 install -e '.[full]' ```bash pip3 install -e '.[jax_cpu]' -pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html' +pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121' pip3 install -e '.[full]' ``` diff --git a/docker/Dockerfile b/docker/Dockerfile index bc3b51649..9b72aea86 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" @@ -36,8 +36,6 @@ ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git RUN git clone $git_url && cd /algorithmic-efficiency RUN cd /algorithmic-efficiency && git checkout $branch -RUN cd /algorithmic-efficiency && pip install -e '.[full]' - RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ @@ -47,17 +45,19 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[jax_cpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ fi +RUN cd /algorithmic-efficiency && pip install -e '.[full]' + RUN cd /algorithmic-efficiency && pip install -e '.[wandb]' RUN cd /algorithmic-efficiency && git fetch origin diff --git a/docker/Singularity.def b/docker/Singularity.def index 5f5c31d60..d3ae3f186 100644 --- a/docker/Singularity.def +++ b/docker/Singularity.def @@ -51,12 +51,12 @@ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[jax_cpu]' \ -&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ +&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ -&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ +&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ diff --git a/setup.cfg b/setup.cfg index 89275d5e3..792f38420 100644 --- a/setup.cfg +++ b/setup.cfg @@ -123,15 +123,16 @@ jax_core_deps = # JAX CPU jax_cpu = - %(jax_core_deps)s jax==0.4.10 jaxlib==0.4.10 + %(jax_core_deps)s # JAX GPU +# Note this installs both jax and jaxlib. jax_gpu = + jax==0.4.10 + jaxlib==0.4.10+cuda12.cudnn88 %(jax_core_deps)s - jax[cuda]==0.4.10 - jaxlib==0.4.10+cuda11.cudnn86 # PyTorch CPU pytorch_cpu = @@ -139,9 +140,11 @@ pytorch_cpu = torchvision==0.16.0 # PyTorch GPU +# Note: omit the cuda suffix and installing from the appropriate +# wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.1.0+cu118 - torchvision==0.16.0+cu118 + torch==2.1.0 + torchvision==0.16.0 # wandb wandb =