@@ -1376,7 +1376,7 @@ def _maybe_convert_scalar_types_to_dtypes(
13761376class Work (_Work ):
13771377 def __init__ (self ) -> None :
13781378 super ().__init__ ()
1379- self .event = torch .cuda .Event ()
1379+ self .event = torch .xpu .Event ()
13801380 self .event .record ()
13811381
13821382 def wait (self , timeout : timedelta = timedelta (seconds = 0 )) -> bool :
@@ -1421,7 +1421,7 @@ def _low_contention_all_gather_meta(
14211421 group_size = c10d ._get_group_size_by_name (group_name )
14221422 return tensor .new_empty (tensor .shape [0 ] * group_size , * tensor .shape [1 :])
14231423
1424-
1424+ @ torch . library . impl ( lib , "_low_contention_all_gather" , "XPU" )
14251425@torch .library .impl (lib , "_low_contention_all_gather" , "CUDA" )
14261426def _low_contention_all_gather (
14271427 tensor : torch .Tensor ,
@@ -1454,7 +1454,7 @@ def _low_contention_all_gather(
14541454 output = tensor .new_empty (tensor .shape [0 ] * world_size , * tensor .shape [1 :])
14551455 chunks = output .chunk (world_size )
14561456
1457- _get_backend_stream ().wait_stream (torch .cuda .current_stream ())
1457+ _get_backend_stream ().wait_stream (torch .xpu .current_stream ())
14581458 with _get_backend_stream ():
14591459 if not input_is_symm_mem :
14601460 local_buf = symm_mem .get_buffer (rank , tensor .shape , tensor .dtype )
@@ -1492,7 +1492,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input(
14921492 a2a_res = torch .empty_like (tensor )
14931493 chunks = a2a_res .chunk (world_size )
14941494
1495- _get_backend_stream ().wait_stream (torch .cuda .current_stream ())
1495+ _get_backend_stream ().wait_stream (torch .xpu .current_stream ())
14961496 with _get_backend_stream ():
14971497 # pull + offline reduction
14981498 symm_mem .barrier ()
@@ -1529,7 +1529,7 @@ def _low_contention_reduce_scatter_with_workspace(
15291529 assert tensor .shape [0 ] % world_size == 0
15301530 chunks = tensor .chunk (world_size )
15311531
1532- _get_backend_stream ().wait_stream (torch .cuda .current_stream ())
1532+ _get_backend_stream ().wait_stream (torch .xpu .current_stream ())
15331533 with _get_backend_stream ():
15341534 # push + offline reduction
15351535 workspace .barrier ()
@@ -1552,7 +1552,7 @@ def _low_contention_reduce_scatter_with_workspace(
15521552 torch ._C ._distributed_c10d ._register_work (ret , Work ())
15531553 return ret
15541554
1555-
1555+ @ torch . library . impl ( lib , "_low_contention_reduce_scatter" , "XPU" )
15561556@torch .library .impl (lib , "_low_contention_reduce_scatter" , "CUDA" )
15571557def _low_contention_reduce_scatter (
15581558 tensor : torch .Tensor ,
0 commit comments