diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index a202ed625b5..7c6530d66c4 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -3,7 +3,7 @@ import sys from dataclasses import dataclass, fields from inspect import signature -from typing import Any, Callable, Dict, cast +from typing import Any, Callable, Dict, Mapping, cast from torchvision._utils import StrEnum @@ -59,7 +59,7 @@ def verify(cls, obj: Any) -> Any: ) return obj - def get_state_dict(self, progress: bool) -> Dict[str, Any]: + def get_state_dict(self, progress: bool) -> Mapping[str, Any]: return load_state_dict_from_url(self.url, progress=progress) def __repr__(self) -> str: