1717from lmms_eval .models .model_utils .load_video import read_video_pyav
1818
1919eval_logger = logging .getLogger ("lmms-eval" )
20- import sys
2120
22- sys .path .append ("llava-video" )
2321try :
24- from llavavid .model .language_model .llava_llama import LlavaConfig
25-
26- # from llavavid.model.language_model.llava_qwen import LlavaQwenConfig
2722 from llavavid .model .builder import load_pretrained_model
2823 from llavavid .mm_utils import tokenizer_image_token , get_model_name_from_path , KeywordsStoppingCriteria
29- from llavavid .constants import IMAGE_TOKEN_INDEX , DEFAULT_IMAGE_TOKEN , DEFAULT_IM_START_TOKEN , DEFAULT_IM_END_TOKEN
24+ from llavavid .constants import IMAGE_TOKEN_INDEX , DEFAULT_IMAGE_TOKEN , DEFAULT_IM_START_TOKEN , DEFAULT_IM_END_TOKEN , IGNORE_INDEX
3025 from llavavid .conversation import conv_templates , SeparatorStyle
31-
32- # AutoConfig.register("llava_qwen", LlavaQwenConfig)
33- AutoConfig .register ("llava_llama" , LlavaConfig )
34-
26+ from llavavid .mm_utils import tokenizer_image_token_qwen_merge , preprocess_qwen , preprocess_llama3
3527except ImportError :
3628 eval_logger .debug ("LLaVA-Video is not installed. Please install LLaVA-Video to use this model." )
3729
38- try :
39- from llavavid .model .language_model .llava_qwen import LlavaQwenConfig
30+ from llavavid . model . language_model . llava_qwen import LlavaQwenConfig
31+ from llavavid .model .language_model .llava_llama import LlavaConfig
4032
41- AutoConfig .register ("llava_qwen" , LlavaQwenConfig )
42- except :
43- eval_logger .debug ("" )
33+ AutoConfig .register ("llava_qwen" , LlavaQwenConfig )
34+ AutoConfig .register ("llava_llama" , LlavaConfig )
4435
4536
4637@register_model ("llavavid" )
@@ -96,7 +87,6 @@ def __init__(
9687 self .mm_spatial_pool_out_channels = int (mm_spatial_pool_out_channels )
9788 self .mm_spatial_pool_mode = mm_spatial_pool_mode
9889 self .max_frames_num = int (max_frames_num )
99- print (self .max_frames_num )
10090 if self .overwrite == True :
10191 overwrite_config = {}
10292 overwrite_config ["mm_resampler_type" ] = self .mm_resampler_type
@@ -361,7 +351,7 @@ def generate_until(self, requests) -> List[str]:
361351 if self .model .config .mm_use_im_start_end :
362352 qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n " + qs
363353 else :
364- qs = DEFAULT_IMAGE_TOKEN + "\n " + qs
354+ qs = DEFAULT_IMAGE_TOKEN * len ( videos ) + "\n " + qs
365355
366356 # This is much safer for llama3, as we now have some object type in it
367357 if "llama_3" in self .conv_template :
@@ -379,11 +369,6 @@ def generate_until(self, requests) -> List[str]:
379369 pad_token_ids = 0 # lmms-lab/llama3-llava-8b is trained on this pad token id. You may need to customize this for other models.
380370 attention_masks = input_ids .ne (pad_token_ids ).long ().cuda ()
381371
382- # input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in question_input]
383- # pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
384- # input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device)
385- # attention_masks = input_ids.ne(pad_token_ids).to(self.device)
386-
387372 stop_str = conv .sep if conv .sep_style != SeparatorStyle .TWO else conv .sep2
388373 keywords = [stop_str ]
389374 stopping_criteria = KeywordsStoppingCriteria (keywords , self .tokenizer , input_ids )
@@ -393,7 +378,7 @@ def generate_until(self, requests) -> List[str]:
393378 if "max_new_tokens" not in gen_kwargs :
394379 gen_kwargs ["max_new_tokens" ] = 1024
395380 if "temperature" not in gen_kwargs :
396- gen_kwargs ["temperature" ] = 0.2
381+ gen_kwargs ["temperature" ] = 0
397382 if "top_p" not in gen_kwargs :
398383 gen_kwargs ["top_p" ] = None
399384 if "num_beams" not in gen_kwargs :
@@ -417,4 +402,4 @@ def generate_until(self, requests) -> List[str]:
417402 outputs = self .tokenizer .batch_decode (output_ids , skip_special_tokens = True )[0 ].strip ()
418403 res .append (outputs )
419404 pbar .update (1 )
420- return res
405+ return res
0 commit comments