Skip to content

Commit 60c1454

Browse files
committed
apply updates smolVLM (still needs workaround for chat template)
1 parent a31fa21 commit 60c1454

File tree

7 files changed

+325
-329
lines changed

7 files changed

+325
-329
lines changed

src/transformers/models/smolvlm/processing_smolvlm.py

Lines changed: 57 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,29 @@
1616
Processor class for SmolVLM.
1717
"""
1818

19-
import copy
2019
from datetime import timedelta
2120
from typing import TYPE_CHECKING, Dict, List, Optional, Union
2221

23-
import numpy as np
24-
2522
from ...feature_extraction_utils import BatchFeature
2623
from ...image_utils import ImageInput, make_nested_list_of_images
2724
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
2825
from ...tokenization_utils_base import BatchEncoding, TextInput
2926
from ...utils import is_num2words_available, is_vision_available, logging
30-
from ...video_utils import VideoInput, load_video, make_batched_videos
27+
from ...video_utils import VideoInput
3128

3229

3330
if is_vision_available():
3431
from .video_processing_smolvlm import (
3532
DEFAULT_MEDIA_OUTTRO,
3633
DEFAULT_VIDEO_INTRO,
3734
FRAME_TIMESTAMP_MESSAGE,
38-
smolvlm_sample_indices_fn,
35+
)
36+
37+
if is_vision_available():
38+
from .video_processing_smolvlm import (
39+
DEFAULT_MEDIA_OUTTRO,
40+
DEFAULT_VIDEO_INTRO,
41+
FRAME_TIMESTAMP_MESSAGE,
3942
)
4043

4144
if TYPE_CHECKING:
@@ -141,9 +144,7 @@ class SmolVLMProcessor(ProcessorMixin):
141144
attributes = ["image_processor", "tokenizer", "video_processor"]
142145
valid_kwargs = ["image_seq_len", "chat_template"]
143146
image_processor_class = "SmolVLMImageProcessor"
144-
video_processor_class = (
145-
"SmolVLMImageProcessor" # TODO: raushan should be VideoProcessor when LANCZOS resizing is settled
146-
)
147+
video_processor_class = "SmolVLMVideoProcessor" # NOTE: uses different interpolation than slow processors
147148
tokenizer_class = "AutoTokenizer"
148149

149150
def __init__(
@@ -161,17 +162,7 @@ def __init__(
161162
self.end_of_utterance_token = getattr(tokenizer, "end_of_utterance_token", "<end_of_utterance>")
162163
self.global_image_token = getattr(tokenizer, "global_image_token", "<global-img>")
163164
self.image_seq_len = image_seq_len
164-
165-
self.video_size = video_processor.video_sampling["video_size"]
166-
self.image_size = image_processor.size
167-
168-
self.do_image_splitting = image_processor.do_image_splitting
169-
self.do_video_splitting = video_processor.video_sampling.get("do_image_splitting", False)
170-
171-
self.default_max_frames = video_processor.video_sampling["max_frames"]
172-
self.default_fps = video_processor.video_sampling["fps"]
173-
# Matches one or more occurrences of <row_x_col_y> tags (where x and y are digits, optionally surrounded by newline characters
174-
# self._regex_to_remove_extra_special_tokens = re.compile(r"(<row_\d+_col_\d+>\n?)+")
165+
self.video_token = getattr(tokenizer, "video_token", "<video>")
175166

176167
if not num2words:
177168
raise ImportError(
@@ -180,14 +171,12 @@ def __init__(
180171

181172
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
182173

183-
def process_vision(self, text, images, output_kwargs, do_image_splitting=False, image_processor_size=None):
174+
def process_vision(self, text, images, output_kwargs):
184175
if text is not None:
185176
n_images_in_text = [sample.count(self.image_token) for sample in text]
186177

187178
n_images_in_images = [len(sublist) for sublist in images]
188-
image_inputs = self.image_processor(
189-
images, do_image_splitting=do_image_splitting, size=image_processor_size, **output_kwargs["images_kwargs"]
190-
)
179+
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
191180

192181
if text is None:
193182
return None, image_inputs
@@ -226,6 +215,50 @@ def process_vision(self, text, images, output_kwargs, do_image_splitting=False,
226215

227216
return prompt_strings, image_inputs
228217

218+
def process_video(self, text, videos, output_kwargs):
219+
if text is not None:
220+
n_videos_in_text = [sample.count(self.video_token) for sample in text]
221+
222+
n_videos_in_videos = [len(sublist) for sublist in videos]
223+
video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
224+
225+
num_frames = video_inputs["pixel_values"].shape[1]
226+
batch_timestamps = iter(video_inputs.pop("timestamps"))
227+
batch_durations = iter(video_inputs.pop("durations"))
228+
229+
if text is None:
230+
return None, video_inputs
231+
232+
if n_videos_in_videos != n_videos_in_text:
233+
raise ValueError(
234+
f"The number of videos in the text {n_videos_in_text} and videos {n_videos_in_videos} should be the same."
235+
)
236+
237+
prompt_strings = []
238+
for sample in text:
239+
while self.video_token in sample:
240+
timestamps = next(batch_timestamps)
241+
duration = next(batch_durations)
242+
duration_td = timedelta(seconds=int(duration))
243+
image_prompt_strings = DEFAULT_VIDEO_INTRO.format(
244+
frame_count=num2words(num_frames), video_duration=str(duration_td)
245+
)
246+
for timestamp in timestamps:
247+
image_prompt_string = _prompt_single_image(
248+
self.image_seq_len,
249+
image_token=self.image_token,
250+
fake_token_around_image=self.fake_image_token,
251+
global_image_token=self.global_image_token,
252+
)
253+
timestamp = f"{timestamp[0]:02d}:{timestamp[1]:02d}"
254+
image_prompt_string = FRAME_TIMESTAMP_MESSAGE.format(timestamp=timestamp) + image_prompt_string
255+
image_prompt_strings += image_prompt_string
256+
257+
image_prompt_strings += DEFAULT_MEDIA_OUTTRO
258+
sample = sample.replace(self.video_token, image_prompt_strings, 1)
259+
prompt_strings.append(sample)
260+
return prompt_strings, video_inputs
261+
229262
def __call__(
230263
self,
231264
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
@@ -310,18 +343,13 @@ def __call__(
310343
text,
311344
images,
312345
output_kwargs,
313-
do_image_splitting=self.do_image_splitting,
314-
image_processor_size=self.image_size,
315346
)
316347
inputs.update(vision_inputs)
317348
elif videos is not None:
318-
videos = make_batched_videos(videos)
319-
text, vision_inputs = self.process_vision(
349+
text, vision_inputs = self.process_video(
320350
text,
321351
videos,
322352
output_kwargs,
323-
do_image_splitting=self.do_image_splitting,
324-
image_processor_size=self.video_size,
325353
)
326354
inputs.update(vision_inputs)
327355

@@ -334,93 +362,6 @@ def __call__(
334362

335363
return BatchFeature(inputs, tensor_type=return_tensors)
336364

337-
def _process_messages_for_chat_template(
338-
self,
339-
conversations: List[List[Dict[str, str]]],
340-
batch_images: List[ImageInput],
341-
batch_videos: List[VideoInput],
342-
batch_video_metadata: List[List[Dict[str, any]]],
343-
**chat_template_kwargs,
344-
):
345-
"""
346-
Used within `apply_chat_template` when a model has special way to process conversation history. For example,
347-
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps
348-
were sampled. This information cannot be accessed before the video is loaded.
349-
For most models it is a no-op, must be overridden by model processors which require special processing.
350-
Args:
351-
conversation (`List[Dict, str, str]`):
352-
The conversation to process. Always comes in batched format.
353-
batch_images (`List[List[ImageInput]]`):
354-
Batch of images that were loaded from url/path defined in the conversation. The images
355-
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images
356-
per batch.
357-
batch_videos (`List[List[ImageInput]]`):
358-
Batch of videos that were loaded from url/path defined in the conversation. The videos
359-
are ordered in the same way as in the conversation. Comes in nested list format, one list of 4D video arrays
360-
per batch.
361-
batch_video_metadata (`List[List[Dict[[str, any]]]]`):
362-
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video.
363-
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays
364-
per batch.
365-
"""
366-
# We don't want to modify in-place the messages passed by user
367-
# The user might want to add new turn on conv and continue generation
368-
conversations = copy.deepcopy(conversations)
369-
batch_num_frames, batch_timestamps = [], []
370-
for metadata_list, video_list in zip(batch_video_metadata, batch_videos):
371-
for metadata, video in zip(metadata_list, video_list):
372-
duration_sec = getattr(metadata, "duration")
373-
frames_idx = getattr(metadata, "frames_indices")
374-
fps = getattr(metadata, "fps")
375-
376-
timestamps = []
377-
for idx, frame_np in zip(frames_idx, video):
378-
sec = idx / fps
379-
mm = int(sec // 60)
380-
ss = int(sec % 60)
381-
timestamps.append(f"{mm:02d}:{ss:02d}")
382-
batch_timestamps.append(timestamps)
383-
batch_num_frames.append(len(video))
384-
385-
for conversation in conversations:
386-
# For each message, scan content for {"type": "video"}
387-
for msg in conversation:
388-
if "content" not in msg:
389-
continue
390-
391-
new_content = []
392-
for block in msg["content"]:
393-
if block.get("type") == "video":
394-
curr_timestamps = batch_timestamps.pop(0)
395-
curr_num_frames = batch_num_frames.pop(0)
396-
397-
# Build the video intro texts
398-
td = timedelta(seconds=int(duration_sec))
399-
new_content.append(
400-
{
401-
"type": "text",
402-
"text": DEFAULT_VIDEO_INTRO.format(
403-
frame_count=num2words(curr_num_frames), video_duration=str(td)
404-
),
405-
}
406-
)
407-
408-
# 2) Insert per-frame lines: "Frame from {timestamp}:", then an "image" block
409-
for i, ts in enumerate(curr_timestamps):
410-
new_content.append({"type": "text", "text": FRAME_TIMESTAMP_MESSAGE.format(timestamp=ts)})
411-
new_content.append({"type": "image"})
412-
413-
# 3) Optionally add an outro (e.g. "Now answer the question:")
414-
new_content.append({"type": "text", "text": DEFAULT_MEDIA_OUTTRO})
415-
# Do NOT add the original block => we skip it (since we've replaced it)
416-
else:
417-
# keep original block
418-
new_content.append(block)
419-
420-
# update the content
421-
msg["content"] = new_content
422-
return conversations
423-
424365
def batch_decode(self, *args, **kwargs):
425366
"""
426367
This method forwards all its arguments to SmolVLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
@@ -443,45 +384,5 @@ def model_input_names(self):
443384
image_processor_input_names = self.image_processor.model_input_names
444385
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))
445386

446-
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
447-
def _load_video_for_model(
448-
self,
449-
video: Union[str, "VideoInput"],
450-
num_frames: Optional[int] = None,
451-
fps: Optional[int] = None,
452-
backend: str = "opencv",
453-
skip_secs: int = 0.0,
454-
**kwargs,
455-
) -> np.array:
456-
"""
457-
Loads `video` to a numpy array.
458-
459-
Args:
460-
video (`str` or `VideoInput`):
461-
The video to convert to the numpy array format. Can be a link to video or local path.
462-
num_frames (`int`, *optional*):
463-
Number of frames to sample uniformly. If not passed, the whole video is loaded.
464-
fps (`int`, *optional*):
465-
Number of frames to sample per second. Should be passed only when `num_frames=None`.
466-
If not specified and `num_frames==None`, all frames are sampled.
467-
backend (`str`, *optional*, defaults to `"opencv"`):
468-
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
469-
470-
Returns:
471-
Tuple[`np.array`, Dict]: A tuple containing:
472-
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
473-
- Metadata dictionary.
474-
"""
475-
max_frames = self.default_max_frames if num_frames is None else num_frames
476-
target_fps = self.default_fps if fps is None else fps
477-
478-
def sample_indices_fn_func(metadata, **fn_kwargs):
479-
return smolvlm_sample_indices_fn(
480-
metadata, max_frames=max_frames, target_fps=target_fps, skip_secs=skip_secs, **fn_kwargs
481-
)
482-
483-
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
484-
return video, metadata
485-
486387

487388
__all__ = ["SmolVLMProcessor"]

0 commit comments

Comments
 (0)