43
43
_TEST_PROMPTS = [os .path .join (_TEST_DIR , "prompts" , "example.txt" )]
44
44
_LONG_PROMPTS = [os .path .join (_TEST_DIR , "prompts" , "summary.txt" )]
45
45
46
- PromptImageInput = Union [List [Image .Image ], List [List [Image .Image ]]]
47
- PromptAudioInput = Union [List [Tuple [np .ndarray , int ]],
48
- List [List [Tuple [np .ndarray , int ]]]]
49
- PromptVideoInput = Union [List [np .ndarray ], List [List [np .ndarray ]]]
46
+ _M = TypeVar ("_M" )
47
+ _PromptMultiModalInput = Union [List [_M ], List [List [_M ]]]
48
+
49
+ PromptImageInput = _PromptMultiModalInput [Image .Image ]
50
+ PromptAudioInput = _PromptMultiModalInput [Tuple [np .ndarray , int ]]
51
+ PromptVideoInput = _PromptMultiModalInput [np .ndarray ]
50
52
51
53
52
54
def _read_prompts (filename : str ) -> List [str ]:
@@ -318,12 +320,12 @@ def get_inputs(
318
320
"text" : prompt ,
319
321
"return_tensors" : "pt" ,
320
322
}
321
- if images is not None and images [i ] is not None :
322
- processor_kwargs ["images" ] = images [ i ]
323
- if videos is not None and videos [i ] is not None :
324
- processor_kwargs ["videos" ] = videos [ i ]
325
- if audios is not None and audios [i ] is not None :
326
- audio , sr = audios [ i ]
323
+ if images is not None and ( image := images [i ]) is not None :
324
+ processor_kwargs ["images" ] = image
325
+ if videos is not None and ( video := videos [i ]) is not None :
326
+ processor_kwargs ["videos" ] = video
327
+ if audios is not None and ( audio_tuple := audios [i ]) is not None :
328
+ audio , sr = audio_tuple
327
329
processor_kwargs ["audio" ] = audio
328
330
processor_kwargs ["sampling_rate" ] = sr
329
331
@@ -338,7 +340,7 @@ def generate(
338
340
self ,
339
341
prompts : List [str ],
340
342
images : Optional [PromptImageInput ] = None ,
341
- videos : Optional [List [ np . ndarray ] ] = None ,
343
+ videos : Optional [PromptVideoInput ] = None ,
342
344
audios : Optional [PromptAudioInput ] = None ,
343
345
** kwargs : Any ,
344
346
) -> List [Tuple [List [List [int ]], List [str ]]]:
@@ -368,7 +370,7 @@ def generate_greedy(
368
370
prompts : List [str ],
369
371
max_tokens : int ,
370
372
images : Optional [PromptImageInput ] = None ,
371
- videos : Optional [List [ np . ndarray ] ] = None ,
373
+ videos : Optional [PromptVideoInput ] = None ,
372
374
audios : Optional [PromptAudioInput ] = None ,
373
375
** kwargs : Any ,
374
376
) -> List [Tuple [List [int ], str ]]:
@@ -409,7 +411,7 @@ def generate_greedy_logprobs(
409
411
prompts : List [str ],
410
412
max_tokens : int ,
411
413
images : Optional [PromptImageInput ] = None ,
412
- videos : Optional [List [ np . ndarray ] ] = None ,
414
+ videos : Optional [PromptVideoInput ] = None ,
413
415
audios : Optional [PromptAudioInput ] = None ,
414
416
** kwargs : Any ,
415
417
) -> List [List [torch .Tensor ]]:
@@ -488,7 +490,7 @@ def generate_greedy_logprobs_limit(
488
490
num_logprobs : int ,
489
491
images : Optional [PromptImageInput ] = None ,
490
492
audios : Optional [PromptAudioInput ] = None ,
491
- videos : Optional [List [ np . ndarray ] ] = None ,
493
+ videos : Optional [PromptVideoInput ] = None ,
492
494
** kwargs : Any ,
493
495
) -> List [TokensTextLogprobs ]:
494
496
all_inputs = self .get_inputs (prompts ,
@@ -657,15 +659,18 @@ def get_inputs(
657
659
inputs = [TextPrompt (prompt = prompt ) for prompt in prompts ]
658
660
if images is not None :
659
661
for i , image in enumerate (images ):
660
- inputs [i ]["multi_modal_data" ] = {"image" : image }
662
+ if image is not None :
663
+ inputs [i ]["multi_modal_data" ] = {"image" : image }
661
664
662
665
if videos is not None :
663
666
for i , video in enumerate (videos ):
664
- inputs [i ]["multi_modal_data" ] = {"video" : video }
667
+ if video is not None :
668
+ inputs [i ]["multi_modal_data" ] = {"video" : video }
665
669
666
670
if audios is not None :
667
671
for i , audio in enumerate (audios ):
668
- inputs [i ]["multi_modal_data" ] = {"audio" : audio }
672
+ if audio is not None :
673
+ inputs [i ]["multi_modal_data" ] = {"audio" : audio }
669
674
670
675
return inputs
671
676
@@ -837,13 +842,20 @@ def generate_beam_search(
837
842
returned_outputs .append ((token_ids , texts ))
838
843
return returned_outputs
839
844
840
- def encode (self , prompts : List [str ]) -> List [List [float ]]:
841
- req_outputs = self .model .encode (prompts )
842
- outputs = []
843
- for req_output in req_outputs :
844
- embedding = req_output .outputs .embedding
845
- outputs .append (embedding )
846
- return outputs
845
+ def encode (
846
+ self ,
847
+ prompts : List [str ],
848
+ images : Optional [PromptImageInput ] = None ,
849
+ videos : Optional [PromptVideoInput ] = None ,
850
+ audios : Optional [PromptAudioInput ] = None ,
851
+ ) -> List [List [float ]]:
852
+ inputs = self .get_inputs (prompts ,
853
+ images = images ,
854
+ videos = videos ,
855
+ audios = audios )
856
+
857
+ req_outputs = self .model .encode (inputs )
858
+ return [req_output .outputs .embedding for req_output in req_outputs ]
847
859
848
860
def __enter__ (self ):
849
861
return self
0 commit comments