@@ -451,19 +451,18 @@ def main():
451
451
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
452
452
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
453
453
# initialized to zero.
454
- if accelerator .is_main_process :
455
- logger .info ("Initializing the InstructPix2Pix UNet from the pretrained UNet." )
456
- in_channels = 8
457
- out_channels = unet .conv_in .out_channels
458
- unet .register_to_config (in_channels = in_channels )
459
-
460
- with torch .no_grad ():
461
- new_conv_in = nn .Conv2d (
462
- in_channels , out_channels , unet .conv_in .kernel_size , unet .conv_in .stride , unet .conv_in .padding
463
- )
464
- new_conv_in .weight .zero_ ()
465
- new_conv_in .weight [:, :4 , :, :].copy_ (unet .conv_in .weight )
466
- unet .conv_in = new_conv_in
454
+ logger .info ("Initializing the InstructPix2Pix UNet from the pretrained UNet." )
455
+ in_channels = 8
456
+ out_channels = unet .conv_in .out_channels
457
+ unet .register_to_config (in_channels = in_channels )
458
+
459
+ with torch .no_grad ():
460
+ new_conv_in = nn .Conv2d (
461
+ in_channels , out_channels , unet .conv_in .kernel_size , unet .conv_in .stride , unet .conv_in .padding
462
+ )
463
+ new_conv_in .weight .zero_ ()
464
+ new_conv_in .weight [:, :4 , :, :].copy_ (unet .conv_in .weight )
465
+ unet .conv_in = new_conv_in
467
466
468
467
# Freeze vae and text_encoder
469
468
vae .requires_grad_ (False )
@@ -892,9 +891,12 @@ def collate_fn(examples):
892
891
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
893
892
ema_unet .store (unet .parameters ())
894
893
ema_unet .copy_to (unet .parameters ())
894
+ # The models need unwrapping because for compatibility in distributed training mode.
895
895
pipeline = StableDiffusionInstructPix2PixPipeline .from_pretrained (
896
896
args .pretrained_model_name_or_path ,
897
- unet = unet ,
897
+ unet = accelerator .unwrap_model (unet ),
898
+ text_encoder = accelerator .unwrap_model (text_encoder ),
899
+ vae = accelerator .unwrap_model (vae ),
898
900
revision = args .revision ,
899
901
torch_dtype = weight_dtype ,
900
902
)
@@ -904,7 +906,9 @@ def collate_fn(examples):
904
906
# run inference
905
907
original_image = download_image (args .val_image_url )
906
908
edited_images = []
907
- with torch .autocast (str (accelerator .device ), enabled = accelerator .mixed_precision == "fp16" ):
909
+ with torch .autocast (
910
+ str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16"
911
+ ):
908
912
for _ in range (args .num_validation_images ):
909
913
edited_images .append (
910
914
pipeline (
@@ -959,7 +963,7 @@ def collate_fn(examples):
959
963
if args .validation_prompt is not None :
960
964
edited_images = []
961
965
pipeline = pipeline .to (accelerator .device )
962
- with torch .autocast (str (accelerator .device )):
966
+ with torch .autocast (str (accelerator .device ). replace ( ":0" , "" ) ):
963
967
for _ in range (args .num_validation_images ):
964
968
edited_images .append (
965
969
pipeline (
0 commit comments