Skip to content

[Examples] Test all examples on CPU #2289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ jobs:
runner: docker-cpu
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu

name: ${{ matrix.config.name }}

Expand Down Expand Up @@ -90,6 +95,13 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \
tests/

- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples/test_examples.py

- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
Expand Down
69 changes: 55 additions & 14 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

from accelerate.utils import write_basic_config

from diffusers.utils import slow


logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -74,51 +72,94 @@ def tearDownClass(cls):
super().tearDownClass()
shutil.rmtree(cls._tmpdir)

@slow
def test_train_unconditional(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name huggan/few-shot-aurora
--dataset_name hf-internal-testing/dummy_image_class_data
--model_config_name_or_path diffusers/ddpm_dummy
--resolution 64
--output_dir {tmpdir}
--train_batch_size 4
--train_batch_size 2
--num_epochs 1
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--learning_rate 1e-3
--lr_warmup_steps 5
--mixed_precision fp16
""".split()

run_command(self._launch_args + test_args, return_stdout=True)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
# logging test
self.assertTrue(len(os.listdir(os.path.join(tmpdir, "logs", "train_unconditional"))) > 0)

@slow
def test_textual_inversion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--train_data_dir docs/source/en/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token toy
--initializer_token a
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 2
--max_train_steps 10
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--mixed_precision fp16
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin")))

def test_dreambooth(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_text_to_image(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
62 changes: 37 additions & 25 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from diffusers.utils import check_min_version, is_tensorboard_available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensorboard is in the requirements for the text_to_image and textual_inversion examples, but I actually prefer it this way.



# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -67,6 +67,12 @@ def parse_args():
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--model_config_name_or_path",
type=str,
default=None,
help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
)
Comment on lines +70 to +75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

parser.add_argument(
"--train_data_dir",
type=str,
Expand Down Expand Up @@ -222,6 +228,7 @@ def parse_args():
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
parser.add_argument(
"--checkpointing_steps",
Expand Down Expand Up @@ -340,29 +347,33 @@ def load_model_hook(models, input_dir):
os.makedirs(args.output_dir, exist_ok=True)

# Initialize the model
model = UNet2DModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
if args.model_config_name_or_path is None:
model = UNet2DModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
else:
config = UNet2DModel.load_config(args.model_config_name_or_path)
model = UNet2DModel.from_config(config)

# Create EMA for the model.
if args.use_ema:
Expand Down Expand Up @@ -586,13 +597,14 @@ def transform_images(examples):
images = pipeline(
generator=generator,
batch_size=args.eval_batch_size,
num_inference_steps=args.ddpm_num_inference_steps,
output_type="numpy",
).images

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")

if args.logger == "tensorboard":
if args.logger == "tensorboard" and is_tensorboard_available():
accelerator.get_tracker("tensorboard").add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
is_onnx_available,
is_safetensors_available,
is_scipy_available,
is_tensorboard_available,
is_tf_available,
is_torch_available,
is_torch_version,
Expand Down
18 changes: 18 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@
except importlib_metadata.PackageNotFoundError:
_omegaconf_available = False

_tensorboard_available = importlib.util.find_spec("tensorboard")
try:
_tensorboard_version = importlib_metadata.version("tensorboard")
logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
except importlib_metadata.PackageNotFoundError:
_tensorboard_available = False


def is_torch_available():
return _torch_available
Expand Down Expand Up @@ -285,6 +292,10 @@ def is_omegaconf_available():
return _omegaconf_available


def is_tensorboard_available():
return _tensorboard_available


# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
Expand Down Expand Up @@ -351,6 +362,12 @@ def is_omegaconf_available():
install omegaconf`
"""

# docstyle-ignore
TENSORBOARD_IMPORT_ERROR = """
{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip
install tensorboard`
"""

BACKENDS_MAPPING = OrderedDict(
[
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
Expand All @@ -364,6 +381,7 @@ def is_omegaconf_available():
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
]
)

Expand Down