1010from typing import Optional , Tuple , Union
1111
1212import torch
13- import torch .utils .checkpoint
1413from torch import nn
15- from torch .nn import CrossEntropyLoss
16- from transformers .cache_utils import Cache , DynamicCache
14+ from transformers .cache_utils import Cache
1715from transformers .modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
1816from transformers .models .codegen .modeling_codegen import (
1917 CodeGenAttention ,
2018 CodeGenBlock ,
2119 CodeGenForCausalLM ,
2220 CodeGenModel ,
2321 apply_rotary_pos_emb ,
24- logger ,
2522)
2623
24+ from QEfficient .transformers .cache_utils import QEffDynamicCache
2725from QEfficient .transformers .modeling_attn_mask_utils import _create_causal_mask
2826
2927
@@ -133,7 +131,7 @@ def forward(
133131 "position_ids" : position_ids ,
134132 "batch_index" : batch_index ,
135133 }
136- pkv = DynamicCache ()
134+ pkv = QEffDynamicCache ()
137135 pkv .key_cache .append (past_key_value [0 ])
138136 pkv .value_cache .append (past_key_value [1 ])
139137 key , value = pkv .update (key , value , 0 , cache_kwargs )
@@ -261,14 +259,6 @@ def forward(
261259
262260 output_shape = input_shape + (hidden_states .size (- 1 ),)
263261
264- if self .gradient_checkpointing and self .training :
265- if use_cache :
266- logger .warning_once (
267- "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
268- "`use_cache=False`..."
269- )
270- use_cache = False
271-
272262 if position_ids is None :
273263 position_ids = cache_position .unsqueeze (0 )
274264
@@ -279,41 +269,17 @@ def forward(
279269 if output_hidden_states :
280270 all_hidden_states = all_hidden_states + (hidden_states ,)
281271
282- if self .gradient_checkpointing and self .training :
283- outputs = self ._gradient_checkpointing_func (
284- block .__call__ ,
285- hidden_states ,
286- None ,
287- attention_mask ,
288- position_ids ,
289- head_mask [i ],
290- use_cache ,
291- output_attentions ,
292- cache_position ,
293- )
294- elif batch_index is not None :
295- outputs = block (
296- hidden_states = hidden_states ,
297- layer_past = layer_past ,
298- batch_index = batch_index ,
299- attention_mask = attention_mask ,
300- position_ids = position_ids ,
301- head_mask = head_mask [i ],
302- use_cache = use_cache ,
303- output_attentions = output_attentions ,
304- cache_position = cache_position ,
305- )
306- else :
307- outputs = block (
308- hidden_states = hidden_states ,
309- layer_past = layer_past ,
310- attention_mask = attention_mask ,
311- position_ids = position_ids ,
312- head_mask = head_mask [i ],
313- use_cache = use_cache ,
314- output_attentions = output_attentions ,
315- cache_position = cache_position ,
316- )
272+ outputs = block (
273+ hidden_states = hidden_states ,
274+ layer_past = layer_past ,
275+ batch_index = batch_index ,
276+ attention_mask = attention_mask ,
277+ position_ids = position_ids ,
278+ head_mask = head_mask [i ],
279+ use_cache = use_cache ,
280+ output_attentions = output_attentions ,
281+ cache_position = cache_position ,
282+ )
317283
318284 hidden_states = outputs [0 ]
319285 if use_cache is True :
@@ -398,25 +364,8 @@ def forward(
398364 hidden_states = transformer_outputs [0 ][torch .arange (position_ids .shape [0 ]).view (- 1 , 1 ), logit_index ]
399365 lm_logits = self .lm_head (hidden_states )
400366
401- loss = None
402- if labels is not None :
403- # move labels to correct device to enable model parallelism
404- labels = labels .to (lm_logits .device )
405- # Shift so that tokens < n predict n
406- shift_logits = lm_logits [..., :- 1 , :].contiguous ()
407- shift_labels = labels [..., 1 :].contiguous ()
408- # Flatten the tokens
409- loss_fct = CrossEntropyLoss ()
410- loss = loss_fct (shift_logits .view (- 1 , shift_logits .size (- 1 )), shift_labels .view (- 1 ))
411-
412- loss = loss .to (hidden_states .dtype )
413-
414- if not return_dict :
415- output = (lm_logits ,) + transformer_outputs [1 :]
416- return ((loss ,) + output ) if loss is not None else output
417-
418367 return CausalLMOutputWithPast (
419- loss = loss ,
368+ loss = None ,
420369 logits = lm_logits ,
421370 past_key_values = transformer_outputs .past_key_values ,
422371 hidden_states = transformer_outputs .hidden_states ,
0 commit comments