@@ -33,9 +33,10 @@ def _is_disagg() -> bool:
3333
3434# TODO: update the reference config path once Nano v2 VLM is released.
3535IMAGE_TOKEN_ID = 131072
36- IMG_CONTEXT_TOKEN = "<image>"
37- IMG_START_TOKEN = "<img>"
38- IMG_END_TOKEN = "</img>"
36+ IMG_CONTEXT_TOKEN = "<image>" # nosec
37+ VIDEO_CONTEXT_TOKEN = "<video>" # nosec
38+ IMG_START_TOKEN = "<img>" # nosec
39+ IMG_END_TOKEN = "</img>" # nosec
3940
4041
4142class SquaredReLU (nn .Module ):
@@ -127,11 +128,11 @@ def extract_feature(self, pixel_values):
127128 def forward (self , multimodal_params : List [MultimodalParams ]):
128129 mm_embedding = []
129130 # Batch data.
130- batched_pixel_values = torch . cat ( [
131+ pixel_values = [
131132 multimodal_param .multimodal_data ["pixel_values" ]
132133 for multimodal_param in multimodal_params
133- ],
134- dim = 0 )
134+ ]
135+ batched_pixel_values = torch . cat ( pixel_values , dim = 0 )
135136 # -> [num_patches, channel, height, width]
136137 patch_list = [
137138 multimodal_param .multimodal_data ["num_patches" ]
@@ -180,6 +181,7 @@ def __init__(self,
180181 model_path , trust_remote_code = True , use_fast = self .use_fast )
181182
182183 self .img_context_token = IMG_CONTEXT_TOKEN
184+ self .video_context_token = VIDEO_CONTEXT_TOKEN
183185 self .img_start_token = IMG_START_TOKEN
184186 self .img_end_token = IMG_END_TOKEN
185187 self .dtype = model_config .torch_dtype
@@ -262,36 +264,76 @@ def __call__(
262264 text_prompt , mm_data = inputs .get ("prompt" ), inputs .get (
263265 "multi_modal_data" , {})
264266 images = mm_data .get ("image" , None )
265-
267+ videos = mm_data . get ( "video" , None )
266268 if images is not None :
267269 if isinstance (images [0 ], torch .Tensor ):
268270 # NanoV2VL can only support PIL images. Convert normalized tensors (0-1) to PIL images (0-255).
269271 images = [
270272 Image .fromarray ((image .permute (1 , 2 , 0 ) * 255 ).to (
271273 torch .uint8 ).cpu ().numpy ()) for image in images
272274 ]
275+ # Processing for multimodal data.
276+ processed_images = self .processor (images = images ,
277+ return_tensors = 'pt' ).to (
278+ self .device )
279+ # Insert enough special tokens for image embedding.
280+ parts = text_prompt .split (self .img_context_token )
281+ if len (parts ) - 1 != len (processed_images ['num_patches' ]):
282+ raise ValueError (
283+ f"Number of { self .img_context_token } tokens ({ len (parts ) - 1 } ) doesn't match num_patches_list length ({ len (processed_images ['num_patches' ])} )"
284+ )
285+ processed_query = parts [0 ]
286+ for num_patches , part in zip (processed_images ['num_patches' ],
287+ parts [1 :]):
288+ feature_size = num_patches * self .num_image_token
289+ image_repl = self .img_start_token + self .img_context_token * feature_size + self .img_end_token
290+ processed_query += image_repl + part
291+
292+ elif videos is not None :
293+ num_videos = len (videos )
294+
295+ num_patches_list = []
296+ pixel_values_list = []
297+ parts = text_prompt .split (self .video_context_token )
298+ if len (parts ) - 1 != num_videos :
299+ raise ValueError (
300+ f"Number of { self .video_context_token } tokens ({ len (parts ) - 1 } ) doesn't match number of videos ({ num_videos } )"
301+ )
302+ # Process videos one by one to get correct processed_query.
303+ processed_query = ""
304+ for video_index , video in enumerate (videos ):
305+ if isinstance (videos [0 ][0 ], torch .Tensor ):
306+ # NanoV2VL can only support PIL images. Convert normalized tensors (0-1) to PIL images (0-255).
307+ images = [
308+ Image .fromarray ((image .permute (1 , 2 , 0 ) * 255 ).to (
309+ torch .uint8 ).cpu ().numpy ()) for image in video
310+ ]
311+ else :
312+ images = video
313+ # Processing for multimodal data.
314+ processed_images = self .processor (images = images ,
315+ return_tensors = 'pt' ).to (
316+ self .device )
317+ num_patches_list .append (processed_images ['num_patches' ])
318+ pixel_values_list .append (processed_images ['pixel_values' ])
319+
320+ # Processing the text prompt.
321+ processed_query += parts [video_index ]
322+ for num_patches in processed_images ['num_patches' ]:
323+ feature_size = num_patches * self .num_image_token
324+ image_repl = self .img_start_token + self .img_context_token * feature_size + self .img_end_token
325+ processed_query += image_repl
326+ processed_query += parts [num_videos ]
327+ processed_images ['num_patches' ] = torch .tensor (
328+ [sum (num_patches ) for num_patches in num_patches_list ])
329+ processed_images ['pixel_values' ] = torch .cat (pixel_values_list ,
330+ dim = 0 )
273331 else :
274332 input_ids = self .tokenizer .encode (text_prompt ,
275333 add_special_tokens = False ,
276334 return_tensors = "pt" )
277335 return input_ids [0 ].to (torch .int32 ).tolist (), {}
278336
279- # Processing for multimodal data.
280- processed_images = self .processor (images = images ,
281- return_tensors = 'pt' ).to (self .device )
282-
283- # Insert enough special tokens for image embedding.
284- parts = text_prompt .split (self .img_context_token )
285- if len (parts ) - 1 != len (processed_images ['num_patches' ]):
286- raise ValueError (
287- f"Number of { self .img_context_token } tokens ({ len (parts ) - 1 } ) doesn't match num_patches_list length ({ len (processed_images ['num_patches' ])} )"
288- )
289- processed_query = parts [0 ]
290- for num_patches , part in zip (processed_images ['num_patches' ],
291- parts [1 :]):
292- feature_size = num_patches * self .num_image_token
293- image_repl = self .img_start_token + self .img_context_token * feature_size + self .img_end_token
294- processed_query += image_repl + part
295337 input_ids = self .tokenizer .encode (processed_query ,
296338 add_special_tokens = False ,
297339 return_tensors = "pt" )
@@ -313,6 +355,7 @@ def __call__(
313355 placeholder_metadata = MultimodalPlaceholderMetadata (
314356 placeholder_map = {
315357 "image" : IMG_CONTEXT_TOKEN ,
358+ "video" : VIDEO_CONTEXT_TOKEN ,
316359 },
317360 placeholder_placement = MultimodalPlaceholderPlacement .BEFORE_TEXT ,
318361 placeholders_separator = "" ,
0 commit comments