5555from diffusers .utils .torch_utils import is_compiled_module
5656
5757
58+ if is_wandb_available ():
59+ import wandb
60+
5861# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5962check_min_version ("0.26.0.dev0" )
6063
6770TORCH_DTYPE_MAPPING = {"fp32" : torch .float32 , "fp16" : torch .float16 , "bf16" : torch .bfloat16 }
6871
6972
73+ def log_validation (
74+ pipeline ,
75+ args ,
76+ accelerator ,
77+ generator ,
78+ global_step ,
79+ is_final_validation = False ,
80+ ):
81+ logger .info (
82+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
83+ f" { args .validation_prompt } ."
84+ )
85+
86+ pipeline = pipeline .to (accelerator .device )
87+ pipeline .set_progress_bar_config (disable = True )
88+
89+ if not is_final_validation :
90+ val_save_dir = os .path .join (args .output_dir , "validation_images" )
91+ if not os .path .exists (val_save_dir ):
92+ os .makedirs (val_save_dir )
93+
94+ original_image = (
95+ lambda image_url_or_path : load_image (image_url_or_path )
96+ if urlparse (image_url_or_path ).scheme
97+ else Image .open (image_url_or_path ).convert ("RGB" )
98+ )(args .val_image_url_or_path )
99+
100+ with torch .autocast (str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16" ):
101+ edited_images = []
102+ # Run inference
103+ for val_img_idx in range (args .num_validation_images ):
104+ a_val_img = pipeline (
105+ args .validation_prompt ,
106+ image = original_image ,
107+ num_inference_steps = 20 ,
108+ image_guidance_scale = 1.5 ,
109+ guidance_scale = 7 ,
110+ generator = generator ,
111+ ).images [0 ]
112+ edited_images .append (a_val_img )
113+ # Save validation images
114+ if not is_final_validation :
115+ a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
116+
117+ for tracker in accelerator .trackers :
118+ if tracker .name == "wandb" :
119+ wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
120+ for edited_image in edited_images :
121+ wandb_table .add_data (
122+ wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
123+ )
124+ logger_name = "test" if is_final_validation else "validation"
125+ tracker .log ({logger_name : wandb_table })
126+
127+
70128def import_model_class_from_model_name_or_path (
71129 pretrained_model_name_or_path : str , revision : str , subfolder : str = "text_encoder"
72130):
@@ -447,11 +505,6 @@ def main():
447505
448506 generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed )
449507
450- if args .report_to == "wandb" :
451- if not is_wandb_available ():
452- raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
453- import wandb
454-
455508 # Make one log on every process with the configuration for debugging.
456509 logging .basicConfig (
457510 format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -1111,11 +1164,6 @@ def collate_fn(examples):
11111164 ### BEGIN: Perform validation every `validation_epochs` steps
11121165 if global_step % args .validation_steps == 0 :
11131166 if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1114- logger .info (
1115- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
1116- f" { args .validation_prompt } ."
1117- )
1118-
11191167 # create pipeline
11201168 if args .use_ema :
11211169 # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1135,44 +1183,16 @@ def collate_fn(examples):
11351183 variant = args .variant ,
11361184 torch_dtype = weight_dtype ,
11371185 )
1138- pipeline = pipeline .to (accelerator .device )
1139- pipeline .set_progress_bar_config (disable = True )
1140-
1141- # run inference
1142- # Save validation images
1143- val_save_dir = os .path .join (args .output_dir , "validation_images" )
1144- if not os .path .exists (val_save_dir ):
1145- os .makedirs (val_save_dir )
1146-
1147- original_image = (
1148- lambda image_url_or_path : load_image (image_url_or_path )
1149- if urlparse (image_url_or_path ).scheme
1150- else Image .open (image_url_or_path ).convert ("RGB" )
1151- )(args .val_image_url_or_path )
1152- with torch .autocast (
1153- str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16"
1154- ):
1155- edited_images = []
1156- for val_img_idx in range (args .num_validation_images ):
1157- a_val_img = pipeline (
1158- args .validation_prompt ,
1159- image = original_image ,
1160- num_inference_steps = 20 ,
1161- image_guidance_scale = 1.5 ,
1162- guidance_scale = 7 ,
1163- generator = generator ,
1164- ).images [0 ]
1165- edited_images .append (a_val_img )
1166- a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
1167-
1168- for tracker in accelerator .trackers :
1169- if tracker .name == "wandb" :
1170- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1171- for edited_image in edited_images :
1172- wandb_table .add_data (
1173- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1174- )
1175- tracker .log ({"validation" : wandb_table })
1186+
1187+ log_validation (
1188+ pipeline ,
1189+ args ,
1190+ accelerator ,
1191+ generator ,
1192+ global_step ,
1193+ is_final_validation = False ,
1194+ )
1195+
11761196 if args .use_ema :
11771197 # Switch back to the original UNet parameters.
11781198 ema_unet .restore (unet .parameters ())
@@ -1187,7 +1207,6 @@ def collate_fn(examples):
11871207 # Create the pipeline using the trained modules and save it.
11881208 accelerator .wait_for_everyone ()
11891209 if accelerator .is_main_process :
1190- unet = unwrap_model (unet )
11911210 if args .use_ema :
11921211 ema_unet .copy_to (unet .parameters ())
11931212
@@ -1198,10 +1217,11 @@ def collate_fn(examples):
11981217 tokenizer = tokenizer_1 ,
11991218 tokenizer_2 = tokenizer_2 ,
12001219 vae = vae ,
1201- unet = unet ,
1220+ unet = unwrap_model ( unet ) ,
12021221 revision = args .revision ,
12031222 variant = args .variant ,
12041223 )
1224+
12051225 pipeline .save_pretrained (args .output_dir )
12061226
12071227 if args .push_to_hub :
@@ -1212,30 +1232,15 @@ def collate_fn(examples):
12121232 ignore_patterns = ["step_*" , "epoch_*" ],
12131233 )
12141234
1215- if args .validation_prompt is not None :
1216- edited_images = []
1217- pipeline = pipeline .to (accelerator .device )
1218- with torch .autocast (str (accelerator .device ).replace (":0" , "" )):
1219- for _ in range (args .num_validation_images ):
1220- edited_images .append (
1221- pipeline (
1222- args .validation_prompt ,
1223- image = original_image ,
1224- num_inference_steps = 20 ,
1225- image_guidance_scale = 1.5 ,
1226- guidance_scale = 7 ,
1227- generator = generator ,
1228- ).images [0 ]
1229- )
1230-
1231- for tracker in accelerator .trackers :
1232- if tracker .name == "wandb" :
1233- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1234- for edited_image in edited_images :
1235- wandb_table .add_data (
1236- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1237- )
1238- tracker .log ({"test" : wandb_table })
1235+ if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1236+ log_validation (
1237+ pipeline ,
1238+ args ,
1239+ accelerator ,
1240+ generator ,
1241+ global_step = None ,
1242+ is_final_validation = True ,
1243+ )
12391244
12401245 accelerator .end_training ()
12411246
0 commit comments