6
6
7
7
import logging
8
8
from enum import Enum
9
- from typing import Tuple
9
+ from typing import Optional , Tuple
10
10
11
11
import torch
12
12
import torch .nn as nn
@@ -93,7 +93,7 @@ def _quantize(self, value):
93
93
)
94
94
return quantized_value , scales , zero_points
95
95
96
- def _quantize_and_update (self , input_pos , k_val , v_val ):
96
+ def _quantize_and_update (self , input_pos , k_val , v_val , indices = None ):
97
97
quantized_k_val , k_scales , k_zero_points = self ._quantize (k_val )
98
98
quantized_v_val , v_scales , v_zero_points = self ._quantize (v_val )
99
99
@@ -104,26 +104,57 @@ def _quantize_and_update(self, input_pos, k_val, v_val):
104
104
105
105
if self .use_custom_update_cache_op :
106
106
start_pos = input_pos [0 ].item ()
107
- _ = torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
108
- _ = torch .ops .llama .update_cache (k_scales , self .k_cache_scales , start_pos )
109
- _ = torch .ops .llama .update_cache (
110
- k_zero_points , self .k_cache_zero_points , start_pos
111
- )
112
- _ = torch .ops .llama .update_cache (quantized_v_val , self .v_cache , start_pos )
113
- _ = torch .ops .llama .update_cache (v_scales , self .v_cache_scales , start_pos )
114
- _ = torch .ops .llama .update_cache (
115
- v_zero_points , self .v_cache_zero_points , start_pos
116
- )
107
+ if indices is not None :
108
+ _ = torch .ops .llama .update_cache_with_indices (
109
+ quantized_k_val , self .k_cache , start_pos , indices
110
+ )
111
+ _ = torch .ops .llama .update_cache_with_indices (
112
+ k_scales , self .k_cache_scales , start_pos , indices
113
+ )
114
+ _ = torch .ops .llama .update_cache_with_indices (
115
+ k_zero_points , self .k_cache_zero_points , start_pos , indices
116
+ )
117
+ _ = torch .ops .llama .update_cache_with_indices (
118
+ quantized_v_val , self .v_cache , start_pos , indices
119
+ )
120
+ _ = torch .ops .llama .update_cache_with_indices (
121
+ v_scales , self .v_cache_scales , start_pos , indices
122
+ )
123
+ _ = torch .ops .llama .update_cache_with_indices (
124
+ v_zero_points , self .v_cache_zero_points , start_pos , indices
125
+ )
126
+ else :
127
+ _ = torch .ops .llama .update_cache (
128
+ quantized_k_val , self .k_cache , start_pos
129
+ )
130
+ _ = torch .ops .llama .update_cache (
131
+ k_scales , self .k_cache_scales , start_pos
132
+ )
133
+ _ = torch .ops .llama .update_cache (
134
+ k_zero_points , self .k_cache_zero_points , start_pos
135
+ )
136
+ _ = torch .ops .llama .update_cache (
137
+ quantized_v_val , self .v_cache , start_pos
138
+ )
139
+ _ = torch .ops .llama .update_cache (
140
+ v_scales , self .v_cache_scales , start_pos
141
+ )
142
+ _ = torch .ops .llama .update_cache (
143
+ v_zero_points , self .v_cache_zero_points , start_pos
144
+ )
117
145
else :
146
+ assert indices is None , "Indices not supported for this path"
147
+ # Following is also broken because in prefill input_pos = [0]
148
+ # but we need to update some slice of cache
118
149
self .k_cache [:, input_pos ] = quantized_k_val
119
150
self .k_cache_scales [:, input_pos ] = k_scales
120
151
self .k_cache_zero_points [:, input_pos ] = k_zero_points
121
152
self .v_cache [:, input_pos ] = quantized_v_val
122
153
self .v_cache_scales [:, input_pos ] = v_scales
123
154
self .v_cache_zero_points [:, input_pos ] = v_zero_points
124
155
125
- def _update_and_return_float_values (self , input_pos , k_val , v_val ):
126
- self ._quantize_and_update (input_pos , k_val , v_val )
156
+ def _update_and_return_float_values (self , input_pos , k_val , v_val , indices = None ):
157
+ self ._quantize_and_update (input_pos , k_val , v_val , indices )
127
158
128
159
k_out = torch .ops .quantized_decomposed .dequantize_per_token (
129
160
self .k_cache ,
@@ -144,24 +175,34 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val):
144
175
self .cache_fp_type ,
145
176
)
146
177
147
- # When returning float values we jsut use the last value
178
+ # When returning float values we just use the last value
148
179
# instead of dequantized value.
149
180
start_pos = input_pos [0 ].item ()
150
181
if self .use_custom_update_cache_op :
151
- _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
152
- _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
182
+ if indices is not None :
183
+ _ = torch .ops .llama .update_cache_with_indices (
184
+ k_val , k_out , start_pos , indices
185
+ )
186
+ _ = torch .ops .llama .update_cache_with_indices (
187
+ v_val , v_out , start_pos , indices
188
+ )
189
+ else :
190
+ _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
191
+ _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
153
192
else :
154
193
k_out [:, input_pos ] = k_val
155
194
v_out [:, input_pos ] = v_val
156
195
157
196
return k_out , v_out
158
197
159
- def _update_and_return_quantized_values (self , input_pos , k_val , v_val ):
160
- self ._quantize_and_update (input_pos , k_val , v_val )
198
+ def _update_and_return_quantized_values (
199
+ self , input_pos , k_val , v_val , indices = None
200
+ ):
201
+ self ._quantize_and_update (input_pos , k_val , v_val , indices )
161
202
162
203
return self .k_cache , self .v_cache
163
204
164
- def update (self , input_pos , k_val , v_val ):
205
+ def update (self , input_pos , k_val , v_val , indices = None ):
165
206
"""
166
207
k_val, v_val: [B, H, S, D]
167
208
return: [B, H, S, D]
@@ -172,10 +213,12 @@ def update(self, input_pos, k_val, v_val):
172
213
v_val = v_val .transpose (1 , 2 )
173
214
174
215
if self .return_float_values :
175
- k_out , v_out = self ._update_and_return_float_values (input_pos , k_val , v_val )
216
+ k_out , v_out = self ._update_and_return_float_values (
217
+ input_pos , k_val , v_val , indices
218
+ )
176
219
else :
177
220
k_out , v_out = self ._update_and_return_quantized_values (
178
- input_pos , k_val , v_val
221
+ input_pos , k_val , v_val , indices
179
222
)
180
223
return k_out .transpose (1 , 2 ), v_out .transpose (1 , 2 )
181
224
@@ -277,14 +320,28 @@ def __init__(
277
320
)
278
321
279
322
def update (
280
- self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
323
+ self ,
324
+ input_pos : torch .Tensor ,
325
+ k_val : torch .Tensor ,
326
+ v_val : torch .Tensor ,
327
+ indices : Optional [torch .Tensor ] = None ,
281
328
) -> Tuple [torch .Tensor , torch .Tensor ]:
282
329
# input_pos: [S], k_val: [B, H, S, D]
283
330
k_val = k_val .transpose (1 , 2 )
284
331
v_val = v_val .transpose (1 , 2 )
285
332
start_pos = input_pos [0 ].item ()
286
- _ = torch .ops .llama .update_cache (k_val , self .k_cache , start_pos )
287
- _ = torch .ops .llama .update_cache (v_val , self .v_cache , start_pos )
333
+
334
+ if indices is not None :
335
+ _ = torch .ops .llama .update_cache_with_indices (
336
+ k_val , self .k_cache , start_pos , indices
337
+ )
338
+ _ = torch .ops .llama .update_cache_with_indices (
339
+ v_val , self .v_cache , start_pos , indices
340
+ )
341
+ else :
342
+ _ = torch .ops .llama .update_cache (k_val , self .k_cache , start_pos )
343
+ _ = torch .ops .llama .update_cache (v_val , self .v_cache , start_pos )
344
+
288
345
return (
289
346
self .k_cache .transpose (1 , 2 ),
290
347
self .v_cache .transpose (1 , 2 ),
0 commit comments