Skip to content

Commit 9d0ce55

Browse files
committed
Check for wandb installation
1 parent 4e8e21c commit 9d0ce55

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
2222
from diffusers.optimization import get_scheduler
2323
from diffusers.training_utils import EMAModel
24-
from diffusers.utils import check_min_version
24+
from diffusers.utils import check_min_version, is_wandb_available
2525

2626

2727
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -271,6 +271,11 @@ def main(args):
271271
logging_dir=logging_dir,
272272
)
273273

274+
if args.logger == "wandb":
275+
if not is_wandb_available():
276+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
277+
import wandb
278+
274279
# Make one log on every process with the configuration for debugging.
275280
logging.basicConfig(
276281
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -562,8 +567,6 @@ def transform_images(examples):
562567
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
563568
)
564569
elif args.logger == "wandb":
565-
import wandb
566-
567570
accelerator.get_tracker("wandb").log(
568571
{"test_samples": [wandb.Image(img) for img in images_processed]}, step=global_step
569572
)

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
2323
from diffusers.optimization import get_scheduler
2424
from diffusers.training_utils import EMAModel
25-
from diffusers.utils import check_min_version
25+
from diffusers.utils import check_min_version, is_wandb_available
2626

2727

2828
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -273,6 +273,11 @@ def main(args):
273273
logging_dir=logging_dir,
274274
)
275275

276+
if args.logger == "wandb":
277+
if not is_wandb_available():
278+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
279+
import wandb
280+
276281
# `accelerate` 0.16.0 will have better support for customized saving
277282
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
278283
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -597,8 +602,6 @@ def transform_images(examples):
597602
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
598603
)
599604
elif args.logger == "wandb":
600-
import wandb
601-
602605
accelerator.get_tracker("wandb").log(
603606
{"test_samples": [wandb.Image(img) for img in images_processed]}, step=global_step
604607
)

0 commit comments

Comments
 (0)