Skip to content

Commit 222b5aa

Browse files
nupurkmr9Nupur KumarisayakpaulpatrickvonplatenNupur Kumari
authored
adding custom diffusion training to diffusers examples (huggingface#3031)
* diffusers==0.14.0 update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion * custom diffusion * custom diffusion * custom diffusion * custom diffusion * apply formatting and get rid of bare except. * refactor readme and other minor changes. * misc refactor. * fix: repo_id issue and loaders logging bug. * fix: save_model_card. * fix: save_model_card. * fix: save_model_card. * add: doc entry. * refactor doc,. * custom diffusion * custom diffusion * custom diffusion * apply style. * remove tralining whitespace. * fix: toctree entry. * remove unnecessary print. * custom diffusion * custom diffusion * custom diffusion test * custom diffusion xformer update * custom diffusion xformer update * custom diffusion xformer update --------- Co-authored-by: Nupur Kumari <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Nupur Kumari <[email protected]>
1 parent 278b134 commit 222b5aa

File tree

2 files changed

+250
-9
lines changed

2 files changed

+250
-9
lines changed

loaders.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
import torch
2020
from huggingface_hub import hf_hub_download
2121

22-
from .models.attention_processor import LoRAAttnProcessor
22+
from .models.attention_processor import (
23+
CustomDiffusionAttnProcessor,
24+
CustomDiffusionXFormersAttnProcessor,
25+
LoRAAttnProcessor,
26+
)
2327
from .utils import (
2428
DIFFUSERS_CACHE,
2529
HF_HUB_OFFLINE,
@@ -48,6 +52,9 @@
4852
TEXT_INVERSION_NAME = "learned_embeds.bin"
4953
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
5054

55+
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
56+
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
57+
5158

5259
class AttnProcsLayers(torch.nn.Module):
5360
def __init__(self, state_dict: Dict[str, torch.Tensor]):
@@ -215,6 +222,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
215222
attn_processors = {}
216223

217224
is_lora = all("lora" in k for k in state_dict.keys())
225+
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
218226

219227
if is_lora:
220228
lora_grouped_dict = defaultdict(dict)
@@ -231,9 +239,38 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
231239
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
232240
)
233241
attn_processors[key].load_state_dict(value_dict)
234-
242+
elif is_custom_diffusion:
243+
custom_diffusion_grouped_dict = defaultdict(dict)
244+
for key, value in state_dict.items():
245+
if len(value) == 0:
246+
custom_diffusion_grouped_dict[key] = {}
247+
else:
248+
if "to_out" in key:
249+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
250+
else:
251+
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
252+
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
253+
254+
for key, value_dict in custom_diffusion_grouped_dict.items():
255+
if len(value_dict) == 0:
256+
attn_processors[key] = CustomDiffusionAttnProcessor(
257+
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
258+
)
259+
else:
260+
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
261+
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
262+
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
263+
attn_processors[key] = CustomDiffusionAttnProcessor(
264+
train_kv=True,
265+
train_q_out=train_q_out,
266+
hidden_size=hidden_size,
267+
cross_attention_dim=cross_attention_dim,
268+
)
269+
attn_processors[key].load_state_dict(value_dict)
235270
else:
236-
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
271+
raise ValueError(
272+
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
273+
)
237274

238275
# set correct dtype & device
239276
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
@@ -287,16 +324,31 @@ def save_function(weights, filename):
287324

288325
os.makedirs(save_directory, exist_ok=True)
289326

290-
model_to_save = AttnProcsLayers(self.attn_processors)
291-
292-
# Save the model
293-
state_dict = model_to_save.state_dict()
327+
is_custom_diffusion = any(
328+
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
329+
for (_, x) in self.attn_processors.items()
330+
)
331+
if is_custom_diffusion:
332+
model_to_save = AttnProcsLayers(
333+
{
334+
y: x
335+
for (y, x) in self.attn_processors.items()
336+
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
337+
}
338+
)
339+
state_dict = model_to_save.state_dict()
340+
for name, attn in self.attn_processors.items():
341+
if len(attn.state_dict()) == 0:
342+
state_dict[name] = {}
343+
else:
344+
model_to_save = AttnProcsLayers(self.attn_processors)
345+
state_dict = model_to_save.state_dict()
294346

295347
if weight_name is None:
296348
if safe_serialization:
297-
weight_name = LORA_WEIGHT_NAME_SAFE
349+
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
298350
else:
299-
weight_name = LORA_WEIGHT_NAME
351+
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
300352

301353
# Save the model
302354
save_function(state_dict, os.path.join(save_directory, weight_name))

models/attention_processor.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def set_use_memory_efficient_attention_xformers(
149149
is_lora = hasattr(self, "processor") and isinstance(
150150
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
151151
)
152+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
153+
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
154+
)
152155

153156
if use_memory_efficient_attention_xformers:
154157
if self.added_kv_proj_dim is not None:
@@ -192,6 +195,17 @@ def set_use_memory_efficient_attention_xformers(
192195
)
193196
processor.load_state_dict(self.processor.state_dict())
194197
processor.to(self.processor.to_q_lora.up.weight.device)
198+
elif is_custom_diffusion:
199+
processor = CustomDiffusionXFormersAttnProcessor(
200+
train_kv=self.processor.train_kv,
201+
train_q_out=self.processor.train_q_out,
202+
hidden_size=self.processor.hidden_size,
203+
cross_attention_dim=self.processor.cross_attention_dim,
204+
attention_op=attention_op,
205+
)
206+
processor.load_state_dict(self.processor.state_dict())
207+
if hasattr(self.processor, "to_k_custom_diffusion"):
208+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
195209
else:
196210
processor = XFormersAttnProcessor(attention_op=attention_op)
197211
else:
@@ -203,6 +217,16 @@ def set_use_memory_efficient_attention_xformers(
203217
)
204218
processor.load_state_dict(self.processor.state_dict())
205219
processor.to(self.processor.to_q_lora.up.weight.device)
220+
elif is_custom_diffusion:
221+
processor = CustomDiffusionAttnProcessor(
222+
train_kv=self.processor.train_kv,
223+
train_q_out=self.processor.train_q_out,
224+
hidden_size=self.processor.hidden_size,
225+
cross_attention_dim=self.processor.cross_attention_dim,
226+
)
227+
processor.load_state_dict(self.processor.state_dict())
228+
if hasattr(self.processor, "to_k_custom_diffusion"):
229+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
206230
else:
207231
processor = AttnProcessor()
208232

@@ -459,6 +483,84 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
459483
return hidden_states
460484

461485

486+
class CustomDiffusionAttnProcessor(nn.Module):
487+
def __init__(
488+
self,
489+
train_kv=True,
490+
train_q_out=True,
491+
hidden_size=None,
492+
cross_attention_dim=None,
493+
out_bias=True,
494+
dropout=0.0,
495+
):
496+
super().__init__()
497+
self.train_kv = train_kv
498+
self.train_q_out = train_q_out
499+
500+
self.hidden_size = hidden_size
501+
self.cross_attention_dim = cross_attention_dim
502+
503+
# `_custom_diffusion` id for easy serialization and loading.
504+
if self.train_kv:
505+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
506+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
507+
if self.train_q_out:
508+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
509+
self.to_out_custom_diffusion = nn.ModuleList([])
510+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
511+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
512+
513+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
514+
batch_size, sequence_length, _ = hidden_states.shape
515+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
516+
if self.train_q_out:
517+
query = self.to_q_custom_diffusion(hidden_states)
518+
else:
519+
query = attn.to_q(hidden_states)
520+
521+
if encoder_hidden_states is None:
522+
crossattn = False
523+
encoder_hidden_states = hidden_states
524+
else:
525+
crossattn = True
526+
if attn.norm_cross:
527+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
528+
529+
if self.train_kv:
530+
key = self.to_k_custom_diffusion(encoder_hidden_states)
531+
value = self.to_v_custom_diffusion(encoder_hidden_states)
532+
else:
533+
key = attn.to_k(encoder_hidden_states)
534+
value = attn.to_v(encoder_hidden_states)
535+
536+
if crossattn:
537+
detach = torch.ones_like(key)
538+
detach[:, :1, :] = detach[:, :1, :] * 0.0
539+
key = detach * key + (1 - detach) * key.detach()
540+
value = detach * value + (1 - detach) * value.detach()
541+
542+
query = attn.head_to_batch_dim(query)
543+
key = attn.head_to_batch_dim(key)
544+
value = attn.head_to_batch_dim(value)
545+
546+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
547+
hidden_states = torch.bmm(attention_probs, value)
548+
hidden_states = attn.batch_to_head_dim(hidden_states)
549+
550+
if self.train_q_out:
551+
# linear proj
552+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
553+
# dropout
554+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
555+
else:
556+
# linear proj
557+
hidden_states = attn.to_out[0](hidden_states)
558+
# dropout
559+
hidden_states = attn.to_out[1](hidden_states)
560+
561+
return hidden_states
562+
563+
462564
class AttnAddedKVProcessor:
463565
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
464566
residual = hidden_states
@@ -699,6 +801,91 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
699801
return hidden_states
700802

701803

804+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
805+
def __init__(
806+
self,
807+
train_kv=True,
808+
train_q_out=False,
809+
hidden_size=None,
810+
cross_attention_dim=None,
811+
out_bias=True,
812+
dropout=0.0,
813+
attention_op: Optional[Callable] = None,
814+
):
815+
super().__init__()
816+
self.train_kv = train_kv
817+
self.train_q_out = train_q_out
818+
819+
self.hidden_size = hidden_size
820+
self.cross_attention_dim = cross_attention_dim
821+
self.attention_op = attention_op
822+
823+
# `_custom_diffusion` id for easy serialization and loading.
824+
if self.train_kv:
825+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
826+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
827+
if self.train_q_out:
828+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
829+
self.to_out_custom_diffusion = nn.ModuleList([])
830+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
831+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
832+
833+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
834+
batch_size, sequence_length, _ = (
835+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
836+
)
837+
838+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
839+
840+
if self.train_q_out:
841+
query = self.to_q_custom_diffusion(hidden_states)
842+
else:
843+
query = attn.to_q(hidden_states)
844+
845+
if encoder_hidden_states is None:
846+
crossattn = False
847+
encoder_hidden_states = hidden_states
848+
else:
849+
crossattn = True
850+
if attn.norm_cross:
851+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
852+
853+
if self.train_kv:
854+
key = self.to_k_custom_diffusion(encoder_hidden_states)
855+
value = self.to_v_custom_diffusion(encoder_hidden_states)
856+
else:
857+
key = attn.to_k(encoder_hidden_states)
858+
value = attn.to_v(encoder_hidden_states)
859+
860+
if crossattn:
861+
detach = torch.ones_like(key)
862+
detach[:, :1, :] = detach[:, :1, :] * 0.0
863+
key = detach * key + (1 - detach) * key.detach()
864+
value = detach * value + (1 - detach) * value.detach()
865+
866+
query = attn.head_to_batch_dim(query).contiguous()
867+
key = attn.head_to_batch_dim(key).contiguous()
868+
value = attn.head_to_batch_dim(value).contiguous()
869+
870+
hidden_states = xformers.ops.memory_efficient_attention(
871+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
872+
)
873+
hidden_states = hidden_states.to(query.dtype)
874+
hidden_states = attn.batch_to_head_dim(hidden_states)
875+
876+
if self.train_q_out:
877+
# linear proj
878+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
879+
# dropout
880+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
881+
else:
882+
# linear proj
883+
hidden_states = attn.to_out[0](hidden_states)
884+
# dropout
885+
hidden_states = attn.to_out[1](hidden_states)
886+
return hidden_states
887+
888+
702889
class SlicedAttnProcessor:
703890
def __init__(self, slice_size):
704891
self.slice_size = slice_size
@@ -834,4 +1021,6 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
8341021
AttnAddedKVProcessor2_0,
8351022
LoRAAttnProcessor,
8361023
LoRAXFormersAttnProcessor,
1024+
CustomDiffusionAttnProcessor,
1025+
CustomDiffusionXFormersAttnProcessor,
8371026
]

0 commit comments

Comments
 (0)