Skip to content

Commit d328c48

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent afb1847 commit d328c48

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
116116
assert isinstance(source.instance, (DataLoader, list))
117117
TPUSpawnStrategy._validate_dataloader(source.instance)
118118

119-
def connect(self, model: "pl.LightningModule") -> None: # type: ignore
119+
def connect(self, model: "pl.LightningModule") -> None: # type: ignore
120120
TPUSpawnStrategy._validate_patched_dataloaders(model)
121121
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
122122
return super().connect(model)
@@ -182,7 +182,9 @@ def broadcast(self, obj: object, src: int = 0) -> Any:
182182
obj = torch.load(buffer)
183183
return obj
184184

185-
def reduce(self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
185+
def reduce(
186+
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
187+
) -> Tensor:
186188
if not isinstance(output, Tensor):
187189
output = torch.tensor(output, device=self.root_device)
188190

0 commit comments

Comments
 (0)