@@ -1030,9 +1030,8 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10301030 H_LOCAL = H // group_world_size
10311031
10321032 # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1033- x_temp = (x .reshape (B , S_LOCAL , group_world_size , H_LOCAL , D )
1034- .transpose (0 , 2 ).contiguous ()
1035- )
1033+ x_temp = x .reshape (B , S_LOCAL , group_world_size , H_LOCAL , D ).transpose (0 , 2 ).contiguous ()
1034+
10361035
10371036 if group_world_size > 1 :
10381037 #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
@@ -1050,8 +1049,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10501049 S_LOCAL = S // group_world_size
10511050
10521051 #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1053- x_temp = (x .reshape (B , group_world_size , S_LOCAL , H_LOCAL , D )
1054- .permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D ))
1052+ x_temp = x .reshape (B , group_world_size , S_LOCAL , H_LOCAL , D ).permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D )
10551053
10561054 if group_world_size > 1 :
10571055 #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
0 commit comments