Skip to content

Commit 0db19da

Browse files
Log Unconditional Image Generation Samples to W&B (#2287)
* Log Unconditional Image Generation Samples to WandB * Check for wandb installation and parity between onnxruntime script * Log epoch to wandb * Check for tensorboard logger early on * style fixes --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 62b3c9e commit 0db19da

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
2222
from diffusers.optimization import get_scheduler
2323
from diffusers.training_utils import EMAModel
24-
from diffusers.utils import check_min_version
24+
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
2525

2626

2727
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -220,6 +220,7 @@ def parse_args():
220220
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
221221
)
222222
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
223+
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
223224
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
224225
parser.add_argument(
225226
"--checkpointing_steps",
@@ -271,6 +272,15 @@ def main(args):
271272
logging_dir=logging_dir,
272273
)
273274

275+
if args.logger == "tensorboard":
276+
if not is_tensorboard_available():
277+
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
278+
279+
elif args.logger == "wandb":
280+
if not is_wandb_available():
281+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
282+
import wandb
283+
274284
# Make one log on every process with the configuration for debugging.
275285
logging.basicConfig(
276286
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -552,7 +562,7 @@ def transform_images(examples):
552562
generator=generator,
553563
batch_size=args.eval_batch_size,
554564
output_type="numpy",
555-
num_inference_steps=args.ddpm_num_steps,
565+
num_inference_steps=args.ddpm_num_inference_steps,
556566
).images
557567

558568
# denormalize the images and save to tensorboard
@@ -562,6 +572,11 @@ def transform_images(examples):
562572
accelerator.get_tracker("tensorboard").add_images(
563573
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
564574
)
575+
elif args.logger == "wandb":
576+
accelerator.get_tracker("wandb").log(
577+
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
578+
step=global_step,
579+
)
565580

566581
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
567582
# save the model

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
2323
from diffusers.optimization import get_scheduler
2424
from diffusers.training_utils import EMAModel
25-
from diffusers.utils import check_min_version, is_tensorboard_available
25+
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
2626

2727

2828
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -280,6 +280,15 @@ def main(args):
280280
logging_dir=logging_dir,
281281
)
282282

283+
if args.logger == "tensorboard":
284+
if not is_tensorboard_available():
285+
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
286+
287+
elif args.logger == "wandb":
288+
if not is_wandb_available():
289+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
290+
import wandb
291+
283292
# `accelerate` 0.16.0 will have better support for customized saving
284293
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
285294
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -604,10 +613,15 @@ def transform_images(examples):
604613
# denormalize the images and save to tensorboard
605614
images_processed = (images * 255).round().astype("uint8")
606615

607-
if args.logger == "tensorboard" and is_tensorboard_available():
616+
if args.logger == "tensorboard":
608617
accelerator.get_tracker("tensorboard").add_images(
609618
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
610619
)
620+
elif args.logger == "wandb":
621+
accelerator.get_tracker("wandb").log(
622+
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
623+
step=global_step,
624+
)
611625

612626
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
613627
# save the model

0 commit comments

Comments
 (0)