@@ -68,6 +68,21 @@ def __init__(
68
68
self .height , self .width = height // patch_size , width // patch_size
69
69
self .base_size = height // patch_size
70
70
71
+ def pe_selection_index_based_on_dim (self , h , w ):
72
+ # select subset of positional embedding based on H, W, where H, W is size of latent
73
+ # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
74
+ # because original input are in flattened format, we have to flatten this 2d grid as well.
75
+ h_p , w_p = h // self .patch_size , w // self .patch_size
76
+ original_pe_indexes = torch .arange (self .pos_embed .shape [1 ])
77
+ h_max , w_max = int (self .pos_embed_max_size ** 0.5 ), int (self .pos_embed_max_size ** 0.5 )
78
+ original_pe_indexes = original_pe_indexes .view (h_max , w_max )
79
+ starth = h_max // 2 - h_p // 2
80
+ endh = starth + h_p
81
+ startw = w_max // 2 - w_p // 2
82
+ endw = startw + w_p
83
+ original_pe_indexes = original_pe_indexes [starth :endh , startw :endw ]
84
+ return original_pe_indexes .flatten ()
85
+
71
86
def forward (self , latent ):
72
87
batch_size , num_channels , height , width = latent .size ()
73
88
latent = latent .view (
@@ -80,7 +95,8 @@ def forward(self, latent):
80
95
)
81
96
latent = latent .permute (0 , 2 , 4 , 1 , 3 , 5 ).flatten (- 3 ).flatten (1 , 2 )
82
97
latent = self .proj (latent )
83
- return latent + self .pos_embed
98
+ pe_index = self .pe_selection_index_based_on_dim (height , width )
99
+ return latent + self .pos_embed [:, pe_index ]
84
100
85
101
86
102
# Taken from the original Aura flow inference code.
0 commit comments