@@ -361,3 +361,208 @@ def test_sliding_window_multi_seq():
361361
362362 # assert all blocks are free now
363363 assert block_manager .get_num_free_gpu_blocks () == num_gpu_blocks
364+
365+
366+ def test_seq_cached_blocks_num ():
367+ # Initialize the block manager
368+ block_size = 16
369+ num_gpu_blocks = 64
370+ num_cpu_blocks = 32
371+ block_manager = BlockSpaceManager (block_size ,
372+ num_gpu_blocks ,
373+ num_cpu_blocks ,
374+ enable_caching = True )
375+ assert block_manager .get_num_free_gpu_blocks () == num_gpu_blocks
376+ assert block_manager .get_num_free_cpu_blocks () == num_cpu_blocks
377+
378+ seq_prompt_length = 64
379+ seq = Sequence (seq_id = 0 ,
380+ prompt = "zero to sixty three" ,
381+ block_size = block_size ,
382+ prompt_token_ids = list (range (seq_prompt_length )))
383+ seq_group = SequenceGroup (request_id = 0 ,
384+ seqs = [seq ],
385+ sampling_params = SamplingParams (),
386+ arrival_time = time .time ())
387+ block_manager .allocate (seq_group )
388+
389+ seq_num_cached_blocks = block_manager .get_num_cached_blocks (seq )
390+ assert seq_num_cached_blocks == 4
391+ # 64 - 4 = 60
392+ assert block_manager .get_num_free_gpu_blocks () == 60
393+
394+ seq1_prompt_length = 32
395+ seq1 = Sequence (seq_id = 1 ,
396+ prompt = "zero to thirty one" ,
397+ block_size = block_size ,
398+ prompt_token_ids = list (range (seq1_prompt_length )))
399+ seq1_num_cached_blocks = block_manager .get_num_cached_blocks (seq1 )
400+ assert seq1_num_cached_blocks == 2
401+ seq1_group = SequenceGroup (request_id = 1 ,
402+ seqs = [seq1 ],
403+ sampling_params = SamplingParams (),
404+ arrival_time = time .time ())
405+ block_manager .allocate (seq1_group )
406+ # 64 - 4 - (2 - 2) = 60
407+ assert block_manager .get_num_free_gpu_blocks () == 60
408+
409+ seq2_prompt_length = 47
410+ seq2 = Sequence (seq_id = 2 ,
411+ prompt = "zero to forty six" ,
412+ block_size = block_size ,
413+ prompt_token_ids = list (range (seq2_prompt_length )))
414+ seq2_num_cached_blocks = block_manager .get_num_cached_blocks (seq2 )
415+ assert seq2_num_cached_blocks == 2
416+ seq2_group = SequenceGroup (request_id = 2 ,
417+ seqs = [seq2 ],
418+ sampling_params = SamplingParams (),
419+ arrival_time = time .time ())
420+ block_manager .allocate (seq2_group )
421+ # 64 - 4 - (2 - 2) - (3 - 2) = 59
422+ assert block_manager .get_num_free_gpu_blocks () == 59
423+
424+ seq3_prompt_length = 96
425+ seq3 = Sequence (seq_id = 3 ,
426+ prompt = "zero to ninety five" ,
427+ block_size = block_size ,
428+ prompt_token_ids = list (range (seq3_prompt_length )))
429+ seq3_num_cached_blocks = block_manager .get_num_cached_blocks (seq3 )
430+ assert seq3_num_cached_blocks == 4
431+ seq3_group = SequenceGroup (request_id = 3 ,
432+ seqs = [seq3 ],
433+ sampling_params = SamplingParams (),
434+ arrival_time = time .time ())
435+ block_manager .allocate (seq3_group )
436+ # 64 - 4 - (2 - 2) - (3 - 2) - (6 - 4) = 57
437+ assert block_manager .get_num_free_gpu_blocks () == 57
438+
439+
440+ def test_seq_computed_blocks_num ():
441+ # Initialize the block manager
442+ block_size = 16
443+ num_gpu_blocks = 64
444+ num_cpu_blocks = 32
445+ block_manager = BlockSpaceManager (block_size ,
446+ num_gpu_blocks ,
447+ num_cpu_blocks ,
448+ enable_caching = True )
449+ assert block_manager .get_num_free_gpu_blocks () == num_gpu_blocks
450+ assert block_manager .get_num_free_cpu_blocks () == num_cpu_blocks
451+
452+ seq_prompt_length = 64
453+ seq = Sequence (seq_id = 0 ,
454+ prompt = "zero to sixty three" ,
455+ block_size = block_size ,
456+ prompt_token_ids = list (range (seq_prompt_length )))
457+ seq_group = SequenceGroup (request_id = 0 ,
458+ seqs = [seq ],
459+ sampling_params = SamplingParams (),
460+ arrival_time = time .time ())
461+ block_manager .allocate (seq_group )
462+ block_manager .mark_blocks_as_computed (seq_group )
463+
464+ seq_num_computed_blocks = block_manager .get_num_computed_blocks (seq )
465+ assert seq_num_computed_blocks == 3
466+ # Ensure the computed blocks number aligns with the real allocation behavior
467+ seq_computed_blocks = block_manager .get_all_computed_blocks (seq )
468+ assert seq_num_computed_blocks == len (seq_computed_blocks )
469+
470+ seq1_prompt_length = 48
471+ seq1 = Sequence (seq_id = 1 ,
472+ prompt = "zero to forty seven" ,
473+ block_size = block_size ,
474+ prompt_token_ids = list (range (seq1_prompt_length )))
475+ seq1_num_computed_blocks = block_manager .get_num_computed_blocks (seq1 )
476+ assert seq1_num_computed_blocks == 2
477+ seq1_group = SequenceGroup (request_id = 1 ,
478+ seqs = [seq1 ],
479+ sampling_params = SamplingParams (),
480+ arrival_time = time .time ())
481+ # Ensure the computed blocks number aligns with the real allocation behavior
482+ block_manager .allocate (seq1_group )
483+ block_manager .mark_blocks_as_computed (seq1_group )
484+ seq1_computed_blocks = block_manager .get_all_computed_blocks (seq1 )
485+ assert seq1_num_computed_blocks == len (seq1_computed_blocks )
486+
487+ seq2_prompt_length = 55
488+ seq2 = Sequence (seq_id = 2 ,
489+ prompt = "zero to fifty four" ,
490+ block_size = block_size ,
491+ prompt_token_ids = list (range (seq2_prompt_length )))
492+ seq2_num_computed_blocks = block_manager .get_num_computed_blocks (seq2 )
493+ assert seq2_num_computed_blocks == 3
494+ seq2_group = SequenceGroup (request_id = 1 ,
495+ seqs = [seq2 ],
496+ sampling_params = SamplingParams (),
497+ arrival_time = time .time ())
498+ # Ensure the computed blocks number aligns with the real allocation behavior
499+ block_manager .allocate (seq2_group )
500+ block_manager .mark_blocks_as_computed (seq2_group )
501+ seq2_computed_blocks = block_manager .get_all_computed_blocks (seq2 )
502+ assert seq2_num_computed_blocks == len (seq2_computed_blocks )
503+
504+ seq3_prompt_length = 81
505+ seq3 = Sequence (seq_id = 3 ,
506+ prompt = "zero to eighty" ,
507+ block_size = block_size ,
508+ prompt_token_ids = list (range (seq3_prompt_length )))
509+ seq3_num_computed_blocks = block_manager .get_num_computed_blocks (seq3 )
510+ assert seq3_num_computed_blocks == 4
511+ seq3_group = SequenceGroup (request_id = 3 ,
512+ seqs = [seq3 ],
513+ sampling_params = SamplingParams (),
514+ arrival_time = time .time ())
515+ # Ensure the computed blocks number aligns with the real allocation behavior
516+ block_manager .allocate (seq3_group )
517+ block_manager .mark_blocks_as_computed (seq3_group )
518+ seq3_computed_blocks = block_manager .get_all_computed_blocks (seq3 )
519+ assert seq3_num_computed_blocks == len (seq3_computed_blocks )
520+
521+ # Test the computed blocks num after the sequences are freed
522+ # Free operation doesn't influence the computed blocks number
523+ block_manager .free (seq )
524+ block_manager .free (seq1 )
525+ block_manager .free (seq2 )
526+ block_manager .free (seq3 )
527+ seq_num_computed_blocks = block_manager .get_num_computed_blocks (seq )
528+ assert seq_num_computed_blocks == 3
529+ seq1_num_computed_blocks = block_manager .get_num_computed_blocks (seq1 )
530+ assert seq1_num_computed_blocks == 2
531+ seq2_num_computed_blocks = block_manager .get_num_computed_blocks (seq2 )
532+ assert seq2_num_computed_blocks == 3
533+ seq3_num_computed_blocks = block_manager .get_num_computed_blocks (seq3 )
534+ assert seq3_num_computed_blocks == 4
535+
536+ # Test the computed blocks num
537+ # after the second block(token 15~31) are evicted
538+ # Since the second block is evicted, the caches are not continuous
539+ # from the second block. Therefore, all seqs' computed blocks numbers
540+ # are reduced to 1.
541+ evicted_block_hash = seq .hash_of_block (1 )
542+ evicted_block = block_manager .gpu_allocator .evictor .remove (
543+ evicted_block_hash )
544+ seq_num_computed_blocks = block_manager .get_num_computed_blocks (seq )
545+ assert seq_num_computed_blocks == 1
546+ seq1_num_computed_blocks = block_manager .get_num_computed_blocks (seq1 )
547+ assert seq1_num_computed_blocks == 1
548+ seq2_num_computed_blocks = block_manager .get_num_computed_blocks (seq2 )
549+ assert seq2_num_computed_blocks == 1
550+ seq3_num_computed_blocks = block_manager .get_num_computed_blocks (seq3 )
551+ assert seq3_num_computed_blocks == 1
552+
553+ # Test the computed blocks num
554+ # after the second block(token 15~31) are marked as not computed
555+ # Since the second block is marked as not computed, the caches are not
556+ # continuous from the second block. Therefore, all seqs' computed blocks
557+ # numbers are reduced to 1.
558+ evicted_block .computed = False
559+ block_manager .gpu_allocator .cached_blocks [
560+ evicted_block .block_hash ] = evicted_block
561+ seq_num_computed_blocks = block_manager .get_num_computed_blocks (seq )
562+ assert seq_num_computed_blocks == 1
563+ seq1_num_computed_blocks = block_manager .get_num_computed_blocks (seq1 )
564+ assert seq1_num_computed_blocks == 1
565+ seq2_num_computed_blocks = block_manager .get_num_computed_blocks (seq2 )
566+ assert seq2_num_computed_blocks == 1
567+ seq3_num_computed_blocks = block_manager .get_num_computed_blocks (seq3 )
568+ assert seq3_num_computed_blocks == 1
0 commit comments