Skip to content

Commit dc89891

Browse files
committed
tests
1 parent b7d2395 commit dc89891

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

examples/test_examples.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,87 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
827827
{"checkpoint-4", "checkpoint-6"},
828828
)
829829

830+
def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
831+
prompt = "a prompt"
832+
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
833+
834+
with tempfile.TemporaryDirectory() as tmpdir:
835+
# Run training script with checkpointing
836+
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
837+
# Should create checkpoints at steps 2, 4, 6
838+
# with checkpoint at step 2 deleted
839+
840+
initial_run_args = f"""
841+
examples/text_to_image/train_text_to_image_lora_sdxl.py
842+
--pretrained_model_name_or_path {pipeline_path}
843+
--dataset_name hf-internal-testing/dummy_image_text_data
844+
--resolution 64
845+
--train_batch_size 1
846+
--gradient_accumulation_steps 1
847+
--max_train_steps 7
848+
--learning_rate 5.0e-04
849+
--scale_lr
850+
--lr_scheduler constant
851+
--lr_warmup_steps 0
852+
--output_dir {tmpdir}
853+
--checkpointing_steps=2
854+
--checkpoints_total_limit=2
855+
""".split()
856+
857+
run_command(self._launch_args + initial_run_args)
858+
859+
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
860+
pipe.load_lora_weights(tmpdir)
861+
pipe(prompt, num_inference_steps=2)
862+
863+
# check checkpoint directories exist
864+
self.assertEqual(
865+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
866+
# checkpoint-2 should have been deleted
867+
{"checkpoint-4", "checkpoint-6"},
868+
)
869+
870+
def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
871+
prompt = "a prompt"
872+
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
873+
874+
with tempfile.TemporaryDirectory() as tmpdir:
875+
# Run training script with checkpointing
876+
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
877+
# Should create checkpoints at steps 2, 4, 6
878+
# with checkpoint at step 2 deleted
879+
880+
initial_run_args = f"""
881+
examples/text_to_image/train_text_to_image_lora_sdxl.py
882+
--pretrained_model_name_or_path {pipeline_path}
883+
--dataset_name hf-internal-testing/dummy_image_text_data
884+
--resolution 64
885+
--train_batch_size 1
886+
--gradient_accumulation_steps 1
887+
--max_train_steps 7
888+
--learning_rate 5.0e-04
889+
--scale_lr
890+
--lr_scheduler constant
891+
--train_text_encoder
892+
--lr_warmup_steps 0
893+
--output_dir {tmpdir}
894+
--checkpointing_steps=2
895+
--checkpoints_total_limit=2
896+
""".split()
897+
898+
run_command(self._launch_args + initial_run_args)
899+
900+
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
901+
pipe.load_lora_weights(tmpdir)
902+
pipe(prompt, num_inference_steps=2)
903+
904+
# check checkpoint directories exist
905+
self.assertEqual(
906+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
907+
# checkpoint-2 should have been deleted
908+
{"checkpoint-4", "checkpoint-6"},
909+
)
910+
830911
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
831912
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
832913
prompt = "a prompt"

0 commit comments

Comments
 (0)