|
| 1 | +from contextlib import nullcontext |
1 | 2 | from functools import partial
|
2 | 3 | from typing import cast
|
| 4 | +from unittest.mock import MagicMock |
3 | 5 |
|
4 | 6 | import numpy as np
|
5 | 7 | import pytest
|
@@ -526,6 +528,100 @@ def _rand_audio(
|
526 | 528 | return rng.rand(audio_len), sr
|
527 | 529 |
|
528 | 530 |
|
| 531 | +@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) |
| 532 | +@pytest.mark.parametrize( |
| 533 | + ("limit", "num_supported", "is_valid"), |
| 534 | + [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), |
| 535 | + (2, 1, False), (2, 2, True)], |
| 536 | +) |
| 537 | +def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): |
| 538 | + limit_mm_per_prompt = {"image": limit} |
| 539 | + |
| 540 | + model_config = ModelConfig( |
| 541 | + model=model_id, |
| 542 | + task="auto", |
| 543 | + tokenizer=model_id, |
| 544 | + tokenizer_mode="auto", |
| 545 | + trust_remote_code=False, |
| 546 | + seed=0, |
| 547 | + dtype="half", |
| 548 | + revision=None, |
| 549 | + limit_mm_per_prompt=limit_mm_per_prompt, |
| 550 | + ) |
| 551 | + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) |
| 552 | + |
| 553 | + processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] |
| 554 | + ctx = InputProcessingContext( |
| 555 | + model_config, |
| 556 | + tokenizer=cached_get_tokenizer(model_config.tokenizer), |
| 557 | + ) |
| 558 | + |
| 559 | + processor = processor_factory(ctx, cache=None) |
| 560 | + |
| 561 | + mock_supported_mm_limits = MagicMock(return_value={"image": num_supported}) |
| 562 | + processor.get_supported_mm_limits = mock_supported_mm_limits |
| 563 | + |
| 564 | + if is_valid: |
| 565 | + exc_ctx = nullcontext() |
| 566 | + else: |
| 567 | + exc_ctx = pytest.raises(ValueError, match="this model only supports") |
| 568 | + |
| 569 | + with exc_ctx: |
| 570 | + processor._get_and_validate_dummy_mm_counts() |
| 571 | + |
| 572 | + |
| 573 | +@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) |
| 574 | +@pytest.mark.parametrize( |
| 575 | + ("num_images", "limit", "is_valid"), |
| 576 | + [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), |
| 577 | + (2, 1, False), (2, 2, True)], |
| 578 | +) |
| 579 | +def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): |
| 580 | + limit_mm_per_prompt = {"image": limit} |
| 581 | + |
| 582 | + model_config = ModelConfig( |
| 583 | + model=model_id, |
| 584 | + task="auto", |
| 585 | + tokenizer=model_id, |
| 586 | + tokenizer_mode="auto", |
| 587 | + trust_remote_code=False, |
| 588 | + seed=0, |
| 589 | + dtype="half", |
| 590 | + revision=None, |
| 591 | + limit_mm_per_prompt=limit_mm_per_prompt, |
| 592 | + ) |
| 593 | + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) |
| 594 | + |
| 595 | + processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] |
| 596 | + ctx = InputProcessingContext( |
| 597 | + model_config, |
| 598 | + tokenizer=cached_get_tokenizer(model_config.tokenizer), |
| 599 | + ) |
| 600 | + |
| 601 | + processor = processor_factory(ctx, cache=None) |
| 602 | + |
| 603 | + rng = np.random.RandomState(0) |
| 604 | + image = _rand_img(rng, min_wh=128, max_wh=256) |
| 605 | + if num_images == 0: |
| 606 | + mm_data = {} |
| 607 | + elif num_images == 1: |
| 608 | + mm_data = {"image": image} |
| 609 | + else: |
| 610 | + mm_data = {"image": [image] * num_images} |
| 611 | + |
| 612 | + if is_valid: |
| 613 | + exc_ctx = nullcontext() |
| 614 | + else: |
| 615 | + exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image") |
| 616 | + |
| 617 | + with exc_ctx: |
| 618 | + processor.apply( |
| 619 | + "<image>" * num_images, |
| 620 | + mm_data=mm_data, |
| 621 | + hf_processor_mm_kwargs={}, |
| 622 | + ) |
| 623 | + |
| 624 | + |
529 | 625 | def _test_processing_cache_correctness(
|
530 | 626 | model_id: str,
|
531 | 627 | modalities: dict[str, bool],
|
@@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
|
631 | 727 | ("facebook/chameleon-7b", {"image": False}),
|
632 | 728 | ("adept/fuyu-8b", {"image": False}),
|
633 | 729 | ("llava-hf/llava-1.5-7b-hf", {"image": True}),
|
| 730 | + ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), |
634 | 731 | ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
|
635 | 732 | ("mistral-community/pixtral-12b", {"image": True}),
|
636 | 733 | ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
|
|
0 commit comments