diff --git a/torchbenchmark/models/moco/__init__.py b/torchbenchmark/models/moco/__init__.py index b70605c30f..4287a3239a 100644 --- a/torchbenchmark/models/moco/__init__.py +++ b/torchbenchmark/models/moco/__init__.py @@ -70,6 +70,16 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): ) except RuntimeError: pass # already initialized? + elif device == "xpu": + try: + dist.init_process_group( + backend="xccl", + init_method="tcp://localhost:10001", + world_size=1, + rank=0, + ) + except RuntimeError: + pass # already initialized? elif device == "xla": import torch_xla.distributed.xla_backend