Skip to content

Commit 35889b3

Browse files
committed
dreambooth checkpointing tests and docs
1 parent 6782b70 commit 35889b3

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,11 @@ def parse_args(input_args=None):
188188
type=int,
189189
default=500,
190190
help=(
191-
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
192-
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
193-
" training using `--resume_from_checkpoint`."
191+
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
192+
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
193+
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
194+
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
195+
"instructions."
194196
),
195197
)
196198
parser.add_argument(

examples/test_examples.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from accelerate.utils import write_basic_config
2727

28+
from diffusers import DiffusionPipeline, UNet2DConditionModel
29+
2830

2931
logging.basicConfig(level=logging.DEBUG)
3032

@@ -140,6 +142,85 @@ def test_dreambooth(self):
140142
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
141143
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
142144

145+
def test_dreambooth_checkpointing(self):
146+
with tempfile.TemporaryDirectory() as tmpdir:
147+
instance_prompt = "photo"
148+
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
149+
150+
# Run training script with checkpointing
151+
# max_train_steps == 5, checkpointing_steps == 2
152+
# Should create checkpoints at steps 2, 4
153+
154+
initial_run_args = f"""
155+
examples/dreambooth/train_dreambooth.py
156+
--pretrained_model_name_or_path {pretrained_model_name_or_path}
157+
--instance_data_dir docs/source/en/imgs
158+
--instance_prompt {instance_prompt}
159+
--resolution 64
160+
--train_batch_size 1
161+
--gradient_accumulation_steps 1
162+
--max_train_steps 5
163+
--learning_rate 5.0e-04
164+
--scale_lr
165+
--lr_scheduler constant
166+
--lr_warmup_steps 0
167+
--output_dir {tmpdir}
168+
--checkpointing_steps=2
169+
--seed=0
170+
""".split()
171+
172+
run_command(self._launch_args + initial_run_args)
173+
174+
# check can run the original fully trained output pipeline
175+
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
176+
pipe(instance_prompt, num_inference_steps=2)
177+
178+
# check checkpoint directories exist
179+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
180+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
181+
182+
# check can run an intermediate checkpoint
183+
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
184+
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
185+
pipe(instance_prompt, num_inference_steps=2)
186+
187+
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
188+
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
189+
190+
# Run training script for 7 total steps resuming from checkpoint 4
191+
192+
resume_run_args = f"""
193+
examples/dreambooth/train_dreambooth.py
194+
--pretrained_model_name_or_path {pretrained_model_name_or_path}
195+
--instance_data_dir docs/source/en/imgs
196+
--instance_prompt {instance_prompt}
197+
--resolution 64
198+
--train_batch_size 1
199+
--gradient_accumulation_steps 1
200+
--max_train_steps 7
201+
--learning_rate 5.0e-04
202+
--scale_lr
203+
--lr_scheduler constant
204+
--lr_warmup_steps 0
205+
--output_dir {tmpdir}
206+
--checkpointing_steps=2
207+
--resume_from_checkpoint=checkpoint-4
208+
--seed=0
209+
""".split()
210+
211+
run_command(self._launch_args + resume_run_args)
212+
213+
# check can run new fully trained pipeline
214+
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
215+
pipe(instance_prompt, num_inference_steps=2)
216+
217+
# check old checkpoints do not exist
218+
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
219+
220+
# check new checkpoints exist
221+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
222+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
223+
143224
def test_text_to_image(self):
144225
with tempfile.TemporaryDirectory() as tmpdir:
145226
test_args = f"""

0 commit comments

Comments
 (0)