Skip to content

Commit cff5c77

Browse files
committed
[wwb] Add text reranking pipeline
1 parent 4274a9a commit cff5c77

File tree

6 files changed

+471
-2
lines changed

6 files changed

+471
-2
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import subprocess # nosec B404
2+
import pytest
3+
import logging
4+
from test_cli_image import run_wwb
5+
6+
7+
logging.basicConfig(level=logging.INFO)
8+
logger = logging.getLogger(__name__)
9+
10+
11+
@pytest.mark.parametrize(
12+
("model_id", "model_type"),
13+
[
14+
("cross-encoder/ms-marco-TinyBERT-L2-v2", "text-reranking"),
15+
],
16+
)
17+
def test_reranking_basic(model_id, model_type, tmp_path):
18+
GT_FILE = tmp_path / "gt.csv"
19+
MODEL_PATH = tmp_path / model_id.replace("/", "--")
20+
21+
result = subprocess.run(["optimum-cli", "export",
22+
"openvino", "-m", model_id,
23+
MODEL_PATH, "--task",
24+
"text-classification",
25+
"--trust-remote-code"],
26+
capture_output=True,
27+
text=True,
28+
)
29+
assert result.returncode == 0
30+
31+
# Collect reference with HF model
32+
run_wwb([
33+
"--base-model",
34+
model_id,
35+
"--num-samples",
36+
"1",
37+
"--gt-data",
38+
GT_FILE,
39+
"--device",
40+
"CPU",
41+
"--model-type",
42+
model_type,
43+
"--hf",
44+
])
45+
46+
# test Optimum
47+
run_wwb([
48+
"--target-model",
49+
MODEL_PATH,
50+
"--num-samples",
51+
"1",
52+
"--gt-data",
53+
GT_FILE,
54+
"--device",
55+
"CPU",
56+
"--model-type",
57+
model_type,
58+
])
59+
60+
# test GenAI
61+
run_wwb([
62+
"--target-model",
63+
MODEL_PATH,
64+
"--num-samples",
65+
"1",
66+
"--gt-data",
67+
GT_FILE,
68+
"--device",
69+
"CPU",
70+
"--model-type",
71+
model_type,
72+
"--genai",
73+
"--output",
74+
tmp_path,
75+
])
76+
77+
# test w/o models
78+
run_wwb([
79+
"--target-data",
80+
tmp_path / "target.csv",
81+
"--num-samples",
82+
"1",
83+
"--gt-data",
84+
GT_FILE,
85+
"--device",
86+
"CPU",
87+
"--model-type",
88+
model_type,
89+
"--genai",
90+
])
91+
92+
93+
@pytest.mark.parametrize(
94+
("model_id", "model_type"),
95+
[
96+
("Qwen/Qwen3-Reranker-0.6B", "text-reranking"),
97+
],
98+
)
99+
def test_reranking_qwen(model_id, model_type, tmp_path):
100+
GT_FILE = tmp_path / "gt.csv"
101+
MODEL_PATH = tmp_path / model_id.replace("/", "--")
102+
103+
result = subprocess.run(["optimum-cli", "export",
104+
"openvino", "-m", model_id,
105+
MODEL_PATH, "--task",
106+
"text-generation",
107+
"--trust-remote-code"],
108+
capture_output=True,
109+
text=True,
110+
)
111+
assert result.returncode == 0
112+
113+
# Collect reference with HF model
114+
run_wwb([
115+
"--base-model",
116+
model_id,
117+
"--num-samples",
118+
"1",
119+
"--gt-data",
120+
GT_FILE,
121+
"--device",
122+
"CPU",
123+
"--model-type",
124+
model_type,
125+
"--hf",
126+
])
127+
128+
# test Optimum
129+
run_wwb([
130+
"--target-model",
131+
MODEL_PATH,
132+
"--num-samples",
133+
"1",
134+
"--gt-data",
135+
GT_FILE,
136+
"--device",
137+
"CPU",
138+
"--model-type",
139+
model_type,
140+
])
141+

tools/who_what_benchmark/whowhatbench/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .visualtext_evaluator import VisualTextEvaluator
66
from .im2im_evaluator import Image2ImageEvaluator
77
from .inpaint_evaluator import InpaintingEvaluator
8+
from .reranking_evaluator import RerankingEvaluator
89

910

1011
__all__ = [
@@ -15,5 +16,6 @@
1516
"VisualTextEvaluator",
1617
"Image2ImageEvaluator",
1718
"InpaintingEvaluator",
19+
"RerankingEvaluator",
1820
"EVALUATOR_REGISTRY",
1921
]

tools/who_what_benchmark/whowhatbench/model_loaders.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, AutoTokenizer
66

7+
from .reranking_evaluator import DEF_TOP_K, DEF_MAX_LENGTH
78
from .utils import mock_torch_cuda_is_available, mock_AwqQuantizer_validate_environment
89

910

@@ -20,7 +21,7 @@ def __init__(self, model, model_dir, model_type):
2021
self.model = model
2122
self.model_type = model_type
2223

23-
if model_type == "text" or model_type == "visual-text":
24+
if model_type in ["text", "visual-text", "text-reranking"]:
2425
try:
2526
self.config = AutoConfig.from_pretrained(model_dir)
2627
except Exception:
@@ -428,6 +429,68 @@ def load_inpainting_model(
428429
return model
429430

430431

432+
def load_reranking_genai_pipeline(model_dir, device="CPU", ov_config=None):
433+
try:
434+
import openvino_genai
435+
except ImportError as e:
436+
logger.error("Failed to import openvino_genai package. Please install it. Details:\n", e)
437+
exit(-1)
438+
439+
logger.info("Using OpenVINO GenAI TextRerankPipeline API")
440+
441+
config = openvino_genai.TextRerankPipeline.Config()
442+
config.top_n = DEF_TOP_K
443+
config.max_length = DEF_MAX_LENGTH
444+
445+
pipeline = openvino_genai.TextRerankPipeline(model_dir, device.upper(), config, **ov_config)
446+
447+
return GenAIModelWrapper(
448+
pipeline,
449+
model_dir,
450+
"text-reranking"
451+
)
452+
453+
454+
def load_reranking_model(model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False):
455+
if use_hf:
456+
logger.info("Using HF Transformers API")
457+
if 'qwen3' in model_id.lower():
458+
from transformers import AutoModelForCausalLM
459+
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
460+
else:
461+
from transformers import AutoModelForSequenceClassification
462+
model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True)
463+
elif use_genai:
464+
logger.info("Using OpenVINO GenAI API")
465+
model = load_reranking_genai_pipeline(model_id, device, ov_config)
466+
else:
467+
logger.info("Using Optimum API")
468+
model_cls = None
469+
if 'qwen3' in model_id.lower():
470+
from optimum.intel.openvino import OVModelForCausalLM
471+
model_cls = OVModelForCausalLM
472+
else:
473+
from optimum.intel.openvino import OVModelForSequenceClassification
474+
model_cls = OVModelForSequenceClassification
475+
476+
try:
477+
model = model_cls.from_pretrained(
478+
model_id, device=device, ov_config=ov_config, safety_checker=None,
479+
)
480+
except ValueError as e:
481+
logger.error("Failed to load reranking pipeline. Details:\n", e)
482+
model = model_cls.from_pretrained(
483+
model_id,
484+
trust_remote_code=True,
485+
use_cache=False,
486+
device=device,
487+
ov_config=ov_config,
488+
safety_checker=None
489+
)
490+
491+
return model
492+
493+
431494
def load_model(
432495
model_type, model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False, use_llamacpp=False, **kwargs
433496
):
@@ -452,5 +515,7 @@ def load_model(
452515
return load_imagetext2image_model(model_id, device, ov_options, use_hf, use_genai)
453516
elif model_type == "image-inpainting":
454517
return load_inpainting_model(model_id, device, ov_options, use_hf, use_genai)
518+
elif model_type == "text-reranking":
519+
return load_reranking_model(model_id, device, ov_options, use_hf, use_genai)
455520
else:
456521
raise ValueError(f"Unsupported model type: {model_type}")

0 commit comments

Comments
 (0)