-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Examples] InstructPix2Pix instruct training script #2478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@patrickvonplaten I ran two experiments with the current state of the training script and I will post my analysis soon. |
@patrickvonplaten @patil-suraj so I ran another experiment with: accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
--dataset_name=sayakpaul/instructpix2pix-1000-samples \
--use_ema \
--enable_xformers_memory_efficient_attention \
--resolution=512 --random_flip \
--train_batch_size=2 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--mixed_precision=fp16 \
--val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
--validation_prompt="make the mountains snowy" \
--seed=42 \
--report_to=wandb The final trained model is here: https://huggingface.co/sayakpaul/instruct-pix2pix. I ran inference with this model: import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
model_id = "sayakpaul/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
generator = torch.Generator("cuda").manual_seed(0)
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"
def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
image = download_image(url)
prompt = "wipe out the lake"
for num_inference_steps in [20, 50, 100]:
for image_guidance_scale in [1.0, 1.5, 2.0]:
for guidance_scale in [5, 6, 7, 8, 10]:
edited_image = pipe(prompt,
image=image,
num_inference_steps=num_inference_steps,
image_guidance_scale=image_guidance_scale,
guidance_scale=guidance_scale,
generator=generator,
).images[0]
edited_image.save(
f"lake_inf_steps@{num_inference_steps}_imgs@{image_guidance_scale}_gs@{guidance_scale}.png"
) Here are the results:
Note that the sample I used for inference is actually present in the mini training dataset I used for the example script. You would notice that the results are not up to the mark. I suspect:
Both these aspects can be improved with more rigorous hyperparameter tuning. Regardless of those, I believe it would be great to review the script's implementation details if you have time. |
Sorry for the delay here - going through it now |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) | ||
unet = UNet2DConditionModel.from_pretrained( | ||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) | |
) | |
unet.register_to_config(in_channel=8) | |
conv_in = nn.Linear(..., bias=False) | |
conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) | |
unet.conv_in = conv_in |
Think just those three lines would be nicer than to create a whole new function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks more elegant and indeed cleaner. One cannot do conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
because in-place operations cannot be performed on leaf variables that require grads. So, if one puts the in-place ops within a torch.no_grad()
context it should work.
Full working snippet:
out_channels = unet.conv_in.out_channels
in_channels = 8
with torch.no_grad():
new_conv_in = nn.Conv2d(
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
unet.conv_in = new_conv_in
|
||
# Freeze vae and text_encoder | ||
vae.requires_grad_(False) | ||
text_encoder.requires_grad_(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
text encoder is never trained for instruct pix2pix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. But feel free to verify this from Section 3.2 of the original paper.
@williamberman I am facing something weird here. The final pipeline inference fails mid-way: File "train_instruct_pix2pix.py", line 1010, in <module>
main()
File "train_instruct_pix2pix.py", line 987, in main
pipeline(
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py", line 380, in __call__
image = self.decode_latents(latents)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py", line 643, in decode_latents
image = self.vae.decode(latents).sample
File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py", line 185, in decode
decoded = self._decode(z).sample
File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py", line 171, in _decode
z = self.post_quant_conv(z)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::Half) should be the same Here's a command to reproduce it quickly: accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5
--dataset_name=sayakpaul/instructpix2pix-1000-samples
--enable_xformers_memory_efficient_attention
--resolution=512 --random_flip --train_batch_size=2
--gradient_accumulation_steps=4 --gradient_checkpointing
--max_train_steps=2 --max_train_samples 10
--checkpointing_steps=1
--checkpoints_total_limit=1
--learning_rate=5e-05 --lr_warmup_steps=0
--conditioning_dropout_prob=0.05 --mixed_precision=fp16
--val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
--validation_prompt="make the mountains snowy"
--seed=42 --report_to=wandb Outputs of the print statements I have added for debugging: Pipeline device: cuda:0
UNet: cuda:0 Text Encoder: cuda:0 VAE: cuda:0
UNet: torch.float16 Text Encoder: torch.float16 VAE: torch.float16 When you get a chance, could you look into it once? |
f"UNet: {pipeline.unet.dtype} Text Encoder: {pipeline.text_encoder.dtype} VAE: {pipeline.vae.dtype}" | ||
) | ||
|
||
edited_images.append( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited_images.append( | |
edited_images.append( | |
with torch.autocast("cuda"): |
Should solve the problem. See reasoning here: #2568
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
But why it would be needed when the dtypes and devices are uniform across the sub-models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten, it still fails with the change (fbc626a).
Thanks for the in-detail error description @sayakpaul ! This issues should explain how to solve it: https://github.com/huggingface/diffusers/pull/2478/files#r1138923835 :-) |
@patrickvonplaten issue resolved. Here's what I had to do following #2631:
|
Ah interesting! Thanks for the pointer. Pr is good to merge for me :-) Nice job |
* add: initial implementation of the pix2pix instruct training script. * shorten cli arg. * fix: main process check. * fix: dataset column names. * simplify tokenization. * proper placement of null conditions. * apply styling. * remove debugging message for conditioning do. * complete license. * add: requirements.tzt * wandb column name order. * fix: augmentation. * change: dataset_id. * fix: convert_to_np() call. * fix: reshaping. * fix: final ema copy. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * address PR comments. * add: readme details. * config fix. * downgrade version. * reduce image width in the readme. * note on hyperparameters during generation. * add: output images. * update readme. * minor edits to readme. * debugging statement. * explicitly placement of the pipeline. * bump minimum diffusers version. * fix: device attribute error. * weight dtype. * debugging. * add dtype inform. * add seoarate te and vae. * add: explicit casting/ * remove casting. * up. * up 2. * up 3. * autocast. * disable mixed-precision in the final inference. * debugging information. * autocasting. * add: instructpix2pix training section to the docs. * Empty-Commit --------- Co-authored-by: Patrick von Platen <[email protected]>
* add: initial implementation of the pix2pix instruct training script. * shorten cli arg. * fix: main process check. * fix: dataset column names. * simplify tokenization. * proper placement of null conditions. * apply styling. * remove debugging message for conditioning do. * complete license. * add: requirements.tzt * wandb column name order. * fix: augmentation. * change: dataset_id. * fix: convert_to_np() call. * fix: reshaping. * fix: final ema copy. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * address PR comments. * add: readme details. * config fix. * downgrade version. * reduce image width in the readme. * note on hyperparameters during generation. * add: output images. * update readme. * minor edits to readme. * debugging statement. * explicitly placement of the pipeline. * bump minimum diffusers version. * fix: device attribute error. * weight dtype. * debugging. * add dtype inform. * add seoarate te and vae. * add: explicit casting/ * remove casting. * up. * up 2. * up 3. * autocast. * disable mixed-precision in the final inference. * debugging information. * autocasting. * add: instructpix2pix training section to the docs. * Empty-Commit --------- Co-authored-by: Patrick von Platen <[email protected]>
Closes #2288.
I opted for using WandB tables instead of just images, as I think tables here are better suited.
Here's how they would look like.
Command to launch an experiment (with
wandb
logging)Run ongoing here: https://wandb.ai/sayakpaul/instruct-pix2pix/runs/kjmrkjop
Testing machine
diffusers-cli env
diffusers
version: 0.14.0.dev0TODO
Run experiments to find reasonable hyperparameters(for now)