Skip to content

Feature extraction in torchvision.models.vit_b_16 #5718

@DavidTorpey

Description

@DavidTorpey

🐛 Describe the bug

Hi

It’s easy enough to obtain output features from the CNNs in torchvision.models by doing this:

import torch
import torch.nn as nn
import torchvision.models as models

model = models.resnet18()
feature_extractor = nn.Sequential(*list(model.children())[:-1])
output_features = feature_extractor(torch.randn(1, 3, 224, 224))

However, when I attempt to do this with torchvision.models.vit_b_16:

import torch
import torch.nn as nn
import torchvision.models as models

model = models.vit_b_16()
feature_extractor = nn.Sequential(*list(model.children())[:-1])
output_features = feature_extractor(torch.randn(1, 3, 224, 224))

I get the following error:

AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 14, 14])

Any help would be greatly appreciated.

Versions

Torch version: 1.11.0+cu102
Torchvision version: 0.12.0+cu102

cc @datumbox

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions