From 82a4a35b20dd46a17c13b39279c0de432de99483 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Wed, 21 Oct 2020 15:23:59 +0200 Subject: [PATCH 1/4] style: Added typing in models._utils --- torchvision/models/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 291041d7b5f..122accb9f4c 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,6 +1,7 @@ from collections import OrderedDict import torch +from torch import Tensor from torch import nn from torch.jit.annotations import Dict @@ -41,7 +42,7 @@ class IntermediateLayerGetter(nn.ModuleDict): "return_layers": Dict[str, str], } - def __init__(self, model, return_layers): + def __init__(self, model: nn.Module, return_layers: Dict[str, str]): if not set(return_layers).issubset([name for name, _ in model.named_children()]): raise ValueError("return_layers are not present in model") orig_return_layers = return_layers @@ -57,7 +58,7 @@ def __init__(self, model, return_layers): super(IntermediateLayerGetter, self).__init__(layers) self.return_layers = orig_return_layers - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: out = OrderedDict() for name, module in self.items(): x = module(x) From 159bc782a08e4a19e9f7b38d80962faea84c5d0e Mon Sep 17 00:00:00 2001 From: frgfm Date: Wed, 21 Oct 2020 17:53:15 +0200 Subject: [PATCH 2/4] fix: Removed non-necessary import --- torchvision/models/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 122accb9f4c..f2e9c683ca8 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,7 +1,6 @@ from collections import OrderedDict import torch -from torch import Tensor from torch import nn from torch.jit.annotations import Dict @@ -58,7 +57,7 @@ def __init__(self, model: nn.Module, return_layers: Dict[str, str]): super(IntermediateLayerGetter, self).__init__(layers) self.return_layers = orig_return_layers - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: out = OrderedDict() for name, module in self.items(): x = module(x) From 2ae0057fb1681e01dbce676891e1bb262a0651f8 Mon Sep 17 00:00:00 2001 From: frgfm Date: Wed, 21 Oct 2020 18:18:37 +0200 Subject: [PATCH 3/4] fix: Removed type annotation of forward method --- torchvision/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index f2e9c683ca8..e76f641503b 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -57,7 +57,7 @@ def __init__(self, model: nn.Module, return_layers: Dict[str, str]): super(IntermediateLayerGetter, self).__init__(layers) self.return_layers = orig_return_layers - def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + def forward(self, x): out = OrderedDict() for name, module in self.items(): x = module(x) From 74e6b2bc8f12d770c54045ff9c24090c37e3d832 Mon Sep 17 00:00:00 2001 From: frgfm Date: Thu, 22 Oct 2020 12:11:01 +0200 Subject: [PATCH 4/4] refactor: Removed un-necessary import --- torchvision/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index e76f641503b..7d8008c4f27 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -2,7 +2,7 @@ import torch from torch import nn -from torch.jit.annotations import Dict +from typing import Dict class IntermediateLayerGetter(nn.ModuleDict):