@@ -241,6 +241,7 @@ def __init__(
241241 self .scheduler_config = scheduler_config
242242 self .lora_config = lora_config
243243 self .load_config = load_config
244+ self .cache_config = cache_config
244245 self .is_driver_worker = is_driver_worker
245246 self .profiler = Profiler ()
246247
@@ -267,6 +268,9 @@ def __init__(
267268 self .lora_manager : LRUCacheWorkerLoRAManager = None
268269 self .model : torch .nn .Module = None
269270
271+ # Profiler stats
272+ self .profiler_counter_helper = HabanaProfilerCounterHelper ()
273+
270274 self ._setup_buckets ()
271275
272276 def load_model (self ) -> None :
@@ -876,19 +880,18 @@ def execute_model(
876880 output .outputs = output .outputs [:real_batch_size ]
877881 htorch .core .mark_step ()
878882
879- if self .is_driver_worker :
883+ if self .is_driver_worker and self . profiler . enabled :
880884 # Stop recording 'execute_model' event
881885 self .profiler .end ()
882886 event_end = self .profiler .get_timestamp_us ()
883- duration = event_end - event_start
884- throughput = batch_size_padded / (duration / 1e6 )
885- throughput_effective = real_batch_size / (duration / 1e6 )
886- counters = {
887- 'batch_size' : batch_size_padded ,
888- 'batch_size_effective' : real_batch_size ,
889- 'throughput' : throughput ,
890- 'throughput_effective' : throughput_effective
891- }
887+ counters = self .profiler_counter_helper .get_counter_dict (
888+ cache_config = self .cache_config ,
889+ duration = event_end - event_start ,
890+ seq_len = seq_len ,
891+ batch_size_padded = batch_size_padded ,
892+ real_batch_size = real_batch_size ,
893+ seq_group_metadata_list = seq_group_metadata_list ,
894+ is_prompt = is_prompt )
892895 self .profiler .record_counter (event_start , counters )
893896
894897 return output
@@ -1014,3 +1017,62 @@ def vocab_size(self) -> int:
10141017
10151018def _maybe_wrap_in_hpu_graph (model ):
10161019 return htorch .hpu .wrap_in_hpu_graph (HpuModelAdapter (model )) if htorch .utils .internal .is_lazy () else HpuModelAdapter (model )
1020+
1021+
1022+ class HabanaProfilerCounterHelper ():
1023+ def __init__ (self ):
1024+ self .niter = 0
1025+ self .average_real_throughput = None
1026+ self .logged_once = False
1027+
1028+ def get_counter_dict (self , cache_config , duration , seq_len , batch_size_padded , real_batch_size , seq_group_metadata_list , is_prompt ):
1029+ throughput = batch_size_padded / (duration / 1e6 )
1030+ throughput_effective = real_batch_size / (duration / 1e6 )
1031+ real_seq_lens = [len (seq_data .prompt_token_ids ) + len (seq_data .output_token_ids ) for seq_group_metadata in seq_group_metadata_list for seq_data in seq_group_metadata .seq_data .values ()]
1032+ real_max_seq_len = max (real_seq_lens )
1033+ real_num_tokens = sum (real_seq_lens )
1034+ padded_num_tokens = batch_size_padded * seq_len
1035+ batch_token_utilization = real_num_tokens / padded_num_tokens
1036+ if self .average_real_throughput is None :
1037+ self .average_real_throughput = throughput_effective
1038+ else : # https://www.heikohoffmann.de/htmlthesis/node134.html
1039+ self .average_real_throughput = self .average_real_throughput + 1 / (self .niter + 1 ) * (throughput_effective - self .average_real_throughput )
1040+ phase = "prompt" if is_prompt else "decode"
1041+ counters = {
1042+ f'{ phase } _bucket_batch_size' : batch_size_padded ,
1043+ f'{ phase } _batch_size' : real_batch_size ,
1044+ f'{ phase } _bucket_seq_len' : seq_len ,
1045+ f'{ phase } _seq_len' : real_max_seq_len ,
1046+ f'{ phase } _bucket_gen_throughput' : throughput ,
1047+ f'{ phase } _real_gen_throughput' : throughput_effective ,
1048+ f'{ phase } _batch_token_utilization' : batch_token_utilization ,
1049+ 'average_real_throughput' : self .average_real_throughput ,
1050+ 'engine_iteration' : self .niter ,
1051+ }
1052+ self .niter += 1
1053+ if is_prompt :
1054+ prompt_seq_lens = [len (seq_data .prompt_token_ids ) for seq_group_metadata in seq_group_metadata_list for seq_data in seq_group_metadata .seq_data .values ()]
1055+ prompt_bucket_in_throughput = (seq_len * batch_size_padded ) / (duration / 1e6 )
1056+ prompt_real_in_throughput = sum (prompt_seq_lens ) / (duration / 1e6 )
1057+ counters [f'{ phase } _bucket_in_throughput' ] = prompt_bucket_in_throughput
1058+ counters [f'{ phase } _real_in_throughput' ] = prompt_real_in_throughput
1059+
1060+ # KV cache might not be created yet (e.g. for profiling run)
1061+ if cache_config .num_gpu_blocks is not None and cache_config .num_gpu_blocks != 0 :
1062+ cache_num_blocks_used = [math .ceil (sl / cache_config .block_size ) for sl in real_seq_lens ]
1063+ cache_total_num_blocks_used = sum (cache_num_blocks_used )
1064+ num_cache_blocks = cache_config .num_gpu_blocks
1065+ cache_total_num_free_blocks = num_cache_blocks - cache_total_num_blocks_used
1066+ cache_computed_utilization = cache_total_num_blocks_used / num_cache_blocks
1067+ max_blocks_per_seq = math .ceil (seq_len / cache_config .block_size )
1068+ batch_block_utilization = cache_total_num_blocks_used / (batch_size_padded * max_blocks_per_seq )
1069+ counters ['cache_num_blocks_used' ] = cache_total_num_blocks_used
1070+ counters ['cache_num_free_blocks' ] = cache_total_num_free_blocks
1071+ counters ['cache_computed_utilization' ] = cache_computed_utilization
1072+ counters [f'{ phase } _batch_block_utilization' ] = batch_block_utilization
1073+ if not self .logged_once :
1074+ counters ['const_cache_num_blocks' ] = cache_config .num_gpu_blocks
1075+ counters ['const_gpu_memory_utilization' ] = cache_config .gpu_memory_utilization
1076+ counters ['const_block_size' ] = cache_config .block_size
1077+ self .logged_once = True
1078+ return counters
0 commit comments