Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 72 additions & 41 deletions src/diffusers/models/transformers/transformer_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __call__(
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
Expand All @@ -113,17 +114,26 @@ def __call__(
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo

if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
# # Apply RoPE
# def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
# with torch.amp.autocast("cuda", enabled=False):
# x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
# freqs_cis = freqs_cis.unsqueeze(2)
# x_out = torch.view_as_real(x * freqs_cis).flatten(3)
# return x_out.type_as(x_in) # todo

def apply_rotary_emb(x_in: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
freqs_cos = freqs_cos.unsqueeze(2) # [batch, seq, 1, head_dim//2]
freqs_sin = freqs_sin.unsqueeze(2)
x = x_in.reshape(*x_in.shape[:-1], -1, 2)
x0, x1 = x[..., 0], x[..., 1]
out0 = x0 * freqs_cos - x1 * freqs_sin
out1 = x0 * freqs_sin + x1 * freqs_cos
return torch.stack([out0, out1], dim=-1).flatten(-2).type_as(x_in)

if freqs_cos is not None and freqs_sin is not None:
query = apply_rotary_emb(query, freqs_cos, freqs_sin)
key = apply_rotary_emb(key, freqs_cos, freqs_sin)

# Cast to correct dtype
dtype = query.dtype
Expand Down Expand Up @@ -219,7 +229,8 @@ def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
):
if self.modulation:
Expand All @@ -232,7 +243,8 @@ def forward(
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
attention_mask=attn_mask,
freqs_cis=freqs_cis,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
)
x = x + gate_msa * self.attention_norm2(attn_out)

Expand All @@ -247,7 +259,8 @@ def forward(
attn_out = self.attention(
self.attention_norm1(x),
attention_mask=attn_mask,
freqs_cis=freqs_cis,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
)
x = x + self.attention_norm2(attn_out)

Expand Down Expand Up @@ -290,39 +303,48 @@ def __init__(
self.axes_dims = axes_dims
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None
self.freqs_cos = None
self.freqs_sin = None

@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
with torch.device("cpu"):
freqs_cis = []
freqs_cos = []
freqs_sin = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)
freqs_cos.append(freqs.cos())
freqs_sin.append(freqs.sin())
# freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
# freqs_cis.append(freqs_cis_i)

return freqs_cis
return freqs_cos, freqs_sin

def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device

if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
if self.freqs_cos is None or self.freqs_sin is None:
self.freqs_cos, self.freqs_sin = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cos = [f.to(device) for f in self.freqs_cos]
self.freqs_sin = [f.to(device) for f in self.freqs_sin]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
if self.freqs_cos[0].device != device:
self.freqs_cos = [f.to(device) for f in self.freqs_cos]
if self.freqs_sin[0].device != device:
self.freqs_sin = [f.to(device) for f in self.freqs_sin]

result = []
cos_result = []
sin_result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)
cos_result.append(self.freqs_cos[i][index])
sin_result.append(self.freqs_sin[i][index])
return torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1)


class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
Expand Down Expand Up @@ -587,20 +609,23 @@ def forward(
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x_freqs_cos, x_freqs_sin =self.rope_embedder(torch.cat(x_pos_ids, dim=0))
x_freqs_cos = list(x_freqs_cos.split(x_item_seqlens, dim=0))
x_freqs_sin = list(x_freqs_sin.split(x_item_seqlens, dim=0))

x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_freqs_cos = pad_sequence(x_freqs_cos, batch_first=True, padding_value=0.0)
x_freqs_sin = pad_sequence(x_freqs_sin, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1

if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.noise_refiner:
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input)
else:
for layer in self.noise_refiner:
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
x = layer(x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input)

# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
Expand All @@ -611,47 +636,53 @@ def forward(
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_freqs_cos, cap_freqs_sin = self.rope_embedder(torch.cat(cap_pos_ids, dim=0))
cap_freqs_cos = list(cap_freqs_cos.split(cap_item_seqlens, dim=0))
cap_freqs_sin = list(cap_freqs_sin.split(cap_item_seqlens, dim=0))

cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_freqs_cos = pad_sequence(cap_freqs_cos, batch_first=True, padding_value=0.0)
cap_freqs_sin = pad_sequence(cap_freqs_sin, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1

if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.context_refiner:
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin)
else:
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin)

# unified
unified = []
unified_freqs_cis = []
unified_freqs_cos = []
unified_freqs_sin = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_freqs_cos.append(torch.cat([x_freqs_cos[i][:x_len], cap_freqs_cos[i][:cap_len]]))
unified_freqs_sin.append(torch.cat([x_freqs_sin[i][:x_len], cap_freqs_sin[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)

unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_freqs_cos = pad_sequence(unified_freqs_cos, batch_first=True, padding_value=0.0)
unified_freqs_sin = pad_sequence(unified_freqs_sin, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1

if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.layers:
unified = self._gradient_checkpointing_func(
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
layer, unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input
)
else:
for layer in self.layers:
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
unified = layer(unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input)

unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
Expand Down
Loading