21
21
from diffusers import DDPMPipeline , DDPMScheduler , UNet2DModel
22
22
from diffusers .optimization import get_scheduler
23
23
from diffusers .training_utils import EMAModel
24
- from diffusers .utils import check_min_version
24
+ from diffusers .utils import check_min_version , is_tensorboard_available , is_wandb_available
25
25
26
26
27
27
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -220,6 +220,7 @@ def parse_args():
220
220
help = "Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'." ,
221
221
)
222
222
parser .add_argument ("--ddpm_num_steps" , type = int , default = 1000 )
223
+ parser .add_argument ("--ddpm_num_inference_steps" , type = int , default = 1000 )
223
224
parser .add_argument ("--ddpm_beta_schedule" , type = str , default = "linear" )
224
225
parser .add_argument (
225
226
"--checkpointing_steps" ,
@@ -271,6 +272,15 @@ def main(args):
271
272
logging_dir = logging_dir ,
272
273
)
273
274
275
+ if args .logger == "tensorboard" :
276
+ if not is_tensorboard_available ():
277
+ raise ImportError ("Make sure to install tensorboard if you want to use it for logging during training." )
278
+
279
+ elif args .logger == "wandb" :
280
+ if not is_wandb_available ():
281
+ raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
282
+ import wandb
283
+
274
284
# Make one log on every process with the configuration for debugging.
275
285
logging .basicConfig (
276
286
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -552,7 +562,7 @@ def transform_images(examples):
552
562
generator = generator ,
553
563
batch_size = args .eval_batch_size ,
554
564
output_type = "numpy" ,
555
- num_inference_steps = args .ddpm_num_steps ,
565
+ num_inference_steps = args .ddpm_num_inference_steps ,
556
566
).images
557
567
558
568
# denormalize the images and save to tensorboard
@@ -562,6 +572,11 @@ def transform_images(examples):
562
572
accelerator .get_tracker ("tensorboard" ).add_images (
563
573
"test_samples" , images_processed .transpose (0 , 3 , 1 , 2 ), epoch
564
574
)
575
+ elif args .logger == "wandb" :
576
+ accelerator .get_tracker ("wandb" ).log (
577
+ {"test_samples" : [wandb .Image (img ) for img in images_processed ], "epoch" : epoch },
578
+ step = global_step ,
579
+ )
565
580
566
581
if epoch % args .save_model_epochs == 0 or epoch == args .num_epochs - 1 :
567
582
# save the model
0 commit comments