Skip to content

Commit 9dee8f8

Browse files
committed
bug fix, lse calculation, testing
bug fixes, lse calculation - switched to _all_to_all_single helper in _all_to_all_dim_exchange due contiguity issues bug fix bug fix bug fix
1 parent 925da4e commit 9dee8f8

File tree

2 files changed

+89
-70
lines changed

2 files changed

+89
-70
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 88 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,14 +1030,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10301030
H_LOCAL = H // group_world_size
10311031

10321032
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1033-
x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D)
1034-
.transpose(0, 2).contiguous()
1035-
)
1033+
x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous()
1034+
10361035

10371036
if group_world_size >1:
10381037
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1039-
out = funcol.all_to_all_single(x_temp, None, None, group=group)
1040-
out = _wait_tensor(out)
1038+
out = _all_to_all_single(x_temp, group=group)
1039+
#out = _wait_tensor(out)
10411040
else:
10421041
out = x_temp
10431042
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
@@ -1049,14 +1048,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10491048
H = H_LOCAL * group_world_size
10501049
S_LOCAL = S // group_world_size
10511050

1052-
#
1053-
x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D)
1054-
.permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D))
1051+
#B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1052+
x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)
10551053

10561054
if group_world_size >1:
10571055
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1058-
output = funcol.all_to_all_single(x_temp, None, None, group)
1059-
output = _wait_tensor(output)
1056+
output = _all_to_all_single(x_temp, group)
1057+
#output = _wait_tensor(output)
10601058
else:
10611059
output = x_temp
10621060
output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous()
@@ -1075,8 +1073,14 @@ def forward(ctx, group, input, scatter_id=2, gather_id=1):
10751073
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
10761074

10771075
@staticmethod
1078-
def backward(ctx, *grad_outputs):
1079-
return (None, _all_to_all_dim_exchange(grad_outputs[0], ctx.gather_id, ctx.scatter_id, ctx.group), None, None)
1076+
def backward(ctx, grad_outputs):
1077+
grad_input = SeqAllToAllDim.apply(
1078+
ctx.group,
1079+
grad_outputs,
1080+
ctx.gather_id, # reversed
1081+
ctx.scatter_id, # reversed
1082+
)
1083+
return (None, grad_input, None, None)
10801084

10811085

10821086

@@ -1298,62 +1302,64 @@ def backward(
12981302

12991303
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
13001304

1301-
class TemplatedUnifiedAttention(torch.nn.Module):
1302-
@staticmethod
1303-
def forward(ctx: torch.autograd.function.FunctionCtx,
1304-
query: torch.Tensor,
1305-
key: torch.Tensor,
1306-
value: torch.Tensor,
1307-
attn_mask: Optional[torch.Tensor],
1308-
dropout_p: float,
1309-
is_causal: bool,
1310-
scale: Optional[float],
1311-
enable_gqa: bool,
1312-
return_lse: bool,
1305+
def TemplatedUnifiedAttention(
1306+
query: torch.Tensor,
1307+
key: torch.Tensor,
1308+
value: torch.Tensor,
1309+
attn_mask: Optional[torch.Tensor],
1310+
dropout_p: float,
1311+
is_causal: bool,
1312+
scale: Optional[float],
1313+
enable_gqa: bool,
1314+
return_lse: bool,
1315+
forward_op,
1316+
backward_op,
1317+
_parallel_config: Optional["ParallelConfig"] = None,
1318+
):
1319+
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
1320+
ulysses_group = ulysses_mesh.get_group()
1321+
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
1322+
ring_group = ring_mesh.get_group()
1323+
#hardcoded for now
1324+
scatter_idx = 2
1325+
gather_idx = 1
1326+
1327+
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
1328+
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
1329+
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
1330+
out = TemplatedRingAttention.apply(
1331+
query,
1332+
key,
1333+
value,
1334+
attn_mask,
1335+
dropout_p,
1336+
is_causal,
1337+
scale,
1338+
enable_gqa,
1339+
return_lse,
13131340
forward_op,
13141341
backward_op,
1315-
_parallel_config: Optional["ParallelConfig"] = None,
1316-
):
1317-
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
1318-
ulysses_group = ulysses_mesh.get_group()
1319-
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
1320-
ring_group = ring_mesh.get_group()
1321-
#hardcoded for now
1322-
scatter_idx = 2
1323-
gather_idx = 1
1324-
1325-
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
1326-
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
1327-
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
1328-
out = TemplatedRingAttention.apply(
1329-
query,
1330-
key,
1331-
value,
1332-
attn_mask,
1333-
dropout_p,
1334-
is_causal,
1335-
scale,
1336-
enable_gqa,
1337-
return_lse,
1338-
forward_op,
1339-
backward_op,
1340-
_parallel_config,
1341-
)
1342-
if return_lse:
1343-
context_layer, lse, *_ = out
1344-
else:
1345-
context_layer = out
1346-
output = SeqAllToAllDim.apply(
1347-
ulysses_group,
1348-
context_layer,
1349-
gather_idx,
1350-
scatter_idx,
1351-
)
1352-
if return_lse:
1353-
# not sure if this is correct
1354-
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
1355-
return (output, lse)
1356-
return output
1342+
_parallel_config,
1343+
)
1344+
if return_lse:
1345+
context_layer, lse, *_ = out
1346+
else:
1347+
context_layer = out
1348+
# Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D)
1349+
output = SeqAllToAllDim.apply(
1350+
ulysses_group,
1351+
context_layer,
1352+
gather_idx,
1353+
scatter_idx,
1354+
)
1355+
if return_lse:
1356+
# not sure if this is correct: Assuming (based on forward ops in ringAttention)
1357+
# the lse is of shape (B, S, H_LOCAL)
1358+
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
1359+
lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1)
1360+
lse = lse.squeeze(-1)
1361+
return (output, lse)
1362+
return output
13571363

13581364
def _templated_context_parallel_attention(
13591365
query: torch.Tensor,
@@ -1378,7 +1384,22 @@ def _templated_context_parallel_attention(
13781384
raise ValueError("GQA is not yet supported for templated attention.")
13791385

13801386
# TODO: add support for unified attention with ring/ulysses degree both being > 1
1381-
if _parallel_config.context_parallel_config.ring_degree > 1:
1387+
if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1:
1388+
return TemplatedUnifiedAttention(
1389+
query,
1390+
key,
1391+
value,
1392+
attn_mask,
1393+
dropout_p,
1394+
is_causal,
1395+
scale,
1396+
enable_gqa,
1397+
return_lse,
1398+
forward_op,
1399+
backward_op,
1400+
_parallel_config,
1401+
)
1402+
elif _parallel_config.context_parallel_config.ring_degree > 1:
13821403
return TemplatedRingAttention.apply(
13831404
query,
13841405
key,

tests/others/test_unified_sp_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,8 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs):
102102
grad_v,
103103
)
104104

105-
attn = TemplatedUnifiedAttention()
106105

107-
out = attn(
108-
None,
106+
out = TemplatedUnifiedAttention(
109107
q, k, v, None,
110108
dropout_p=0.0,
111109
is_causal=False,

0 commit comments

Comments
 (0)