Skip to content

Commit b74fac3

Browse files
committed
update
1 parent 35ef156 commit b74fac3

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

tools/who_what_benchmark/whowhatbench/model_loaders.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, AutoTokenizer
66

7-
from .reranking_evaluator import DEF_TOP_K, DEF_MAX_LENGTH
7+
from .reranking_evaluator import DEF_TOP_K, DEF_MAX_LENGTH, reranking_base_on_causallm_arch
88
from .utils import mock_torch_cuda_is_available, mock_AwqQuantizer_validate_environment
99

1010

@@ -452,9 +452,14 @@ def load_reranking_genai_pipeline(model_dir, device="CPU", ov_config=None):
452452

453453

454454
def load_reranking_model(model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False):
455+
try:
456+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=False)
457+
except Exception:
458+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
459+
455460
if use_hf:
456461
logger.info("Using HF Transformers API")
457-
if 'qwen3' in model_id.lower():
462+
if reranking_base_on_causallm_arch(config):
458463
from transformers import AutoModelForCausalLM
459464
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
460465
else:
@@ -466,7 +471,7 @@ def load_reranking_model(model_id, device="CPU", ov_config=None, use_hf=False, u
466471
else:
467472
logger.info("Using Optimum API")
468473
model_cls = None
469-
if 'qwen3' in model_id.lower():
474+
if reranking_base_on_causallm_arch(config):
470475
from optimum.intel.openvino import OVModelForCausalLM
471476
model_cls = OVModelForCausalLM
472477
else:
@@ -478,7 +483,7 @@ def load_reranking_model(model_id, device="CPU", ov_config=None, use_hf=False, u
478483
model_id, device=device, ov_config=ov_config, safety_checker=None,
479484
)
480485
except ValueError as e:
481-
logger.error("Failed to load reranking pipeline. Details:\n", e)
486+
logger.error("Failed to load reranking pipeline, an attempt will be made again with updated parameters. Details:\n", e)
482487
model = model_cls.from_pretrained(
483488
model_id,
484489
trust_remote_code=True,

tools/who_what_benchmark/whowhatbench/reranking_evaluator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
DEF_MAX_LENGTH_QWEN = 8192
1818

1919

20+
def reranking_base_on_causallm_arch(config):
21+
return config.model_type == "qwen3" and "Qwen3ForCausalLM" in config.architectures
22+
23+
2024
def preprocess_fn(example):
2125
return {
2226
"query": example["query"],
@@ -27,7 +31,7 @@ def preprocess_fn(example):
2731
def prepare_default_data(num_samples=None):
2832
DATASET_NAME = "microsoft/ms_marco"
2933
NUM_SAMPLES = num_samples if num_samples else 24
30-
set_seed(70)
34+
set_seed(42)
3135
default_dataset = datasets.load_dataset(
3236
DATASET_NAME, 'v2.1', split="test", streaming=True
3337
).shuffle(42).take(NUM_SAMPLES)
@@ -65,7 +69,7 @@ def __init__(
6569
self.gt_data = pd.read_csv(gt_data, keep_default_na=False)
6670

6771
self.similarity = RerankingSimilarity()
68-
# self.last_cmp = None
72+
self.last_cmp = None
6973

7074
def get_generation_fn(self):
7175
return self.generation_fn
@@ -122,9 +126,7 @@ def default_gen_answer(model, tokenizer, query, passages):
122126
)
123127
for i, ele in enumerate(input_data["input_ids"]):
124128
input_data["input_ids"][i] = prefix_tokens + ele + suffix_tokens
125-
input_data = tokenizer.pad(input_data, padding=True, return_tensors="pt", max_length=DEF_MAX_LENGTH_QWEN)
126-
for key in input_data:
127-
input_data[key] = input_data[key].to(device)
129+
input_data = tokenizer.pad(input_data, padding=True, return_tensors="pt", max_length=DEF_MAX_LENGTH_QWEN).to(device)
128130
else:
129131
tokenizer_kwargs = {"truncation": True, "padding": True, "max_length": DEF_MAX_LENGTH}
130132
inputs = [query] * len(passages)
@@ -133,9 +135,8 @@ def default_gen_answer(model, tokenizer, query, passages):
133135
with torch.no_grad():
134136
outputs = model(**input_data).logits
135137

136-
if model.config.model_type == "qwen3":
138+
if reranking_base_on_causallm_arch(model.config):
137139
batch_scores = outputs[:, -1, :]
138-
139140
token_false_id = tokenizer.convert_tokens_to_ids("no")
140141
token_true_id = tokenizer.convert_tokens_to_ids("yes")
141142
true_vector = batch_scores[:, token_true_id]

tools/who_what_benchmark/whowhatbench/whowhat_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def evaluate(self, data_gold, data_prediction):
198198
# documets on the same position of top_n is different
199199
if i >= len(prediction_data) or int(score[0]) != int(prediction_data[i][0]):
200200
per_query_text.append(math.inf)
201-
mean_per_query_text.append(abs(score[1] - prediction_data[i][1]))
201+
mean_per_query_text.append(len(gold_data))
202202
else:
203203
per_query_text.append(abs(score[1] - prediction_data[i][1]))
204204
mean_per_query_text.append(abs(score[1] - prediction_data[i][1]))

0 commit comments

Comments
 (0)