Skip to content

Commit 8ad1a81

Browse files
committed
Check for wandb installation and parity between onnxruntime script
1 parent e17d516 commit 8ad1a81

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
2222
from diffusers.optimization import get_scheduler
2323
from diffusers.training_utils import EMAModel
24-
from diffusers.utils import check_min_version
24+
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
2525

2626

2727
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -220,6 +220,7 @@ def parse_args():
220220
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
221221
)
222222
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
223+
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
223224
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
224225
parser.add_argument(
225226
"--checkpointing_steps",
@@ -271,6 +272,11 @@ def main(args):
271272
logging_dir=logging_dir,
272273
)
273274

275+
if args.logger == "wandb":
276+
if not is_wandb_available():
277+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
278+
import wandb
279+
274280
# Make one log on every process with the configuration for debugging.
275281
logging.basicConfig(
276282
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -552,19 +558,17 @@ def transform_images(examples):
552558
generator=generator,
553559
batch_size=args.eval_batch_size,
554560
output_type="numpy",
555-
num_inference_steps=args.ddpm_num_steps,
561+
num_inference_steps=args.ddpm_num_inference_steps,
556562
).images
557563

558564
# denormalize the images and save to tensorboard
559565
images_processed = (images * 255).round().astype("uint8")
560566

561-
if args.logger == "tensorboard":
567+
if args.logger == "tensorboard" and is_tensorboard_available():
562568
accelerator.get_tracker("tensorboard").add_images(
563569
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
564570
)
565571
elif args.logger == "wandb":
566-
import wandb
567-
568572
accelerator.get_tracker("wandb").log(
569573
{"test_samples": [wandb.Image(img) for img in images_processed]}, step=global_step
570574
)

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
2323
from diffusers.optimization import get_scheduler
2424
from diffusers.training_utils import EMAModel
25-
from diffusers.utils import check_min_version, is_tensorboard_available
25+
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
2626

2727

2828
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -280,6 +280,11 @@ def main(args):
280280
logging_dir=logging_dir,
281281
)
282282

283+
if args.logger == "wandb":
284+
if not is_wandb_available():
285+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
286+
import wandb
287+
283288
# `accelerate` 0.16.0 will have better support for customized saving
284289
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
285290
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -599,7 +604,6 @@ def transform_images(examples):
599604
batch_size=args.eval_batch_size,
600605
num_inference_steps=args.ddpm_num_inference_steps,
601606
output_type="numpy",
602-
num_inference_steps=args.ddpm_num_steps,
603607
).images
604608

605609
# denormalize the images and save to tensorboard
@@ -610,8 +614,6 @@ def transform_images(examples):
610614
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
611615
)
612616
elif args.logger == "wandb":
613-
import wandb
614-
615617
accelerator.get_tracker("wandb").log(
616618
{"test_samples": [wandb.Image(img) for img in images_processed]}, step=global_step
617619
)

0 commit comments

Comments
 (0)