From 10c2a48e972a17a87bcee5a38d0ce73aa8215164 Mon Sep 17 00:00:00 2001 From: fullyz Date: Mon, 4 Oct 2021 13:21:48 +0000 Subject: [PATCH] Use enumerate to get index of ModuleList --- torchvision/models/detection/ssd.py | 4 +--- torchvision/ops/feature_pyramid_network.py | 8 ++------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index e67c4930b30..8defb40f022 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -62,12 +62,10 @@ def _get_result_from_module_list(self, x: Tensor, idx: int) -> Tensor: num_blocks = len(self.module_list) if idx < 0: idx += num_blocks - i = 0 out = x - for module in self.module_list: + for i, module in enumerate(self.module_list): if i == idx: out = module(x) - i += 1 return out def forward(self, x: List[Tensor]) -> Tensor: diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 7d72769ab07..d62adbdf510 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -102,12 +102,10 @@ def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor: num_blocks = len(self.inner_blocks) if idx < 0: idx += num_blocks - i = 0 out = x - for module in self.inner_blocks: + for i, module in enumerate(self.inner_blocks): if i == idx: out = module(x) - i += 1 return out def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor: @@ -118,12 +116,10 @@ def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor: num_blocks = len(self.layer_blocks) if idx < 0: idx += num_blocks - i = 0 out = x - for module in self.layer_blocks: + for i, module in enumerate(self.layer_blocks): if i == idx: out = module(x) - i += 1 return out def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]: