@@ -379,3 +379,78 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
379
379
progress = progress ,
380
380
** kwargs ,
381
381
)
382
+
383
+
384
+ def interpolate_embeddings (
385
+ image_size : int ,
386
+ patch_size : int ,
387
+ model_state : "OrderedDict[str, torch.Tensor]" ,
388
+ interpolation_mode : str = "bicubic" ,
389
+ reset_heads : bool = False ,
390
+ ) -> "OrderedDict[str, torch.Tensor]" :
391
+ """This function helps interpolating positional embeddings during checkpoint loading,
392
+ especially when you want to apply a pre-trained model on images with different resolution.
393
+
394
+ Args:
395
+ image_size (int): Image size of the new model.
396
+ patch_size (int): Patch size of the new model.
397
+ model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
398
+ interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
399
+ reset_heads (bool): If true, not copying the state of heads. Default: False.
400
+
401
+ Returns:
402
+ OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
403
+ """
404
+ # Shape of pos_embedding is (1, seq_length, hidden_dim)
405
+ pos_embedding = model_state ["encoder.pos_embedding" ]
406
+ n , seq_length , hidden_dim = pos_embedding .shape
407
+ if n != 1 :
408
+ raise ValueError (f"Unexpected position embedding shape: { pos_embedding .shape } " )
409
+
410
+ new_seq_length = (image_size // patch_size ) ** 2 + 1
411
+
412
+ # Need to interpolate the weights for the position embedding.
413
+ # We do this by reshaping the positions embeddings to a 2d grid, performing
414
+ # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
415
+ if new_seq_length != seq_length :
416
+ # The class token embedding shouldn't be interpolated so we split it up.
417
+ seq_length -= 1
418
+ new_seq_length -= 1
419
+ pos_embedding_token = pos_embedding [:, :1 , :]
420
+ pos_embedding_img = pos_embedding [:, 1 :, :]
421
+
422
+ # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
423
+ pos_embedding_img = pos_embedding_img .permute (0 , 2 , 1 )
424
+ seq_length_1d = int (math .sqrt (seq_length ))
425
+ torch ._assert (seq_length_1d * seq_length_1d == seq_length , "seq_length is not a perfect square!" )
426
+
427
+ # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
428
+ pos_embedding_img = pos_embedding_img .reshape (1 , hidden_dim , seq_length_1d , seq_length_1d )
429
+ new_seq_length_1d = image_size // patch_size
430
+
431
+ # Perform interpolation.
432
+ # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
433
+ new_pos_embedding_img = nn .functional .interpolate (
434
+ pos_embedding_img ,
435
+ size = new_seq_length_1d ,
436
+ mode = interpolation_mode ,
437
+ align_corners = True ,
438
+ )
439
+
440
+ # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
441
+ new_pos_embedding_img = new_pos_embedding_img .reshape (1 , hidden_dim , new_seq_length )
442
+
443
+ # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
444
+ new_pos_embedding_img = new_pos_embedding_img .permute (0 , 2 , 1 )
445
+ new_pos_embedding = torch .cat ([pos_embedding_token , new_pos_embedding_img ], dim = 1 )
446
+
447
+ model_state ["encoder.pos_embedding" ] = new_pos_embedding
448
+
449
+ if reset_heads :
450
+ model_state_copy : "OrderedDict[str, torch.Tensor]" = OrderedDict ()
451
+ for k , v in model_state .items ():
452
+ if not k .startswith ("heads" ):
453
+ model_state_copy [k ] = v
454
+ model_state = model_state_copy
455
+
456
+ return model_state
0 commit comments