diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 4b25bf3ea523a..23d4640253e12 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -340,7 +340,7 @@ def register_backend( if devices is not None: for device in devices: - if device != "cpu" and device != "cuda": + if device not in Backend.default_device_backend_map: Backend.default_device_backend_map[device] = name.lower() Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM