1
1
#!/usr/bin/env python3
2
2
3
+ import copy
3
4
from collections import namedtuple
4
5
from typing import cast , List , Optional , Union
5
6
@@ -53,9 +54,11 @@ def __init__(self):
53
54
54
55
def forward (self , input_ids , * args , ** kwargs ):
55
56
emb = self .emb (input_ids )
57
+ if "past_key_values" in kwargs :
58
+ emb = torch .cat ((kwargs ["past_key_values" ], emb ), dim = 1 )
56
59
logits = self .linear (self .trans (emb ))
57
- Result = namedtuple ("Result" , ["logits" ])
58
- return Result (logits = logits )
60
+ Result = namedtuple ("Result" , ["logits" , "past_key_values" ])
61
+ return Result (logits = logits , past_key_values = emb )
59
62
60
63
def generate (self , input_ids , * args , mock_response = None , ** kwargs ):
61
64
assert mock_response , "must mock response to use DummyLLM to geenrate"
@@ -64,16 +67,35 @@ def generate(self, input_ids, *args, mock_response=None, **kwargs):
64
67
[input_ids , torch .tensor ([response ], device = self .device )], dim = 1
65
68
)
66
69
70
+ def _update_model_kwargs_for_generation (self , outputs , model_kwargs ):
71
+ new_kwargs = copy .deepcopy (model_kwargs )
72
+ if hasattr (outputs , "past_key_values" ):
73
+ new_kwargs ["past_key_values" ] = outputs .past_key_values
74
+ return new_kwargs
75
+
76
+ def prepare_inputs_for_generation (self , model_inp , ** model_kwargs ):
77
+ if "past_key_values" in model_kwargs :
78
+ emb_len = model_kwargs ["past_key_values" ].shape [1 ]
79
+ return {
80
+ "input_ids" : model_inp [:, emb_len :],
81
+ "past_key_values" : model_kwargs ["past_key_values" ],
82
+ }
83
+ return {"input_ids" : model_inp }
84
+
67
85
@property
68
86
def device (self ):
69
87
return next (self .parameters ()).device
70
88
71
89
72
90
@parameterized_class (
73
- ("device" ,), [("cpu" ,), ("cuda" ,)] if torch .cuda .is_available () else [("cpu" ,)]
91
+ ("device" , "use_cached_outputs" ),
92
+ [("cpu" , True ), ("cpu" , False ), ("cuda" , True ), ("cuda" , False )]
93
+ if torch .cuda .is_available ()
94
+ else [("cpu" , True ), ("cpu" , False )],
74
95
)
75
96
class TestLLMAttr (BaseTest ):
76
97
device : str
98
+ use_cached_outputs : bool
77
99
78
100
@parameterized .expand ([(FeatureAblation ,), (ShapleyValueSampling ,)])
79
101
def test_llm_attr (self , AttrClass ) -> None :
@@ -83,7 +105,9 @@ def test_llm_attr(self, AttrClass) -> None:
83
105
llm_attr = LLMAttribution (AttrClass (llm ), tokenizer )
84
106
85
107
inp = TextTemplateInput ("{} b {} {} e {}" , ["a" , "c" , "d" , "f" ])
86
- res = llm_attr .attribute (inp , "m n o p q" )
108
+ res = llm_attr .attribute (
109
+ inp , "m n o p q" , use_cached_outputs = self .use_cached_outputs
110
+ )
87
111
88
112
self .assertEqual (res .seq_attr .shape , (4 ,))
89
113
self .assertEqual (cast (Tensor , res .token_attr ).shape , (5 , 4 ))
@@ -100,7 +124,11 @@ def test_llm_attr_without_target(self) -> None:
100
124
llm_fa = LLMAttribution (fa , tokenizer )
101
125
102
126
inp = TextTemplateInput ("{} b {} {} e {}" , ["a" , "c" , "d" , "f" ])
103
- res = llm_fa .attribute (inp , gen_args = {"mock_response" : "x y z" })
127
+ res = llm_fa .attribute (
128
+ inp ,
129
+ gen_args = {"mock_response" : "x y z" },
130
+ use_cached_outputs = self .use_cached_outputs ,
131
+ )
104
132
105
133
self .assertEqual (res .seq_attr .shape , (4 ,))
106
134
self .assertEqual (cast (Tensor , res .token_attr ).shape , (3 , 4 ))
@@ -117,7 +145,9 @@ def test_llm_attr_fa_log_prob(self) -> None:
117
145
llm_fa = LLMAttribution (fa , tokenizer , attr_target = "log_prob" )
118
146
119
147
inp = TextTemplateInput ("{} b {} {} e {}" , ["a" , "c" , "d" , "f" ])
120
- res = llm_fa .attribute (inp , "m n o p q" )
148
+ res = llm_fa .attribute (
149
+ inp , "m n o p q" , use_cached_outputs = self .use_cached_outputs
150
+ )
121
151
122
152
# With FeatureAblation, the seq attr in log_prob
123
153
# equals to the sum of each token attr
@@ -132,7 +162,9 @@ def test_llm_attr_without_token(self, AttrClass) -> None:
132
162
llm_fa = LLMAttribution (fa , tokenizer , attr_target = "log_prob" )
133
163
134
164
inp = TextTemplateInput ("{} b {} {} e {}" , ["a" , "c" , "d" , "f" ])
135
- res = llm_fa .attribute (inp , "m n o p q" )
165
+ res = llm_fa .attribute (
166
+ inp , "m n o p q" , use_cached_outputs = self .use_cached_outputs
167
+ )
136
168
137
169
self .assertEqual (res .seq_attr .shape , (4 ,))
138
170
self .assertEqual (res .seq_attr .device .type , self .device )
0 commit comments