diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 082b12404a85..112596057dd9 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -36,6 +36,11 @@ jobs: runner: docker-cpu image: diffusers/diffusers-onnxruntime-cpu report: onnx_cpu + - name: PyTorch Example CPU tests on Ubuntu + framework: pytorch_examples + runner: docker-cpu + image: diffusers/diffusers-pytorch-cpu + report: torch_cpu name: ${{ matrix.config.name }} @@ -90,6 +95,13 @@ jobs: --make-reports=tests_${{ matrix.config.report }} \ tests/ + - name: Run example PyTorch CPU tests + if: ${{ matrix.config.framework == 'pytorch_examples' }} + run: | + python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ + --make-reports=tests_${{ matrix.config.report }} \ + examples/test_examples.py + - name: Failure short reports if: ${{ failure() }} run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt diff --git a/examples/test_examples.py b/examples/test_examples.py index d940c6d93b6f..329769656347 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -25,8 +25,6 @@ from accelerate.utils import write_basic_config -from diffusers.utils import slow - logging.basicConfig(level=logging.DEBUG) @@ -74,51 +72,94 @@ def tearDownClass(cls): super().tearDownClass() shutil.rmtree(cls._tmpdir) - @slow def test_train_unconditional(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" examples/unconditional_image_generation/train_unconditional.py - --dataset_name huggan/few-shot-aurora + --dataset_name hf-internal-testing/dummy_image_class_data + --model_config_name_or_path diffusers/ddpm_dummy --resolution 64 --output_dir {tmpdir} - --train_batch_size 4 + --train_batch_size 2 --num_epochs 1 --gradient_accumulation_steps 1 + --ddpm_num_inference_steps 2 --learning_rate 1e-3 --lr_warmup_steps 5 - --mixed_precision fp16 """.split() run_command(self._launch_args + test_args, return_stdout=True) # save_pretrained smoke test self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) - # logging test - self.assertTrue(len(os.listdir(os.path.join(tmpdir, "logs", "train_unconditional"))) > 0) - @slow def test_textual_inversion(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" examples/textual_inversion/textual_inversion.py - --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe --train_data_dir docs/source/en/imgs --learnable_property object --placeholder_token - --initializer_token toy + --initializer_token a --resolution 64 --train_batch_size 1 - --gradient_accumulation_steps 2 - --max_train_steps 10 + --gradient_accumulation_steps 1 + --max_train_steps 2 --learning_rate 5.0e-04 --scale_lr --lr_scheduler constant --lr_warmup_steps 0 --output_dir {tmpdir} - --mixed_precision fp16 """.split() run_command(self._launch_args + test_args) # save_pretrained smoke test self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin"))) + + def test_dreambooth(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + + def test_text_to_image(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/text_to_image/train_text_to_image.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --center_crop + --random_flip + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 324ddb0538e3..64ba126d0cce 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -22,7 +22,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version +from diffusers.utils import check_min_version, is_tensorboard_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -67,6 +67,12 @@ def parse_args(): default=None, help="The config of the Dataset, leave as None if there's only one config.", ) + parser.add_argument( + "--model_config_name_or_path", + type=str, + default=None, + help="The config of the UNet model to train, leave as None to use standard DDPM configuration.", + ) parser.add_argument( "--train_data_dir", type=str, @@ -222,6 +228,7 @@ def parse_args(): help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", ) parser.add_argument("--ddpm_num_steps", type=int, default=1000) + parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000) parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") parser.add_argument( "--checkpointing_steps", @@ -340,29 +347,33 @@ def load_model_hook(models, input_dir): os.makedirs(args.output_dir, exist_ok=True) # Initialize the model - model = UNet2DModel( - sample_size=args.resolution, - in_channels=3, - out_channels=3, - layers_per_block=2, - block_out_channels=(128, 128, 256, 256, 512, 512), - down_block_types=( - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "AttnDownBlock2D", - "DownBlock2D", - ), - up_block_types=( - "UpBlock2D", - "AttnUpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - ), - ) + if args.model_config_name_or_path is None: + model = UNet2DModel( + sample_size=args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + else: + config = UNet2DModel.load_config(args.model_config_name_or_path) + model = UNet2DModel.from_config(config) # Create EMA for the model. if args.use_ema: @@ -586,13 +597,14 @@ def transform_images(examples): images = pipeline( generator=generator, batch_size=args.eval_batch_size, + num_inference_steps=args.ddpm_num_inference_steps, output_type="numpy", ).images # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") - if args.logger == "tensorboard": + if args.logger == "tensorboard" and is_tensorboard_available(): accelerator.get_tracker("tensorboard").add_images( "test_samples", images_processed.transpose(0, 3, 1, 2), epoch ) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d53cf9d634a5..8e61b5757eb5 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -52,6 +52,7 @@ is_onnx_available, is_safetensors_available, is_scipy_available, + is_tensorboard_available, is_tf_available, is_torch_available, is_torch_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 48428fcefaf0..cc607138758f 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -224,6 +224,13 @@ except importlib_metadata.PackageNotFoundError: _omegaconf_available = False +_tensorboard_available = importlib.util.find_spec("tensorboard") +try: + _tensorboard_version = importlib_metadata.version("tensorboard") + logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") +except importlib_metadata.PackageNotFoundError: + _tensorboard_available = False + def is_torch_available(): return _torch_available @@ -285,6 +292,10 @@ def is_omegaconf_available(): return _omegaconf_available +def is_tensorboard_available(): + return _tensorboard_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -351,6 +362,12 @@ def is_omegaconf_available(): install omegaconf` """ +# docstyle-ignore +TENSORBOARD_IMPORT_ERROR = """ +{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip +install tensorboard` +""" + BACKENDS_MAPPING = OrderedDict( [ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), @@ -364,6 +381,7 @@ def is_omegaconf_available(): ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), + ("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ] )