Skip to content

train_text_to_image.py multi_gpu training cuda out of memory error but sufficient memory when using single GPU #3382

@TrueWheelProgramming

Description

@TrueWheelProgramming

Describe the bug

When using the train_text_to_image.py example script on a single NVIDIA A10G GPU the example script works great. However, when using 4xNVIDIA A10G if I use the same input arguments but use the --multi_gpu accelerate flag all 4 of the GPUs run out of memory before the first step is complete.

In the single GPU case I can train with batch size of 1 and resolution of 512.

In multi_gpu case even a batch size of 1 and resolution of 64 results in a cuda out of memory error.

Is there any reason why multi-gpu will use significantly more memory?

Is there an issue with my (i) accelerate config or (ii) script arguments?

Thanks in advance.

Reproduction

Accelerate Config

compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Command

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --use_ema \
  --resolution=64 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="sd-pokemon-model"

Logs

RuntimeError: CUDA out of memory. Tried to allocate 30.00 MiB (GPU 2; 22.20 GiB total capacity; 20.13 GiB already allocated; 26.06 MiB free; 20.29 GiB reserved in total by 
PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and 
PYTORCH_CUDA_ALLOC_CONF
Steps:   0%|                                                                                                  | 1/15000 [00:06<26:26:53,  6.35s/it, lr=1e-5, step_loss=0.44]
[10:27:31] ERROR    failed (exitcode: 1) local_rank: 0 (pid: 75374) of binary: /opt/conda/bin/python3.9                                                           api.py:671
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/ubuntu/.local/bin/accelerate:8 in <module>                                                 │
│                                                                                                  │
│   5 from accelerate.commands.accelerate_cli import main                                          │
│   6 if __name__ == '__main__':                                                                   │
│   7 │   sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])                         │
│ ❱ 8 │   sys.exit(main())                                                                         │
│   9                                                                                              │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.9/site-packages/accelerate/commands/accelerate_cli.py:45 in main │
│                                                                                                  │
│   42 │   │   exit(1)                                                                             │
│   43 │                                                                                           │
│   44 │   # Run                                                                                   │
│ ❱ 45 │   args.func(args)                                                                         │
│   46                                                                                             │
│   47                                                                                             │
│   48 if __name__ == "__main__":                                                                  │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.9/site-packages/accelerate/commands/launch.py:909 in             │
│ launch_command                                                                                   │
│                                                                                                  │
│   906 │   elif args.use_megatron_lm and not args.cpu:                                            │
│   907 │   │   multi_gpu_launcher(args)                                                           │
│   908 │   elif args.multi_gpu and not args.cpu:                                                  │
│ ❱ 909 │   │   multi_gpu_launcher(args)                                                           │
│   910 │   elif args.tpu and not args.cpu:                                                        │
│   911 │   │   if args.tpu_use_cluster:                                                           │
│   912 │   │   │   tpu_pod_launcher(args)                                                         │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.9/site-packages/accelerate/commands/launch.py:604 in             │
│ multi_gpu_launcher                                                                               │
│                                                                                                  │
│   601 │   )                                                                                      │
│   602 │   with patch_environment(**current_env):                                                 │
│   603 │   │   try:                                                                               │
│ ❱ 604 │   │   │   distrib_run.run(args)                                                          │
│   605 │   │   except Exception:                                                                  │
│   606 │   │   │   if is_rich_available() and debug:                                              │
│   607 │   │   │   │   console = get_console()                                                    │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.9/site-packages/torch/distributed/run.py:752 in run              │
│                                                                                                  │
│   749 │   │   )                                                                                  │
│   750 │                                                                                          │
│   751 │   config, cmd, cmd_args = config_from_args(args)                                         │
│ ❱ 752 │   elastic_launch(                                                                        │
│   753 │   │   config=config,                                                                     │
│   754 │   │   entrypoint=cmd,                                                                    │
│   755 │   )(*cmd_args)                                                                           │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.9/site-packages/torch/distributed/launcher/api.py:131 in         │
│ __call__                                                                                         │
│                                                                                                  │
│   128 │   │   self._entrypoint = entrypoint                                                      │
│   129 │                                                                                          │
│   130 │   def __call__(self, *args):                                                             │
│ ❱ 131 │   │   return launch_agent(self._config, self._entrypoint, list(args))                    │
│   132                                                                                            │
│   133                                                                                            │
│   134 def _get_entrypoint_name(                                                                  │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.9/site-packages/torch/distributed/launcher/api.py:245 in         │
│ launch_agent                                                                                     │
│                                                                                                  │
│   242 │   │   │   # if the error files for the failed children exist                             │
│   243 │   │   │   # @record will copy the first error (root cause)                               │
│   244 │   │   │   # to the error file of the launcher process.                                   │
│ ❱ 245 │   │   │   raise ChildFailedError(                                                        │
│   246 │   │   │   │   name=entrypoint_name,                                                      │
│   247 │   │   │   │   failures=result.failures,                                                  │
│   248 │   │   │   )                                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ChildFailedError: 
============================================================
train_text_to_image.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-05-10_10:27:31
  host      : ip-10-0-6-163.eu-west-1.compute.internal
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 75375)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2023-05-10_10:27:31
  host      : ip-10-0-6-163.eu-west-1.compute.internal
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 75376)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2023-05-10_10:27:31
  host      : ip-10-0-6-163.eu-west-1.compute.internal
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 75377)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-05-10_10:27:31
  host      : ip-10-0-6-163.eu-west-1.compute.internal
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 75374)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

System Info

  • diffusers version: 0.17.0.dev0
  • Platform: Linux-5.15.0-1015-aws-x86_64-with-glibc2.31
  • Python version: 3.9.4
  • PyTorch version (GPU?): 1.12.1+cu116 (True)
  • Huggingface_hub version: 0.14.1
  • Transformers version: 4.26.1
  • Accelerate version: 0.19.0
  • xFormers version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions