Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 232 additions & 33 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (is_fake_hpu, is_pin_memory_available,
make_tensor_with_pad)
from vllm.utils import (get_vllm_instance_id, is_fake_hpu,
is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -80,7 +80,7 @@ def subtuple(obj: object,
return _TYPE_CACHE[typename](**values)


def read_bucket_settings(phase: str, dim: str, **defaults):
def read_bucket_settings(phase: str, dim: str, from_file=False, **defaults):
"""Read bucketing configuration from env variables.

phase is either 'prompt' or 'decode'
Expand All @@ -94,8 +94,9 @@ def read_bucket_settings(phase: str, dim: str, **defaults):
values = [
int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values)
]
label = 'default' if not from_file else 'VLLM_HPU_BUCKET_CFG'
for e, v, d in zip(env_vars, values, default_values):
logger.info('%s=%s (default:%s)', e, v, d)
logger.info('%s=%s (%s:%s)', e, v, label, d)
return values


Expand Down Expand Up @@ -562,6 +563,22 @@ def __init__(
self.lora_manager: LRUCacheWorkerLoRAManager = None
self.model: torch.nn.Module = None
self.inc_initialized_successfully = False
self.calibrate_buckets = os.environ.get('VLLM_HPU_CALIBRATE_BUCKETS',
'false') in ['true', '1']
self.bucket_cfg_file = os.environ.get('VLLM_HPU_BUCKET_CFG', None)
# Set default filename only if bucket calibration is enabled
if self.calibrate_buckets and self.bucket_cfg_file is None:
vllm_instance_id = get_vllm_instance_id()
self.bucket_cfg_file = f'hpu-buckets-{vllm_instance_id}.yaml'
if self.calibrate_buckets:
msg = f"Calibration results will be saved to {self.bucket_cfg_file}"
logger.info(msg)
if self.bucket_cfg_file is not None:
self.bucket_cfg_file_format = os.path.splitext(
self.bucket_cfg_file)[1].strip('.').lower()
assert self.bucket_cfg_file_format in [
'yml', 'yaml', 'csv'
], 'Calibration file must be either YAML or CSV!'

# Profiler stats
self.profiler_counter_helper = HabanaProfilerCounterHelper()
Expand Down Expand Up @@ -593,8 +610,9 @@ def _set_gc_threshold(self) -> None:
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)

self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP',
'false').lower() == 'true'
self.skip_warmup = os.environ.get(
'VLLM_SKIP_WARMUP',
'false').lower() == 'true' or self.calibrate_buckets

def load_model(self) -> None:
import habana_frameworks.torch.core as htcore
Expand Down Expand Up @@ -683,29 +701,61 @@ def _setup_buckets(self) -> None:
#FIXME: The default values should be max_model_len
max_prompt_seq = 1024
max_decode_seq = 2048
self.prompt_bs_bucket_cfg = read_bucket_settings(
'prompt',
'bs',

DefaultBucketConfig = collections.namedtuple(
'DefaultBucketConfig', ['min', 'step', 'max', 'from_file'])
prompt_bs_bucket_cfg_defaults = DefaultBucketConfig(
min=1,
step=align_bs(32),
max=align_bs(max_bucket_cfg))
self.decode_bs_bucket_cfg = read_bucket_settings('decode',
'bs',
min=1,
step=align_bs(32),
max=self.max_num_seqs)
self.prompt_seq_bucket_cfg = read_bucket_settings('prompt',
'seq',
min=self.block_size,
step=self.block_size,
max=max_prompt_seq)
self.decode_block_bucket_cfg = read_bucket_settings(
'decode',
'block',
max=align_bs(max_bucket_cfg),
from_file=False)
prompt_seq_bucket_cfg_defaults = DefaultBucketConfig(
min=self.block_size,
step=self.block_size,
max=max_prompt_seq,
from_file=False)
decode_bs_bucket_cfg_defaults = DefaultBucketConfig(
min=1,
step=align_bs(32),
max=align_bs(max_bucket_cfg),
from_file=False)
decode_block_bucket_cfg_defaults = DefaultBucketConfig(
min=self.block_size,
step=self.block_size,
max=max(self.block_size,
self.max_num_seqs * max_decode_seq // self.block_size))
self.max_num_seqs * max_decode_seq // self.block_size),
from_file=False)

# Do not load bucket config from file during bucket calibration
if self.bucket_cfg_file is not None and not self.calibrate_buckets:
bucket_settings, (
prompt_buckets,
decode_buckets) = self.deserialize_bucket_settings(
self.bucket_cfg_file, self.bucket_cfg_file_format)
if bucket_settings is not None:
prompt_bs_bucket_cfg_defaults = DefaultBucketConfig(
**bucket_settings['prompt_bs_bucket_cfg'], from_file=True)
prompt_seq_bucket_cfg_defaults = DefaultBucketConfig(
**bucket_settings['prompt_seq_bucket_cfg'], from_file=True)
decode_bs_bucket_cfg_defaults = DefaultBucketConfig(
**bucket_settings['decode_bs_bucket_cfg'], from_file=True)
decode_block_bucket_cfg_defaults = DefaultBucketConfig(
**bucket_settings['decode_block_bucket_cfg'],
from_file=True)
if prompt_buckets is not None:
self.prompt_buckets = prompt_buckets
if decode_buckets is not None:
self.decode_buckets = decode_buckets

self.prompt_bs_bucket_cfg = read_bucket_settings(
'prompt', 'bs', **prompt_bs_bucket_cfg_defaults._asdict())
self.decode_bs_bucket_cfg = read_bucket_settings(
'decode', 'bs', **decode_bs_bucket_cfg_defaults._asdict())
self.prompt_seq_bucket_cfg = read_bucket_settings(
'prompt', 'seq', **prompt_seq_bucket_cfg_defaults._asdict())
self.decode_block_bucket_cfg = read_bucket_settings(
'decode', 'block', **decode_block_bucket_cfg_defaults._asdict())

self.graphed_buckets: Set[Any] = set()

msg = ("Prompt bucket config (min, step, max_warmup) "
Expand All @@ -718,6 +768,147 @@ def _setup_buckets(self) -> None:
f"block:{self.decode_block_bucket_cfg}")
logger.info(msg)

if getattr(self, 'prompt_buckets', None) is not None:
msg = (f"Loaded {len(self.prompt_buckets)} "
"prompt buckets from file [bs, seq]: "
f"{list(sorted(self.prompt_buckets))}")
logger.info(msg)

if getattr(self, 'decode_buckets', None) is not None:
msg = (f"Loaded {len(self.decode_buckets)} "
f"decode buckets from file [bs, block]: "
f"{list(sorted(self.decode_buckets))}")
logger.info(msg)

def serialize_bucket_settings(self, bucket_cfg_file, fmt='yaml'):
import pandas as pd

def yaml_serializer(df, bucket_cfg_file):
import yaml

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add pyyaml to requirements?


def bucket_cfg_to_dict(cfg):
return {'min': cfg[0], 'step': cfg[1], 'max': cfg[2]}

data: Dict[str, Any] = {} # type: ignore
#data['buckets'] = df.to_dict(orient='records')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented code

buckets_dict = df.groupby('phase').apply(
lambda dd: sorted(list([
list(a) for a in zip(dd['batch_size'], dd['seq_or_block'])
]),
reverse=True)).to_dict()
prefill_df = df[df['is_prefill']]
decode_df = df[~df['is_prefill']]

updated_prompt_bs_bucket_cfg = self.prompt_bs_bucket_cfg
updated_prompt_bs_bucket_cfg[0] = int(
prefill_df['batch_size'].min())
updated_prompt_bs_bucket_cfg[2] = int(
prefill_df['batch_size'].max())
updated_prompt_seq_bucket_cfg = self.prompt_seq_bucket_cfg
updated_prompt_seq_bucket_cfg[0] = int(
prefill_df['seq_or_block'].min())
updated_prompt_seq_bucket_cfg[2] = int(
prefill_df['seq_or_block'].max())
updated_decode_bs_bucket_cfg = self.decode_bs_bucket_cfg
updated_decode_bs_bucket_cfg[0] = int(
decode_df['batch_size'].min())
updated_decode_bs_bucket_cfg[2] = int(
decode_df['batch_size'].max())
updated_decode_block_bucket_cfg = self.decode_block_bucket_cfg
updated_decode_block_bucket_cfg[0] = int(
decode_df['seq_or_block'].min())
updated_decode_block_bucket_cfg[2] = int(
decode_df['seq_or_block'].max())

data['bucket_cfg'] = {
'prompt_bs_bucket_cfg':
bucket_cfg_to_dict(updated_prompt_bs_bucket_cfg),
'prompt_seq_bucket_cfg':
bucket_cfg_to_dict(updated_prompt_seq_bucket_cfg),
'decode_bs_bucket_cfg':
bucket_cfg_to_dict(updated_decode_bs_bucket_cfg),
'decode_block_bucket_cfg':
bucket_cfg_to_dict(updated_decode_block_bucket_cfg),
}
data['buckets'] = buckets_dict

with open(bucket_cfg_file, 'w') as outfile:
yaml.safe_dump(data,
outfile,
default_flow_style=None,
sort_keys=False)

def csv_serializer(df, bucket_cfg_file):
csv_df = df[['phase', 'batch_size', 'seq_or_block']]
csv_df.to_csv(bucket_cfg_file, index=False)

df = pd.DataFrame(
self.seen_configs,
columns=['batch_size', 'seq_or_block', 'is_prefill']).sort_values(
['is_prefill', 'batch_size', 'seq_or_block'], ascending=False)
df['phase'] = df['is_prefill'].apply(lambda x: 'prefill'
if x else 'decode')
if fmt == 'csv':
csv_serializer(df, bucket_cfg_file)
elif fmt in ['yaml', 'yml']:
yaml_serializer(df, bucket_cfg_file)
else:
raise NotImplementedError(f"Unsupported format: {fmt}")

msg = f"Bucket calibration settings saved to {bucket_cfg_file}"
logger.info(msg)

def deserialize_bucket_settings(
self,
bucket_cfg_file,
fmt='yaml'
) -> Tuple[Optional[Dict[str, Dict[str, int]]], Tuple[Optional[List[Tuple[
int, int]]], Optional[List[Tuple[int, int]]]]]:
bucket_cfg = None
prompt_buckets = None
decode_buckets = None
# CSV does not support overriding bucket_cfg
if fmt == 'csv':
import csv
try:
with open(bucket_cfg_file, 'r') as f:
reader = csv.DictReader(f, skipinitialspace=True)
data = list(reader)
prompt_buckets = [(int(b['batch_size']),
int(b['seq_or_block'])) for b in data
if b['phase'] == 'prefill']
decode_buckets = [(int(b['batch_size']),
int(b['seq_or_block'])) for b in data
if b['phase'] == 'decode']
except (FileNotFoundError, IOError, PermissionError):
msg = ("Could not open file specified in VLLM_HPU_BUCKET_CFG: "
f"{bucket_cfg_file}. Falling back to default config.")
logger.error(msg)
elif fmt in ['yaml', 'yml']:
try:
import yaml
with open(bucket_cfg_file, 'r') as f:
data = yaml.safe_load(f)
assert isinstance(data, dict), "Invalid YAML contents"
# Load min,step,max from file
bucket_cfg = data.get('bucket_cfg', None)
# Load pre-generated buckets, if any
if 'buckets' in data:
assert isinstance(data['buckets'], dict), \
"Invalid YAML contents"
prompt_buckets = data['buckets']['prefill']
prompt_buckets = [tuple(b) for b in prompt_buckets]
decode_buckets = data['buckets']['decode']
decode_buckets = [tuple(b) for b in decode_buckets]
except (FileNotFoundError, IOError, PermissionError,
AssertionError):
msg = ("Could not open file specified in VLLM_HPU_BUCKET_CFG: "
f"{bucket_cfg_file}. Falling back to default config.")
logger.error(msg)
else:
raise NotImplementedError(f"Unsupported format: {fmt}")
return bucket_cfg, (prompt_buckets, decode_buckets)

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down Expand Up @@ -1536,10 +1727,12 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
return
self.profiler.start('internal', 'warmup')
max_blocks = kv_caches[0][0].size(0)

self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets(
self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)
prompt_omitted_buckets = []
if getattr(self, 'prompt_buckets', None) is None:
self.prompt_buckets, prompt_omitted_buckets = \
generate_prompt_buckets(
self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)
if self.lora_config:
self.prompt_buckets[:] = [
bucket for bucket in self.prompt_buckets
Expand All @@ -1559,9 +1752,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}"
logger.debug(msg)

self.decode_buckets = generate_decode_buckets(
self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg,
max_blocks)
if getattr(self, 'decode_buckets', None) is None:
self.decode_buckets = generate_decode_buckets(
self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg,
max_blocks)
if self.lora_config:
self.decode_buckets[:] = [
bucket for bucket in self.decode_buckets
Expand Down Expand Up @@ -1850,8 +2044,10 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
self.seen_configs.add(cfg)
if not seen and not warmup_mode:
phase = 'prompt' if is_prompt else 'decode'
logger.warning("Configuration: (%s, %s, %s) was not warmed-up!",
phase, batch_size, seq_len)
if not self.calibrate_buckets:
logger.warning(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to divide code into 3 lines? We now have wide displays and it does nt make it more readable

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to divide code into 3 lines?

I wish we didn't, but format.sh made this into such an abomination.

"Configuration: (%s, %s, %s) was not warmed-up!", phase,
batch_size, seq_len)

@torch.inference_mode()
def execute_model(
Expand Down Expand Up @@ -1976,4 +2172,7 @@ def shutdown_inc(self):
self._is_inc_finalized = True

def __del__(self):
if getattr(self, 'calibrate_buckets', False):
self.serialize_bucket_settings(self.bucket_cfg_file,
self.bucket_cfg_file_format)
self.shutdown_inc()