-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Fix InstructPix2Pix training in multi-GPU mode #2978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
As pointed out by @whbzju here, we also need to do unwrapping while running validation inference. So, I fixed that and did the testing with the following command: accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
--dataset_name=sayakpaul/instructpix2pix-1000-samples \
--use_ema \
--enable_xformers_memory_efficient_attention \
--resolution=512 --random_flip \
--train_batch_size=2 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=20 \
--checkpointing_steps=10 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
--validation_prompt="make the mountains snowy" \
--seed=42 \
--report_to=wandb @patrickvonplaten could you take a look again? |
9b6a6db
to
91854ff
Compare
* fix: norm group test for UNet3D. * fix: unet rejig. * fix: unwrapping when running validation inputs. * unwrapping the unet too. * fix: device. * better unwrapping. * unwrapping before ema. * unwrapping.
* fix: norm group test for UNet3D. * fix: unet rejig. * fix: unwrapping when running validation inputs. * unwrapping the unet too. * fix: device. * better unwrapping. * unwrapping before ema. * unwrapping.
* fix: norm group test for UNet3D. * fix: unet rejig. * fix: unwrapping when running validation inputs. * unwrapping the unet too. * fix: device. * better unwrapping. * unwrapping before ema. * unwrapping.
Should close #2966.
Command used for testing:
accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \ --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \ --dataset_name=sayakpaul/instructpix2pix-1000-samples \ --use_ema \ --enable_xformers_memory_efficient_attention \ --resolution=512 --random_flip \ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \ --max_train_steps=100 \ --checkpointing_steps=10 --checkpoints_total_limit=1 \ --learning_rate=5e-05 --lr_warmup_steps=0 \ --conditioning_dropout_prob=0.05 \ --mixed_precision=fp16 \ --seed=42