@@ -131,8 +131,10 @@ def __init__(self, cpu_block_allocator: PrefixCachingBlockAllocator,
131
131
self .num_gpu_blocks = gpu_block_allocator .get_num_total_blocks ()
132
132
self .num_cpu_blocks = cpu_block_allocator .get_num_total_blocks ()
133
133
134
- def allocate_mutable_block (self , prev_block : Optional [Block ],
135
- device : Device ) -> Block :
134
+ def allocate_mutable_block (self ,
135
+ prev_block : Optional [Block ],
136
+ device : Device ,
137
+ extra_hash : Optional [int ] = None ) -> Block :
136
138
"""Allocates a new mutable block on the specified device.
137
139
138
140
Args:
@@ -148,13 +150,17 @@ def allocate_mutable_block(self, prev_block: Optional[Block],
148
150
"handles CPU offloading internally." \
149
151
# mark this block as uncached
150
152
151
- block = self ._allocators [device ].allocate_mutable_block (prev_block )
153
+ block = self ._allocators [device ].allocate_mutable_block (
154
+ prev_block , extra_hash = extra_hash )
152
155
self ._uncached_blocks .append (block )
153
156
return block
154
157
155
- def allocate_immutable_blocks (self , prev_block : Optional [Block ],
156
- block_token_ids : List [List [int ]],
157
- device : Device ) -> List [Block ]:
158
+ def allocate_immutable_blocks (
159
+ self ,
160
+ prev_block : Optional [Block ],
161
+ block_token_ids : List [List [int ]],
162
+ device : Device ,
163
+ extra_hash : Optional [int ] = None ) -> List [Block ]:
158
164
"""Allocates a new group of immutable blocks with the provided block
159
165
token IDs on the specified device.
160
166
@@ -179,13 +185,16 @@ def allocate_immutable_blocks(self, prev_block: Optional[Block],
179
185
for token_ids in block_token_ids :
180
186
prev_block = self .allocate_immutable_block (prev_block = prev_block ,
181
187
token_ids = token_ids ,
182
- device = device )
188
+ device = device ,
189
+ extra_hash = extra_hash )
183
190
blocks .append (prev_block )
184
191
return blocks
185
192
186
- def allocate_immutable_block (self , prev_block : Optional [Block ],
193
+ def allocate_immutable_block (self ,
194
+ prev_block : Optional [Block ],
187
195
token_ids : List [int ],
188
- device : Device ) -> Block :
196
+ device : Device ,
197
+ extra_hash : Optional [int ] = None ) -> Block :
189
198
"""Allocates a new immutable block with the provided token IDs on the
190
199
specified device.
191
200
@@ -207,7 +216,7 @@ def allocate_immutable_block(self, prev_block: Optional[Block],
207
216
208
217
# allocate a GPU block
209
218
block = self ._allocators [device ].allocate_immutable_block (
210
- prev_block , token_ids )
219
+ prev_block , token_ids , extra_hash = extra_hash )
211
220
block_id = block .block_id
212
221
assert block_id is not None
213
222
block_computed = self ._allocators [device ].block_is_computed (block_id )
@@ -222,7 +231,7 @@ def allocate_immutable_block(self, prev_block: Optional[Block],
222
231
else :
223
232
# check if we can hit cache on CPU by trying to allocate CPU block
224
233
cpu_block = self ._allocators [Device .CPU ].allocate_immutable_block (
225
- prev_block , token_ids )
234
+ prev_block , token_ids , extra_hash = extra_hash )
226
235
cpu_block_id = cpu_block .block_id
227
236
assert cpu_block_id is not None
228
237
cpu_block_computed = self ._allocators [
@@ -329,7 +338,10 @@ def get_and_reset_swaps(self,
329
338
if computed : # This block is computed, copy it to CPU
330
339
# allocate a block on CPU
331
340
cpu_block = cpu_allocator .allocate_immutable_block (
332
- prev_block = block .prev_block , token_ids = block .token_ids )
341
+ prev_block = block .prev_block ,
342
+ token_ids = block .token_ids ,
343
+ extra_hash = block .extra_hash ,
344
+ )
333
345
assert cpu_block .block_id is not None
334
346
self ._allocated_cpu_blocks .append (cpu_block )
335
347
0 commit comments