Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tools/who_what_benchmark/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ datasets>=3.6.0
auto-gptq; sys_platform == "linux"
autoawq<0.2.8; sys_platform == "linux"
sentencepiece
jinja2>=3.1.0
jinja2>=3.1.0
scipy
140 changes: 140 additions & 0 deletions tools/who_what_benchmark/tests/test_cli_reranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import subprocess # nosec B404
import pytest
import logging
from test_cli_image import run_wwb


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@pytest.mark.parametrize(
("model_id", "model_type"),
[
("cross-encoder/ms-marco-TinyBERT-L2-v2", "text-reranking"),
],
)
def test_reranking_basic(model_id, model_type, tmp_path):
GT_FILE = tmp_path / "gt.csv"
MODEL_PATH = tmp_path / model_id.replace("/", "--")

result = subprocess.run(["optimum-cli", "export",
"openvino", "-m", model_id,
MODEL_PATH, "--task",
"text-classification",
"--trust-remote-code"],
capture_output=True,
text=True,
)
assert result.returncode == 0

# Collect reference with HF model
run_wwb([
"--base-model",
model_id,
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
"--hf",
])

# test Optimum
run_wwb([
"--target-model",
MODEL_PATH,
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
])

# test GenAI
run_wwb([
"--target-model",
MODEL_PATH,
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
"--genai",
"--output",
tmp_path,
])

# test w/o models
run_wwb([
"--target-data",
tmp_path / "target.csv",
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
"--genai",
])


@pytest.mark.parametrize(
("model_id", "model_type"),
[
("Qwen/Qwen3-Reranker-0.6B", "text-reranking"),
],
)
def test_reranking_qwen(model_id, model_type, tmp_path):
GT_FILE = tmp_path / "gt.csv"
MODEL_PATH = tmp_path / model_id.replace("/", "--")

result = subprocess.run(["optimum-cli", "export",
"openvino", "-m", model_id,
MODEL_PATH, "--task",
"text-generation",
"--trust-remote-code"],
capture_output=True,
text=True,
)
assert result.returncode == 0

# Collect reference with HF model
run_wwb([
"--base-model",
model_id,
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
"--hf",
])

# test Optimum
run_wwb([
"--target-model",
MODEL_PATH,
"--num-samples",
"1",
"--gt-data",
GT_FILE,
"--device",
"CPU",
"--model-type",
model_type,
])
2 changes: 2 additions & 0 deletions tools/who_what_benchmark/whowhatbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .visualtext_evaluator import VisualTextEvaluator
from .im2im_evaluator import Image2ImageEvaluator
from .inpaint_evaluator import InpaintingEvaluator
from .reranking_evaluator import RerankingEvaluator


__all__ = [
Expand All @@ -15,5 +16,6 @@
"VisualTextEvaluator",
"Image2ImageEvaluator",
"InpaintingEvaluator",
"RerankingEvaluator",
"EVALUATOR_REGISTRY",
]
72 changes: 71 additions & 1 deletion tools/who_what_benchmark/whowhatbench/model_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, AutoTokenizer

from .reranking_evaluator import DEF_TOP_K, DEF_MAX_LENGTH, reranking_base_on_causallm_arch
from .utils import mock_torch_cuda_is_available, mock_AwqQuantizer_validate_environment


Expand All @@ -20,7 +21,7 @@ def __init__(self, model, model_dir, model_type):
self.model = model
self.model_type = model_type

if model_type == "text" or model_type == "visual-text":
if model_type in ["text", "visual-text", "text-reranking"]:
try:
self.config = AutoConfig.from_pretrained(model_dir)
except Exception:
Expand Down Expand Up @@ -428,6 +429,73 @@ def load_inpainting_model(
return model


def load_reranking_genai_pipeline(model_dir, device="CPU", ov_config=None):
try:
import openvino_genai
except ImportError as e:
logger.error("Failed to import openvino_genai package. Please install it. Details:\n", e)
exit(-1)

logger.info("Using OpenVINO GenAI TextRerankPipeline API")

config = openvino_genai.TextRerankPipeline.Config()
config.top_n = DEF_TOP_K
config.max_length = DEF_MAX_LENGTH

pipeline = openvino_genai.TextRerankPipeline(model_dir, device.upper(), config, **ov_config)

return GenAIModelWrapper(
pipeline,
model_dir,
"text-reranking"
)


def load_reranking_model(model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False):
try:
config = AutoConfig.from_pretrained(model_id, trust_remote_code=False)
except Exception:
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)

if use_hf:
logger.info("Using HF Transformers API")
if reranking_base_on_causallm_arch(config):
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
else:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True)
elif use_genai:
logger.info("Using OpenVINO GenAI API")
model = load_reranking_genai_pipeline(model_id, device, ov_config)
else:
logger.info("Using Optimum API")
model_cls = None
if reranking_base_on_causallm_arch(config):
from optimum.intel.openvino import OVModelForCausalLM
model_cls = OVModelForCausalLM
else:
from optimum.intel.openvino import OVModelForSequenceClassification
model_cls = OVModelForSequenceClassification

try:
model = model_cls.from_pretrained(
model_id, device=device, ov_config=ov_config, safety_checker=None,
)
except ValueError as e:
logger.error("Failed to load reranking pipeline, an attempt will be made again with updated parameters. Details:\n", e)
model = model_cls.from_pretrained(
model_id,
trust_remote_code=True,
use_cache=False,
device=device,
ov_config=ov_config,
safety_checker=None
)

return model


def load_model(
model_type, model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False, use_llamacpp=False, **kwargs
):
Expand All @@ -452,5 +520,7 @@ def load_model(
return load_imagetext2image_model(model_id, device, ov_options, use_hf, use_genai)
elif model_type == "image-inpainting":
return load_inpainting_model(model_id, device, ov_options, use_hf, use_genai)
elif model_type == "text-reranking":
return load_reranking_model(model_id, device, ov_options, use_hf, use_genai)
else:
raise ValueError(f"Unsupported model type: {model_type}")
Loading
Loading