@@ -1030,14 +1030,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10301030 H_LOCAL = H // group_world_size
10311031
10321032 # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1033- x_temp = (x .reshape (B , S_LOCAL , group_world_size , H_LOCAL , D )
1034- .transpose (0 , 2 ).contiguous ()
1035- )
1033+ x_temp = x .reshape (B , S_LOCAL , group_world_size , H_LOCAL , D ).transpose (0 , 2 ).contiguous ()
1034+
10361035
10371036 if group_world_size > 1 :
10381037 #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1039- out = funcol . all_to_all_single (x_temp , None , None , group = group )
1040- out = _wait_tensor (out )
1038+ out = _all_to_all_single (x_temp , group = group )
1039+ # out = _wait_tensor(out)
10411040 else :
10421041 out = x_temp
10431042 # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
@@ -1049,14 +1048,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10491048 H = H_LOCAL * group_world_size
10501049 S_LOCAL = S // group_world_size
10511050
1052- #
1053- x_temp = (x .reshape (B , group_world_size , S_LOCAL , H_LOCAL , D )
1054- .permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D ))
1051+ #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1052+ x_temp = x .reshape (B , group_world_size , S_LOCAL , H_LOCAL , D ).permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D )
10551053
10561054 if group_world_size > 1 :
10571055 #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1058- output = funcol . all_to_all_single (x_temp , None , None , group )
1059- output = _wait_tensor (output )
1056+ output = _all_to_all_single (x_temp , group )
1057+ # output = _wait_tensor(output)
10601058 else :
10611059 output = x_temp
10621060 output = output .reshape (H , S_LOCAL , B , D ).transpose (0 , 2 ).contiguous ()
@@ -1075,8 +1073,14 @@ def forward(ctx, group, input, scatter_id=2, gather_id=1):
10751073 return _all_to_all_dim_exchange (input , scatter_id , gather_id , group )
10761074
10771075 @staticmethod
1078- def backward (ctx , * grad_outputs ):
1079- return (None , _all_to_all_dim_exchange (grad_outputs [0 ], ctx .gather_id , ctx .scatter_id , ctx .group ), None , None )
1076+ def backward (ctx , grad_outputs ):
1077+ grad_input = SeqAllToAllDim .apply (
1078+ ctx .group ,
1079+ grad_outputs ,
1080+ ctx .gather_id , # reversed
1081+ ctx .scatter_id , # reversed
1082+ )
1083+ return (None , grad_input , None , None )
10801084
10811085
10821086
@@ -1298,62 +1302,64 @@ def backward(
12981302
12991303 return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
13001304
1301- class TemplatedUnifiedAttention (torch .nn .Module ):
1302- @staticmethod
1303- def forward (ctx : torch .autograd .function .FunctionCtx ,
1304- query : torch .Tensor ,
1305- key : torch .Tensor ,
1306- value : torch .Tensor ,
1307- attn_mask : Optional [torch .Tensor ],
1308- dropout_p : float ,
1309- is_causal : bool ,
1310- scale : Optional [float ],
1311- enable_gqa : bool ,
1312- return_lse : bool ,
1305+ def TemplatedUnifiedAttention (
1306+ query : torch .Tensor ,
1307+ key : torch .Tensor ,
1308+ value : torch .Tensor ,
1309+ attn_mask : Optional [torch .Tensor ],
1310+ dropout_p : float ,
1311+ is_causal : bool ,
1312+ scale : Optional [float ],
1313+ enable_gqa : bool ,
1314+ return_lse : bool ,
1315+ forward_op ,
1316+ backward_op ,
1317+ _parallel_config : Optional ["ParallelConfig" ] = None ,
1318+ ):
1319+ ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
1320+ ulysses_group = ulysses_mesh .get_group ()
1321+ ring_mesh = _parallel_config .context_parallel_config ._ring_mesh
1322+ ring_group = ring_mesh .get_group ()
1323+ #hardcoded for now
1324+ scatter_idx = 2
1325+ gather_idx = 1
1326+
1327+ query = SeqAllToAllDim .apply (ulysses_group , query , scatter_idx , gather_idx )
1328+ key = SeqAllToAllDim .apply (ulysses_group , key , scatter_idx , gather_idx )
1329+ value = SeqAllToAllDim .apply (ulysses_group , value , scatter_idx , gather_idx )
1330+ out = TemplatedRingAttention .apply (
1331+ query ,
1332+ key ,
1333+ value ,
1334+ attn_mask ,
1335+ dropout_p ,
1336+ is_causal ,
1337+ scale ,
1338+ enable_gqa ,
1339+ return_lse ,
13131340 forward_op ,
13141341 backward_op ,
1315- _parallel_config : Optional ["ParallelConfig" ] = None ,
1316- ):
1317- ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
1318- ulysses_group = ulysses_mesh .get_group ()
1319- ring_mesh = _parallel_config .context_parallel_config ._ring_mesh
1320- ring_group = ring_mesh .get_group ()
1321- #hardcoded for now
1322- scatter_idx = 2
1323- gather_idx = 1
1324-
1325- query = SeqAllToAllDim .apply (ulysses_group , query , scatter_idx , gather_idx )
1326- key = SeqAllToAllDim .apply (ulysses_group , key , scatter_idx , gather_idx )
1327- value = SeqAllToAllDim .apply (ulysses_group , value , scatter_idx , gather_idx )
1328- out = TemplatedRingAttention .apply (
1329- query ,
1330- key ,
1331- value ,
1332- attn_mask ,
1333- dropout_p ,
1334- is_causal ,
1335- scale ,
1336- enable_gqa ,
1337- return_lse ,
1338- forward_op ,
1339- backward_op ,
1340- _parallel_config ,
1341- )
1342- if return_lse :
1343- context_layer , lse , * _ = out
1344- else :
1345- context_layer = out
1346- output = SeqAllToAllDim .apply (
1347- ulysses_group ,
1348- context_layer ,
1349- gather_idx ,
1350- scatter_idx ,
1351- )
1352- if return_lse :
1353- # not sure if this is correct
1354- lse = SeqAllToAllDim .apply (ulysses_group , lse , gather_idx , scatter_idx )
1355- return (output , lse )
1356- return output
1342+ _parallel_config ,
1343+ )
1344+ if return_lse :
1345+ context_layer , lse , * _ = out
1346+ else :
1347+ context_layer = out
1348+ # Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D)
1349+ output = SeqAllToAllDim .apply (
1350+ ulysses_group ,
1351+ context_layer ,
1352+ gather_idx ,
1353+ scatter_idx ,
1354+ )
1355+ if return_lse :
1356+ # not sure if this is correct: Assuming (based on forward ops in ringAttention)
1357+ # the lse is of shape (B, S, H_LOCAL)
1358+ lse = lse .unsqueeze (- 1 ) # (B, S, H_LOCAL, 1)
1359+ lse = SeqAllToAllDim .apply (ulysses_group , lse , scatter_idx = 2 , gather_idx = 1 )
1360+ lse = lse .squeeze (- 1 )
1361+ return (output , lse )
1362+ return output
13571363
13581364def _templated_context_parallel_attention (
13591365 query : torch .Tensor ,
@@ -1378,7 +1384,22 @@ def _templated_context_parallel_attention(
13781384 raise ValueError ("GQA is not yet supported for templated attention." )
13791385
13801386 # TODO: add support for unified attention with ring/ulysses degree both being > 1
1381- if _parallel_config .context_parallel_config .ring_degree > 1 :
1387+ if _parallel_config .context_parallel_config .ring_degree > 1 and _parallel_config .context_parallel_config .ulysses_degree > 1 :
1388+ return TemplatedUnifiedAttention (
1389+ query ,
1390+ key ,
1391+ value ,
1392+ attn_mask ,
1393+ dropout_p ,
1394+ is_causal ,
1395+ scale ,
1396+ enable_gqa ,
1397+ return_lse ,
1398+ forward_op ,
1399+ backward_op ,
1400+ _parallel_config ,
1401+ )
1402+ elif _parallel_config .context_parallel_config .ring_degree > 1 :
13821403 return TemplatedRingAttention .apply (
13831404 query ,
13841405 key ,
0 commit comments