Skip to content

Commit a08a28e

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] [ViT] Support fine-tuning with different image resolution (#5025)
Summary: * add from_checkpoint method for vit * remove useless change * Making interpolate_embeddings a utility function * remove logging * fix type hint * fix return type check * ad retuurns in docsting & unify type hint * remove useless import * fix issue: 'type' object is not subscriptable * Fixing typing issues * Making interpolation mode configurable * formatting Reviewed By: prabhat00155 Differential Revision: D33253466 fbshipit-source-id: 79bf6855f2dcee3c2fef6c05c243a0dc8dfee25e Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 66a5b76 commit a08a28e

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

torchvision/prototype/models/vision_transformer.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,78 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
379379
progress=progress,
380380
**kwargs,
381381
)
382+
383+
384+
def interpolate_embeddings(
385+
image_size: int,
386+
patch_size: int,
387+
model_state: "OrderedDict[str, torch.Tensor]",
388+
interpolation_mode: str = "bicubic",
389+
reset_heads: bool = False,
390+
) -> "OrderedDict[str, torch.Tensor]":
391+
"""This function helps interpolating positional embeddings during checkpoint loading,
392+
especially when you want to apply a pre-trained model on images with different resolution.
393+
394+
Args:
395+
image_size (int): Image size of the new model.
396+
patch_size (int): Patch size of the new model.
397+
model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
398+
interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
399+
reset_heads (bool): If true, not copying the state of heads. Default: False.
400+
401+
Returns:
402+
OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
403+
"""
404+
# Shape of pos_embedding is (1, seq_length, hidden_dim)
405+
pos_embedding = model_state["encoder.pos_embedding"]
406+
n, seq_length, hidden_dim = pos_embedding.shape
407+
if n != 1:
408+
raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
409+
410+
new_seq_length = (image_size // patch_size) ** 2 + 1
411+
412+
# Need to interpolate the weights for the position embedding.
413+
# We do this by reshaping the positions embeddings to a 2d grid, performing
414+
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
415+
if new_seq_length != seq_length:
416+
# The class token embedding shouldn't be interpolated so we split it up.
417+
seq_length -= 1
418+
new_seq_length -= 1
419+
pos_embedding_token = pos_embedding[:, :1, :]
420+
pos_embedding_img = pos_embedding[:, 1:, :]
421+
422+
# (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
423+
pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
424+
seq_length_1d = int(math.sqrt(seq_length))
425+
torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!")
426+
427+
# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
428+
pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
429+
new_seq_length_1d = image_size // patch_size
430+
431+
# Perform interpolation.
432+
# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
433+
new_pos_embedding_img = nn.functional.interpolate(
434+
pos_embedding_img,
435+
size=new_seq_length_1d,
436+
mode=interpolation_mode,
437+
align_corners=True,
438+
)
439+
440+
# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
441+
new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
442+
443+
# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
444+
new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
445+
new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)
446+
447+
model_state["encoder.pos_embedding"] = new_pos_embedding
448+
449+
if reset_heads:
450+
model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
451+
for k, v in model_state.items():
452+
if not k.startswith("heads"):
453+
model_state_copy[k] = v
454+
model_state = model_state_copy
455+
456+
return model_state

0 commit comments

Comments
 (0)