diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index e47eaf73aab..05b35fe87f0 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -1,7 +1,6 @@ import importlib import inspect import sys -from collections import OrderedDict from dataclasses import dataclass, fields from inspect import signature from typing import Any, Callable, Dict, cast @@ -60,7 +59,7 @@ def verify(cls, obj: Any) -> Any: ) return obj - def get_state_dict(self, progress: bool) -> OrderedDict: + def get_state_dict(self, progress: bool) -> Dict[str, Any]: return load_state_dict_from_url(self.url, progress=progress) def __repr__(self) -> str: