@@ -50,7 +50,7 @@ def websrc_process_results(doc, results):
5050 "websrc_squad_f1" : websrc_ans ,
5151 "submission" : {
5252 websrc_ans ['question_id' ]: pred ,
53- },
53+ } if 'question_id' in websrc_ans else None
5454 }
5555
5656
@@ -122,27 +122,39 @@ def _normalize_str(string):
122122 # lower it
123123 string = string .lower ()
124124
125- # strip non-alphanumeric characters
126- string = re .sub (r"[^a-zA-Z0-9]" , "" , string )
127-
128125 # strip leading and trailing whitespaces
129126 string = string .strip ()
130127
131128 return string
132129
130+ def _tokenize (text ):
131+ # Regex pattern to match words and isolate punctuation
132+ pattern = r'\w+|[^\w\s]'
133+ tokens = re .findall (pattern , text )
134+ return tokens
135+
136+ def _compute_f1 (sa , sb ):
137+ sa = _normalize_str (sa )
138+ sb = _normalize_str (sb )
139+
140+ sa = _tokenize (sa )
141+ sb = _tokenize (sb )
142+
143+ sa = set (sa )
144+ sb = set (sb )
145+
146+ if len (sa ) == 0 or len (sb ) == 0 :
147+ return 0.0
148+
149+ comm = sa .intersection (sb )
150+ prec = len (comm ) / len (sb )
151+ rec = len (comm ) / len (sa )
152+ f1 = 2 * prec * rec / (prec + rec ) if prec + rec > 0 else 0
153+ return f1
154+
133155 judge_list = []
134156 for sample in samples :
135- gold_i = set (_normalize_str (sample ["answer" ]))
136- pred_i = set (_normalize_str ( sample ["parsed_pred" ]))
137- if len (pred_i ) == 0 :
138- judge_list .append (0.0 )
139- continue
140-
141- comm_i = gold_i .intersection (pred_i )
142- prec_i = len (comm_i ) / len (pred_i )
143- rec_i = len (comm_i ) / len (gold_i )
144- f1_i = 2 * prec_i * rec_i / (prec_i + rec_i ) if prec_i + rec_i > 0 else 0
145- judge_list .append (f1_i )
157+ judge_list .append (_compute_f1 (sample ["answer" ], sample ["parsed_pred" ]))
146158
147159 f1 = np .mean (judge_list )
148160 return judge_list , {"f1" : f1 }
0 commit comments