Skip to content

Commit 4014e70

Browse files
authored
Merge branch 'quic:main' into main
2 parents c3dc747 + a706a01 commit 4014e70

29 files changed

+2040
-4282
lines changed

QEfficient/cloud/finetune.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ def main(**kwargs):
103103
# print the datatype of the model parameters
104104
# print(get_parameter_dtypes(model))
105105

106+
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
107+
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
108+
# apply gradient checkpointing related hooks to the input embeddings. Without this we will get
109+
# "No inf checks were recorded for this optimizer." error.
110+
# Enable gradient checkpointing
111+
if train_config.gradient_checkpointing:
112+
# Note: below attribute and method is only available in HuggingFace Transformer models.
113+
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
114+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
115+
else:
116+
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
117+
106118
if train_config.use_peft:
107119
# Load the pre-trained peft model checkpoint and setup its configuration
108120
if train_config.from_peft_checkpoint:

QEfficient/finetune/configs/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class train_config:
1515
batch_size_training: int = 1
1616
context_length: int = None
1717
gradient_accumulation_steps: int = 4
18+
gradient_checkpointing: bool = False
1819
num_epochs: int = 1
1920
max_train_step: int = 0
2021
max_eval_step: int = 0

QEfficient/finetune/dataset/dataset_config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from QEfficient.finetune.dataset.samsum_dataset import (
2222
get_preprocessed_samsum as get_samsum_dataset,
2323
)
24+
from QEfficient.finetune.dataset.samsum_dataset import (
25+
get_samsum_collate_fn,
26+
)
2427

2528
DATASET_PREPROC = {
2629
"alpaca_dataset": partial(get_alpaca_dataset),
@@ -29,4 +32,7 @@
2932
"gsm8k_dataset": get_gsm8k_dataset,
3033
"custom_dataset": get_custom_dataset,
3134
}
32-
DATALOADER_COLLATE_FUNC = {"custom_dataset": get_data_collator}
35+
DATALOADER_COLLATE_FUNC = {
36+
"custom_dataset": get_data_collator,
37+
"samsum_dataset": get_samsum_collate_fn,
38+
}

QEfficient/finetune/dataset/samsum_dataset.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# -----------------------------------------------------------------------------
77

88
import datasets
9+
import torch
10+
from torch.nn.utils.rnn import pad_sequence
911

1012

1113
def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
@@ -46,3 +48,22 @@ def tokenize_add_label(sample):
4648
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
4749

4850
return dataset
51+
52+
53+
def collate_fn(batch):
54+
eos_token = batch[0]["input_ids"][-1]
55+
56+
input_ids = pad_sequence(
57+
[torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
58+
)
59+
attn_mask = pad_sequence(
60+
[torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
61+
)
62+
labels = pad_sequence(
63+
[torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
64+
)
65+
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}
66+
67+
68+
def get_samsum_collate_fn(dataset_processer, dataset_config):
69+
return collate_fn

QEfficient/finetune/utils/train_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def train(
8383
best_val_loss = float("inf")
8484
total_train_steps = 0
8585
max_steps_reached = False # Flag to indicate max training steps reached
86+
device_type = device.split(":")[0]
8687

8788
tensorboard_updates = None
8889
if train_config.enable_ddp:
@@ -95,7 +96,7 @@ def train(
9596
if device.startswith("qaic"):
9697
scaler = QAicGradScaler()
9798
else:
98-
scaler = GradScaler()
99+
scaler = GradScaler(device_type)
99100

100101
loss_0_counter = torch.tensor([0]).to(device)
101102

@@ -177,10 +178,7 @@ def train(
177178
# adjust atol & rtol this as required
178179
atol=1e-1,
179180
use_ref_output_on_mismatch=True,
180-
# report all mismatches
181-
max_failures=None,
182-
# generate unittest for each op once
183-
repeat_same_op=True,
181+
filter_config=qaic_debug.DispatchFilterConfig.default(device),
184182
dump_root_dir=train_config.dump_root_dir + str(step),
185183
) as verifier:
186184
loss = model(**batch).loss # Forward call
@@ -296,8 +294,6 @@ def train(
296294
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
297295
model, train_config, eval_dataloader, local_rank, tokenizer, device
298296
)
299-
dist.barrier()
300-
dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM)
301297
if local_rank == 0:
302298
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
303299

QEfficient/transformers/cache_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Any, Dict, Optional, Tuple
1010

1111
import torch
12-
from transformers.cache_utils import DynamicCache
12+
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
1313

1414
from QEfficient.customop import (
1515
CtxGatherFunc,
@@ -181,3 +181,28 @@ def update3D(
181181
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
182182

183183
return k_out, v_out
184+
185+
186+
class QEffEncoderDecoderCache(EncoderDecoderCache):
187+
"""
188+
Updated the `EncoderDecoderCache` to use the `QEffDynamicCache` for both self-attention and cross-attention caches.
189+
"""
190+
191+
@classmethod
192+
def from_legacy_cache(
193+
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
194+
) -> "EncoderDecoderCache":
195+
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
196+
cache = cls(
197+
self_attention_cache=QEffDynamicCache(),
198+
cross_attention_cache=QEffDynamicCache(),
199+
)
200+
if past_key_values is not None:
201+
for layer_idx in range(len(past_key_values)):
202+
key_states, value_states = past_key_values[layer_idx][:2]
203+
cache.self_attention_cache.update(key_states, value_states, layer_idx)
204+
if len(past_key_values[layer_idx]) > 2:
205+
key_states, value_states = past_key_values[layer_idx][2:]
206+
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
207+
cache.is_updated[layer_idx] = True
208+
return cache

QEfficient/transformers/models/codegen/modeling_codegen.py

Lines changed: 15 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,18 @@
1010
from typing import Optional, Tuple, Union
1111

1212
import torch
13-
import torch.utils.checkpoint
1413
from torch import nn
15-
from torch.nn import CrossEntropyLoss
16-
from transformers.cache_utils import Cache, DynamicCache
14+
from transformers.cache_utils import Cache
1715
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
1816
from transformers.models.codegen.modeling_codegen import (
1917
CodeGenAttention,
2018
CodeGenBlock,
2119
CodeGenForCausalLM,
2220
CodeGenModel,
2321
apply_rotary_pos_emb,
24-
logger,
2522
)
2623

24+
from QEfficient.transformers.cache_utils import QEffDynamicCache
2725
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
2826

2927

@@ -133,7 +131,7 @@ def forward(
133131
"position_ids": position_ids,
134132
"batch_index": batch_index,
135133
}
136-
pkv = DynamicCache()
134+
pkv = QEffDynamicCache()
137135
pkv.key_cache.append(past_key_value[0])
138136
pkv.value_cache.append(past_key_value[1])
139137
key, value = pkv.update(key, value, 0, cache_kwargs)
@@ -261,14 +259,6 @@ def forward(
261259

262260
output_shape = input_shape + (hidden_states.size(-1),)
263261

264-
if self.gradient_checkpointing and self.training:
265-
if use_cache:
266-
logger.warning_once(
267-
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
268-
"`use_cache=False`..."
269-
)
270-
use_cache = False
271-
272262
if position_ids is None:
273263
position_ids = cache_position.unsqueeze(0)
274264

@@ -279,41 +269,17 @@ def forward(
279269
if output_hidden_states:
280270
all_hidden_states = all_hidden_states + (hidden_states,)
281271

282-
if self.gradient_checkpointing and self.training:
283-
outputs = self._gradient_checkpointing_func(
284-
block.__call__,
285-
hidden_states,
286-
None,
287-
attention_mask,
288-
position_ids,
289-
head_mask[i],
290-
use_cache,
291-
output_attentions,
292-
cache_position,
293-
)
294-
elif batch_index is not None:
295-
outputs = block(
296-
hidden_states=hidden_states,
297-
layer_past=layer_past,
298-
batch_index=batch_index,
299-
attention_mask=attention_mask,
300-
position_ids=position_ids,
301-
head_mask=head_mask[i],
302-
use_cache=use_cache,
303-
output_attentions=output_attentions,
304-
cache_position=cache_position,
305-
)
306-
else:
307-
outputs = block(
308-
hidden_states=hidden_states,
309-
layer_past=layer_past,
310-
attention_mask=attention_mask,
311-
position_ids=position_ids,
312-
head_mask=head_mask[i],
313-
use_cache=use_cache,
314-
output_attentions=output_attentions,
315-
cache_position=cache_position,
316-
)
272+
outputs = block(
273+
hidden_states=hidden_states,
274+
layer_past=layer_past,
275+
batch_index=batch_index,
276+
attention_mask=attention_mask,
277+
position_ids=position_ids,
278+
head_mask=head_mask[i],
279+
use_cache=use_cache,
280+
output_attentions=output_attentions,
281+
cache_position=cache_position,
282+
)
317283

318284
hidden_states = outputs[0]
319285
if use_cache is True:
@@ -398,25 +364,8 @@ def forward(
398364
hidden_states = transformer_outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
399365
lm_logits = self.lm_head(hidden_states)
400366

401-
loss = None
402-
if labels is not None:
403-
# move labels to correct device to enable model parallelism
404-
labels = labels.to(lm_logits.device)
405-
# Shift so that tokens < n predict n
406-
shift_logits = lm_logits[..., :-1, :].contiguous()
407-
shift_labels = labels[..., 1:].contiguous()
408-
# Flatten the tokens
409-
loss_fct = CrossEntropyLoss()
410-
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
411-
412-
loss = loss.to(hidden_states.dtype)
413-
414-
if not return_dict:
415-
output = (lm_logits,) + transformer_outputs[1:]
416-
return ((loss,) + output) if loss is not None else output
417-
418367
return CausalLMOutputWithPast(
419-
loss=loss,
368+
loss=None,
420369
logits=lm_logits,
421370
past_key_values=transformer_outputs.past_key_values,
422371
hidden_states=transformer_outputs.hidden_states,

0 commit comments

Comments
 (0)