From 18840781581aaf34be83fb4c24e6da0e319afdaf Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Fri, 27 Sep 2024 22:04:12 +0300 Subject: [PATCH 1/7] Draft: Add bucket calibration, allow reading/writing bucketing configs to file --- vllm/worker/habana_model_runner.py | 90 +++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 25 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index bfbe4085ddd3..fa0ea92bb970 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -80,7 +80,7 @@ def subtuple(obj: object, return _TYPE_CACHE[typename](**values) -def read_bucket_settings(phase: str, dim: str, **defaults): +def read_bucket_settings(phase: str, dim: str, from_file=False, **defaults): """Read bucketing configuration from env variables. phase is either 'prompt' or 'decode' @@ -94,8 +94,9 @@ def read_bucket_settings(phase: str, dim: str, **defaults): values = [ int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) ] + label = 'default' if not from_file else 'VLLM_HPU_BUCKET_CFG' for e, v, d in zip(env_vars, values, default_values): - logger.info('%s=%s (default:%s)', e, v, d) + logger.info('%s=%s (%s:%s)', e, v, label, d) return values @@ -562,6 +563,7 @@ def __init__( self.lora_manager: LRUCacheWorkerLoRAManager = None self.model: torch.nn.Module = None self.inc_initialized_successfully = False + self.loaded_bucketing_config_from_file = False # Profiler stats self.profiler_counter_helper = HabanaProfilerCounterHelper() @@ -683,29 +685,48 @@ def _setup_buckets(self) -> None: #FIXME: The default values should be max_model_len max_prompt_seq = 1024 max_decode_seq = 2048 + + DefaultBucketConfig = collections.namedtuple('DefaultBucketConfig', ['min','step','max','from_file']) + prompt_bs_bucket_cfg_defaults = DefaultBucketConfig(min=1,max=align_bs(32),step=align_bs(max_bucket_cfg), from_file=False) + prompt_seq_bucket_cfg_defaults = DefaultBucketConfig(min=self.block_size,max=self.block_size,step=max_prompt_seq, from_file=False) + decode_bs_bucket_cfg_defaults = DefaultBucketConfig(min=1,max=align_bs(32),step=align_bs(max_bucket_cfg), from_file=False) + decode_block_bucket_cfg_defaults = DefaultBucketConfig(min=self.block_size, step=self.block_size, max=max(self.block_size,self.max_num_seqs * max_decode_seq // self.block_size), from_file=False) + + bucket_cfg_file = os.environ.get('VLLM_HPU_BUCKET_CFG', None) + if bucket_cfg_file is not None: + import pandas as pd + import yaml + try: + with open(bucket_cfg_file, 'r') as f: + data = yaml.safe_load(f) + prompt_bs_bucket_cfg_defaults = DefaultBucketConfig(*data['bucket_cfg']['prompt_bs_bucket_cfg'], from_file=True) + prompt_seq_bucket_cfg_defaults = DefaultBucketConfig(*data['bucket_cfg']['prompt_seq_bucket_cfg'], from_file=True) + decode_bs_bucket_cfg_defaults = DefaultBucketConfig(*data['bucket_cfg']['decode_bs_bucket_cfg'], from_file=True) + decode_block_bucket_cfg_defaults = DefaultBucketConfig(*data['bucket_cfg']['decode_block_bucket_cfg'], from_file=True) + df = pd.DataFrame.from_dict(data['records']) + self.prompt_buckets = df[df['is_prefill'] == True][['batch_size','seq_or_block']].values.tolist() + self.decode_buckets = df[df['is_prefill'] == False][['batch_size','seq_or_block']].values.tolist() + self.loaded_bucketing_config_from_file = True + except (FileNotFoundError, IOError, PermissionError): + msg = "Could not open file specified in VLLM_HPU_BUCKET_CFG: {bucket_cfg_file}. Falling back to default config." + logger.error(msg) + self.prompt_bs_bucket_cfg = read_bucket_settings( 'prompt', 'bs', - min=1, - step=align_bs(32), - max=align_bs(max_bucket_cfg)) + **prompt_bs_bucket_cfg_defaults._asdict()) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', - min=1, - step=align_bs(32), - max=self.max_num_seqs) + **decode_bs_bucket_cfg_defaults._asdict()) self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) + **prompt_seq_bucket_cfg_defaults._asdict()) self.decode_block_bucket_cfg = read_bucket_settings( 'decode', - 'block', - min=self.block_size, - step=self.block_size, - max=max(self.block_size, - self.max_num_seqs * max_decode_seq // self.block_size)) + 'block', + **decode_block_bucket_cfg_defaults._asdict() + ) + self.graphed_buckets: Set[Any] = set() msg = ("Prompt bucket config (min, step, max_warmup) " @@ -1536,10 +1557,11 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: return self.profiler.start('internal', 'warmup') max_blocks = kv_caches[0][0].size(0) - - self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( - self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) + prompt_omitted_buckets = [] + if not self.loaded_bucketing_config_from_file: + self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.prompt_buckets[:] = [ bucket for bucket in self.prompt_buckets @@ -1552,16 +1574,17 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info(msg) msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") logger.info(msg) msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" logger.debug(msg) - self.decode_buckets = generate_decode_buckets( - self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, - max_blocks) + if not self.loaded_bucketing_config_from_file: + self.decode_buckets = generate_decode_buckets( + self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, + max_blocks) if self.lora_config: self.decode_buckets[:] = [ bucket for bucket in self.decode_buckets @@ -1976,4 +1999,21 @@ def shutdown_inc(self): self._is_inc_finalized = True def __del__(self): + calibrate_buckets = os.environ.get('VLLM_HPU_CALIBRATE_BUCKETS', 'false') in ['true', '1'] + if calibrate_buckets: + import pandas as pd + import yaml + df = pd.DataFrame(self.seen_configs, columns=['batch_size', 'seq_or_block', 'is_prefill']).sort_values(['is_prefill', 'batch_size','seq_or_block'], ascending=False) + data = {} + data['buckets'] = df.to_dict(orient='records') + data['bucket_cfg'] = { + 'prompt_bs_bucket_cfg': self.prompt_bs_bucket_cfg, + 'prompt_seq_bucket_cfg': self.prompt_seq_bucket_cfg, + 'decode_bs_bucket_cfg': self.decode_bs_bucket_cfg, + 'decode_block_bucket_cfg': self.decode_block_bucket_cfg, + } + with open('data.yml', 'w') as outfile: + yaml.dump(data, outfile, default_flow_style=False) + self.shutdown_inc() + From 699e1061d5c2db8f4442c10bdb0ad6751efa6fd3 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 30 Sep 2024 13:55:28 +0300 Subject: [PATCH 2/7] refine overall implementation --- vllm/worker/habana_model_runner.py | 211 ++++++++++++++++++++--------- 1 file changed, 144 insertions(+), 67 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index fa0ea92bb970..9de67912babb 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -42,7 +42,7 @@ from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.utils import (is_fake_hpu, is_pin_memory_available, - make_tensor_with_pad) + make_tensor_with_pad, get_vllm_instance_id) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -563,7 +563,15 @@ def __init__( self.lora_manager: LRUCacheWorkerLoRAManager = None self.model: torch.nn.Module = None self.inc_initialized_successfully = False - self.loaded_bucketing_config_from_file = False + self.calibrate_buckets = os.environ.get('VLLM_HPU_CALIBRATE_BUCKETS', + 'false') in ['true', '1'] + self.bucket_cfg_file = os.environ.get('VLLM_HPU_BUCKET_CFG', None) + # Set default filename only if bucket calibration is enabled + if self.calibrate_buckets and self.bucket_cfg_file is None: + vllm_instance_id = get_vllm_instance_id() + self.bucket_cfg_file = f'hpu-buckets-{vllm_instance_id}.yaml' + msg = f"Calibration results will be saved to {self.bucket_cfg_file}" + logger.info(msg) # Profiler stats self.profiler_counter_helper = HabanaProfilerCounterHelper() @@ -595,8 +603,9 @@ def _set_gc_threshold(self) -> None: self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ .create_input_mapper(self.model_config) - self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', - 'false').lower() == 'true' + self.skip_warmup = os.environ.get( + 'VLLM_SKIP_WARMUP', + 'false').lower() == 'true' or self.calibrate_buckets def load_model(self) -> None: import habana_frameworks.torch.core as htcore @@ -685,47 +694,60 @@ def _setup_buckets(self) -> None: #FIXME: The default values should be max_model_len max_prompt_seq = 1024 max_decode_seq = 2048 - - DefaultBucketConfig = collections.namedtuple('DefaultBucketConfig', ['min','step','max','from_file']) - prompt_bs_bucket_cfg_defaults = DefaultBucketConfig(min=1,max=align_bs(32),step=align_bs(max_bucket_cfg), from_file=False) - prompt_seq_bucket_cfg_defaults = DefaultBucketConfig(min=self.block_size,max=self.block_size,step=max_prompt_seq, from_file=False) - decode_bs_bucket_cfg_defaults = DefaultBucketConfig(min=1,max=align_bs(32),step=align_bs(max_bucket_cfg), from_file=False) - decode_block_bucket_cfg_defaults = DefaultBucketConfig(min=self.block_size, step=self.block_size, max=max(self.block_size,self.max_num_seqs * max_decode_seq // self.block_size), from_file=False) - - bucket_cfg_file = os.environ.get('VLLM_HPU_BUCKET_CFG', None) - if bucket_cfg_file is not None: - import pandas as pd - import yaml - try: - with open(bucket_cfg_file, 'r') as f: - data = yaml.safe_load(f) - prompt_bs_bucket_cfg_defaults = DefaultBucketConfig(*data['bucket_cfg']['prompt_bs_bucket_cfg'], from_file=True) - prompt_seq_bucket_cfg_defaults = DefaultBucketConfig(*data['bucket_cfg']['prompt_seq_bucket_cfg'], from_file=True) - decode_bs_bucket_cfg_defaults = DefaultBucketConfig(*data['bucket_cfg']['decode_bs_bucket_cfg'], from_file=True) - decode_block_bucket_cfg_defaults = DefaultBucketConfig(*data['bucket_cfg']['decode_block_bucket_cfg'], from_file=True) - df = pd.DataFrame.from_dict(data['records']) - self.prompt_buckets = df[df['is_prefill'] == True][['batch_size','seq_or_block']].values.tolist() - self.decode_buckets = df[df['is_prefill'] == False][['batch_size','seq_or_block']].values.tolist() - self.loaded_bucketing_config_from_file = True - except (FileNotFoundError, IOError, PermissionError): - msg = "Could not open file specified in VLLM_HPU_BUCKET_CFG: {bucket_cfg_file}. Falling back to default config." - logger.error(msg) - + + DefaultBucketConfig = collections.namedtuple( + 'DefaultBucketConfig', ['min', 'step', 'max', 'from_file']) + prompt_bs_bucket_cfg_defaults = DefaultBucketConfig( + min=1, + step=align_bs(32), + max=align_bs(max_bucket_cfg), + from_file=False) + prompt_seq_bucket_cfg_defaults = DefaultBucketConfig( + min=self.block_size, + step=self.block_size, + max=max_prompt_seq, + from_file=False) + decode_bs_bucket_cfg_defaults = DefaultBucketConfig( + min=1, + step=align_bs(32), + max=align_bs(max_bucket_cfg), + from_file=False) + decode_block_bucket_cfg_defaults = DefaultBucketConfig( + min=self.block_size, + step=self.block_size, + max=max(self.block_size, + self.max_num_seqs * max_decode_seq // self.block_size), + from_file=False) + + # Do not load bucket config from file during bucket calibration + if self.bucket_cfg_file is not None and not self.calibrate_buckets: + bucket_settings, ( + prompt_buckets, + decode_buckets) = self.deserialize_bucket_settings( + self.bucket_cfg_file) + if bucket_settings is not None: + prompt_bs_bucket_cfg_defaults = DefaultBucketConfig( + **bucket_settings['prompt_bs_bucket_cfg'], from_file=True) + prompt_seq_bucket_cfg_defaults = DefaultBucketConfig( + **bucket_settings['prompt_seq_bucket_cfg'], from_file=True) + decode_bs_bucket_cfg_defaults = DefaultBucketConfig( + **bucket_settings['decode_bs_bucket_cfg'], from_file=True) + decode_block_bucket_cfg_defaults = DefaultBucketConfig( + **bucket_settings['decode_block_bucket_cfg'], + from_file=True) + if prompt_buckets is not None: + self.prompt_buckets = prompt_buckets + if decode_buckets is not None: + self.decode_buckets = decode_buckets + self.prompt_bs_bucket_cfg = read_bucket_settings( - 'prompt', - 'bs', - **prompt_bs_bucket_cfg_defaults._asdict()) - self.decode_bs_bucket_cfg = read_bucket_settings('decode', - 'bs', - **decode_bs_bucket_cfg_defaults._asdict()) - self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', - 'seq', - **prompt_seq_bucket_cfg_defaults._asdict()) + 'prompt', 'bs', **prompt_bs_bucket_cfg_defaults._asdict()) + self.decode_bs_bucket_cfg = read_bucket_settings( + 'decode', 'bs', **decode_bs_bucket_cfg_defaults._asdict()) + self.prompt_seq_bucket_cfg = read_bucket_settings( + 'prompt', 'seq', **prompt_seq_bucket_cfg_defaults._asdict()) self.decode_block_bucket_cfg = read_bucket_settings( - 'decode', - 'block', - **decode_block_bucket_cfg_defaults._asdict() - ) + 'decode', 'block', **decode_block_bucket_cfg_defaults._asdict()) self.graphed_buckets: Set[Any] = set() @@ -739,6 +761,73 @@ def _setup_buckets(self) -> None: f"block:{self.decode_block_bucket_cfg}") logger.info(msg) + if getattr(self, 'prompt_buckets', None) is not None: + msg = (f"Loaded {len(self.prompt_buckets)} " + "prompt buckets from file [bs, seq]: " + f"{list(sorted(self.prompt_buckets))}") + logger.info(msg) + + if getattr(self, 'decode_buckets', None) is not None: + msg = (f"Loaded {len(self.decode_buckets)} " + f"decode buckets from file [bs, block]: " + f"{list(sorted(self.decode_buckets))}") + logger.info(msg) + + def serialize_bucket_settings(self, bucket_cfg_file): + import pandas as pd + import yaml + + def bucket_cfg_to_dict(cfg): + return {'min': cfg[0], 'step': cfg[1], 'max': cfg[2]} + + df = pd.DataFrame( + self.seen_configs, + columns=['batch_size', 'seq_or_block', 'is_prefill']).sort_values( + ['is_prefill', 'batch_size', 'seq_or_block'], ascending=False) + data = {} + data['buckets'] = df.to_dict(orient='records') + data['bucket_cfg'] = { + 'prompt_bs_bucket_cfg': + bucket_cfg_to_dict(self.prompt_bs_bucket_cfg), + 'prompt_seq_bucket_cfg': + bucket_cfg_to_dict(self.prompt_seq_bucket_cfg), + 'decode_bs_bucket_cfg': + bucket_cfg_to_dict(self.decode_bs_bucket_cfg), + 'decode_block_bucket_cfg': + bucket_cfg_to_dict(self.decode_block_bucket_cfg), + } + with open(bucket_cfg_file, 'w') as outfile: + yaml.dump(data, outfile, default_flow_style=False) + msg = f"Bucket calibration settings saved to {bucket_cfg_file}" + logger.info(msg) + + def deserialize_bucket_settings(self, bucket_cfg_file): + import pandas as pd + import yaml + prompt_buckets = None + decode_buckets = None + try: + with open(bucket_cfg_file, 'r') as f: + data = yaml.safe_load(f) + # Load min,step,max from file + bucket_cfg = data['bucket_cfg'] + # Load pre-generated buckets, if any + if 'buckets' in data: + df = pd.DataFrame.from_dict(data['buckets']) + prompt_buckets = df[df['is_prefill']][[ + 'batch_size', 'seq_or_block' + ]].values.tolist() + prompt_buckets = [tuple(b) for b in prompt_buckets] + decode_buckets = df[~df['is_prefill']][[ + 'batch_size', 'seq_or_block' + ]].values.tolist() + decode_buckets = [tuple(b) for b in decode_buckets] + except (FileNotFoundError, IOError, PermissionError): + msg = "Could not open file specified in VLLM_HPU_BUCKET_CFG: " + f"{bucket_cfg_file}. Falling back to default config." + logger.error(msg) + return bucket_cfg, (prompt_buckets, decode_buckets) + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1558,8 +1647,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.profiler.start('internal', 'warmup') max_blocks = kv_caches[0][0].size(0) prompt_omitted_buckets = [] - if not self.loaded_bucketing_config_from_file: - self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( + if getattr(self, 'prompt_buckets', None) is None: + self.prompt_buckets, prompt_omitted_buckets = \ + generate_prompt_buckets( self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, self.max_num_batched_tokens) if self.lora_config: @@ -1574,14 +1664,14 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info(msg) msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") logger.info(msg) msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" logger.debug(msg) - if not self.loaded_bucketing_config_from_file: + if getattr(self, 'decode_buckets', None) is None: self.decode_buckets = generate_decode_buckets( self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, max_blocks) @@ -1873,8 +1963,10 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): self.seen_configs.add(cfg) if not seen and not warmup_mode: phase = 'prompt' if is_prompt else 'decode' - logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", - phase, batch_size, seq_len) + if not self.calibrate_buckets: + logger.warning( + "Configuration: (%s, %s, %s) was not warmed-up!", phase, + batch_size, seq_len) @torch.inference_mode() def execute_model( @@ -1999,21 +2091,6 @@ def shutdown_inc(self): self._is_inc_finalized = True def __del__(self): - calibrate_buckets = os.environ.get('VLLM_HPU_CALIBRATE_BUCKETS', 'false') in ['true', '1'] - if calibrate_buckets: - import pandas as pd - import yaml - df = pd.DataFrame(self.seen_configs, columns=['batch_size', 'seq_or_block', 'is_prefill']).sort_values(['is_prefill', 'batch_size','seq_or_block'], ascending=False) - data = {} - data['buckets'] = df.to_dict(orient='records') - data['bucket_cfg'] = { - 'prompt_bs_bucket_cfg': self.prompt_bs_bucket_cfg, - 'prompt_seq_bucket_cfg': self.prompt_seq_bucket_cfg, - 'decode_bs_bucket_cfg': self.decode_bs_bucket_cfg, - 'decode_block_bucket_cfg': self.decode_block_bucket_cfg, - } - with open('data.yml', 'w') as outfile: - yaml.dump(data, outfile, default_flow_style=False) - + if self.calibrate_buckets: + self.serialize_bucket_settings(self.bucket_cfg_file) self.shutdown_inc() - From b4f0fea1c6f4c67f716664a3b12c24318069eef7 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 30 Sep 2024 13:57:16 +0300 Subject: [PATCH 3/7] getattr in destructor --- vllm/worker/habana_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 9de67912babb..b81624723ad1 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -2091,6 +2091,6 @@ def shutdown_inc(self): self._is_inc_finalized = True def __del__(self): - if self.calibrate_buckets: + if getattr(self, 'calibrate_buckets', False): self.serialize_bucket_settings(self.bucket_cfg_file) self.shutdown_inc() From 00831bc025b8c970eed2154b227ca480e5a16997 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 30 Sep 2024 13:58:41 +0300 Subject: [PATCH 4/7] format.sh --- vllm/worker/habana_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index b81624723ad1..14b9ebec8d22 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -41,8 +41,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import (is_fake_hpu, is_pin_memory_available, - make_tensor_with_pad, get_vllm_instance_id) +from vllm.utils import (get_vllm_instance_id, is_fake_hpu, + is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, From 086f97acd971e9d7b6d22b737c75a9576561d8ac Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 30 Sep 2024 17:32:20 +0300 Subject: [PATCH 5/7] change data format --- vllm/worker/habana_model_runner.py | 75 +++++++++++++++++++++++------- 1 file changed, 58 insertions(+), 17 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 14b9ebec8d22..a9bcc095b3cb 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -570,6 +570,7 @@ def __init__( if self.calibrate_buckets and self.bucket_cfg_file is None: vllm_instance_id = get_vllm_instance_id() self.bucket_cfg_file = f'hpu-buckets-{vllm_instance_id}.yaml' + if self.calibrate_buckets: msg = f"Calibration results will be saved to {self.bucket_cfg_file}" logger.info(msg) @@ -784,26 +785,70 @@ def bucket_cfg_to_dict(cfg): self.seen_configs, columns=['batch_size', 'seq_or_block', 'is_prefill']).sort_values( ['is_prefill', 'batch_size', 'seq_or_block'], ascending=False) - data = {} - data['buckets'] = df.to_dict(orient='records') + df['phase'] = df['is_prefill'].apply(lambda x: 'prefill' + if x else 'decode') + data: Dict[str, Any] = {} # type: ignore + #data['buckets'] = df.to_dict(orient='records') + buckets_dict = df.groupby('phase').apply(lambda dd: sorted( + list([list(a) for a in zip(dd['batch_size'], dd['seq_or_block'])]), + reverse=True)).to_dict() + prefill_df = df[df['is_prefill']] + decode_df = df[~df['is_prefill']] + + updated_prompt_bs_bucket_cfg = self.prompt_bs_bucket_cfg + updated_prompt_bs_bucket_cfg[0] = int(prefill_df['batch_size'].min()) + updated_prompt_bs_bucket_cfg[2] = int(prefill_df['batch_size'].max()) + updated_prompt_seq_bucket_cfg = self.prompt_seq_bucket_cfg + updated_prompt_seq_bucket_cfg[0] = int( + prefill_df['seq_or_block'].min()) + updated_prompt_seq_bucket_cfg[2] = int( + prefill_df['seq_or_block'].max()) + updated_decode_bs_bucket_cfg = self.decode_bs_bucket_cfg + updated_decode_bs_bucket_cfg[0] = int(decode_df['batch_size'].min()) + updated_decode_bs_bucket_cfg[2] = int(decode_df['batch_size'].max()) + updated_decode_block_bucket_cfg = self.decode_block_bucket_cfg + updated_decode_block_bucket_cfg[0] = int( + decode_df['seq_or_block'].min()) + updated_decode_block_bucket_cfg[2] = int( + decode_df['seq_or_block'].max()) + data['bucket_cfg'] = { 'prompt_bs_bucket_cfg': - bucket_cfg_to_dict(self.prompt_bs_bucket_cfg), + bucket_cfg_to_dict(updated_prompt_bs_bucket_cfg), 'prompt_seq_bucket_cfg': - bucket_cfg_to_dict(self.prompt_seq_bucket_cfg), + bucket_cfg_to_dict(updated_prompt_seq_bucket_cfg), 'decode_bs_bucket_cfg': - bucket_cfg_to_dict(self.decode_bs_bucket_cfg), + bucket_cfg_to_dict(updated_decode_bs_bucket_cfg), 'decode_block_bucket_cfg': - bucket_cfg_to_dict(self.decode_block_bucket_cfg), + bucket_cfg_to_dict(updated_decode_block_bucket_cfg), } + data['buckets'] = buckets_dict + + class PSS(str): + pass + + csv_df = df[['phase', 'batch_size', 'seq_or_block']] + data['buckets_csv'] = PSS(csv_df.to_csv(index=False)) + + def pss_representer(dumper, data): + style = '|' + tag = u'tag:yaml.org,2002:str' + return dumper.represent_scalar(tag, data, style=style) + + yaml.add_representer(PSS, pss_representer, Dumper=yaml.SafeDumper) + import pdb + pdb.set_trace() with open(bucket_cfg_file, 'w') as outfile: - yaml.dump(data, outfile, default_flow_style=False) + yaml.safe_dump(data, + outfile, + default_flow_style=None, + sort_keys=False) msg = f"Bucket calibration settings saved to {bucket_cfg_file}" logger.info(msg) def deserialize_bucket_settings(self, bucket_cfg_file): - import pandas as pd import yaml + bucket_cfg = None prompt_buckets = None decode_buckets = None try: @@ -813,19 +858,15 @@ def deserialize_bucket_settings(self, bucket_cfg_file): bucket_cfg = data['bucket_cfg'] # Load pre-generated buckets, if any if 'buckets' in data: - df = pd.DataFrame.from_dict(data['buckets']) - prompt_buckets = df[df['is_prefill']][[ - 'batch_size', 'seq_or_block' - ]].values.tolist() + prompt_buckets = data['buckets']['prefill'] prompt_buckets = [tuple(b) for b in prompt_buckets] - decode_buckets = df[~df['is_prefill']][[ - 'batch_size', 'seq_or_block' - ]].values.tolist() + decode_buckets = data['buckets']['decode'] decode_buckets = [tuple(b) for b in decode_buckets] except (FileNotFoundError, IOError, PermissionError): - msg = "Could not open file specified in VLLM_HPU_BUCKET_CFG: " - f"{bucket_cfg_file}. Falling back to default config." + msg = ("Could not open file specified in VLLM_HPU_BUCKET_CFG: " + f"{bucket_cfg_file}. Falling back to default config.") logger.error(msg) + return bucket_cfg, (prompt_buckets, decode_buckets) def _prepare_prompt( From efd05ff4b2f19aa36d55a7e19a5445d0a8ba0bb4 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 30 Sep 2024 18:26:10 +0300 Subject: [PATCH 6/7] add csv serializer/deserializer --- vllm/worker/habana_model_runner.py | 197 +++++++++++++++++------------ 1 file changed, 117 insertions(+), 80 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index a9bcc095b3cb..52fce7535e63 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -573,6 +573,12 @@ def __init__( if self.calibrate_buckets: msg = f"Calibration results will be saved to {self.bucket_cfg_file}" logger.info(msg) + if self.bucket_cfg_file is not None: + self.bucket_cfg_file_format = os.path.splitext( + self.bucket_cfg_file)[1].strip('.').lower() + assert self.bucket_cfg_file_format in [ + 'yml', 'yaml', 'csv' + ], 'Calibration file must be either YAML or CSV!' # Profiler stats self.profiler_counter_helper = HabanaProfilerCounterHelper() @@ -725,7 +731,7 @@ def _setup_buckets(self) -> None: bucket_settings, ( prompt_buckets, decode_buckets) = self.deserialize_bucket_settings( - self.bucket_cfg_file) + self.bucket_cfg_file, self.bucket_cfg_file_format) if bucket_settings is not None: prompt_bs_bucket_cfg_defaults = DefaultBucketConfig( **bucket_settings['prompt_bs_bucket_cfg'], from_file=True) @@ -774,12 +780,67 @@ def _setup_buckets(self) -> None: f"{list(sorted(self.decode_buckets))}") logger.info(msg) - def serialize_bucket_settings(self, bucket_cfg_file): + def serialize_bucket_settings(self, bucket_cfg_file, fmt='yaml'): import pandas as pd - import yaml - def bucket_cfg_to_dict(cfg): - return {'min': cfg[0], 'step': cfg[1], 'max': cfg[2]} + def yaml_serializer(df, bucket_cfg_file): + import yaml + + def bucket_cfg_to_dict(cfg): + return {'min': cfg[0], 'step': cfg[1], 'max': cfg[2]} + + data: Dict[str, Any] = {} # type: ignore + #data['buckets'] = df.to_dict(orient='records') + buckets_dict = df.groupby('phase').apply( + lambda dd: sorted(list([ + list(a) for a in zip(dd['batch_size'], dd['seq_or_block']) + ]), + reverse=True)).to_dict() + prefill_df = df[df['is_prefill']] + decode_df = df[~df['is_prefill']] + + updated_prompt_bs_bucket_cfg = self.prompt_bs_bucket_cfg + updated_prompt_bs_bucket_cfg[0] = int( + prefill_df['batch_size'].min()) + updated_prompt_bs_bucket_cfg[2] = int( + prefill_df['batch_size'].max()) + updated_prompt_seq_bucket_cfg = self.prompt_seq_bucket_cfg + updated_prompt_seq_bucket_cfg[0] = int( + prefill_df['seq_or_block'].min()) + updated_prompt_seq_bucket_cfg[2] = int( + prefill_df['seq_or_block'].max()) + updated_decode_bs_bucket_cfg = self.decode_bs_bucket_cfg + updated_decode_bs_bucket_cfg[0] = int( + decode_df['batch_size'].min()) + updated_decode_bs_bucket_cfg[2] = int( + decode_df['batch_size'].max()) + updated_decode_block_bucket_cfg = self.decode_block_bucket_cfg + updated_decode_block_bucket_cfg[0] = int( + decode_df['seq_or_block'].min()) + updated_decode_block_bucket_cfg[2] = int( + decode_df['seq_or_block'].max()) + + data['bucket_cfg'] = { + 'prompt_bs_bucket_cfg': + bucket_cfg_to_dict(updated_prompt_bs_bucket_cfg), + 'prompt_seq_bucket_cfg': + bucket_cfg_to_dict(updated_prompt_seq_bucket_cfg), + 'decode_bs_bucket_cfg': + bucket_cfg_to_dict(updated_decode_bs_bucket_cfg), + 'decode_block_bucket_cfg': + bucket_cfg_to_dict(updated_decode_block_bucket_cfg), + } + data['buckets'] = buckets_dict + + with open(bucket_cfg_file, 'w') as outfile: + yaml.safe_dump(data, + outfile, + default_flow_style=None, + sort_keys=False) + + def csv_serializer(df, bucket_cfg_file): + csv_df = df[['phase', 'batch_size', 'seq_or_block']] + csv_df.to_csv(bucket_cfg_file, index=False) df = pd.DataFrame( self.seen_configs, @@ -787,86 +848,61 @@ def bucket_cfg_to_dict(cfg): ['is_prefill', 'batch_size', 'seq_or_block'], ascending=False) df['phase'] = df['is_prefill'].apply(lambda x: 'prefill' if x else 'decode') - data: Dict[str, Any] = {} # type: ignore - #data['buckets'] = df.to_dict(orient='records') - buckets_dict = df.groupby('phase').apply(lambda dd: sorted( - list([list(a) for a in zip(dd['batch_size'], dd['seq_or_block'])]), - reverse=True)).to_dict() - prefill_df = df[df['is_prefill']] - decode_df = df[~df['is_prefill']] - - updated_prompt_bs_bucket_cfg = self.prompt_bs_bucket_cfg - updated_prompt_bs_bucket_cfg[0] = int(prefill_df['batch_size'].min()) - updated_prompt_bs_bucket_cfg[2] = int(prefill_df['batch_size'].max()) - updated_prompt_seq_bucket_cfg = self.prompt_seq_bucket_cfg - updated_prompt_seq_bucket_cfg[0] = int( - prefill_df['seq_or_block'].min()) - updated_prompt_seq_bucket_cfg[2] = int( - prefill_df['seq_or_block'].max()) - updated_decode_bs_bucket_cfg = self.decode_bs_bucket_cfg - updated_decode_bs_bucket_cfg[0] = int(decode_df['batch_size'].min()) - updated_decode_bs_bucket_cfg[2] = int(decode_df['batch_size'].max()) - updated_decode_block_bucket_cfg = self.decode_block_bucket_cfg - updated_decode_block_bucket_cfg[0] = int( - decode_df['seq_or_block'].min()) - updated_decode_block_bucket_cfg[2] = int( - decode_df['seq_or_block'].max()) - - data['bucket_cfg'] = { - 'prompt_bs_bucket_cfg': - bucket_cfg_to_dict(updated_prompt_bs_bucket_cfg), - 'prompt_seq_bucket_cfg': - bucket_cfg_to_dict(updated_prompt_seq_bucket_cfg), - 'decode_bs_bucket_cfg': - bucket_cfg_to_dict(updated_decode_bs_bucket_cfg), - 'decode_block_bucket_cfg': - bucket_cfg_to_dict(updated_decode_block_bucket_cfg), - } - data['buckets'] = buckets_dict - - class PSS(str): - pass - - csv_df = df[['phase', 'batch_size', 'seq_or_block']] - data['buckets_csv'] = PSS(csv_df.to_csv(index=False)) - - def pss_representer(dumper, data): - style = '|' - tag = u'tag:yaml.org,2002:str' - return dumper.represent_scalar(tag, data, style=style) - - yaml.add_representer(PSS, pss_representer, Dumper=yaml.SafeDumper) - import pdb - pdb.set_trace() - with open(bucket_cfg_file, 'w') as outfile: - yaml.safe_dump(data, - outfile, - default_flow_style=None, - sort_keys=False) + if fmt == 'csv': + csv_serializer(df, bucket_cfg_file) + elif fmt in ['yaml', 'yml']: + yaml_serializer(df, bucket_cfg_file) + else: + raise NotImplementedError(f"Unsupported format: {fmt}") + msg = f"Bucket calibration settings saved to {bucket_cfg_file}" logger.info(msg) - def deserialize_bucket_settings(self, bucket_cfg_file): - import yaml + def deserialize_bucket_settings( + self, + bucket_cfg_file, + fmt='yaml' + ) -> Tuple[Optional[Dict[str, int]], Tuple[Optional[List[Tuple[ + int, int]]], Optional[List[Tuple[int, int]]]]]: bucket_cfg = None prompt_buckets = None decode_buckets = None - try: - with open(bucket_cfg_file, 'r') as f: - data = yaml.safe_load(f) - # Load min,step,max from file - bucket_cfg = data['bucket_cfg'] - # Load pre-generated buckets, if any - if 'buckets' in data: - prompt_buckets = data['buckets']['prefill'] - prompt_buckets = [tuple(b) for b in prompt_buckets] - decode_buckets = data['buckets']['decode'] - decode_buckets = [tuple(b) for b in decode_buckets] - except (FileNotFoundError, IOError, PermissionError): - msg = ("Could not open file specified in VLLM_HPU_BUCKET_CFG: " - f"{bucket_cfg_file}. Falling back to default config.") - logger.error(msg) - + # CSV does not support overriding bucket_cfg + if fmt == 'csv': + import csv + try: + with open(bucket_cfg_file, 'r') as f: + reader = csv.DictReader(f, skipinitialspace=True) + data = list(reader) + prompt_buckets = [(int(b['batch_size']), + int(b['seq_or_block'])) for b in data + if b['phase'] == 'prefill'] + decode_buckets = [(int(b['batch_size']), + int(b['seq_or_block'])) for b in data + if b['phase'] == 'decode'] + except (FileNotFoundError, IOError, PermissionError): + msg = ("Could not open file specified in VLLM_HPU_BUCKET_CFG: " + f"{bucket_cfg_file}. Falling back to default config.") + logger.error(msg) + elif fmt in ['yaml', 'yml']: + try: + import yaml + with open(bucket_cfg_file, 'r') as f: + data = yaml.safe_load(f) + # Load min,step,max from file + bucket_cfg = data['bucket_cfg'] + # Load pre-generated buckets, if any + if 'buckets' in data: + prompt_buckets = data['buckets']['prefill'] + prompt_buckets = [tuple(b) for b in prompt_buckets] + decode_buckets = data['buckets']['decode'] + decode_buckets = [tuple(b) for b in decode_buckets] + except (FileNotFoundError, IOError, PermissionError): + msg = ("Could not open file specified in VLLM_HPU_BUCKET_CFG: " + f"{bucket_cfg_file}. Falling back to default config.") + logger.error(msg) + else: + raise NotImplementedError(f"Unsupported format: {fmt}") return bucket_cfg, (prompt_buckets, decode_buckets) def _prepare_prompt( @@ -2133,5 +2169,6 @@ def shutdown_inc(self): def __del__(self): if getattr(self, 'calibrate_buckets', False): - self.serialize_bucket_settings(self.bucket_cfg_file) + self.serialize_bucket_settings(self.bucket_cfg_file, + self.bucket_cfg_file_format) self.shutdown_inc() From 6c6bc62cd82eb143c3e4afaab18844343d459134 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 30 Sep 2024 18:43:21 +0300 Subject: [PATCH 7/7] format.sh --- vllm/worker/habana_model_runner.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 52fce7535e63..82610a2b3a02 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -862,7 +862,7 @@ def deserialize_bucket_settings( self, bucket_cfg_file, fmt='yaml' - ) -> Tuple[Optional[Dict[str, int]], Tuple[Optional[List[Tuple[ + ) -> Tuple[Optional[Dict[str, Dict[str, int]]], Tuple[Optional[List[Tuple[ int, int]]], Optional[List[Tuple[int, int]]]]]: bucket_cfg = None prompt_buckets = None @@ -889,15 +889,19 @@ def deserialize_bucket_settings( import yaml with open(bucket_cfg_file, 'r') as f: data = yaml.safe_load(f) + assert isinstance(data, dict), "Invalid YAML contents" # Load min,step,max from file - bucket_cfg = data['bucket_cfg'] + bucket_cfg = data.get('bucket_cfg', None) # Load pre-generated buckets, if any if 'buckets' in data: + assert isinstance(data['buckets'], dict), \ + "Invalid YAML contents" prompt_buckets = data['buckets']['prefill'] prompt_buckets = [tuple(b) for b in prompt_buckets] decode_buckets = data['buckets']['decode'] decode_buckets = [tuple(b) for b in decode_buckets] - except (FileNotFoundError, IOError, PermissionError): + except (FileNotFoundError, IOError, PermissionError, + AssertionError): msg = ("Could not open file specified in VLLM_HPU_BUCKET_CFG: " f"{bucket_cfg_file}. Falling back to default config.") logger.error(msg)