Skip to content

Commit 9bca441

Browse files
committed
Add code to enable compilation of submission for WebSRC test split
1 parent 7687495 commit 9bca441

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

lmms_eval/tasks/websrc/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,28 @@ def websrc_process_results(doc, results):
3939
pred = results[0]
4040
parsed_pred = pred
4141
id = doc["page_id"]
42-
websrc_ans = {"id": id, "domain": doc['domain'], "answer": doc["answer"], "parsed_pred": parsed_pred}
42+
websrc_ans = {"id": id, "domain": doc['domain'], "parsed_pred": parsed_pred}
43+
if "answer" in doc:
44+
websrc_ans["answer"] = doc["answer"]
45+
46+
if 'id' in doc:
47+
websrc_ans['question_id'] = doc['id']
48+
4349
return {
4450
"websrc_squad_f1": websrc_ans,
4551
"submission": {
46-
id: pred,
52+
websrc_ans['question_id']: pred,
4753
},
4854
}
4955

5056

5157
def websrc_test_aggregate_results_for_submission(results, args):
5258
path = generate_submission_file("websrc_test_for_submission.json", args)
5359
with open(path, "w") as f:
54-
json.dump(results, f)
60+
out = {}
61+
for result in results:
62+
out.update(result)
63+
json.dump(out, f, indent=4)
5564
lmms_logger.info(f"Results saved to {path}.")
5665

5766

lmms_eval/tasks/websrc/websrc.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
group: websrc
22
task:
33
- websrc_val
4+
- websrc_test
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
dataset_path: rootsautomation/websrc-test
2+
task: "websrc_test"
3+
test_split: test
4+
output_type: generate_until
5+
doc_to_visual: !function utils.websrc_doc_to_visual
6+
doc_to_text: !function utils.websrc_doc_to_text
7+
doc_to_target: "answer"
8+
# The return value of process_results will be used by metrics
9+
process_results: !function utils.websrc_process_results
10+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
11+
generation_kwargs:
12+
max_new_tokens: 16
13+
image_aspect_ratio: pad
14+
metric_list:
15+
- metric: submission
16+
aggregation: !function utils.websrc_test_aggregate_results_for_submission
17+
higher_is_better: true
18+
metadata:
19+
- version: 0.0

0 commit comments

Comments
 (0)