Skip to content

Commit ec4d381

Browse files
committed
bug fix
1 parent 3fbd1cf commit ec4d381

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10361036

10371037
if group_world_size >1:
10381038
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1039-
out = _all_to_all_single(x_temp, None, None, group=group)
1039+
out = _all_to_all_single(x_temp, group=group)
10401040
#out = _wait_tensor(out)
10411041
else:
10421042
out = x_temp
@@ -1055,7 +1055,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10551055

10561056
if group_world_size >1:
10571057
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1058-
output = _all_to_all_single(x_temp, None, None, group)
1058+
output = _all_to_all_single(x_temp, group)
10591059
#output = _wait_tensor(output)
10601060
else:
10611061
output = x_temp

0 commit comments

Comments
 (0)