|
25 | 25 |
|
26 | 26 | from accelerate.utils import write_basic_config
|
27 | 27 |
|
| 28 | +from diffusers import DiffusionPipeline, UNet2DConditionModel |
| 29 | + |
28 | 30 |
|
29 | 31 | logging.basicConfig(level=logging.DEBUG)
|
30 | 32 |
|
@@ -140,6 +142,85 @@ def test_dreambooth(self):
|
140 | 142 | self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
|
141 | 143 | self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
|
142 | 144 |
|
| 145 | + def test_dreambooth_checkpointing(self): |
| 146 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 147 | + instance_prompt = "photo" |
| 148 | + pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" |
| 149 | + |
| 150 | + # Run training script with checkpointing |
| 151 | + # max_train_steps == 5, checkpointing_steps == 2 |
| 152 | + # Should create checkpoints at steps 2, 4 |
| 153 | + |
| 154 | + initial_run_args = f""" |
| 155 | + examples/dreambooth/train_dreambooth.py |
| 156 | + --pretrained_model_name_or_path {pretrained_model_name_or_path} |
| 157 | + --instance_data_dir docs/source/en/imgs |
| 158 | + --instance_prompt {instance_prompt} |
| 159 | + --resolution 64 |
| 160 | + --train_batch_size 1 |
| 161 | + --gradient_accumulation_steps 1 |
| 162 | + --max_train_steps 5 |
| 163 | + --learning_rate 5.0e-04 |
| 164 | + --scale_lr |
| 165 | + --lr_scheduler constant |
| 166 | + --lr_warmup_steps 0 |
| 167 | + --output_dir {tmpdir} |
| 168 | + --checkpointing_steps=2 |
| 169 | + --seed=0 |
| 170 | + """.split() |
| 171 | + |
| 172 | + run_command(self._launch_args + initial_run_args) |
| 173 | + |
| 174 | + # check can run the original fully trained output pipeline |
| 175 | + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) |
| 176 | + pipe(instance_prompt, num_inference_steps=2) |
| 177 | + |
| 178 | + # check checkpoint directories exist |
| 179 | + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) |
| 180 | + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) |
| 181 | + |
| 182 | + # check can run an intermediate checkpoint |
| 183 | + unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") |
| 184 | + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None) |
| 185 | + pipe(instance_prompt, num_inference_steps=2) |
| 186 | + |
| 187 | + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming |
| 188 | + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) |
| 189 | + |
| 190 | + # Run training script for 7 total steps resuming from checkpoint 4 |
| 191 | + |
| 192 | + resume_run_args = f""" |
| 193 | + examples/dreambooth/train_dreambooth.py |
| 194 | + --pretrained_model_name_or_path {pretrained_model_name_or_path} |
| 195 | + --instance_data_dir docs/source/en/imgs |
| 196 | + --instance_prompt {instance_prompt} |
| 197 | + --resolution 64 |
| 198 | + --train_batch_size 1 |
| 199 | + --gradient_accumulation_steps 1 |
| 200 | + --max_train_steps 7 |
| 201 | + --learning_rate 5.0e-04 |
| 202 | + --scale_lr |
| 203 | + --lr_scheduler constant |
| 204 | + --lr_warmup_steps 0 |
| 205 | + --output_dir {tmpdir} |
| 206 | + --checkpointing_steps=2 |
| 207 | + --resume_from_checkpoint=checkpoint-4 |
| 208 | + --seed=0 |
| 209 | + """.split() |
| 210 | + |
| 211 | + run_command(self._launch_args + resume_run_args) |
| 212 | + |
| 213 | + # check can run new fully trained pipeline |
| 214 | + pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) |
| 215 | + pipe(instance_prompt, num_inference_steps=2) |
| 216 | + |
| 217 | + # check old checkpoints do not exist |
| 218 | + self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) |
| 219 | + |
| 220 | + # check new checkpoints exist |
| 221 | + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) |
| 222 | + self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) |
| 223 | + |
143 | 224 | def test_text_to_image(self):
|
144 | 225 | with tempfile.TemporaryDirectory() as tmpdir:
|
145 | 226 | test_args = f"""
|
|
0 commit comments