diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 3ad835ceeeb0..49512afb1c8e 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -97,7 +97,8 @@ def __call__( hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - freqs_cis: Optional[torch.Tensor] = None, + freqs_cos: Optional[torch.Tensor] = None, + freqs_sin: Optional[torch.Tensor] = None, ) -> torch.Tensor: query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -113,17 +114,26 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - # Apply RoPE - def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast("cuda", enabled=False): - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) # todo - - if freqs_cis is not None: - query = apply_rotary_emb(query, freqs_cis) - key = apply_rotary_emb(key, freqs_cis) + # # Apply RoPE + # def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + # with torch.amp.autocast("cuda", enabled=False): + # x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + # freqs_cis = freqs_cis.unsqueeze(2) + # x_out = torch.view_as_real(x * freqs_cis).flatten(3) + # return x_out.type_as(x_in) # todo + + def apply_rotary_emb(x_in: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor: + freqs_cos = freqs_cos.unsqueeze(2) # [batch, seq, 1, head_dim//2] + freqs_sin = freqs_sin.unsqueeze(2) + x = x_in.reshape(*x_in.shape[:-1], -1, 2) + x0, x1 = x[..., 0], x[..., 1] + out0 = x0 * freqs_cos - x1 * freqs_sin + out1 = x0 * freqs_sin + x1 * freqs_cos + return torch.stack([out0, out1], dim=-1).flatten(-2).type_as(x_in) + + if freqs_cos is not None and freqs_sin is not None: + query = apply_rotary_emb(query, freqs_cos, freqs_sin) + key = apply_rotary_emb(key, freqs_cos, freqs_sin) # Cast to correct dtype dtype = query.dtype @@ -219,7 +229,8 @@ def forward( self, x: torch.Tensor, attn_mask: torch.Tensor, - freqs_cis: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, ): if self.modulation: @@ -232,7 +243,8 @@ def forward( attn_out = self.attention( self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, - freqs_cis=freqs_cis, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, ) x = x + gate_msa * self.attention_norm2(attn_out) @@ -247,7 +259,8 @@ def forward( attn_out = self.attention( self.attention_norm1(x), attention_mask=attn_mask, - freqs_cis=freqs_cis, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, ) x = x + self.attention_norm2(attn_out) @@ -290,39 +303,48 @@ def __init__( self.axes_dims = axes_dims self.axes_lens = axes_lens assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" - self.freqs_cis = None + self.freqs_cos = None + self.freqs_sin = None @staticmethod def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): with torch.device("cpu"): - freqs_cis = [] + freqs_cos = [] + freqs_sin = [] for i, (d, e) in enumerate(zip(dim, end)): freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 - freqs_cis.append(freqs_cis_i) + freqs_cos.append(freqs.cos()) + freqs_sin.append(freqs.sin()) + # freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + # freqs_cis.append(freqs_cis_i) - return freqs_cis + return freqs_cos, freqs_sin def __call__(self, ids: torch.Tensor): assert ids.ndim == 2 assert ids.shape[-1] == len(self.axes_dims) device = ids.device - if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + if self.freqs_cos is None or self.freqs_sin is None: + self.freqs_cos, self.freqs_sin = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cos = [f.to(device) for f in self.freqs_cos] + self.freqs_sin = [f.to(device) for f in self.freqs_sin] else: # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + if self.freqs_cos[0].device != device: + self.freqs_cos = [f.to(device) for f in self.freqs_cos] + if self.freqs_sin[0].device != device: + self.freqs_sin = [f.to(device) for f in self.freqs_sin] - result = [] + cos_result = [] + sin_result = [] for i in range(len(self.axes_dims)): index = ids[:, i] - result.append(self.freqs_cis[i][index]) - return torch.cat(result, dim=-1) + cos_result.append(self.freqs_cos[i][index]) + sin_result.append(self.freqs_sin[i][index]) + return torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1) class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): @@ -587,20 +609,23 @@ def forward( adaln_input = t.type_as(x) x[torch.cat(x_inner_pad_mask)] = self.x_pad_token x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + x_freqs_cos, x_freqs_sin =self.rope_embedder(torch.cat(x_pos_ids, dim=0)) + x_freqs_cos = list(x_freqs_cos.split(x_item_seqlens, dim=0)) + x_freqs_sin = list(x_freqs_sin.split(x_item_seqlens, dim=0)) x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_freqs_cos = pad_sequence(x_freqs_cos, batch_first=True, padding_value=0.0) + x_freqs_sin = pad_sequence(x_freqs_sin, batch_first=True, padding_value=0.0) x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(x_item_seqlens): x_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.noise_refiner: - x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input) else: for layer in self.noise_refiner: - x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + x = layer(x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input) # cap embed & refine cap_item_seqlens = [len(_) for _ in cap_feats] @@ -611,35 +636,41 @@ def forward( cap_feats = self.cap_embedder(cap_feats) cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + cap_freqs_cos, cap_freqs_sin = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)) + cap_freqs_cos = list(cap_freqs_cos.split(cap_item_seqlens, dim=0)) + cap_freqs_sin = list(cap_freqs_sin.split(cap_item_seqlens, dim=0)) cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_freqs_cos = pad_sequence(cap_freqs_cos, batch_first=True, padding_value=0.0) + cap_freqs_sin = pad_sequence(cap_freqs_sin, batch_first=True, padding_value=0.0) cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(cap_item_seqlens): cap_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.context_refiner: - cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin) else: for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin) # unified unified = [] - unified_freqs_cis = [] + unified_freqs_cos = [] + unified_freqs_sin = [] for i in range(bsz): x_len = x_item_seqlens[i] cap_len = cap_item_seqlens[i] unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) - unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_freqs_cos.append(torch.cat([x_freqs_cos[i][:x_len], cap_freqs_cos[i][:cap_len]])) + unified_freqs_sin.append(torch.cat([x_freqs_sin[i][:x_len], cap_freqs_sin[i][:cap_len]])) unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] assert unified_item_seqlens == [len(_) for _ in unified] unified_max_item_seqlen = max(unified_item_seqlens) unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_freqs_cos = pad_sequence(unified_freqs_cos, batch_first=True, padding_value=0.0) + unified_freqs_sin = pad_sequence(unified_freqs_sin, batch_first=True, padding_value=0.0) unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(unified_item_seqlens): unified_attn_mask[i, :seq_len] = 1 @@ -647,11 +678,11 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.layers: unified = self._gradient_checkpointing_func( - layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input + layer, unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input ) else: for layer in self.layers: - unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + unified = layer(unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input) unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = list(unified.unbind(dim=0))