From 848af00287e93e30d01a3cbd31c69bbcee2e0190 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Dec 2023 10:38:38 +0000 Subject: [PATCH] fix adaption prompt bug with transformers 4.36.0 --- src/peft/tuners/adaption_prompt/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/adaption_prompt/utils.py b/src/peft/tuners/adaption_prompt/utils.py index 921982fbb7..8c10cab51b 100644 --- a/src/peft/tuners/adaption_prompt/utils.py +++ b/src/peft/tuners/adaption_prompt/utils.py @@ -73,7 +73,12 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor: seq_len = q_len if past_key_value is not None: - seq_len += past_key_value[0].shape[-2] + # Newer transformers version (>=4.36.0) have a different cache mechanism + # therefore you need to upack it from the tuple of tuples. + if isinstance(past_key_value[0], tuple): + seq_len += past_key_value[0][0].shape[-2] + else: + seq_len += past_key_value[0].shape[-2] cos, sin = model.rotary_emb(value_states, seq_len=seq_len) return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)