Skip to content

[Examples] InstructPix2Pix instruct training script #2478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Mar 23, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 24, 2023

Closes #2288.

I opted for using WandB tables instead of just images, as I think tables here are better suited.

Here's how they would look like.

Command to launch an experiment (with wandb logging)

accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
 --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
 --dataset_name=sayakpaul/instructpix2pix-1000-samples \
 --use_ema \
 --enable_xformers_memory_efficient_attention \
 --resolution=256 --random_flip \
 --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
 --max_train_steps=15000 \
 --checkpointing_steps=5000 --checkpoints_total_limit=1 \
 --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
 --conditioning_dropout_prob=0.05 \
 --mixed_precision=fp16 \
 --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
 --validation_prompt="make the mountains snowy" \
 --seed=42 \
 --report_to=wandb 

Run ongoing here: https://wandb.ai/sayakpaul/instruct-pix2pix/runs/kjmrkjop

Testing machine

diffusers-cli env

  • diffusers version: 0.14.0.dev0
  • Platform: Linux-4.19.0-23-cloud-amd64-x86_64-with-glibc2.10
  • Python version: 3.8.16
  • PyTorch version (GPU?): 1.13.1+cu116 (True)
  • Huggingface_hub version: 0.12.0
  • Transformers version: 4.26.1
  • Accelerate version: 0.16.0
  • xFormers version: 0.0.16
  • Using GPU in script?: Yes (A100 40 GB)
  • Using distributed or parallel set-up in script?: No

TODO

  • Get the script to work
  • Run experiments to find reasonable hyperparameters (for now)
  • Add requirements.txt
  • Fill in the README
  • Add an entry to the training section in the docs

@sayakpaul sayakpaul changed the title add: initial implementation of the pix2pix instruct training script. [Examples] InstructPix2Pix instruct training script Feb 24, 2023
@sayakpaul sayakpaul self-assigned this Feb 24, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 24, 2023

The documentation is not available anymore as the PR was closed or merged.

@sayakpaul
Copy link
Member Author

@patrickvonplaten I ran two experiments with the current state of the training script and I will post my analysis soon.

@sayakpaul
Copy link
Member Author

sayakpaul commented Mar 1, 2023

@patrickvonplaten @patil-suraj so I ran another experiment with:

accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
--dataset_name=sayakpaul/instructpix2pix-1000-samples \
--use_ema \
--enable_xformers_memory_efficient_attention \
--resolution=512 --random_flip \
--train_batch_size=2 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
--validation_prompt="make the mountains snowy" \
--seed=42 \
--report_to=wandb 

The final trained model is here: https://huggingface.co/sayakpaul/instruct-pix2pix.

I ran inference with this model:

import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline

model_id = "sayakpaul/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
generator = torch.Generator("cuda").manual_seed(0)

url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"


def download_image(url):
    image = PIL.Image.open(requests.get(url, stream=True).raw)
    image = PIL.ImageOps.exif_transpose(image)
    image = image.convert("RGB")
    return image


image = download_image(url)

prompt = "wipe out the lake"

for num_inference_steps in [20, 50, 100]:
    for image_guidance_scale in [1.0, 1.5, 2.0]:
        for guidance_scale in [5, 6, 7, 8, 10]:
            edited_image = pipe(prompt, 
                image=image, 
                num_inference_steps=num_inference_steps, 
                image_guidance_scale=image_guidance_scale, 
                guidance_scale=guidance_scale,
                generator=generator,
            ).images[0]
            edited_image.save(
                f"lake_inf_steps@{num_inference_steps}_imgs@{image_guidance_scale}_gs@{guidance_scale}.png"
            )

Here are the results:
inference_images.zip

inf_steps denotes num_inference_steps, img denotes image_guidance_scale, gs denotes guidance_scale.

Note that the sample I used for inference is actually present in the mini training dataset I used for the example script.

You would notice that the results are not up to the mark. I suspect:

  • The network is still under-trained.
  • Lack of training instances.

Both these aspects can be improved with more rigorous hyperparameter tuning. Regardless of those, I believe it would be great to review the script's implementation details if you have time.

@patrickvonplaten
Copy link
Contributor

Sorry for the delay here - going through it now

vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
)
)
unet.register_to_config(in_channel=8)
conv_in = nn.Linear(..., bias=False)
conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
unet.conv_in = conv_in

Think just those three lines would be nicer than to create a whole new function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks more elegant and indeed cleaner. One cannot do conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) because in-place operations cannot be performed on leaf variables that require grads. So, if one puts the in-place ops within a torch.no_grad() context it should work.

Full working snippet:

out_channels = unet.conv_in.out_channels
in_channels = 8

with torch.no_grad():
    new_conv_in = nn.Conv2d(
        in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
    )
    new_conv_in.weight.zero_()
    new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
    unet.conv_in = new_conv_in


# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text encoder is never trained for instruct pix2pix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. But feel free to verify this from Section 3.2 of the original paper.

@sayakpaul
Copy link
Member Author

@williamberman I am facing something weird here.

The final pipeline inference fails mid-way:

  File "train_instruct_pix2pix.py", line 1010, in <module>
    main()
  File "train_instruct_pix2pix.py", line 987, in main
    pipeline(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py", line 380, in __call__
    image = self.decode_latents(latents)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py", line 643, in decode_latents
    image = self.vae.decode(latents).sample
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py", line 185, in decode
    decoded = self._decode(z).sample
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py", line 171, in _decode
    z = self.post_quant_conv(z)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::Half) should be the same

Here's a command to reproduce it quickly:

accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py 
  --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 
  --dataset_name=sayakpaul/instructpix2pix-1000-samples 
  --enable_xformers_memory_efficient_attention 
  --resolution=512 --random_flip --train_batch_size=2 
  --gradient_accumulation_steps=4 --gradient_checkpointing 
  --max_train_steps=2 --max_train_samples 10 
  --checkpointing_steps=1 
  --checkpoints_total_limit=1 
  --learning_rate=5e-05 --lr_warmup_steps=0 
  --conditioning_dropout_prob=0.05 --mixed_precision=fp16 
  --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" 
  --validation_prompt="make the mountains snowy" 
  --seed=42 --report_to=wandb

Outputs of the print statements I have added for debugging:

Pipeline device: cuda:0
UNet: cuda:0 Text Encoder: cuda:0 VAE: cuda:0
UNet: torch.float16 Text Encoder: torch.float16 VAE: torch.float16

When you get a chance, could you look into it once?

Cc: @patrickvonplaten

f"UNet: {pipeline.unet.dtype} Text Encoder: {pipeline.text_encoder.dtype} VAE: {pipeline.vae.dtype}"
)

edited_images.append(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
edited_images.append(
edited_images.append(
with torch.autocast("cuda"):

Should solve the problem. See reasoning here: #2568

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

But why it would be needed when the dtypes and devices are uniform across the sub-models?

Copy link
Member Author

@sayakpaul sayakpaul Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten, it still fails with the change (fbc626a).

@patrickvonplaten
Copy link
Contributor

Thanks for the in-detail error description @sayakpaul ! This issues should explain how to solve it: https://github.com/huggingface/diffusers/pull/2478/files#r1138923835 :-)

@sayakpaul
Copy link
Member Author

@patrickvonplaten issue resolved. Here's what I had to do following #2631:

  • I unwrapped the vae and text_encoder along with unet.
  • I wrapped the final inference block within autocast as you suggested. We should be good now :)

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten issue resolved. Here's what I had to do following #2631:

  • I unwrapped the vae and text_encoder along with unet.
  • I wrapped the final inference block within autocast as you suggested. We should be good now :)

Ah interesting! Thanks for the pointer. Pr is good to merge for me :-) Nice job

@sayakpaul sayakpaul merged commit 9dc8444 into main Mar 23, 2023
@sayakpaul sayakpaul deleted the training/instruct-pix2pix branch March 23, 2023 04:45
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* add: initial implementation of the pix2pix instruct training script.

* shorten cli arg.

* fix: main process check.

* fix: dataset column names.

* simplify tokenization.

* proper placement of null conditions.

* apply styling.

* remove debugging message for conditioning do.

* complete license.

* add: requirements.tzt

* wandb column name order.

* fix: augmentation.

* change: dataset_id.

* fix: convert_to_np() call.

* fix: reshaping.

* fix: final ema copy.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* address PR comments.

* add: readme details.

* config fix.

* downgrade version.

* reduce image width in the readme.

* note on hyperparameters during generation.

* add: output images.

* update readme.

* minor edits to readme.

* debugging statement.

* explicitly placement of the pipeline.

* bump minimum diffusers version.

* fix: device attribute error.

* weight dtype.

* debugging.

* add dtype inform.

* add seoarate te and vae.

* add: explicit casting/

* remove casting.

* up.

* up 2.

* up 3.

* autocast.

* disable mixed-precision in the final inference.

* debugging information.

* autocasting.

* add: instructpix2pix training section to the docs.

* Empty-Commit

---------

Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add: initial implementation of the pix2pix instruct training script.

* shorten cli arg.

* fix: main process check.

* fix: dataset column names.

* simplify tokenization.

* proper placement of null conditions.

* apply styling.

* remove debugging message for conditioning do.

* complete license.

* add: requirements.tzt

* wandb column name order.

* fix: augmentation.

* change: dataset_id.

* fix: convert_to_np() call.

* fix: reshaping.

* fix: final ema copy.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* address PR comments.

* add: readme details.

* config fix.

* downgrade version.

* reduce image width in the readme.

* note on hyperparameters during generation.

* add: output images.

* update readme.

* minor edits to readme.

* debugging statement.

* explicitly placement of the pipeline.

* bump minimum diffusers version.

* fix: device attribute error.

* weight dtype.

* debugging.

* add dtype inform.

* add seoarate te and vae.

* add: explicit casting/

* remove casting.

* up.

* up 2.

* up 3.

* autocast.

* disable mixed-precision in the final inference.

* debugging information.

* autocasting.

* add: instructpix2pix training section to the docs.

* Empty-Commit

---------

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Examples] Implement a training script for InstructPix2Pix
5 participants