@@ -1034,14 +1034,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10341034 H_LOCAL = H // group_world_size
10351035
10361036 # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1037- x_temp = (x .reshape (B , S_LOCAL , group_world_size , H_LOCAL , D )
1038- .transpose (0 , 2 ).contiguous ()
1039- )
1037+ x_temp = x .reshape (B , S_LOCAL , group_world_size , H_LOCAL , D ).transpose (0 , 2 ).contiguous ()
1038+
10401039
10411040 if group_world_size > 1 :
10421041 #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1043- out = funcol . all_to_all_single (x_temp , None , None , group = group )
1044- out = _wait_tensor (out )
1042+ out = _all_to_all_single (x_temp , group = group )
1043+ # out = _wait_tensor(out)
10451044 else :
10461045 out = x_temp
10471046 # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
@@ -1053,14 +1052,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10531052 H = H_LOCAL * group_world_size
10541053 S_LOCAL = S // group_world_size
10551054
1056- #
1057- x_temp = (x .reshape (B , group_world_size , S_LOCAL , H_LOCAL , D )
1058- .permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D ))
1055+ #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1056+ x_temp = x .reshape (B , group_world_size , S_LOCAL , H_LOCAL , D ).permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D )
10591057
10601058 if group_world_size > 1 :
10611059 #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1062- output = funcol . all_to_all_single (x_temp , None , None , group )
1063- output = _wait_tensor (output )
1060+ output = _all_to_all_single (x_temp , group )
1061+ # output = _wait_tensor(output)
10641062 else :
10651063 output = x_temp
10661064 output = output .reshape (H , S_LOCAL , B , D ).transpose (0 , 2 ).contiguous ()
@@ -1079,8 +1077,14 @@ def forward(ctx, group, input, scatter_id=2, gather_id=1):
10791077 return _all_to_all_dim_exchange (input , scatter_id , gather_id , group )
10801078
10811079 @staticmethod
1082- def backward (ctx , * grad_outputs ):
1083- return (None , _all_to_all_dim_exchange (grad_outputs [0 ], ctx .gather_id , ctx .scatter_id , ctx .group ), None , None )
1080+ def backward (ctx , grad_outputs ):
1081+ grad_input = SeqAllToAllDim .apply (
1082+ ctx .group ,
1083+ grad_outputs ,
1084+ ctx .gather_id , # reversed
1085+ ctx .scatter_id , # reversed
1086+ )
1087+ return (None , grad_input , None , None )
10841088
10851089
10861090
@@ -1302,62 +1306,64 @@ def backward(
13021306
13031307 return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
13041308
1305- class TemplatedUnifiedAttention (torch .nn .Module ):
1306- @staticmethod
1307- def forward (ctx : torch .autograd .function .FunctionCtx ,
1308- query : torch .Tensor ,
1309- key : torch .Tensor ,
1310- value : torch .Tensor ,
1311- attn_mask : Optional [torch .Tensor ],
1312- dropout_p : float ,
1313- is_causal : bool ,
1314- scale : Optional [float ],
1315- enable_gqa : bool ,
1316- return_lse : bool ,
1309+ def TemplatedUnifiedAttention (
1310+ query : torch .Tensor ,
1311+ key : torch .Tensor ,
1312+ value : torch .Tensor ,
1313+ attn_mask : Optional [torch .Tensor ],
1314+ dropout_p : float ,
1315+ is_causal : bool ,
1316+ scale : Optional [float ],
1317+ enable_gqa : bool ,
1318+ return_lse : bool ,
1319+ forward_op ,
1320+ backward_op ,
1321+ _parallel_config : Optional ["ParallelConfig" ] = None ,
1322+ ):
1323+ ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
1324+ ulysses_group = ulysses_mesh .get_group ()
1325+ ring_mesh = _parallel_config .context_parallel_config ._ring_mesh
1326+ ring_group = ring_mesh .get_group ()
1327+ #hardcoded for now
1328+ scatter_idx = 2
1329+ gather_idx = 1
1330+
1331+ query = SeqAllToAllDim .apply (ulysses_group , query , scatter_idx , gather_idx )
1332+ key = SeqAllToAllDim .apply (ulysses_group , key , scatter_idx , gather_idx )
1333+ value = SeqAllToAllDim .apply (ulysses_group , value , scatter_idx , gather_idx )
1334+ out = TemplatedRingAttention .apply (
1335+ query ,
1336+ key ,
1337+ value ,
1338+ attn_mask ,
1339+ dropout_p ,
1340+ is_causal ,
1341+ scale ,
1342+ enable_gqa ,
1343+ return_lse ,
13171344 forward_op ,
13181345 backward_op ,
1319- _parallel_config : Optional ["ParallelConfig" ] = None ,
1320- ):
1321- ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
1322- ulysses_group = ulysses_mesh .get_group ()
1323- ring_mesh = _parallel_config .context_parallel_config ._ring_mesh
1324- ring_group = ring_mesh .get_group ()
1325- #hardcoded for now
1326- scatter_idx = 2
1327- gather_idx = 1
1328-
1329- query = SeqAllToAllDim .apply (ulysses_group , query , scatter_idx , gather_idx )
1330- key = SeqAllToAllDim .apply (ulysses_group , key , scatter_idx , gather_idx )
1331- value = SeqAllToAllDim .apply (ulysses_group , value , scatter_idx , gather_idx )
1332- out = TemplatedRingAttention .apply (
1333- query ,
1334- key ,
1335- value ,
1336- attn_mask ,
1337- dropout_p ,
1338- is_causal ,
1339- scale ,
1340- enable_gqa ,
1341- return_lse ,
1342- forward_op ,
1343- backward_op ,
1344- _parallel_config ,
1345- )
1346- if return_lse :
1347- context_layer , lse , * _ = out
1348- else :
1349- context_layer = out
1350- output = SeqAllToAllDim .apply (
1351- ulysses_group ,
1352- context_layer ,
1353- gather_idx ,
1354- scatter_idx ,
1355- )
1356- if return_lse :
1357- # not sure if this is correct
1358- lse = SeqAllToAllDim .apply (ulysses_group , lse , gather_idx , scatter_idx )
1359- return (output , lse )
1360- return output
1346+ _parallel_config ,
1347+ )
1348+ if return_lse :
1349+ context_layer , lse , * _ = out
1350+ else :
1351+ context_layer = out
1352+ # Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D)
1353+ output = SeqAllToAllDim .apply (
1354+ ulysses_group ,
1355+ context_layer ,
1356+ gather_idx ,
1357+ scatter_idx ,
1358+ )
1359+ if return_lse :
1360+ # not sure if this is correct: Assuming (based on forward ops in ringAttention)
1361+ # the lse is of shape (B, S, H_LOCAL)
1362+ lse = lse .unsqueeze (- 1 ) # (B, S, H_LOCAL, 1)
1363+ lse = SeqAllToAllDim .apply (ulysses_group , lse , scatter_idx = 2 , gather_idx = 1 )
1364+ lse = lse .squeeze (- 1 )
1365+ return (output , lse )
1366+ return output
13611367
13621368def _templated_context_parallel_attention (
13631369 query : torch .Tensor ,
@@ -1382,7 +1388,22 @@ def _templated_context_parallel_attention(
13821388 raise ValueError ("GQA is not yet supported for templated attention." )
13831389
13841390 # TODO: add support for unified attention with ring/ulysses degree both being > 1
1385- if _parallel_config .context_parallel_config .ring_degree > 1 :
1391+ if _parallel_config .context_parallel_config .ring_degree > 1 and _parallel_config .context_parallel_config .ulysses_degree > 1 :
1392+ return TemplatedUnifiedAttention (
1393+ query ,
1394+ key ,
1395+ value ,
1396+ attn_mask ,
1397+ dropout_p ,
1398+ is_causal ,
1399+ scale ,
1400+ enable_gqa ,
1401+ return_lse ,
1402+ forward_op ,
1403+ backward_op ,
1404+ _parallel_config ,
1405+ )
1406+ elif _parallel_config .context_parallel_config .ring_degree > 1 :
13861407 return TemplatedRingAttention .apply (
13871408 query ,
13881409 key ,
0 commit comments