|
1 | 1 | from abc import ABC, abstractmethod
|
| 2 | +from enum import Enum |
2 | 3 | from typing import Any, Dict, Optional, Tuple, Type, TypedDict
|
3 | 4 |
|
4 | 5 | import torch
|
@@ -160,6 +161,112 @@ def forward(
|
160 | 161 | return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
161 | 162 |
|
162 | 163 |
|
| 164 | +class CacheUpdateStrategy(Enum): |
| 165 | + RING_BUFFER = "RingBuffer" |
| 166 | + INVALID = "Invalid" |
| 167 | + |
| 168 | + |
| 169 | +class CachePositionsManager(nn.Module): |
| 170 | + def __init__( |
| 171 | + self, |
| 172 | + max_context_length: int, |
| 173 | + cache_update_strategy: CacheUpdateStrategy = CacheUpdateStrategy.RING_BUFFER, |
| 174 | + ): |
| 175 | + super().__init__() |
| 176 | + assert ( |
| 177 | + cache_update_strategy == CacheUpdateStrategy.RING_BUFFER |
| 178 | + ), "Only RingBuffer is supported" |
| 179 | + self.max_context_length = max_context_length |
| 180 | + self.register_buffer( |
| 181 | + "cache_positions", |
| 182 | + torch.zeros((self.max_context_length), dtype=torch.long, device="cpu"), |
| 183 | + ) |
| 184 | + |
| 185 | + def calculate_positions_and_update_indices(self, input_pos: torch.Tensor, seq_len): |
| 186 | + """ |
| 187 | + Calculate indices, into k_cache, v_cache, where to put k_val tensor. |
| 188 | + Given the input_pos and length of k_val at sequence dim, the input pos may |
| 189 | + have to wrap around if it is smaller than the cache capacity. |
| 190 | + If it is larger than the cache capacity then just pick the last |
| 191 | + self.max_context_length entries. |
| 192 | +
|
| 193 | + Additionally: |
| 194 | + Update the cache positions buffer with the new indices. |
| 195 | + Given the cache positions in sequence dim, indicated by indices, |
| 196 | + we can just update cache_positions buffer using orig_indices. |
| 197 | + For example |
| 198 | + Given cache capacity of 4 and update of length 3 with start_pos = 2 |
| 199 | + will have following values |
| 200 | + indices = [2, 3, 0] |
| 201 | + orig_indices = [2, 3, 4] |
| 202 | + So cache_positions after the update will be [4, 1, 2, 3] |
| 203 | + Note cache_positions[1] = 1 that is from previous write to the cache. |
| 204 | + The corner case here is cache positions before cache rolls over. |
| 205 | + For example when start_pos = 0 and update is of length 2, then we have |
| 206 | + filled positions 0 and 1 in the buffer, while the rest are invalid. In this case |
| 207 | + we have |
| 208 | + indices = [0, 1] |
| 209 | + orig_indices = [0, 1] |
| 210 | + But if we have cache_positins = [0, 1, 0, 0] that is not valid. Hence we have |
| 211 | + to make sure that invalid positions have a sentinel value of - 1. |
| 212 | + """ |
| 213 | + start_pos = input_pos[0].item() |
| 214 | + torch._check_is_size(start_pos) |
| 215 | + orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos |
| 216 | + indices = orig_indices % self.max_context_length |
| 217 | + |
| 218 | + full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) |
| 219 | + arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) |
| 220 | + cache_positions = torch.where( |
| 221 | + arange_tensor < start_pos, self.cache_positions, full_t |
| 222 | + ) |
| 223 | + self.cache_positions.copy_(cache_positions) |
| 224 | + self.cache_positions.index_copy_(0, indices, orig_indices) |
| 225 | + |
| 226 | + return indices |
| 227 | + |
| 228 | + |
| 229 | +class RingKVCache(KVCache): |
| 230 | + def __init__( |
| 231 | + self, |
| 232 | + max_batch_size: int, |
| 233 | + max_context_length: int, |
| 234 | + n_heads: int, |
| 235 | + head_dim: int, |
| 236 | + enable_dynamic_shape: bool, |
| 237 | + dtype=torch.float32, |
| 238 | + ): |
| 239 | + super().__init__( |
| 240 | + max_batch_size, |
| 241 | + max_context_length, |
| 242 | + n_heads, |
| 243 | + head_dim, |
| 244 | + enable_dynamic_shape, |
| 245 | + dtype, |
| 246 | + ) |
| 247 | + self.cache_positions_manager = CachePositionsManager(max_context_length) |
| 248 | + |
| 249 | + def update( |
| 250 | + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor |
| 251 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 252 | + # input_pos: [S], k_val: [B, H, S, D] |
| 253 | + seq_len = k_val.size(2) |
| 254 | + indices = self.cache_positions_manager.calculate_positions_and_update_indices( |
| 255 | + input_pos, seq_len |
| 256 | + ) |
| 257 | + if self.enable_dynamic_shape: |
| 258 | + start_pos = input_pos[0].item() |
| 259 | + torch._check_is_size(start_pos) |
| 260 | + |
| 261 | + self.k_cache.index_copy_(2, indices, k_val) |
| 262 | + self.v_cache.index_copy_(2, indices, v_val) |
| 263 | + else: |
| 264 | + self.k_cache[:, :, indices] = k_val |
| 265 | + self.v_cache[:, :, indices] = v_val |
| 266 | + |
| 267 | + return self.k_cache, self.v_cache |
| 268 | + |
| 269 | + |
163 | 270 | @register_attention("mha")
|
164 | 271 | class AttentionMHA(Attention):
|
165 | 272 | def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
|
|
0 commit comments