Skip to content

Commit 5a7d35e

Browse files
authored
Fix InstructPix2Pix training in multi-GPU mode (#2978)
* fix: norm group test for UNet3D. * fix: unet rejig. * fix: unwrapping when running validation inputs. * unwrapping the unet too. * fix: device. * better unwrapping. * unwrapping before ema. * unwrapping.
1 parent 0c72006 commit 5a7d35e

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -451,19 +451,18 @@ def main():
451451
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
452452
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
453453
# initialized to zero.
454-
if accelerator.is_main_process:
455-
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
456-
in_channels = 8
457-
out_channels = unet.conv_in.out_channels
458-
unet.register_to_config(in_channels=in_channels)
459-
460-
with torch.no_grad():
461-
new_conv_in = nn.Conv2d(
462-
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
463-
)
464-
new_conv_in.weight.zero_()
465-
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
466-
unet.conv_in = new_conv_in
454+
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
455+
in_channels = 8
456+
out_channels = unet.conv_in.out_channels
457+
unet.register_to_config(in_channels=in_channels)
458+
459+
with torch.no_grad():
460+
new_conv_in = nn.Conv2d(
461+
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
462+
)
463+
new_conv_in.weight.zero_()
464+
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
465+
unet.conv_in = new_conv_in
467466

468467
# Freeze vae and text_encoder
469468
vae.requires_grad_(False)
@@ -892,9 +891,12 @@ def collate_fn(examples):
892891
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
893892
ema_unet.store(unet.parameters())
894893
ema_unet.copy_to(unet.parameters())
894+
# The models need unwrapping because for compatibility in distributed training mode.
895895
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
896896
args.pretrained_model_name_or_path,
897-
unet=unet,
897+
unet=accelerator.unwrap_model(unet),
898+
text_encoder=accelerator.unwrap_model(text_encoder),
899+
vae=accelerator.unwrap_model(vae),
898900
revision=args.revision,
899901
torch_dtype=weight_dtype,
900902
)
@@ -904,7 +906,9 @@ def collate_fn(examples):
904906
# run inference
905907
original_image = download_image(args.val_image_url)
906908
edited_images = []
907-
with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"):
909+
with torch.autocast(
910+
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
911+
):
908912
for _ in range(args.num_validation_images):
909913
edited_images.append(
910914
pipeline(
@@ -959,7 +963,7 @@ def collate_fn(examples):
959963
if args.validation_prompt is not None:
960964
edited_images = []
961965
pipeline = pipeline.to(accelerator.device)
962-
with torch.autocast(str(accelerator.device)):
966+
with torch.autocast(str(accelerator.device).replace(":0", "")):
963967
for _ in range(args.num_validation_images):
964968
edited_images.append(
965969
pipeline(

0 commit comments

Comments
 (0)