17
17
DEF_MAX_LENGTH_QWEN = 8192
18
18
19
19
20
+ def reranking_base_on_causallm_arch (config ):
21
+ return config .model_type == "qwen3" and "Qwen3ForCausalLM" in config .architectures
22
+
23
+
20
24
def preprocess_fn (example ):
21
25
return {
22
26
"query" : example ["query" ],
@@ -27,7 +31,7 @@ def preprocess_fn(example):
27
31
def prepare_default_data (num_samples = None ):
28
32
DATASET_NAME = "microsoft/ms_marco"
29
33
NUM_SAMPLES = num_samples if num_samples else 24
30
- set_seed (70 )
34
+ set_seed (42 )
31
35
default_dataset = datasets .load_dataset (
32
36
DATASET_NAME , 'v2.1' , split = "test" , streaming = True
33
37
).shuffle (42 ).take (NUM_SAMPLES )
@@ -65,7 +69,7 @@ def __init__(
65
69
self .gt_data = pd .read_csv (gt_data , keep_default_na = False )
66
70
67
71
self .similarity = RerankingSimilarity ()
68
- # self.last_cmp = None
72
+ self .last_cmp = None
69
73
70
74
def get_generation_fn (self ):
71
75
return self .generation_fn
@@ -106,6 +110,9 @@ def default_gen_answer(model, tokenizer, query, passages):
106
110
device = "cpu"
107
111
if hasattr (model , "device" ):
108
112
device = model .device
113
+
114
+ # post/pre processing for qwen models added according to transformers Qwen3-Embedding-0.6B model card:
115
+ # https://huggingface.co/Qwen/Qwen3-Reranker-0.6B#transformers-usage
109
116
if model .config .model_type == "qwen3" :
110
117
prefix = '<|im_start|>system\n Judge whether the Document meets the requirements based on the Query and the' \
111
118
+ 'Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n <|im_start|>user\n '
@@ -122,9 +129,7 @@ def default_gen_answer(model, tokenizer, query, passages):
122
129
)
123
130
for i , ele in enumerate (input_data ["input_ids" ]):
124
131
input_data ["input_ids" ][i ] = prefix_tokens + ele + suffix_tokens
125
- input_data = tokenizer .pad (input_data , padding = True , return_tensors = "pt" , max_length = DEF_MAX_LENGTH_QWEN )
126
- for key in input_data :
127
- input_data [key ] = input_data [key ].to (device )
132
+ input_data = tokenizer .pad (input_data , padding = True , return_tensors = "pt" , max_length = DEF_MAX_LENGTH_QWEN ).to (device )
128
133
else :
129
134
tokenizer_kwargs = {"truncation" : True , "padding" : True , "max_length" : DEF_MAX_LENGTH }
130
135
inputs = [query ] * len (passages )
@@ -133,9 +138,8 @@ def default_gen_answer(model, tokenizer, query, passages):
133
138
with torch .no_grad ():
134
139
outputs = model (** input_data ).logits
135
140
136
- if model .config . model_type == "qwen3" :
141
+ if reranking_base_on_causallm_arch ( model .config ) :
137
142
batch_scores = outputs [:, - 1 , :]
138
-
139
143
token_false_id = tokenizer .convert_tokens_to_ids ("no" )
140
144
token_true_id = tokenizer .convert_tokens_to_ids ("yes" )
141
145
true_vector = batch_scores [:, token_true_id ]
0 commit comments