Skip to content

Commit 62bceb7

Browse files
author
yiyixuxu
committed
Merge branch 'main' into refactor-embedding-rest
2 parents d507a22 + 3e71a20 commit 62bceb7

File tree

9 files changed

+135
-84
lines changed

9 files changed

+135
-84
lines changed

.github/workflows/push_tests_fast.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ jobs:
9898
- name: Run example PyTorch CPU tests
9999
if: ${{ matrix.config.framework == 'pytorch_examples' }}
100100
run: |
101+
python -m pip install peft
101102
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
102103
--make-reports=tests_${{ matrix.config.report }} \
103104
examples

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,17 @@ def main(args):
991991
text_encoder_one.add_adapter(text_lora_config)
992992
text_encoder_two.add_adapter(text_lora_config)
993993

994+
# Make sure the trainable params are in float32.
995+
if args.mixed_precision == "fp16":
996+
models = [unet]
997+
if args.train_text_encoder:
998+
models.extend([text_encoder_one, text_encoder_two])
999+
for model in models:
1000+
for param in model.parameters():
1001+
# only upcast trainable parameters (LoRA) into fp32
1002+
if param.requires_grad:
1003+
param.data = param.to(torch.float32)
1004+
9941005
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
9951006
def save_model_hook(models, weights, output_dir):
9961007
if accelerator.is_main_process:

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,13 @@ def main():
460460
vae.to(accelerator.device, dtype=weight_dtype)
461461
text_encoder.to(accelerator.device, dtype=weight_dtype)
462462

463+
# Add adapter and make sure the trainable params are in float32.
463464
unet.add_adapter(unet_lora_config)
465+
if args.mixed_precision == "fp16":
466+
for param in unet.parameters():
467+
# only upcast trainable parameters (LoRA) into fp32
468+
if param.requires_grad:
469+
param.data = param.to(torch.float32)
464470

465471
if args.enable_xformers_memory_efficient_attention:
466472
if is_xformers_available():
@@ -888,39 +894,42 @@ def collate_fn(examples):
888894
ignore_patterns=["step_*", "epoch_*"],
889895
)
890896

891-
# Final inference
892-
# Load previous pipeline
893-
pipeline = DiffusionPipeline.from_pretrained(
894-
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
895-
)
896-
pipeline = pipeline.to(accelerator.device)
897+
# Final inference
898+
# Load previous pipeline
899+
if args.validation_prompt is not None:
900+
pipeline = DiffusionPipeline.from_pretrained(
901+
args.pretrained_model_name_or_path,
902+
revision=args.revision,
903+
variant=args.variant,
904+
torch_dtype=weight_dtype,
905+
)
906+
pipeline = pipeline.to(accelerator.device)
897907

898-
# load attention processors
899-
pipeline.unet.load_attn_procs(args.output_dir)
908+
# load attention processors
909+
pipeline.load_lora_weights(args.output_dir)
900910

901-
# run inference
902-
generator = torch.Generator(device=accelerator.device)
903-
if args.seed is not None:
904-
generator = generator.manual_seed(args.seed)
905-
images = []
906-
for _ in range(args.num_validation_images):
907-
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
911+
# run inference
912+
generator = torch.Generator(device=accelerator.device)
913+
if args.seed is not None:
914+
generator = generator.manual_seed(args.seed)
915+
images = []
916+
for _ in range(args.num_validation_images):
917+
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
908918

909-
if accelerator.is_main_process:
910-
for tracker in accelerator.trackers:
911-
if len(images) != 0:
912-
if tracker.name == "tensorboard":
913-
np_images = np.stack([np.asarray(img) for img in images])
914-
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
915-
if tracker.name == "wandb":
916-
tracker.log(
917-
{
918-
"test": [
919-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
920-
for i, image in enumerate(images)
921-
]
922-
}
923-
)
919+
for tracker in accelerator.trackers:
920+
if len(images) != 0:
921+
if tracker.name == "tensorboard":
922+
np_images = np.stack([np.asarray(img) for img in images])
923+
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
924+
if tracker.name == "wandb":
925+
tracker.log(
926+
{
927+
"test": [
928+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
929+
for i, image in enumerate(images)
930+
]
931+
}
932+
)
924933

925934
accelerator.end_training()
926935

src/diffusers/models/controlnetxs.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
2525
from ..utils import BaseOutput, logging
26-
from .attention_processor import (
27-
AttentionProcessor,
28-
)
26+
from .attention_processor import USE_PEFT_BACKEND, AttentionProcessor
2927
from .autoencoders import AutoencoderKL
3028
from .lora import LoRACompatibleConv
3129
from .modeling_utils import ModelMixin
@@ -817,11 +815,23 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
817815
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
818816
norm_kwargs["num_channels"] += by # surgery done here
819817
# conv1
820-
conv1_args = (
821-
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
822-
)
818+
conv1_args = [
819+
"in_channels",
820+
"out_channels",
821+
"kernel_size",
822+
"stride",
823+
"padding",
824+
"dilation",
825+
"groups",
826+
"bias",
827+
"padding_mode",
828+
]
829+
if not USE_PEFT_BACKEND:
830+
conv1_args.append("lora_layer")
831+
823832
for a in conv1_args:
824833
assert hasattr(old_conv1, a)
834+
825835
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
826836
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
827837
conv1_kwargs["in_channels"] += by # surgery done here
@@ -839,25 +849,42 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
839849
}
840850
# swap old with new modules
841851
unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
842-
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = LoRACompatibleConv(**conv1_kwargs)
843-
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
852+
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = (
853+
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
854+
)
855+
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = (
856+
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
857+
)
844858
unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
845859

846860

847861
def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
848862
"""Increase channels sizes to allow for additional concatted information from base model"""
849863
old_down = unet.down_blocks[block_no].downsamplers[0].conv
850-
# conv1
851-
args = "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(
852-
" "
853-
)
864+
865+
args = [
866+
"in_channels",
867+
"out_channels",
868+
"kernel_size",
869+
"stride",
870+
"padding",
871+
"dilation",
872+
"groups",
873+
"bias",
874+
"padding_mode",
875+
]
876+
if not USE_PEFT_BACKEND:
877+
args.append("lora_layer")
878+
854879
for a in args:
855880
assert hasattr(old_down, a)
856881
kwargs = {a: getattr(old_down, a) for a in args}
857882
kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
858883
kwargs["in_channels"] += by # surgery done here
859884
# swap old with new modules
860-
unet.down_blocks[block_no].downsamplers[0].conv = LoRACompatibleConv(**kwargs)
885+
unet.down_blocks[block_no].downsamplers[0].conv = (
886+
nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs)
887+
)
861888
unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
862889

863890

@@ -871,12 +898,20 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
871898
assert hasattr(old_norm1, a)
872899
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
873900
norm_kwargs["num_channels"] += by # surgery done here
874-
# conv1
875-
conv1_args = (
876-
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
877-
)
878-
for a in conv1_args:
879-
assert hasattr(old_conv1, a)
901+
conv1_args = [
902+
"in_channels",
903+
"out_channels",
904+
"kernel_size",
905+
"stride",
906+
"padding",
907+
"dilation",
908+
"groups",
909+
"bias",
910+
"padding_mode",
911+
]
912+
if not USE_PEFT_BACKEND:
913+
conv1_args.append("lora_layer")
914+
880915
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
881916
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
882917
conv1_kwargs["in_channels"] += by # surgery done here
@@ -894,8 +929,12 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
894929
}
895930
# swap old with new modules
896931
unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
897-
unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs)
898-
unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
932+
unet.mid_block.resnets[0].conv1 = (
933+
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
934+
)
935+
unet.mid_block.resnets[0].conv_shortcut = (
936+
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
937+
)
899938
unet.mid_block.resnets[0].in_channels += by # surgery done here
900939

901940

src/diffusers/models/embeddings.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def forward(
733733
return objs
734734

735735

736-
class CombinedTimestepSizeEmbeddings(nn.Module):
736+
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
737737
"""
738738
For PixArt-Alpha.
739739
@@ -750,45 +750,27 @@ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool
750750

751751
self.use_additional_conditions = use_additional_conditions
752752
if use_additional_conditions:
753-
self.use_additional_conditions = True
754753
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
755754
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
756755
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
757756

758-
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
759-
if size.ndim == 1:
760-
size = size[:, None]
761-
762-
if size.shape[0] != batch_size:
763-
size = size.repeat(batch_size // size.shape[0], 1)
764-
if size.shape[0] != batch_size:
765-
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
766-
767-
current_batch_size, dims = size.shape[0], size.shape[1]
768-
size = size.reshape(-1)
769-
size_freq = self.additional_condition_proj(size).to(size.dtype)
770-
771-
size_emb = embedder(size_freq)
772-
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
773-
return size_emb
774-
775757
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
776758
timesteps_proj = self.time_proj(timestep)
777759
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
778760

779761
if self.use_additional_conditions:
780-
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
781-
aspect_ratio = self.apply_condition(
782-
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
783-
)
784-
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
762+
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
763+
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
764+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
765+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
766+
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
785767
else:
786768
conditioning = timesteps_emb
787769

788770
return conditioning
789771

790772

791-
class CaptionProjection(nn.Module):
773+
class PixArtAlphaTextProjection(nn.Module):
792774
"""
793775
Projects caption embeddings. Also handles dropout for classifier-free guidance.
794776
@@ -800,9 +782,8 @@ def __init__(self, in_features, hidden_size, num_tokens=120):
800782
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
801783
self.act_1 = nn.GELU(approximate="tanh")
802784
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
803-
self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
804785

805-
def forward(self, caption, force_drop_ids=None):
786+
def forward(self, caption):
806787
hidden_states = self.linear_1(caption)
807788
hidden_states = self.act_1(hidden_states)
808789
hidden_states = self.linear_2(hidden_states)

src/diffusers/models/normalization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn.functional as F
2121

2222
from .activations import get_activation
23-
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
23+
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
2424

2525

2626
class AdaLayerNorm(nn.Module):
@@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module):
9191
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
9292
super().__init__()
9393

94-
self.emb = CombinedTimestepSizeEmbeddings(
94+
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
9595
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
9696
)
9797

src/diffusers/models/transformer_2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..models.embeddings import ImagePositionalEmbeddings
2323
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
2424
from .attention import BasicTransformerBlock
25-
from .embeddings import CaptionProjection, PatchEmbed
25+
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
2626
from .lora import LoRACompatibleConv, LoRACompatibleLinear
2727
from .modeling_utils import ModelMixin
2828
from .normalization import AdaLayerNormSingle
@@ -235,7 +235,7 @@ def __init__(
235235

236236
self.caption_projection = None
237237
if caption_channels is not None:
238-
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
238+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
239239

240240
self.gradient_checkpointing = False
241241

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,11 @@ def __call__(
853853
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
854854
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
855855
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
856+
857+
if do_classifier_free_guidance:
858+
resolution = torch.cat([resolution, resolution], dim=0)
859+
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
860+
856861
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
857862

858863
# 7. Denoising loop

tests/pipelines/controlnetxs/test_controlnetxs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
enable_full_determinism,
3535
load_image,
3636
load_numpy,
37+
numpy_cosine_similarity_distance,
3738
require_python39_or_higher,
3839
require_torch_2,
3940
require_torch_gpu,
@@ -273,7 +274,9 @@ def test_canny(self):
273274

274275
original_image = image[-3:, -3:, -1].flatten()
275276
expected_image = np.array([0.1274, 0.1401, 0.147, 0.1185, 0.1555, 0.1492, 0.1565, 0.1474, 0.1701])
276-
assert np.allclose(original_image, expected_image, atol=1e-04)
277+
278+
max_diff = numpy_cosine_similarity_distance(original_image, expected_image)
279+
assert max_diff < 1e-4
277280

278281
def test_depth(self):
279282
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-depth")
@@ -298,7 +301,9 @@ def test_depth(self):
298301

299302
original_image = image[-3:, -3:, -1].flatten()
300303
expected_image = np.array([0.1098, 0.1025, 0.1211, 0.1129, 0.1165, 0.1262, 0.1185, 0.1261, 0.1703])
301-
assert np.allclose(original_image, expected_image, atol=1e-04)
304+
305+
max_diff = numpy_cosine_similarity_distance(original_image, expected_image)
306+
assert max_diff < 1e-4
302307

303308
@require_python39_or_higher
304309
@require_torch_2

0 commit comments

Comments
 (0)