diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 4992dfbdde..a820da6283 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -167,6 +167,8 @@ jobs: - name: CMake Build run: | + apt update + apt install -y libgtk2.0-dev pkg-config ffmpeg libavformat-dev libavcodec-dev libswscale-dev libavutil-dev source ${{ env.OV_INSTALL_DIR }}/setupvars.sh cmake -DOpenVINODeveloperPackage_DIR=${{ env.OV_INSTALL_DIR }}/developer_package/cmake \ @@ -379,7 +381,7 @@ jobs: matrix: build-type: [Release] needs: [ openvino_download, genai_build_cmake ] - timeout-minutes: 10 + timeout-minutes: 30 defaults: run: shell: bash diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 8bebd2dcf9..7a1c52efd0 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -318,7 +318,7 @@ jobs: matrix: build-type: [Release] needs: [ openvino_download, genai_build_cmake ] - timeout-minutes: 10 + timeout-minutes: 30 defaults: run: shell: bash diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index c559a0e1c1..2151abffed 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -100,7 +100,7 @@ jobs: matrix: build-type: [Release, Debug] needs: [ openvino_download ] - timeout-minutes: 45 + timeout-minutes: 80 defaults: run: shell: pwsh @@ -487,7 +487,7 @@ jobs: matrix: build-type: [Release, Debug] needs: [ openvino_download, genai_build_cpack ] - timeout-minutes: 10 + timeout-minutes: 70 defaults: run: shell: pwsh diff --git a/samples/cpp/visual_language_chat/CMakeLists.txt b/samples/cpp/visual_language_chat/CMakeLists.txt index 54795351a6..37a14faef0 100644 --- a/samples/cpp/visual_language_chat/CMakeLists.txt +++ b/samples/cpp/visual_language_chat/CMakeLists.txt @@ -1,6 +1,10 @@ # Copyright (C) 2023-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +if (MSVC) + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") +endif() + find_package(OpenVINOGenAI REQUIRED PATHS "${CMAKE_BINARY_DIR}" # Reuse the package from the build. @@ -55,3 +59,43 @@ install(TARGETS benchmark_vlm RUNTIME DESTINATION samples_bin/ COMPONENT samples_bin EXCLUDE_FROM_ALL) + + +include(FetchContent) + +set(BUILD_SHARED_LIBS ON) +set(WITH_FFMPEG ON) +FetchContent_Declare( + opencv + GIT_REPOSITORY https://github.com/opencv/opencv.git + GIT_TAG 4.11.0 + GIT_SHALLOW TRUE +) +FetchContent_MakeAvailable(opencv) + + +add_executable(video_to_text_chat video_to_text_chat.cpp) + +target_include_directories(video_to_text_chat PRIVATE + ${OPENCV_CONFIG_FILE_INCLUDE_DIR} + ${OPENCV_MODULE_opencv_core_LOCATION}/include + ${OPENCV_MODULE_opencv_videoio_LOCATION}/include + ) +target_link_libraries(video_to_text_chat opencv_core opencv_imgcodecs opencv_imgproc opencv_videoio openvino::genai cxxopts::cxxopts) + +if(LINUX) + set_target_properties(video_to_text_chat opencv_core opencv_imgcodecs opencv_imgproc opencv_videoio PROPERTIES + INSTALL_RPATH "$ORIGIN/../lib" + # Ensure out of box LC_RPATH on macOS with SIP + INSTALL_RPATH_USE_LINK_PATH ON) +elseif(APPLE) + set_target_properties(video_to_text_chat opencv_core opencv_imgcodecs opencv_imgproc opencv_videoio PROPERTIES + INSTALL_RPATH "@loader_path/../lib" + # Ensure out of box LC_RPATH on macOS with SIP + INSTALL_RPATH_USE_LINK_PATH ON) +endif() + +install(TARGETS video_to_text_chat + RUNTIME DESTINATION samples_bin/ + COMPONENT samples_bin + EXCLUDE_FROM_ALL) diff --git a/samples/cpp/visual_language_chat/README.md b/samples/cpp/visual_language_chat/README.md index 59819e45cc..95d432436d 100644 --- a/samples/cpp/visual_language_chat/README.md +++ b/samples/cpp/visual_language_chat/README.md @@ -3,8 +3,9 @@ This example showcases inference of Visual language models (VLMs). The application doesn't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `ov::genai::VLMPipeline` and runs the simplest deterministic greedy sampling algorithm. There is also a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/minicpm-v-multimodal-chatbot) which provides an example of Visual-language assistant. -There are two sample files: +There are three sample files: - [`visual_language_chat.cpp`](./visual_language_chat.cpp) demonstrates basic usage of the VLM pipeline. + - [`video_to_text_chat.cpp`](./video_to_text_chat.cpp) demonstrates video to text usage of the VLM pipeline. - [`benchmark_vlm.cpp`](./benchmark_vlm.cpp) shows how to benchmark a VLM in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text and calculating various performance metrics. @@ -19,9 +20,9 @@ pip install --upgrade-strategy eager -r ../../requirements.txt optimum-cli export openvino --model openbmb/MiniCPM-V-2_6 --trust-remote-code MiniCPM-V-2_6 ``` -## Run +Follow [Get Started with Samples](https://docs.openvino.ai/2025/get-started/learn-openvino/openvino-samples/get-started-demos.html) to run samples. -Follow [Get Started with Samples](https://docs.openvino.ai/2025/get-started/learn-openvino/openvino-samples/get-started-demos.html) to run the sample. +## Run image-to-text chat sample: [This image](https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11) can be used as a sample image. @@ -31,6 +32,17 @@ Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is Refer to the [Supported Models](https://openvinotoolkit.github.io/openvino.genai/docs/supported-models/#visual-language-models-vlms) for more details. + +## Run video-to-text chat sample: + +A model that supports video input is required to run this sample, for example `llava-hf/LLaVA-NeXT-Video-7B-hf`. + +[This video](https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4) can be used as a sample video. + +`video_to_text_chat ./LLaVA-NeXT-Video-7B-hf/ sample_demo_1.mp4` + +Supported models with video input are listed in [this section](https://openvinotoolkit.github.io/openvino.genai/docs/use-cases/image-processing/#use-image-or-video-tags-in-prompt). + ## Run benchmark: ```sh diff --git a/samples/cpp/visual_language_chat/video_to_text_chat.cpp b/samples/cpp/visual_language_chat/video_to_text_chat.cpp new file mode 100644 index 0000000000..b41e72a622 --- /dev/null +++ b/samples/cpp/visual_language_chat/video_to_text_chat.cpp @@ -0,0 +1,125 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +std::vector make_indices(size_t total_frames, size_t num_frames) { + std::vector indices; + indices.reserve(num_frames); + + auto step = float(total_frames) / num_frames; + + for (size_t i = 0; i < num_frames; ++i) { + size_t idx = std::min(size_t(i * step), total_frames - 1); + indices.push_back(idx); + } + + return indices; +} + +ov::Tensor load_video(const std::filesystem::path& video_path, size_t num_frames = 8) { + cv::VideoCapture cap(video_path.string()); + + if (!cap.isOpened()) { + OPENVINO_THROW("Could not open the video file."); + } + size_t total_num_frames = cap.get(cv::CAP_PROP_FRAME_COUNT); + auto indices = make_indices(total_num_frames, num_frames); + + std::vector frames; + cv::Mat frame; + size_t width = cap.get(cv::CAP_PROP_FRAME_WIDTH); + size_t height = cap.get(cv::CAP_PROP_FRAME_HEIGHT); + ov::Tensor video_tensor(ov::element::u8, ov::Shape{num_frames, height, width, 3}); + auto video_tensor_data = video_tensor.data(); + + size_t frame_idx = 0; + while (cap.read(frame)) { + OPENVINO_ASSERT(frame.cols == width && frame.rows == height && frame.channels() == 3); + if (std::find(indices.begin(), indices.end(), frame_idx) != indices.end()) { + memcpy(video_tensor_data, frame.data, frame.total() * 3 * sizeof(uint8_t)); + video_tensor_data += frame.total() * 3; + } + frame_idx++; + } + OPENVINO_ASSERT(frame_idx == total_num_frames, "Frame count mismatch: expected " + std::to_string(total_num_frames) + ", got " + std::to_string(frame_idx)); + + return video_tensor; +} + +std::vector load_videos(const std::filesystem::path& input_path) { + if (input_path.empty() || !fs::exists(input_path)) { + OPENVINO_THROW("Path to videos is empty or does not exist."); + } + if (fs::is_directory(input_path)) { + std::set sorted_videos{fs::directory_iterator(input_path), fs::directory_iterator()}; + std::vector videos; + for (const fs::path& dir_entry : sorted_videos) { + videos.push_back(load_video(dir_entry)); + } + return videos; + } + return {load_video(input_path)}; +} + +bool print_subword(std::string&& subword) { + return !(std::cout << subword << std::flush); +} + +int main(int argc, char* argv[]) try { + if (argc < 3 || argc > 4) { + OPENVINO_THROW(std::string{"Usage "} + argv[0] + " "); + } + std::vector videos = load_videos(argv[2]); + + // GPU and NPU can be used as well. + // Note: If NPU is selected, only language model will be run on NPU + std::string device = (argc == 4) ? argv[3] : "CPU"; + ov::AnyMap enable_compile_cache; + if (device == "GPU") { + // Cache compiled models on disk for GPU to save time on the + // next run. It's not beneficial for CPU. + enable_compile_cache.insert({ov::cache_dir("vlm_cache")}); + } + ov::genai::VLMPipeline pipe(argv[1], device, enable_compile_cache); + + ov::genai::GenerationConfig generation_config; + generation_config.max_new_tokens = 100; + + std::string prompt; + + pipe.start_chat(); + std::cout << "question:\n"; + + std::getline(std::cin, prompt); + pipe.generate(prompt, + ov::genai::videos(videos), + ov::genai::generation_config(generation_config), + ov::genai::streamer(print_subword)); + std::cout << "\n----------\n" + "question:\n"; + while (std::getline(std::cin, prompt)) { + pipe.generate(prompt, + ov::genai::generation_config(generation_config), + ov::genai::streamer(print_subword)); + std::cout << "\n----------\n" + "question:\n"; + } + pipe.finish_chat(); +} catch (const std::exception& error) { + try { + std::cerr << error.what() << '\n'; + } catch (const std::ios_base::failure&) {} + return EXIT_FAILURE; +} catch (...) { + try { + std::cerr << "Non-exception object thrown\n"; + } catch (const std::ios_base::failure&) {} + return EXIT_FAILURE; +} \ No newline at end of file diff --git a/samples/deployment-requirements.txt b/samples/deployment-requirements.txt index ce69c196e9..ac52780229 100644 --- a/samples/deployment-requirements.txt +++ b/samples/deployment-requirements.txt @@ -4,3 +4,4 @@ librosa==0.11.0 # For Whisper pillow==12.0.0 # Image processing for VLMs json5==0.12.1 # For ReAct pydantic==2.12.5 # For Structured output json schema +opencv-python==4.12.0.88 # For video-to-text VLM sample diff --git a/samples/python/visual_language_chat/README.md b/samples/python/visual_language_chat/README.md index f08b1a46d7..25b95b9758 100644 --- a/samples/python/visual_language_chat/README.md +++ b/samples/python/visual_language_chat/README.md @@ -2,8 +2,9 @@ This example showcases inference of text-generation Vision Language Models (VLMs): `miniCPM-V-2_6` and other models with the same signature. The application doesn't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `openvino_genai.VLMPipeline` and configures it for the chat scenario. There is also a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/minicpm-v-multimodal-chatbot) which provides an example of Visual-language assistant. -There are two sample files: +There are three sample files: - [`visual_language_chat.py`](./visual_language_chat.py) demonstrates basic usage of the VLM pipeline. + - [`video_to_text_chat.py`](./video_to_text_chat.py) demonstrates video to text usage of the VLM pipeline. - [`benchmark_vlm.py`](./benchmark_vlm.py) shows how to benchmark a VLM in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text and calculating various performance metrics. - [`milebench_eval_vlm.py`](./milebench_eval_vlm.py) provides MileBench validation for VLMs, enabling evaluation of image–text reasoning and visual QA tasks across multiple subsets designed to assess the MultImodal Long-contExt capabilities of MLLMs. @@ -39,20 +40,29 @@ tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-V-2_6") export_tokenizer(tokenizer, output_dir) ``` -## Run: +Install [deployment-requirements.txt](../../deployment-requirements.txt) via `pip install -r ../../deployment-requirements.txt` to run VLM samples. -[This image](https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11) can be used as a sample image. +## Run image-to-text chat sample: -Install [deployment-requirements.txt](../../deployment-requirements.txt) via `pip install -r ../../deployment-requirements.txt` and then, run a sample: +[This image](https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11) can be used as a sample image. `python visual_language_chat.py ./miniCPM-V-2_6/ 319483352-d5fbbd1a-d484-415c-88cb-9986625b7b11.jpg` +See https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md#supported-models for the list of supported models. + +## Run video-to-text chat sample: + +A model that supports video input is required to run this sample, for example `llava-hf/LLaVA-NeXT-Video-7B-hf`. + +[This video](https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4) can be used as a sample video. + +`python video_to_text_chat.py ./LLaVA-NeXT-Video-7B-hf/ sample_demo_1.mp4` + +Supported models with video input are listed in [this section](https://openvinotoolkit.github.io/openvino.genai/docs/use-cases/image-processing/#use-image-or-video-tags-in-prompt). Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is recommended to run larger models on a dGPU with 32GB+ RAM. # TODO: examples of larger models Modify the source code to change the device for inference to the GPU. -See https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md#supported-models for the list of supported models. - ## Run benchmark: ```sh diff --git a/samples/python/visual_language_chat/video_to_text_chat.py b/samples/python/visual_language_chat/video_to_text_chat.py new file mode 100644 index 0000000000..b21a83bbb1 --- /dev/null +++ b/samples/python/visual_language_chat/video_to_text_chat.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import argparse +import numpy as np +import cv2 +import openvino_genai +from openvino import Tensor +from pathlib import Path + + +def streamer(subword: str) -> bool: + ''' + + Args: + subword: sub-word of the generated text. + + Returns: Return flag corresponds whether generation should be stopped. + + ''' + print(subword, end='', flush=True) + + # No value is returned as in this example we don't want to stop the generation in this method. + # "return None" will be treated the same as "return openvino_genai.StreamingStatus.RUNNING". + + +def read_video(path: str, num_frames: int = 8) -> Tensor: + ''' + + Args: + path: The path to the video. + num_frames: Number of frames sampled from the video. + + Returns: the ov.Tensor containing the video. + + ''' + cap = cv2.VideoCapture(path) + + frames = [] + total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int) + + idx = 0 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + if idx in indices: + frames.append(np.array(frame)) + idx+=1 + assert idx == total_num_frames, "Frame count mismatch: expected {}, got {}".format(total_num_frames, idx) + + return Tensor(frames) + + +def read_videos(path: str) -> list[Tensor]: + entry = Path(path) + if entry.is_dir(): + return [read_video(str(file)) for file in sorted(entry.iterdir())] + return [read_video(path)] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('model_dir', help="Path to the model directory") + parser.add_argument('video_dir', help="Path to a video file.") + parser.add_argument('device', nargs='?', default='CPU', help="Device to run the model on (default: CPU)") + args = parser.parse_args() + + videos = read_videos(args.video_dir) + + # GPU and NPU can be used as well. + # Note: If NPU is selected, only the language model will be run on the NPU. + enable_compile_cache = dict() + if args.device == "GPU": + # Cache compiled models on disk for GPU to save time on the next run. + # It's not beneficial for CPU. + enable_compile_cache["CACHE_DIR"] = "vlm_cache" + + pipe = openvino_genai.VLMPipeline(args.model_dir, args.device, **enable_compile_cache) + + config = openvino_genai.GenerationConfig() + config.max_new_tokens = 100 + + pipe.start_chat() + prompt = input('question:\n') + pipe.generate(prompt, videos=videos, generation_config=config, streamer=streamer) + + while True: + try: + prompt = input("\n----------\n" + "question:\n") + except EOFError: + break + pipe.generate(prompt, generation_config=config, streamer=streamer) + pipe.finish_chat() + + +if __name__ == '__main__': + main() diff --git a/tests/python_tests/samples/test_video_to_text_chat.py b/tests/python_tests/samples/test_video_to_text_chat.py new file mode 100644 index 0000000000..9a19c0853f --- /dev/null +++ b/tests/python_tests/samples/test_video_to_text_chat.py @@ -0,0 +1,34 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import pytest +import subprocess # nosec B404 +import sys + +from conftest import SAMPLES_PY_DIR, SAMPLES_CPP_DIR, SAMPLES_C_DIR +from test_utils import run_sample + +class TestVisualLanguageChat: + @pytest.mark.vlm + @pytest.mark.samples + @pytest.mark.parametrize( + "convert_model, download_test_content, questions", + [ + pytest.param("tiny-random-llava-next-video", "video0.mp4", 'What is unusual on this video?\nGo on.') + ], + indirect=["convert_model", "download_test_content"], + ) + def test_sample_visual_language_chat(self, convert_model, download_test_content, questions): + # Test CPP sample + cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'video_to_text_chat') + cpp_command = [cpp_sample, convert_model, download_test_content] + cpp_result = run_sample(cpp_command, questions) + + # Test Python sample + py_script = os.path.join(SAMPLES_PY_DIR, "visual_language_chat/video_to_text_chat.py") + py_command = [sys.executable, py_script, convert_model, download_test_content] + py_result = run_sample(py_command, questions) + + # Compare results + assert py_result.stdout == cpp_result.stdout, f"Results should match"