Skip to content

Commit 23c6b0d

Browse files
committed
bug fix
1 parent ec4d381 commit 23c6b0d

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,9 +1030,8 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10301030
H_LOCAL = H // group_world_size
10311031

10321032
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1033-
x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D)
1034-
.transpose(0, 2).contiguous()
1035-
)
1033+
x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous()
1034+
10361035

10371036
if group_world_size >1:
10381037
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
@@ -1050,8 +1049,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10501049
S_LOCAL = S // group_world_size
10511050

10521051
#B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1053-
x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D)
1054-
.permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D))
1052+
x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)
10551053

10561054
if group_world_size >1:
10571055
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues

0 commit comments

Comments
 (0)