Skip to content

Commit b7fd7a9

Browse files
authored
Merge pull request #59 from EvolvingLMMs-Lab/add_idefics2
add idefics2
2 parents 986139a + c5a130b commit b7fd7a9

File tree

4 files changed

+231
-1
lines changed

4 files changed

+231
-1
lines changed

lmms_eval/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"gpt4v": "GPT4V",
1010
"instructblip": "InstructBLIP",
1111
"minicpm_v": "MiniCPM_V",
12+
"idefics2": "Idefics2",
1213
}
1314

1415
for model_name, model_class in AVAILABLE_MODELS.items():

lmms_eval/models/idefics2.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import torch
2+
import logging
3+
from tqdm import tqdm
4+
from lmms_eval import utils
5+
from lmms_eval.api.instance import Instance
6+
from lmms_eval.api.model import lmms
7+
from lmms_eval.api.registry import register_model
8+
from accelerate import Accelerator, DistributedType
9+
from accelerate.state import AcceleratorState
10+
from typing import List, Optional, Union, Tuple
11+
from transformers import Idefics2ForConditionalGeneration, AutoProcessor
12+
13+
import warnings
14+
15+
warnings.filterwarnings("ignore")
16+
17+
eval_logger = logging.getLogger("lmms-eval")
18+
19+
DEFAULT_IMAGE_TOKEN = "<image>"
20+
try:
21+
import flash_attn
22+
best_fit_attn_implementation = "flash_attention_2"
23+
except ImportError:
24+
best_fit_attn_implementation = "eager"
25+
26+
@register_model("idefics2")
27+
class Idefics2(lmms):
28+
"""
29+
Idefics2 Model for Hugging Face Transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
30+
31+
Example usage:
32+
33+
accelerate launch --num_processes=8 -m lmms_eval \
34+
--model idefics2 \
35+
--model_args pretrained=HuggingFaceM4/idefics2-8b \
36+
--tasks mme \
37+
--batch_size 1 \
38+
--output_path ./logs/ \
39+
--log_samples
40+
"""
41+
42+
def __init__(
43+
self,
44+
pretrained: str = "HuggingFaceM4/idefics2-8b",
45+
revision: str = "main",
46+
device: str = "cuda",
47+
dtype: Optional[Union[str, torch.dtype]] = "float16",
48+
batch_size: int = 1,
49+
trust_remote_code: Optional[bool] = False,
50+
attn_implementation: Optional[str] = best_fit_attn_implementation,
51+
device_map: str = "",
52+
use_cache: bool = True,
53+
do_image_splitting: bool =False,
54+
**kwargs,
55+
) -> None:
56+
super().__init__()
57+
# Do not use kwargs for now
58+
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
59+
60+
accelerator = Accelerator()
61+
if accelerator.num_processes > 1 and device_map == "":
62+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
63+
self.device_map = f"cuda:{accelerator.local_process_index}"
64+
else:
65+
self._device = torch.device(device)
66+
self.device_map = device_map
67+
if isinstance(dtype, str) and dtype != "auto":
68+
dtype = getattr(torch, dtype)
69+
self._model = Idefics2ForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
70+
self._processor = AutoProcessor.from_pretrained(pretrained, do_image_splitting=do_image_splitting, revision=revision, trust_remote_code=trust_remote_code)
71+
72+
self._tokenizer = self._processor.tokenizer
73+
self._config = self._model.config
74+
self.batch_size_per_gpu = int(batch_size)
75+
self.use_cache = use_cache
76+
if accelerator.num_processes > 1 and device_map == "":
77+
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
78+
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
79+
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
80+
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
81+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
82+
kwargs = {
83+
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
84+
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
85+
}
86+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
87+
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
88+
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
89+
self._model = accelerator.prepare(self.model)
90+
else:
91+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
92+
self.accelerator = accelerator
93+
if self.accelerator.is_local_main_process:
94+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
95+
self._rank = self.accelerator.local_process_index
96+
self._world_size = self.accelerator.num_processes
97+
elif accelerator.num_processes == 1 and device_map == "auto":
98+
eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism")
99+
self._rank = 0
100+
self._word_size = 1
101+
else:
102+
eval_logger.info(f"Using single device: {self._device}")
103+
self.model.to(self._device)
104+
self._rank = 0
105+
self._word_size = 1
106+
107+
@property
108+
def config(self):
109+
# return the associated transformers.AutoConfig for the given pretrained model.
110+
return self._config
111+
112+
@property
113+
def tokenizer(self):
114+
return self._tokenizer
115+
116+
@property
117+
def model(self):
118+
# returns the model, unwrapping it if using Accelerate
119+
if hasattr(self, "accelerator"):
120+
return self.accelerator.unwrap_model(self._model)
121+
else:
122+
return self._model
123+
124+
@property
125+
def eot_token_id(self):
126+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
127+
return self.tokenizer.eos_token_id
128+
129+
@property
130+
def max_length(self):
131+
return self._max_length
132+
133+
@property
134+
def batch_size(self):
135+
return self.batch_size_per_gpu
136+
137+
@property
138+
def device(self):
139+
return self._device
140+
141+
@property
142+
def rank(self):
143+
return self._rank
144+
145+
@property
146+
def world_size(self):
147+
return self._world_size
148+
149+
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
150+
""" """
151+
add_special_tokens = False if add_special_tokens is None else add_special_tokens
152+
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
153+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
154+
if left_truncate_len:
155+
encoding = encoding[-left_truncate_len:]
156+
return encoding
157+
158+
def tok_decode(self, tokens):
159+
return self.tokenizer.decode(tokens)
160+
161+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
162+
raise NotImplementedError("Loglikelihood is not implemented for Idefics2 model")
163+
164+
def flatten(self, input):
165+
new_list = []
166+
for i in input:
167+
for j in i:
168+
new_list.append(j)
169+
return new_list
170+
171+
def generate_until(self, requests: List[Instance]) -> List[str]:
172+
res = []
173+
174+
def _collate(x):
175+
# the negative sign on len(toks) sorts descending - this has a few advantages:
176+
# - time estimates will always be over not underestimates, which is more useful for planning
177+
# - to know the size of a batch when going through the list, you know the first one is always the batch
178+
# padded context length. this is useful to simplify the batching logic and more importantly to make
179+
# automatic adaptive batches much much easier to implement
180+
# - any OOMs will happen right away rather than near the end
181+
toks = self.tok_encode(x[0])
182+
return -len(toks), x[0]
183+
184+
# we group requests by their generation_kwargs,
185+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
186+
# in the same batch.
187+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
188+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
189+
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
190+
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
191+
for chunk in chunks:
192+
contexts, all_gen_kwargs, doc_to_visuals, doc_id, tasks, splits = zip(*chunk)
193+
visuals = [doc_to_visual(self.task_dict[task][split][ids]) for ids, task, split, doc_to_visual in zip(doc_id, tasks, splits, doc_to_visuals)]
194+
# we assume all gen kwargs in the batch are the same
195+
# this is safe to assume because the `grouper` object ensures it.
196+
gen_kwargs = all_gen_kwargs[0]
197+
#
198+
until = gen_kwargs.pop("until", None)
199+
prompts = []
200+
for context, visual in zip(contexts, visuals):
201+
content = []
202+
if DEFAULT_IMAGE_TOKEN not in context:
203+
for image in visual:
204+
content.append({"type": "image"})
205+
content.append({"type": "text", "text": context})
206+
message = [{"role": "user", "content": content}]
207+
prompt = self._processor.apply_chat_template(message, add_generation_prompt=True)
208+
prompts.append(prompt)
209+
inputs = self._processor(text=prompts, images=visuals, padding=True, return_tensors="pt")
210+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
211+
output_ids = self.model.generate(**inputs, **gen_kwargs)
212+
# only retain the generated text
213+
for output_id, input_id in zip(output_ids, inputs["input_ids"]):
214+
generated_id = output_id[len(input_id):]
215+
generated_text = self.tokenizer.decode(generated_id, skip_special_tokens=True)
216+
217+
res.append(generated_text)
218+
pbar.update(1)
219+
# reorder this group of results back to original unsorted form
220+
res = re_ords.get_original(res)
221+
222+
pbar.close()
223+
return res

lmms_eval/tasks/mmmu/mmmu_val.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ process_results: !function utils.mmmu_process_results
1010
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
1111
generation_kwargs:
1212
max_new_tokens: 16
13-
image_aspect_ratio: original
13+
model_specific_generation_kwargs:
14+
llava:
15+
image_aspect_ratio: original
1416
metric_list:
1517
- metric: mmmu_acc
1618
aggregation: !function utils.mmmu_aggregate_results

lmms_eval/tasks/scienceqa/scienceqa_img.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ model_specific_prompt_kwargs:
2929
post_prompt: "\nAnswer with the option's letter from the given choices directly."
3030
qwen_vl:
3131
format: qwen_vl
32+
idefics2:
33+
format: default
34+
pre_prompt: ""
35+
post_prompt: "\nAnswer:"
3236
model_specific_generation_kwargs:
3337
llava:
3438
image_aspect_ratio: original

0 commit comments

Comments
 (0)