@@ -827,6 +827,87 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
827827 {"checkpoint-4" , "checkpoint-6" },
828828 )
829829
830+ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit (self ):
831+ prompt = "a prompt"
832+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
833+
834+ with tempfile .TemporaryDirectory () as tmpdir :
835+ # Run training script with checkpointing
836+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
837+ # Should create checkpoints at steps 2, 4, 6
838+ # with checkpoint at step 2 deleted
839+
840+ initial_run_args = f"""
841+ examples/text_to_image/train_text_to_image_lora_sdxl.py
842+ --pretrained_model_name_or_path { pipeline_path }
843+ --dataset_name hf-internal-testing/dummy_image_text_data
844+ --resolution 64
845+ --train_batch_size 1
846+ --gradient_accumulation_steps 1
847+ --max_train_steps 7
848+ --learning_rate 5.0e-04
849+ --scale_lr
850+ --lr_scheduler constant
851+ --lr_warmup_steps 0
852+ --output_dir { tmpdir }
853+ --checkpointing_steps=2
854+ --checkpoints_total_limit=2
855+ """ .split ()
856+
857+ run_command (self ._launch_args + initial_run_args )
858+
859+ pipe = DiffusionPipeline .from_pretrained (pipeline_path )
860+ pipe .load_lora_weights (tmpdir )
861+ pipe (prompt , num_inference_steps = 2 )
862+
863+ # check checkpoint directories exist
864+ self .assertEqual (
865+ {x for x in os .listdir (tmpdir ) if "checkpoint" in x },
866+ # checkpoint-2 should have been deleted
867+ {"checkpoint-4" , "checkpoint-6" },
868+ )
869+
870+ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit (self ):
871+ prompt = "a prompt"
872+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
873+
874+ with tempfile .TemporaryDirectory () as tmpdir :
875+ # Run training script with checkpointing
876+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
877+ # Should create checkpoints at steps 2, 4, 6
878+ # with checkpoint at step 2 deleted
879+
880+ initial_run_args = f"""
881+ examples/text_to_image/train_text_to_image_lora_sdxl.py
882+ --pretrained_model_name_or_path { pipeline_path }
883+ --dataset_name hf-internal-testing/dummy_image_text_data
884+ --resolution 64
885+ --train_batch_size 1
886+ --gradient_accumulation_steps 1
887+ --max_train_steps 7
888+ --learning_rate 5.0e-04
889+ --scale_lr
890+ --lr_scheduler constant
891+ --train_text_encoder
892+ --lr_warmup_steps 0
893+ --output_dir { tmpdir }
894+ --checkpointing_steps=2
895+ --checkpoints_total_limit=2
896+ """ .split ()
897+
898+ run_command (self ._launch_args + initial_run_args )
899+
900+ pipe = DiffusionPipeline .from_pretrained (pipeline_path )
901+ pipe .load_lora_weights (tmpdir )
902+ pipe (prompt , num_inference_steps = 2 )
903+
904+ # check checkpoint directories exist
905+ self .assertEqual (
906+ {x for x in os .listdir (tmpdir ) if "checkpoint" in x },
907+ # checkpoint-2 should have been deleted
908+ {"checkpoint-4" , "checkpoint-6" },
909+ )
910+
830911 def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints (self ):
831912 pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
832913 prompt = "a prompt"
0 commit comments