|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import re |
| 4 | +from collections import Counter |
| 5 | +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file |
| 6 | + |
| 7 | +PROMPT = """Question: {} |
| 8 | +(A) {} |
| 9 | +(B) {} |
| 10 | +(C) {} |
| 11 | +(D) {} |
| 12 | +(E) {} |
| 13 | +(F) {}""" |
| 14 | + |
| 15 | +def ii_bench_doc_to_text(doc, model_specific_prompt_kwargs): |
| 16 | + question = PROMPT.format(doc["question"], doc["option1"], doc["option2"], doc["option3"], doc["option4"], doc["option5"], doc["option6"]) |
| 17 | + pre_prompt = model_specific_prompt_kwargs["pre_prompt"] |
| 18 | + post_prompt = model_specific_prompt_kwargs["post_prompt"] |
| 19 | + return f"{pre_prompt}{question}{post_prompt}" |
| 20 | + |
| 21 | + |
| 22 | +def ii_bench_doc_to_visual(doc): |
| 23 | + return [doc["image"].convert("RGB")] |
| 24 | + |
| 25 | + |
| 26 | +def extract_option_labels(text, options=None): |
| 27 | + if isinstance(text, dict): |
| 28 | + return 'error' |
| 29 | + pattern = r"\(([A-F])\)" |
| 30 | + matches = re.findall(pattern, text) |
| 31 | + |
| 32 | + if not matches: |
| 33 | + pattern = r"\b([A-F])\b" |
| 34 | + matches = re.findall(pattern, text) |
| 35 | + |
| 36 | + if matches: |
| 37 | + counter = Counter(matches) |
| 38 | + most_common = counter.most_common() |
| 39 | + max_count = most_common[0][1] |
| 40 | + candidates = [item for item in most_common if item[1] == max_count] |
| 41 | + return candidates[-1][0] |
| 42 | + else: |
| 43 | + if options: |
| 44 | + counter = Counter() |
| 45 | + for i, option in enumerate(options, start=1): |
| 46 | + label = chr(64 + i) |
| 47 | + option_stripped = option.strip() |
| 48 | + if option_stripped in text: |
| 49 | + counter[label] += 1 |
| 50 | + elif text in option: |
| 51 | + counter[label] += 1 |
| 52 | + if counter: |
| 53 | + most_common = counter.most_common() |
| 54 | + max_count = most_common[0][1] |
| 55 | + candidates = [item for item in most_common if item[1] == max_count] |
| 56 | + return candidates[-1][0] |
| 57 | + return None |
| 58 | + |
| 59 | + |
| 60 | +def ii_bench_process_results(doc, results): |
| 61 | + response = results[0] |
| 62 | + predict = extract_option_labels(response, [doc["option1"], doc["option2"], doc["option3"], doc["option4"], doc["option5"], doc["option6"]]) |
| 63 | + return {"submission": {"id": doc["id"], "predict_answer": predict, "response": response}} |
| 64 | + |
| 65 | + |
| 66 | +def ii_bench_aggregate_submissions(results, args): |
| 67 | + file = generate_submission_file("ii_bench_test_for_submission.json", args) |
| 68 | + with open(file, "w") as f: |
| 69 | + json.dump(results, f, indent=4) |
| 70 | + logging.getLogger("lmms-eval").info(f"Results saved to {file}") |
0 commit comments