Skip to content

Commit 09c5ddd

Browse files
authored
Adding missing named param check on ViT (#5196)
1 parent f9bee39 commit 09c5ddd

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchvision/prototype/models/vision_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401
1212
from ._api import WeightsEnum, Weights
1313
from ._meta import _IMAGENET_CATEGORIES
14-
from ._utils import handle_legacy_interface
14+
from ._utils import handle_legacy_interface, _ovewrite_named_param
1515

1616
__all__ = [
1717
"VisionTransformer",
@@ -111,6 +111,9 @@ def _vision_transformer(
111111
) -> VisionTransformer:
112112
image_size = kwargs.pop("image_size", 224)
113113

114+
if weights is not None:
115+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
116+
114117
model = VisionTransformer(
115118
image_size=image_size,
116119
patch_size=patch_size,

0 commit comments

Comments
 (0)