|
21 | 21 | from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
22 | 22 | from diffusers.optimization import get_scheduler
|
23 | 23 | from diffusers.training_utils import EMAModel
|
24 |
| -from diffusers.utils import check_min_version |
| 24 | +from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available |
25 | 25 |
|
26 | 26 |
|
27 | 27 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
@@ -220,6 +220,7 @@ def parse_args():
|
220 | 220 | help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
221 | 221 | )
|
222 | 222 | parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
| 223 | + parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000) |
223 | 224 | parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
224 | 225 | parser.add_argument(
|
225 | 226 | "--checkpointing_steps",
|
@@ -271,6 +272,11 @@ def main(args):
|
271 | 272 | logging_dir=logging_dir,
|
272 | 273 | )
|
273 | 274 |
|
| 275 | + if args.logger == "wandb": |
| 276 | + if not is_wandb_available(): |
| 277 | + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
| 278 | + import wandb |
| 279 | + |
274 | 280 | # Make one log on every process with the configuration for debugging.
|
275 | 281 | logging.basicConfig(
|
276 | 282 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
@@ -552,19 +558,17 @@ def transform_images(examples):
|
552 | 558 | generator=generator,
|
553 | 559 | batch_size=args.eval_batch_size,
|
554 | 560 | output_type="numpy",
|
555 |
| - num_inference_steps=args.ddpm_num_steps, |
| 561 | + num_inference_steps=args.ddpm_num_inference_steps, |
556 | 562 | ).images
|
557 | 563 |
|
558 | 564 | # denormalize the images and save to tensorboard
|
559 | 565 | images_processed = (images * 255).round().astype("uint8")
|
560 | 566 |
|
561 |
| - if args.logger == "tensorboard": |
| 567 | + if args.logger == "tensorboard" and is_tensorboard_available(): |
562 | 568 | accelerator.get_tracker("tensorboard").add_images(
|
563 | 569 | "test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
564 | 570 | )
|
565 | 571 | elif args.logger == "wandb":
|
566 |
| - import wandb |
567 |
| - |
568 | 572 | accelerator.get_tracker("wandb").log(
|
569 | 573 | {"test_samples": [wandb.Image(img) for img in images_processed]}, step=global_step
|
570 | 574 | )
|
|
0 commit comments