Skip to content

Commit 9ebcff5

Browse files
committed
bug fix, lse calculation, testing
bug fixes, lse calculation - switched to _all_to_all_single helper in _all_to_all_dim_exchange due contiguity issues bug fix bug fix bug fix
1 parent 3a407d8 commit 9ebcff5

File tree

2 files changed

+89
-70
lines changed

2 files changed

+89
-70
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 88 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,14 +1034,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10341034
H_LOCAL = H // group_world_size
10351035

10361036
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1037-
x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D)
1038-
.transpose(0, 2).contiguous()
1039-
)
1037+
x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous()
1038+
10401039

10411040
if group_world_size >1:
10421041
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1043-
out = funcol.all_to_all_single(x_temp, None, None, group=group)
1044-
out = _wait_tensor(out)
1042+
out = _all_to_all_single(x_temp, group=group)
1043+
#out = _wait_tensor(out)
10451044
else:
10461045
out = x_temp
10471046
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
@@ -1053,14 +1052,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10531052
H = H_LOCAL * group_world_size
10541053
S_LOCAL = S // group_world_size
10551054

1056-
#
1057-
x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D)
1058-
.permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D))
1055+
#B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1056+
x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)
10591057

10601058
if group_world_size >1:
10611059
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1062-
output = funcol.all_to_all_single(x_temp, None, None, group)
1063-
output = _wait_tensor(output)
1060+
output = _all_to_all_single(x_temp, group)
1061+
#output = _wait_tensor(output)
10641062
else:
10651063
output = x_temp
10661064
output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous()
@@ -1079,8 +1077,14 @@ def forward(ctx, group, input, scatter_id=2, gather_id=1):
10791077
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
10801078

10811079
@staticmethod
1082-
def backward(ctx, *grad_outputs):
1083-
return (None, _all_to_all_dim_exchange(grad_outputs[0], ctx.gather_id, ctx.scatter_id, ctx.group), None, None)
1080+
def backward(ctx, grad_outputs):
1081+
grad_input = SeqAllToAllDim.apply(
1082+
ctx.group,
1083+
grad_outputs,
1084+
ctx.gather_id, # reversed
1085+
ctx.scatter_id, # reversed
1086+
)
1087+
return (None, grad_input, None, None)
10841088

10851089

10861090

@@ -1302,62 +1306,64 @@ def backward(
13021306

13031307
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
13041308

1305-
class TemplatedUnifiedAttention(torch.nn.Module):
1306-
@staticmethod
1307-
def forward(ctx: torch.autograd.function.FunctionCtx,
1308-
query: torch.Tensor,
1309-
key: torch.Tensor,
1310-
value: torch.Tensor,
1311-
attn_mask: Optional[torch.Tensor],
1312-
dropout_p: float,
1313-
is_causal: bool,
1314-
scale: Optional[float],
1315-
enable_gqa: bool,
1316-
return_lse: bool,
1309+
def TemplatedUnifiedAttention(
1310+
query: torch.Tensor,
1311+
key: torch.Tensor,
1312+
value: torch.Tensor,
1313+
attn_mask: Optional[torch.Tensor],
1314+
dropout_p: float,
1315+
is_causal: bool,
1316+
scale: Optional[float],
1317+
enable_gqa: bool,
1318+
return_lse: bool,
1319+
forward_op,
1320+
backward_op,
1321+
_parallel_config: Optional["ParallelConfig"] = None,
1322+
):
1323+
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
1324+
ulysses_group = ulysses_mesh.get_group()
1325+
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
1326+
ring_group = ring_mesh.get_group()
1327+
#hardcoded for now
1328+
scatter_idx = 2
1329+
gather_idx = 1
1330+
1331+
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
1332+
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
1333+
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
1334+
out = TemplatedRingAttention.apply(
1335+
query,
1336+
key,
1337+
value,
1338+
attn_mask,
1339+
dropout_p,
1340+
is_causal,
1341+
scale,
1342+
enable_gqa,
1343+
return_lse,
13171344
forward_op,
13181345
backward_op,
1319-
_parallel_config: Optional["ParallelConfig"] = None,
1320-
):
1321-
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
1322-
ulysses_group = ulysses_mesh.get_group()
1323-
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
1324-
ring_group = ring_mesh.get_group()
1325-
#hardcoded for now
1326-
scatter_idx = 2
1327-
gather_idx = 1
1328-
1329-
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
1330-
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
1331-
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
1332-
out = TemplatedRingAttention.apply(
1333-
query,
1334-
key,
1335-
value,
1336-
attn_mask,
1337-
dropout_p,
1338-
is_causal,
1339-
scale,
1340-
enable_gqa,
1341-
return_lse,
1342-
forward_op,
1343-
backward_op,
1344-
_parallel_config,
1345-
)
1346-
if return_lse:
1347-
context_layer, lse, *_ = out
1348-
else:
1349-
context_layer = out
1350-
output = SeqAllToAllDim.apply(
1351-
ulysses_group,
1352-
context_layer,
1353-
gather_idx,
1354-
scatter_idx,
1355-
)
1356-
if return_lse:
1357-
# not sure if this is correct
1358-
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
1359-
return (output, lse)
1360-
return output
1346+
_parallel_config,
1347+
)
1348+
if return_lse:
1349+
context_layer, lse, *_ = out
1350+
else:
1351+
context_layer = out
1352+
# Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D)
1353+
output = SeqAllToAllDim.apply(
1354+
ulysses_group,
1355+
context_layer,
1356+
gather_idx,
1357+
scatter_idx,
1358+
)
1359+
if return_lse:
1360+
# not sure if this is correct: Assuming (based on forward ops in ringAttention)
1361+
# the lse is of shape (B, S, H_LOCAL)
1362+
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
1363+
lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1)
1364+
lse = lse.squeeze(-1)
1365+
return (output, lse)
1366+
return output
13611367

13621368
def _templated_context_parallel_attention(
13631369
query: torch.Tensor,
@@ -1382,7 +1388,22 @@ def _templated_context_parallel_attention(
13821388
raise ValueError("GQA is not yet supported for templated attention.")
13831389

13841390
# TODO: add support for unified attention with ring/ulysses degree both being > 1
1385-
if _parallel_config.context_parallel_config.ring_degree > 1:
1391+
if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1:
1392+
return TemplatedUnifiedAttention(
1393+
query,
1394+
key,
1395+
value,
1396+
attn_mask,
1397+
dropout_p,
1398+
is_causal,
1399+
scale,
1400+
enable_gqa,
1401+
return_lse,
1402+
forward_op,
1403+
backward_op,
1404+
_parallel_config,
1405+
)
1406+
elif _parallel_config.context_parallel_config.ring_degree > 1:
13861407
return TemplatedRingAttention.apply(
13871408
query,
13881409
key,

tests/others/test_unified_sp_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,8 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs):
102102
grad_v,
103103
)
104104

105-
attn = TemplatedUnifiedAttention()
106105

107-
out = attn(
108-
None,
106+
out = TemplatedUnifiedAttention(
109107
q, k, v, None,
110108
dropout_p=0.0,
111109
is_causal=False,

0 commit comments

Comments
 (0)