@@ -33,6 +33,9 @@ def _is_disagg() -> bool:
3333
3434# TODO: update the reference config path once Nano v2 VLM is released.
3535IMAGE_TOKEN_ID = 131072
36+ IMG_CONTEXT_TOKEN = "<image>"
37+ IMG_START_TOKEN = "<img>"
38+ IMG_END_TOKEN = "</img>"
3639
3740
3841class SquaredReLU (nn .Module ):
@@ -41,8 +44,7 @@ def forward(self, x):
4144 return torch .pow (torch .nn .functional .relu (x ), 2 )
4245
4346
44- class NanoV2VLVisionEncoder (transformers .PreTrainedModel ,
45- transformers .generation .GenerationMixin ):
47+ class NanoV2VLVisionEncoder (transformers .PreTrainedModel ):
4648
4749 def __init__ (self ,
4850 model_config : ModelConfig [transformers .PretrainedConfig ]):
@@ -61,20 +63,21 @@ def __init__(self,
6163 self .llm_hidden_size = config .llm_config .hidden_size
6264 self .mlp1 = nn .Sequential (
6365 nn .RMSNorm (self .vit_hidden_size * int (1 / self .downsample_ratio )** 2 ,
64- eps = config .llm_config .rms_norm_eps ),
66+ eps = config .llm_config .rms_norm_eps ,
67+ dtype = config .torch_dtype ),
6568 nn .Linear (self .vit_hidden_size * int (1 / self .downsample_ratio )** 2 ,
6669 self .vision_projection_hidden_size ,
67- bias = False ), SquaredReLU (),
70+ bias = False ,
71+ dtype = config .torch_dtype ), SquaredReLU (),
6872 nn .Linear (self .vision_projection_hidden_size ,
6973 self .llm_hidden_size ,
70- bias = False ))
71- self . mlp1 = self . mlp1 . to ( config .torch_dtype )
74+ bias = False ,
75+ dtype = config .torch_dtype ) )
7276
7377 # Construct the vision encoder.
7478 vision_model_config = copy .deepcopy (model_config )
7579 vision_model_config .pretrained_config = vision_model_config .pretrained_config .vision_config
7680 self .vision_model = RADIOVisionModel (vision_model_config )
77- self .vision_model .to (config .torch_dtype )
7881
7982 def load_weights (self , weights ):
8083 # Load mlp1 weights.
@@ -111,7 +114,6 @@ def pixel_shuffle(self, x, scale_factor=0.5):
111114
112115 def extract_feature (self , pixel_values ):
113116 vit_embeds = self .vision_model (pixel_values )
114- vit_embeds = vit_embeds .to (dtype = torch .bfloat16 )
115117 # Down-sampling and projection.
116118 h = w = int (vit_embeds .shape [1 ]** 0.5 )
117119 vit_embeds = vit_embeds .reshape (vit_embeds .shape [0 ], h , w , - 1 )
@@ -131,11 +133,11 @@ def forward(self, multimodal_params: List[MultimodalParams]):
131133 ],
132134 dim = 0 )
133135 # -> [num_patches, channel, height, width]
134- batched_num_patches = torch . cat ( [
136+ patch_list = [
135137 multimodal_param .multimodal_data ["num_patches" ]
136138 for multimodal_param in multimodal_params
137- ],
138- dim = 0 ).tolist ()
139+ ]
140+ batched_num_patches = torch . cat ( patch_list , dim = 0 ).tolist ()
139141 # -> list of[num_patches1, num_patches2, ...]
140142 batched_image_embeds = self .extract_feature (batched_pixel_values )
141143 # -> [num_patches, num_image_token, hidden_size]
@@ -176,9 +178,10 @@ def __init__(self,
176178
177179 self .processor = transformers .AutoImageProcessor .from_pretrained (
178180 model_path , trust_remote_code = True , use_fast = self .use_fast )
179- self .img_context_token = "<image>"
180- self .img_start_token = "<img>"
181- self .img_end_token = "</img>"
181+
182+ self .img_context_token = IMG_CONTEXT_TOKEN
183+ self .img_start_token = IMG_START_TOKEN
184+ self .img_end_token = IMG_END_TOKEN
182185 self .dtype = model_config .torch_dtype
183186
184187 def get_vocab_size (self ):
@@ -194,7 +197,7 @@ def get_num_tokens_per_image(
194197 ** kwargs ,
195198 ):
196199
197- def get_internvl_target_ratios (
200+ def _get_internvl_target_ratios (
198201 min_num : int ,
199202 max_num : int ,
200203 ) -> list [tuple [int , int ]]:
@@ -205,8 +208,8 @@ def get_internvl_target_ratios(
205208 if min_num <= i * j <= max_num }
206209 return sorted (target_ratios , key = lambda x : x [0 ] * x [1 ])
207210
208- def find_closest_aspect_ratio (aspect_ratio , target_ratios , width ,
209- height , image_size ):
211+ def _find_closest_aspect_ratio (aspect_ratio , target_ratios , width ,
212+ height , image_size ):
210213 best_factor = float ('-inf' )
211214 best_ratio = (1 , 1 )
212215 area = width * height
@@ -221,7 +224,7 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width,
221224 best_ratio = ratio
222225 return best_ratio
223226
224- def calculate_targets (
227+ def _calculate_targets (
225228 orig_width : int ,
226229 orig_height : int ,
227230 target_ratios : list [tuple [int , int ]],
@@ -230,7 +233,7 @@ def calculate_targets(
230233 aspect_ratio = orig_width / orig_height
231234
232235 # find the closest aspect ratio to the target
233- target_aspect_ratio = find_closest_aspect_ratio (
236+ target_aspect_ratio = _find_closest_aspect_ratio (
234237 aspect_ratio ,
235238 target_ratios ,
236239 width = orig_width ,
@@ -243,10 +246,10 @@ def calculate_targets(
243246
244247 image_height = image .height
245248 image_width = image .width
246- target_ratios = get_internvl_target_ratios ( 1 ,
247- self .processor .max_num_tiles )
248- blocks = calculate_targets (image_width , image_height , target_ratios ,
249- self .image_size )
249+ target_ratios = _get_internvl_target_ratios (
250+ 1 , self .processor .max_num_tiles )
251+ blocks = _calculate_targets (image_width , image_height , target_ratios ,
252+ self .image_size )
250253 if self .processor .use_thumbnail and blocks != 1 :
251254 blocks += 1
252255 num_image_tokens = self .num_image_token * blocks
@@ -309,7 +312,7 @@ def __call__(
309312 model_type = "NemotronH_Nano_VL_V2" ,
310313 placeholder_metadata = MultimodalPlaceholderMetadata (
311314 placeholder_map = {
312- "image" : "<image>" ,
315+ "image" : IMG_CONTEXT_TOKEN ,
313316 },
314317 placeholder_placement = MultimodalPlaceholderPlacement .BEFORE_TEXT ,
315318 placeholders_separator = "" ,
@@ -332,7 +335,6 @@ def __init__(self, model_config: ModelConfig):
332335
333336 if not _is_disagg ():
334337 self .vision_encoder = NanoV2VLVisionEncoder (model_config ).eval ()
335- self .vision_encoder .to (config .torch_dtype )
336338
337339 llm_model_config = copy .deepcopy (model_config )
338340 llm_model_config .pretrained_config = llm_model_config .pretrained_config .llm_config
0 commit comments