4848 DDPMScheduler ,
4949 DiffusionPipeline ,
5050 DPMSolverMultistepScheduler ,
51+ StableDiffusionPipeline ,
5152 UNet2DConditionModel ,
5253)
5354from diffusers .optimization import get_scheduler
@@ -111,18 +112,15 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None)
111112 f .write (yaml + model_card )
112113
113114
114- def log_validation (
115- text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch
116- ):
115+ def log_validation (text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch ):
117116 logger .info (
118117 f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
119118 f" { args .validation_prompt } ."
120119 )
121- # create pipeline (note: unet and vae are loaded again in float32)
122120 pipeline = DiffusionPipeline .from_pretrained (
123121 args .pretrained_model_name_or_path ,
124122 text_encoder = accelerator .unwrap_model (text_encoder_1 ),
125- text_encoder_2 = accelerator . unwrap_model ( text_encoder_2 ) ,
123+ text_encoder_2 = text_encoder_2 ,
126124 tokenizer = tokenizer_1 ,
127125 tokenizer_2 = tokenizer_2 ,
128126 unet = unet ,
@@ -361,7 +359,7 @@ def parse_args():
361359 parser .add_argument (
362360 "--validation_prompt" ,
363361 type = str ,
364- default = None ,
362+ default = "A <cat-toy> backpack" ,
365363 help = "A prompt that is used during validation to verify that the model is learning." ,
366364 )
367365 parser .add_argument (
@@ -380,16 +378,6 @@ def parse_args():
380378 " and logging the images."
381379 ),
382380 )
383- parser .add_argument (
384- "--validation_epochs" ,
385- type = int ,
386- default = None ,
387- help = (
388- "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
389- " `args.validation_prompt` multiple times: `args.num_validation_images`"
390- " and logging the images."
391- ),
392- )
393381 parser .add_argument ("--local_rank" , type = int , default = - 1 , help = "For distributed training: local_rank" )
394382 parser .add_argument (
395383 "--checkpointing_steps" ,
@@ -418,11 +406,6 @@ def parse_args():
418406 parser .add_argument (
419407 "--enable_xformers_memory_efficient_attention" , action = "store_true" , help = "Whether or not to use xformers."
420408 )
421- parser .add_argument (
422- "--no_safe_serialization" ,
423- action = "store_true" ,
424- help = "If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead." ,
425- )
426409
427410 args = parser .parse_args ()
428411 env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
@@ -529,6 +512,7 @@ def __init__(
529512
530513 self .templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
531514 self .flip_transform = transforms .RandomHorizontalFlip (p = self .flip_p )
515+ self .crop = transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size )
532516
533517 def __len__ (self ):
534518 return self ._length
@@ -543,6 +527,18 @@ def __getitem__(self, i):
543527 placeholder_string = self .placeholder_token
544528 text = random .choice (self .templates ).format (placeholder_string )
545529
530+ example ["original_size" ] = (image .height , image .width )
531+
532+ if self .center_crop :
533+ y1 = max (0 , int (round ((image .height - self .size ) / 2.0 )))
534+ x1 = max (0 , int (round ((image .width - self .size ) / 2.0 )))
535+ image = self .crop (image )
536+ else :
537+ y1 , x1 , h , w = self .crop .get_params (image , (self .size , self .size ))
538+ image = transforms .functional .crop (image , y1 , x1 , h , w )
539+
540+ example ["crop_top_left" ] = (y1 , x1 )
541+
546542 example ["input_ids_1" ] = self .tokenizer_1 (
547543 text ,
548544 padding = "max_length" ,
@@ -564,13 +560,7 @@ def __getitem__(self, i):
564560
565561 if self .center_crop :
566562 crop = min (img .shape [0 ], img .shape [1 ])
567- (
568- h ,
569- w ,
570- ) = (
571- img .shape [0 ],
572- img .shape [1 ],
573- )
563+ (h , w ,) = (img .shape [0 ], img .shape [1 ],)
574564 img = img [(h - crop ) // 2 : (h + crop ) // 2 , (w - crop ) // 2 : (w + crop ) // 2 ]
575565
576566 image = Image .fromarray (img )
@@ -646,6 +636,7 @@ def main():
646636 args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , variant = args .variant
647637 )
648638
639+
649640 # Add the placeholder token in tokenizer_1
650641 placeholder_tokens = [args .placeholder_token ]
651642
@@ -686,21 +677,14 @@ def main():
686677 # Freeze vae and unet
687678 vae .requires_grad_ (False )
688679 unet .requires_grad_ (False )
680+ text_encoder_2 .requires_grad_ (False )
689681 # Freeze all parameters except for the token embeddings in text encoder
690682 text_encoder_1 .text_model .encoder .requires_grad_ (False )
691683 text_encoder_1 .text_model .final_layer_norm .requires_grad_ (False )
692684 text_encoder_1 .text_model .embeddings .position_embedding .requires_grad_ (False )
693- text_encoder_2 .text_model .encoder .requires_grad_ (False )
694- text_encoder_2 .text_model .final_layer_norm .requires_grad_ (False )
695- text_encoder_2 .text_model .embeddings .position_embedding .requires_grad_ (False )
696685
697686 if args .gradient_checkpointing :
698- # Keep unet in train mode if we are using gradient checkpointing to save memory.
699- # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
700- unet .train ()
701687 text_encoder_1 .gradient_checkpointing_enable ()
702- text_encoder_2 .gradient_checkpointing_enable ()
703- unet .enable_gradient_checkpointing ()
704688
705689 if args .enable_xformers_memory_efficient_attention :
706690 if is_xformers_available ():
@@ -749,15 +733,6 @@ def main():
749733 train_dataloader = torch .utils .data .DataLoader (
750734 train_dataset , batch_size = args .train_batch_size , shuffle = True , num_workers = args .dataloader_num_workers
751735 )
752- if args .validation_epochs is not None :
753- warnings .warn (
754- f"FutureWarning: You are doing logging with validation_epochs={ args .validation_epochs } ."
755- " Deprecated validation_epochs in favor of `validation_steps`"
756- f"Setting `args.validation_steps` to { args .validation_epochs * len (train_dataset )} " ,
757- FutureWarning ,
758- stacklevel = 2 ,
759- )
760- args .validation_steps = args .validation_epochs * len (train_dataset )
761736
762737 # Scheduler and math around the number of training steps.
763738 overrode_max_train_steps = False
@@ -791,7 +766,7 @@ def main():
791766 # Move vae and unet and text_encoder_2 to device and cast to weight_dtype
792767 unet .to (accelerator .device , dtype = weight_dtype )
793768 vae .to (accelerator .device , dtype = weight_dtype )
794- text_encoder_2 = text_encoder_2 .to (accelerator .device , dtype = weight_dtype )
769+ text_encoder_2 .to (accelerator .device , dtype = weight_dtype )
795770
796771 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
797772 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -876,27 +851,18 @@ def main():
876851 noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
877852
878853 # Get the text embedding for conditioning
879- encoder_hidden_states_1 = (
880- text_encoder_1 (batch ["input_ids_1" ], output_hidden_states = True )
881- .hidden_states [- 2 ]
882- .to (dtype = weight_dtype )
883- )
884- encoder_output_2 = text_encoder_2 (
885- batch ["input_ids_2" ].reshape (batch ["input_ids_1" ].shape [0 ], - 1 ), output_hidden_states = True
886- )
854+ encoder_hidden_states_1 = text_encoder_1 (batch ["input_ids_1" ], output_hidden_states = True ).hidden_states [- 2 ].to (dtype = weight_dtype )
855+ encoder_output_2 = text_encoder_2 (batch ["input_ids_2" ].reshape (batch ["input_ids_1" ].shape [0 ], - 1 ), output_hidden_states = True )
887856 encoder_hidden_states_2 = encoder_output_2 .hidden_states [- 2 ].to (dtype = weight_dtype )
888- sample_size = unet .config .sample_size * (2 ** (len (vae .config .block_out_channels ) - 1 ))
889- original_size = (sample_size , sample_size )
890- add_time_ids = torch .tensor (
891- [list (original_size + (0 , 0 ) + original_size )], dtype = weight_dtype , device = accelerator .device
892- )
857+ original_size = [(batch ["original_size" ][0 ][i ].item (), batch ["original_size" ][1 ][i ].item ()) for i in range (args .train_batch_size )]
858+ crop_top_left = [(batch ["crop_top_left" ][0 ][i ].item (), batch ["crop_top_left" ][1 ][i ].item ()) for i in range (args .train_batch_size )]
859+ target_size = (args .resolution , args .resolution )
860+ add_time_ids = torch .cat ([torch .tensor (original_size [i ] + crop_top_left [i ] + target_size ) for i in range (args .train_batch_size )]).to (accelerator .device , dtype = weight_dtype )
893861 added_cond_kwargs = {"text_embeds" : encoder_output_2 [0 ], "time_ids" : add_time_ids }
894862 encoder_hidden_states = torch .cat ([encoder_hidden_states_1 , encoder_hidden_states_2 ], dim = - 1 )
895863
896864 # Predict the noise residual
897- model_pred = unet (
898- noisy_latents , timesteps , encoder_hidden_states , added_cond_kwargs = added_cond_kwargs
899- ).sample
865+ model_pred = unet (noisy_latents , timesteps , encoder_hidden_states , added_cond_kwargs = added_cond_kwargs ).sample
900866
901867 # Get the target for loss depending on the prediction type
902868 if noise_scheduler .config .prediction_type == "epsilon" :
@@ -929,19 +895,15 @@ def main():
929895 progress_bar .update (1 )
930896 global_step += 1
931897 if global_step % args .save_steps == 0 :
932- weight_name = (
933- f"learned_embeds-steps-{ global_step } .bin"
934- if args .no_safe_serialization
935- else f"learned_embeds-steps-{ global_step } .safetensors"
936- )
898+ weight_name = (f"learned_embeds-steps-{ global_step } .safetensors" )
937899 save_path = os .path .join (args .output_dir , weight_name )
938900 save_progress (
939901 text_encoder_1 ,
940902 placeholder_token_ids ,
941903 accelerator ,
942904 args ,
943905 save_path ,
944- safe_serialization = not args . no_safe_serialization ,
906+ safe_serialization = True ,
945907 )
946908
947909 if accelerator .is_main_process :
@@ -972,16 +934,7 @@ def main():
972934
973935 if args .validation_prompt is not None and global_step % args .validation_steps == 0 :
974936 images = log_validation (
975- text_encoder_1 ,
976- text_encoder_2 ,
977- tokenizer_1 ,
978- tokenizer_2 ,
979- unet ,
980- vae ,
981- args ,
982- accelerator ,
983- weight_dtype ,
984- epoch ,
937+ text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch
985938 )
986939
987940 logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler_1 .get_last_lr ()[0 ]}
@@ -993,6 +946,10 @@ def main():
993946 # Create the pipeline using the trained modules and save it.
994947 accelerator .wait_for_everyone ()
995948 if accelerator .is_main_process :
949+ images = log_validation (
950+ text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch
951+ )
952+
996953 if args .push_to_hub and not args .save_as_full_pipeline :
997954 logger .warn ("Enabling full model saving because --push_to_hub=True was specified." )
998955 save_full_model = True
@@ -1002,23 +959,23 @@ def main():
1002959 pipeline = DiffusionPipeline .from_pretrained (
1003960 args .pretrained_model_name_or_path ,
1004961 text_encoder = accelerator .unwrap_model (text_encoder_1 ),
1005- text_encoder_2 = accelerator . unwrap_model ( text_encoder_2 ) ,
962+ text_encoder_2 = text_encoder_2 ,
1006963 vae = vae ,
1007964 unet = unet ,
1008965 tokenizer = tokenizer_1 ,
1009966 tokenizer_2 = tokenizer_2 ,
1010967 )
1011968 pipeline .save_pretrained (args .output_dir )
1012969 # Save the newly trained embeddings
1013- weight_name = "learned_embeds.bin" if args . no_safe_serialization else "learned_embeds. safetensors"
970+ weight_name = "learned_embeds.safetensors"
1014971 save_path = os .path .join (args .output_dir , weight_name )
1015972 save_progress (
1016973 text_encoder_1 ,
1017974 placeholder_token_ids ,
1018975 accelerator ,
1019976 args ,
1020977 save_path ,
1021- safe_serialization = not args . no_safe_serialization ,
978+ safe_serialization = True ,
1022979 )
1023980
1024981 if args .push_to_hub :
@@ -1035,6 +992,9 @@ def main():
1035992 ignore_patterns = ["step_*" , "epoch_*" ],
1036993 )
1037994
995+ for i in range (len (images )):
996+ images [i ].save (f"cat-backpack_sdxl_test_{ i } .png" )
997+
1038998 accelerator .end_training ()
1039999
10401000
0 commit comments