diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 4daf0e7717e7..39ceadb5acef 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -29,7 +29,7 @@ StableDiffusionXLControlNetPipeline, ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline -from .flux import FluxPipeline +from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -108,6 +108,7 @@ ("pixart-sigma-pag", PixArtSigmaPAGPipeline), ("auraflow", AuraFlowPipeline), ("flux", FluxPipeline), + ("flux-controlnet", FluxControlNetPipeline), ("lumina", LuminaText2ImgPipeline), ] ) @@ -126,6 +127,7 @@ ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), + ("flux", FluxImg2ImgPipeline), ] ) @@ -140,6 +142,7 @@ ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), + ("flux", FluxInpaintPipeline), ] ) @@ -660,12 +663,17 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] + # the `orig_class_name` can be: + # `- *Pipeline` (for regular text-to-image checkpoint) + # `- *Img2ImgPipeline` (for refiner checkpoint) + to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" + if "controlnet" in kwargs: - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) @@ -952,14 +960,17 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] + # The `orig_class_name`` can be: + # `- *InpaintPipeline` (for inpaint-specific checkpoint) + # - or *Pipeline (for regular text-to-image checkpoint) + to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" + if "controlnet" in kwargs: - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" - orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace) - + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py index 768026fa5460..d060963f49d0 100644 --- a/tests/pipelines/test_pipelines_auto.py +++ b/tests/pipelines/test_pipelines_auto.py @@ -235,9 +235,32 @@ def test_from_pretrained_img2img(self): pipe = AutoPipelineForImage2Image.from_pretrained(repo) assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline" + controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") + pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline" + + pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" + + pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline" + + def test_from_pretrained_img2img_refiner(self): + repo = "hf-internal-testing/tiny-stable-diffusion-xl-refiner-pipe" + + pipe = AutoPipelineForImage2Image.from_pretrained(repo) + assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline" + + controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") + pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline" + pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True) assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" + pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline" + def test_from_pipe_pag_img2img(self): # test from tableDiffusionXLPAGImg2ImgPipeline pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") @@ -265,6 +288,16 @@ def test_from_pretrained_inpaint(self): pipe_pag = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True) assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + def test_from_pretrained_inpaint_from_inpaint(self): + repo = "hf-internal-testing/tiny-stable-diffusion-xl-inpaint-pipe" + + pipe = AutoPipelineForInpainting.from_pretrained(repo) + assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline" + + # make sure you can use pag with inpaint-specific pipeline + pipe = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True) + assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + def test_from_pipe_pag_inpaint(self): # test from tableDiffusionXLPAGInpaintPipeline pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")