Skip to content

Commit 0914718

Browse files
authored
LiveBench June (#129)
* init live bench * update path * chore: Refactor live_bench package structure and update dependencies * update * Merge remote-tracking branch 'origin/internal_main_dev' * Refactor live_bench package structure and update dependencies * Refactor live_bench package structure and update dependencies * Fix execution count in example.ipynb * extract_infomation * Refactor extract_infomation.py to improve text extraction from HTML * fix * fix * extract infomation * chore: Refactor extract_infomation.py for improved readability and maintainability * chore: Refactor data_generator prompt.md and check_prompt.md for improved clarity and instructions * lint * update * update prompt * extract infomation * add info * lint * update * filter * update version of live_bench * Update model version to gemini-1.5-pro * update * livebench_eval * livebench * update
1 parent 63532aa commit 0914718

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3388
-2
lines changed

lmms_eval/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def evaluate(
325325
# hack: remove image columns to speed avoid loading images and speed up postprocessing
326326
# reason: doc_iterator will actually load image if it's in the doc.
327327
docs = task.test_docs() if task.has_test_docs() else task.validation_docs()
328-
if "d170" not in task_name and "dc100" not in task_name and "dc200" not in task_name and "llava_wilder" not in task_name and "livebench" not in task_name:
328+
if "d170" not in task_name and "dc100" not in task_name and "dc200" not in task_name and "llava_wilder" not in task_name and "live_bench" not in task_name:
329329
remove_cols = []
330330
features = docs.features
331331
# If it is an Image instance or a Sequence of Image instance. Remove it

lmms_eval/models/claude.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def generate_until(self, requests) -> List[str]:
238238
pbar.update(1)
239239
continue
240240

241+
response_text = message.content[0].text
241242
res.append(message.content[0].text)
242243
pbar.update(1)
243244

lmms_eval/models/gemini_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
class GeminiAPI(lmms):
3232
def __init__(
3333
self,
34-
model_version: str = "gemini-1.5-flash-latest",
34+
model_version: str = "gemini-1.5-pro",
3535
modality: str = "image",
3636
timeout: int = 120,
3737
continual_mode: bool = False,
@@ -46,6 +46,8 @@ def __init__(
4646
if self.continual_mode and response_persistent_folder is None:
4747
raise ValueError("Continual mode requires a persistent path for the response. We will cache the Gemini API response in this path and use it for future requests. Please provide a valid path.")
4848
self.response_persistent_folder = response_persistent_folder
49+
if not os.path.exists(self.response_persistent_folder):
50+
os.makedirs(self.response_persistent_folder)
4951
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")
5052

5153
if os.path.exists(self.response_persistent_file):

lmms_eval/models/model_utils/load_video.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def record_video_length_packet(container):
2929

3030

3131
def read_video_pyav(video_path, num_frm=8):
32+
container = av.open(video_path)
33+
3234
if "webm" not in video_path and "mkv" not in video_path:
3335
# For mp4, we try loading with stream first
3436
try:
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
dataset_path: lmms-lab/LiveBench
2+
dataset_kwargs:
3+
token: True
4+
task: "live_bench"
5+
test_split: test
6+
dataset_name: 2024-06
7+
output_type: generate_until
8+
doc_to_visual: !function utils.livebench_doc_to_visual
9+
doc_to_text: !function utils.livebench_doc_to_text
10+
doc_to_target: "answer"
11+
generation_kwargs:
12+
max_new_tokens: 1024
13+
temperature: 0
14+
top_p: 1.0
15+
num_beams: 1
16+
do_sample: false
17+
process_results: !function utils.livebench_process_results
18+
metric_list:
19+
- metric: gpt4_eval_score
20+
aggregation: !function utils.livebench_aggregate_results
21+
higher_is_better: true
22+
model_specific_prompt_kwargs:
23+
default:
24+
pre_prompt: ""
25+
post_prompt: ""
26+
metadata:
27+
version: "2024-06"
28+
api_type : openai
29+
gpt_eval_model_name: "gpt-4o"
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from pathlib import Path
2+
import yaml
3+
import os
4+
import requests
5+
import logging
6+
import time
7+
import base64
8+
import openai
9+
import json
10+
from io import BytesIO
11+
from tqdm import tqdm
12+
import pandas as pd
13+
import numpy as np
14+
15+
16+
eval_logger = logging.getLogger("lmms-eval")
17+
18+
19+
with open(Path(__file__).parent / "live_bench.yaml", "r") as f:
20+
raw_data = f.readlines()
21+
safe_data = []
22+
for i, line in enumerate(raw_data):
23+
# remove function definition since yaml load cannot handle it
24+
if "!function" not in line:
25+
safe_data.append(line)
26+
27+
config = yaml.safe_load("".join(safe_data))
28+
29+
GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
30+
API_TYPE = config["metadata"]["api_type"]
31+
32+
if API_TYPE == "openai":
33+
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
34+
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
35+
headers = {
36+
"Authorization": f"Bearer {API_KEY}",
37+
"Content-Type": "application/json",
38+
}
39+
40+
elif API_TYPE == "azure":
41+
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
42+
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
43+
headers = {
44+
"api-key": API_KEY,
45+
"Content-Type": "application/json",
46+
}
47+
48+
_PROMPT_WITH_IMAGE = """\
49+
[Question]
50+
51+
{prompt}
52+
53+
[Assistant Response]
54+
55+
{generation}
56+
57+
[Ground Truth Response]
58+
59+
{reference}
60+
61+
[System]
62+
63+
Rate whether the assistant response correctly matches the ground truth, in regards to the image above.
64+
65+
The rating should be 0-10, where 0 is incorrect and 10 is correct.
66+
67+
Below is the specific criteria for rating:
68+
69+
{criteria}
70+
71+
Total score is out of 10. If the model's answer cannot be provided due to political reasons, please assign a score of 0 for further processing. If the model's response is biased due to political factors, please score it based on its understanding of the image, but reduce the objectivity score accordingly.
72+
73+
Your response should be in the JSON format:
74+
```json
75+
{{
76+
"Explanation": "(your explanation)",
77+
"Rating": "(int)"
78+
}}
79+
```
80+
"""
81+
82+
83+
def format_prompt(question, ground_truth_answer, answer, criteria):
84+
return _PROMPT_WITH_IMAGE.format(prompt=question, generation=answer, reference=ground_truth_answer, criteria=criteria)
85+
86+
87+
def get_chat_response(base64_images, question, ground_truth_answer, answer, criteria, max_retries=5, wait_time=10):
88+
client = openai.OpenAI(api_key=API_KEY)
89+
90+
content = []
91+
for base64_image in base64_images:
92+
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}})
93+
prompt = format_prompt(question, ground_truth_answer, answer, criteria)
94+
content.append(
95+
{
96+
"type": "text",
97+
"text": prompt,
98+
}
99+
)
100+
101+
messages = [
102+
{
103+
"role": "user",
104+
"content": content,
105+
}
106+
]
107+
108+
# payload = {
109+
# "model": GPT_EVAL_MODEL_NAME,
110+
# "response_format": {"type": "json_object"},
111+
# "max_tokens": 1024,
112+
# "temperature": 0.0,
113+
# }
114+
115+
for attempt in range(max_retries):
116+
try:
117+
response = client.chat.completions.create(model=GPT_EVAL_MODEL_NAME, messages=messages, max_tokens=1024, response_format={"type": "json_object"}, temperature=0.0)
118+
response_data = response.choices[0].message.content
119+
# print(response_data)
120+
response_data = json.loads(response_data)
121+
rating = response_data["Rating"]
122+
explanation = response_data["Explanation"]
123+
return rating, explanation, GPT_EVAL_MODEL_NAME
124+
except requests.exceptions.RequestException as e:
125+
eval_logger.warning(f"Request failed on attempt {attempt + 1}: {e}")
126+
time.sleep(wait_time)
127+
if attempt == max_retries - 1:
128+
eval_logger.error(f"Failed to get response after {max_retries} attempts")
129+
return -1, str(e), GPT_EVAL_MODEL_NAME
130+
except Exception as e:
131+
eval_logger.error(f"Error on attempt {attempt + 1}: {e}")
132+
return -1, str(e), GPT_EVAL_MODEL_NAME
133+
134+
135+
def image_to_base64(pil_image):
136+
buffered = BytesIO()
137+
pil_image.save(buffered, format="PNG")
138+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
139+
140+
141+
_images = {}
142+
143+
dataset = None
144+
145+
146+
def livebench_doc_to_visual(doc):
147+
img_list = [image.convert("RGB") for image in doc["images"]]
148+
return img_list
149+
150+
151+
def livebench_doc_to_text(doc, model_specific_prompt_kwargs=None):
152+
if model_specific_prompt_kwargs is None:
153+
model_specific_prompt_kwargs = {}
154+
pre_prompt = model_specific_prompt_kwargs.get("pre_prompt", "")
155+
post_prompt = model_specific_prompt_kwargs.get("post_prompt", "")
156+
return f"{pre_prompt}{doc['question']}{post_prompt}"
157+
158+
159+
SUBTASKS = ("Basic Understanding", "Contextual Analysis", "Deeper Implications", "Broader Implications", "Further Insights")
160+
161+
162+
def livebench_process_results(doc, results):
163+
base64_images = [image_to_base64(image) for image in livebench_doc_to_visual(doc)]
164+
subtask = doc["subtask"]
165+
criteria = doc["criteria"]
166+
if subtask not in SUBTASKS:
167+
subtask = "further insights"
168+
if not results:
169+
return {"gpt4_eval_score": {"rating": -1, "explanation": "No response", "model_name": "N/A", "subtask": subtask}}
170+
rating, explanation, model_name = get_chat_response(base64_images=base64_images, question=doc["question"], ground_truth_answer=doc["answer"], answer=results[0] if results else "", criteria=criteria)
171+
if rating >= 0:
172+
return {"gpt4_eval_score": {"rating": rating, "explanation": explanation, "model_name": model_name, "subtask": subtask, "id": doc["id"]}}
173+
else:
174+
return {"gpt4_eval_score": {"rating": -1, "explanation": explanation, "model_name": "N/A", "subtask": subtask, "id": doc["id"]}}
175+
176+
177+
def livebench_aggregate_results(results):
178+
sum_score, count = 0, 0
179+
score = {}
180+
for subtask in SUBTASKS:
181+
score[subtask] = []
182+
for result in results:
183+
if result["rating"] == -1:
184+
continue
185+
sum_score += result["rating"] / 10
186+
count += 1
187+
subtask = result["subtask"]
188+
if subtask not in SUBTASKS:
189+
subtask = "further insights"
190+
score[result["subtask"]].append(result["rating"] / 10)
191+
res = pd.DataFrame([(subtask, len(score[subtask]), np.mean(score[subtask]) * 100) for subtask in SUBTASKS], columns=["Subtask", "Count", "Average Score"])
192+
print("=" * 50)
193+
print(res)
194+
print("=" * 50)
195+
if count == 0:
196+
eval_logger.warning("No valid scores to aggregate")
197+
return sum_score / count * 100 if count > 0 else None

tools/live_bench/create_dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from live_bench.websites import load_websites, load_websites_from_file
2+
from live_bench import LiveBench
3+
4+
5+
if __name__ == "__main__":
6+
website = load_websites()
7+
dataset = LiveBench(force_clear=False, name="2024-06")
8+
dataset.capture(websites=website, driver_kwargs={"headless": True}, screen_shoter="single_screen", shoter_kwargs={"screen_size": (1024, 1024)}, qa_generator="gpt4v", scorer="gpt4v", checker="gemini")
9+
10+
website = load_websites_from_file("/data/pufanyi/project/lmms-eval/temp/images")
11+
dataset.capture(websites=website, screen_shoter="human", qa_generator="gpt4v", scorer="gpt4v", checker="gemini", driver_kwargs={}, shoter_kwargs={}, generator_kwargs={})
12+
dataset.upload()

0 commit comments

Comments
 (0)