Skip to content

add PAG support for SD Img2Img #9463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- all
- __call__

## StableDiffusionPAGImg2ImgPipeline
[[autodoc]] StableDiffusionPAGImg2ImgPipeline
- all
- __call__

## StableDiffusionControlNetPAGPipeline
[[autodoc]] StableDiffusionControlNetPAGPipeline

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@
"StableDiffusionLatentUpscalePipeline",
"StableDiffusionLDM3DPipeline",
"StableDiffusionModelEditingPipeline",
"StableDiffusionPAGImg2ImgPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPanoramaPipeline",
"StableDiffusionParadigmsPipeline",
Expand Down Expand Up @@ -795,6 +796,7 @@
StableDiffusionLatentUpscalePipeline,
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPAGImg2ImgPipeline,
StableDiffusionPAGPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
"HunyuanDiTPAGPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPAGImg2ImgPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPAGInpaintPipeline",
Expand Down Expand Up @@ -569,6 +570,7 @@
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGImg2ImgPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionXLControlNetPAGPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGImg2ImgPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionXLControlNetPAGPipeline,
Expand Down Expand Up @@ -131,6 +132,7 @@
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
("kandinsky3", Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
Expand All @@ -54,6 +55,7 @@
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
Expand Down
1,091 changes: 1,091 additions & 0 deletions src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_torch_and_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusionPAGImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])


class StableDiffusionPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

Expand Down
282 changes: 282 additions & 0 deletions tests/pipelines/pag/test_pag_sd_img2img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import inspect
import random
import unittest

import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

from diffusers import (
AutoencoderKL,
AutoencoderTiny,
AutoPipelineForImage2Image,
EulerDiscreteScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionPAGImg2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
require_torch_gpu,
slow,
torch_device,
)

from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)


enable_full_determinism()


class StableDiffusionPAGImg2ImgPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionPAGImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS

def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components

def get_dummy_tiny_autoencoder(self):
return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)

def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"pag_scale": 0.9,
"output_type": "np",
}
return inputs

def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

# base pipeline (expect same output when pag is disabled)
pipe_sd = StableDiffusionImg2ImgPipeline(**components)
pipe_sd = pipe_sd.to(device)
pipe_sd.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
assert (
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]

# pag disabled with pag_scale=0.0
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 0.0
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]

# pag enabled
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]

assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3

def test_pag_inference(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = pipe_pag(**inputs).images
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (
1,
32,
32,
3,
), f"the shape of the output image should be (1, 32, 32, 3) but got {image.shape}"

expected_slice = np.array(
[0.44203848, 0.49598145, 0.42248967, 0.6707724, 0.5683791, 0.43603387, 0.58316565, 0.60077155, 0.5174199]
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)


@slow
@require_torch_gpu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is not needed? (same as the one above it?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should i remove StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase): also ?

class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusionPAGImg2ImgPipeline
repo_id = "Jiali/stable-diffusion-1.5"

def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()

def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"image": init_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"pag_scale": 3.0,
"output_type": "np",
}
return inputs

def test_pag_cfg(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=None)

inputs = self.get_inputs(torch_device)
image = pipeline(**inputs).images

image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
print(image_slice.flatten())
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"

def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
pipeline.set_progress_bar_config(disable=None)

inputs = self.get_inputs(torch_device, guidance_scale=0.0)
image = pipeline(**inputs).images

image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
print(image_slice.flatten())
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"
Loading