Skip to content

Commit 2860ce5

Browse files
authored
DPO Llava 1.5 and PaliGemma support (huggingface#1797)
* llava support dpo * add_special_tokens=False only when possible * format * pali gemma * refactor size * remove image resize --------- Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 30e33bd commit 2860ce5

File tree

2 files changed

+72
-49
lines changed

2 files changed

+72
-49
lines changed

examples/scripts/dpo_visual.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
accelerate launch examples/scripts/dpo_visual.py \
1717
--dataset_name HuggingFaceH4/rlaif-v_formatted \
1818
--model_name_or_path HuggingFaceM4/idefics2-8b \
19-
--per_device_train_batch_size 1 \
20-
--gradient_accumulation_steps 16 \
19+
--per_device_train_batch_size 2 \
20+
--gradient_accumulation_steps 32 \
2121
--dataset_num_proc 32 \
2222
--output_dir dpo_idefics_rlaif-v \
2323
--bf16 \
2424
--torch_dtype bfloat16 \
25+
--gradient_checkpointing \
2526
--use_peft \
2627
--lora_target_modules=all-linear
2728
"""
@@ -82,21 +83,40 @@
8283

8384
model_kwargs = dict(
8485
revision=model_config.model_revision,
85-
trust_remote_code=model_config.trust_remote_code,
8686
attn_implementation=model_config.attn_implementation,
8787
torch_dtype=torch_dtype,
88-
use_cache=False if training_args.gradient_checkpointing else True,
8988
device_map=get_kbit_device_map() if quantization_config is not None else None,
9089
quantization_config=quantization_config,
9190
)
92-
model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
91+
model = AutoModelForVision2Seq.from_pretrained(
92+
model_config.model_name_or_path,
93+
trust_remote_code=model_config.trust_remote_code,
94+
**model_kwargs,
95+
)
9396
peft_config = get_peft_config(model_config)
9497
if peft_config is None:
95-
model_ref = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
98+
model_ref = AutoModelForVision2Seq.from_pretrained(
99+
model_config.model_name_or_path,
100+
trust_remote_code=model_config.trust_remote_code,
101+
**model_kwargs,
102+
)
96103
else:
97104
model_ref = None
98-
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=False)
105+
processor = AutoProcessor.from_pretrained(
106+
model_config.model_name_or_path,
107+
trust_remote_code=model_config.trust_remote_code,
108+
do_image_splitting=False,
109+
)
99110
tokenizer = processor.tokenizer
111+
112+
# Set up the chat template
113+
if model.config.model_type == "idefics2":
114+
pass # the processor already has a valid chat template
115+
elif model.config.model_type == "paligemma":
116+
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] if item['type'] == 'text' %}{{ item['text'] }}<|im_end|>{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
117+
elif model.config.model_type == "llava":
118+
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
119+
100120
if tokenizer.pad_token is None:
101121
tokenizer.pad_token = tokenizer.eos_token
102122
if args.ignore_bias_buffers:
@@ -124,27 +144,9 @@
124144
ds[key] = ds[key].select(range(50))
125145

126146
def process(row):
127-
# The prompt can be either a string or a list. In some datasets, the prompt is just a common string
128-
# for both rejected and chosen (already included in chosen and rejected) and is not meant to be used
129-
# separately. In other datasets, the prompt is intended to be used as a prefix for rejected and chosen,
130-
# and in such cases, it is properly formatted as a list with keys "role" and "content".
131-
# Example 1:
132-
# row = {"prompt": "What does detox mean?",
133-
# "chosen": [{"content": "What does detox mean?", "role": "user"}, {"content": "It means to get rid of the toxins.", "role": "assistant"}],
134-
# "rejected": [{"content": "What does detox mean?", "role": "assistant"}, {"content": "I don't know.", "role": "user"}]}
135-
# Example 2:
136-
# row = {"prompt": [{"content": "What does detox mean?", "role": "user"}],
137-
# "chosen": [{"content": "It means to get rid of the toxins.", "role": "assistant"}],
138-
# "rejected": [{"content": "I don't know.", "role": "user"}]}
139-
if "prompt" in row and isinstance(row["prompt"], list):
140-
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
141-
147+
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
142148
row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
143149
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)
144-
145-
if "images" in row:
146-
for img in row["images"]: # Resize each image so the largest side is 640 pixels
147-
img.thumbnail((640, 640)) # Resize the image to at most 640x640 pixels
148150
return row
149151

150152
with PartialState().local_main_process_first():
@@ -168,6 +170,6 @@ def process(row):
168170
)
169171

170172
trainer.train()
171-
trainer.push_to_hub
173+
172174
with save_context:
173175
trainer.save_model(training_args.output_dir)

trl/trainer/dpo_trainer.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -723,9 +723,18 @@ def build_tokenized_answer(self, prompt, answer, images=None):
723723
if self.is_vision_model:
724724
if answer.count("<image>") > 0:
725725
raise NotImplementedError("Answer contains <image> token, which is not supported yet.")
726-
full_tokenized = self.processor(prompt + answer, images=images, add_special_tokens=False)
726+
if "add_special_tokens" in inspect.signature(self.processor).parameters:
727+
processor_kwargs = {"add_special_tokens": False}
728+
else:
729+
processor_kwargs = {}
730+
full_tokenized = self.processor(prompt + answer, images=images, **processor_kwargs)
727731
full_tokenized = {k: v[0] for k, v in full_tokenized.items()} # Unbatch, not done when using idefics
728-
prompt_input_ids = self.processor(prompt, images=images, add_special_tokens=False)["input_ids"][0]
732+
if not isinstance(full_tokenized["input_ids"], list): # llava processor returns tensors
733+
full_tokenized["input_ids"] = full_tokenized["input_ids"].tolist()
734+
full_tokenized["attention_mask"] = full_tokenized["attention_mask"].tolist()
735+
prompt_input_ids = self.processor(prompt, images=images, **processor_kwargs)["input_ids"][0]
736+
if not isinstance(prompt_input_ids, list): # llava processor returns tensors
737+
prompt_input_ids = prompt_input_ids.tolist()
729738
else:
730739
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
731740
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
@@ -762,22 +771,18 @@ def build_tokenized_answer(self, prompt, answer, images=None):
762771
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
763772
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
764773

774+
return_dict = dict(
775+
prompt_input_ids=prompt_input_ids,
776+
prompt_attention_mask=prompt_attention_mask,
777+
input_ids=answer_input_ids,
778+
attention_mask=answer_attention_mask,
779+
)
765780
if "pixel_values" in full_tokenized:
766-
return dict(
767-
prompt_input_ids=prompt_input_ids,
768-
prompt_attention_mask=prompt_attention_mask,
769-
prompt_pixel_values=full_tokenized["pixel_values"],
770-
prompt_pixel_attention_mask=full_tokenized["pixel_attention_mask"],
771-
input_ids=answer_input_ids,
772-
attention_mask=answer_attention_mask,
773-
)
774-
else:
775-
return dict(
776-
prompt_input_ids=prompt_input_ids,
777-
prompt_attention_mask=prompt_attention_mask,
778-
input_ids=answer_input_ids,
779-
attention_mask=answer_attention_mask,
780-
)
781+
return_dict["prompt_pixel_values"] = full_tokenized["pixel_values"]
782+
if "pixel_attention_mask" in full_tokenized:
783+
return_dict["prompt_pixel_attention_mask"] = full_tokenized["pixel_attention_mask"]
784+
785+
return return_dict
781786

782787
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
783788
"""Tokenize a single row from a DPO specific dataset.
@@ -805,8 +810,15 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
805810
if not isinstance(prompt, str):
806811
raise ValueError(f"prompt should be an str but got {type(prompt)}")
807812
if self.is_vision_model:
808-
prompt_tokens = self.processor(prompt, images=images, add_special_tokens=False)
813+
if "add_special_tokens" in inspect.signature(self.processor).parameters:
814+
processor_kwargs = {"add_special_tokens": False}
815+
else:
816+
processor_kwargs = {}
817+
prompt_tokens = self.processor(prompt, images=images, **processor_kwargs)
809818
prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()} # Unbatch, not done when using idefics
819+
if not isinstance(prompt_tokens["input_ids"], list): # llava processor returns tensors
820+
prompt_tokens["input_ids"] = prompt_tokens["input_ids"].tolist()
821+
prompt_tokens["attention_mask"] = prompt_tokens["attention_mask"].tolist()
810822
else:
811823
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
812824

@@ -1037,10 +1049,13 @@ def concatenated_inputs(
10371049
)
10381050

10391051
if is_vision_model:
1040-
concatenated_batch["pixel_values"] = batch["prompt_pixel_values"].repeat(2, 1, 1, 1, 1).to(device=device)
1041-
concatenated_batch["pixel_attention_mask"] = (
1042-
batch["prompt_pixel_attention_mask"].repeat(2, 1, 1, 1).to(device=device)
1052+
concatenated_batch["pixel_values"] = torch.cat(
1053+
[batch["prompt_pixel_values"], batch["prompt_pixel_values"]], dim=0
10431054
)
1055+
if "prompt_pixel_attention_mask" in batch:
1056+
concatenated_batch["pixel_attention_mask"] = torch.cat(
1057+
[batch["prompt_pixel_attention_mask"], batch["prompt_pixel_attention_mask"]], dim=0
1058+
)
10441059
return concatenated_batch
10451060

10461061
def dpo_loss(
@@ -1262,7 +1277,8 @@ def concatenated_forward(
12621277

12631278
if self.is_vision_model:
12641279
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
1265-
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
1280+
if "pixel_attention_mask" in concatenated_batch:
1281+
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
12661282

12671283
if self.aux_loss_enabled:
12681284
model_kwargs["output_router_logits"] = True
@@ -1275,6 +1291,11 @@ def concatenated_forward(
12751291
)
12761292
all_logits = outputs.logits
12771293

1294+
if all_logits.shape[:2] != concatenated_batch["concatenated_labels"].shape[:2]:
1295+
# for llava, the model returns logits for the entire sequence, including the image tokens (placed before the text tokens)
1296+
seq_len = concatenated_batch["concatenated_labels"].shape[1]
1297+
all_logits = all_logits[:, -seq_len:]
1298+
12781299
all_logps, size_completion = self.get_batch_logps(
12791300
all_logits,
12801301
concatenated_batch["concatenated_labels"],

0 commit comments

Comments
 (0)