17
17
DEF_MAX_LENGTH_QWEN = 8192
18
18
19
19
20
+ def reranking_base_on_causallm_arch (config ):
21
+ return config .model_type == "qwen3" and "Qwen3ForCausalLM" in config .architectures
22
+
23
+
20
24
def preprocess_fn (example ):
21
25
return {
22
26
"query" : example ["query" ],
@@ -27,7 +31,7 @@ def preprocess_fn(example):
27
31
def prepare_default_data (num_samples = None ):
28
32
DATASET_NAME = "microsoft/ms_marco"
29
33
NUM_SAMPLES = num_samples if num_samples else 24
30
- set_seed (70 )
34
+ set_seed (42 )
31
35
default_dataset = datasets .load_dataset (
32
36
DATASET_NAME , 'v2.1' , split = "test" , streaming = True
33
37
).shuffle (42 ).take (NUM_SAMPLES )
@@ -65,7 +69,7 @@ def __init__(
65
69
self .gt_data = pd .read_csv (gt_data , keep_default_na = False )
66
70
67
71
self .similarity = RerankingSimilarity ()
68
- # self.last_cmp = None
72
+ self .last_cmp = None
69
73
70
74
def get_generation_fn (self ):
71
75
return self .generation_fn
@@ -122,9 +126,7 @@ def default_gen_answer(model, tokenizer, query, passages):
122
126
)
123
127
for i , ele in enumerate (input_data ["input_ids" ]):
124
128
input_data ["input_ids" ][i ] = prefix_tokens + ele + suffix_tokens
125
- input_data = tokenizer .pad (input_data , padding = True , return_tensors = "pt" , max_length = DEF_MAX_LENGTH_QWEN )
126
- for key in input_data :
127
- input_data [key ] = input_data [key ].to (device )
129
+ input_data = tokenizer .pad (input_data , padding = True , return_tensors = "pt" , max_length = DEF_MAX_LENGTH_QWEN ).to (device )
128
130
else :
129
131
tokenizer_kwargs = {"truncation" : True , "padding" : True , "max_length" : DEF_MAX_LENGTH }
130
132
inputs = [query ] * len (passages )
@@ -133,9 +135,8 @@ def default_gen_answer(model, tokenizer, query, passages):
133
135
with torch .no_grad ():
134
136
outputs = model (** input_data ).logits
135
137
136
- if model .config . model_type == "qwen3" :
138
+ if reranking_base_on_causallm_arch ( model .config ) :
137
139
batch_scores = outputs [:, - 1 , :]
138
-
139
140
token_false_id = tokenizer .convert_tokens_to_ids ("no" )
140
141
token_true_id = tokenizer .convert_tokens_to_ids ("yes" )
141
142
true_vector = batch_scores [:, token_true_id ]
0 commit comments