diff --git a/vllm/config.py b/vllm/config.py index ebdcc5e0de93..ad436a1e65ee 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1277,6 +1277,8 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. + use_tqdm_on_load: Whether to enable tqdm for showing progress bar during + loading. Default to True """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO @@ -1284,6 +1286,7 @@ class LoadConfig: model_loader_extra_config: Optional[Union[str, dict]] = field( default_factory=dict) ignore_patterns: Optional[Union[list[str], str]] = None + use_tqdm_on_load: bool = True def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0e572a6f07bd..351ac175e3e9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -217,6 +217,7 @@ class EngineArgs: additional_config: Optional[Dict[str, Any]] = None enable_reasoning: Optional[bool] = None reasoning_parser: Optional[str] = None + use_tqdm_on_load: bool = True def __post_init__(self): if not self.tokenizer: @@ -751,6 +752,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=1, help=('Maximum number of forward steps per ' 'scheduler call.')) + parser.add_argument( + '--use-tqdm-on-load', + dest='use_tqdm_on_load', + action=argparse.BooleanOptionalAction, + default=EngineArgs.use_tqdm_on_load, + help='Whether to enable/disable progress bar ' + 'when loading model weights.', + ) parser.add_argument( '--multi-step-stream-outputs', @@ -1179,6 +1188,7 @@ def create_load_config(self) -> LoadConfig: download_dir=self.download_dir, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, + use_tqdm_on_load=self.use_tqdm_on_load, ) def create_engine_config(self, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 4f1092f68f50..bf226f661126 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -354,11 +354,18 @@ def _get_weights_iterator( self.load_config.download_dir, hf_folder, hf_weights_files, + self.load_config.use_tqdm_on_load, ) elif use_safetensors: - weights_iterator = safetensors_weights_iterator(hf_weights_files) + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) else: - weights_iterator = pt_weights_iterator(hf_weights_files) + weights_iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) if current_platform.is_tpu(): # In PyTorch XLA, we should call `xm.mark_step` frequently so that @@ -806,9 +813,15 @@ def _prepare_weights(self, model_name_or_path: str, def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): if use_safetensors: - iterator = safetensors_weights_iterator(hf_weights_files) + iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) else: - iterator = pt_weights_iterator(hf_weights_files) + iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) for org_name, param in iterator: # mapping weight names from transformers to vllm while preserving # original names. @@ -1396,7 +1409,10 @@ def _get_weights_iterator( revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_weights_files = self._prepare_weights(model_or_path, revision) - return runai_safetensors_weights_iterator(hf_weights_files) + return runai_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) def download_model(self, model_config: ModelConfig) -> None: """Download model if necessary""" diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index d184079fb25d..926172a1daab 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -366,16 +366,22 @@ def filter_files_not_needed_for_inference( _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 +def enable_tqdm(use_tqdm_on_load: bool): + return use_tqdm_on_load and (not torch.distributed.is_initialized() + or torch.distributed.get_rank() == 0) + + def np_cache_weights_iterator( - model_name_or_path: str, cache_dir: Optional[str], hf_folder: str, - hf_weights_files: List[str] + model_name_or_path: str, + cache_dir: Optional[str], + hf_folder: str, + hf_weights_files: List[str], + use_tqdm_on_load: bool, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model np files. Will dump the model weights to numpy files if they are not already dumped. """ - enable_tqdm = not torch.distributed.is_initialized( - ) or torch.distributed.get_rank() == 0 # Convert the model weights from torch tensors to numpy arrays for # faster loading. np_folder = os.path.join(hf_folder, "np") @@ -389,7 +395,7 @@ def np_cache_weights_iterator( for bin_file in tqdm( hf_weights_files, desc="Loading np_cache checkpoint shards", - disable=not enable_tqdm, + disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): state = torch.load(bin_file, @@ -414,15 +420,14 @@ def np_cache_weights_iterator( def safetensors_weights_iterator( - hf_weights_files: List[str] + hf_weights_files: List[str], + use_tqdm_on_load: bool, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" - enable_tqdm = not torch.distributed.is_initialized( - ) or torch.distributed.get_rank() == 0 for st_file in tqdm( hf_weights_files, desc="Loading safetensors checkpoint shards", - disable=not enable_tqdm, + disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): with safe_open(st_file, framework="pt") as f: @@ -432,16 +437,15 @@ def safetensors_weights_iterator( def runai_safetensors_weights_iterator( - hf_weights_files: List[str] + hf_weights_files: List[str], + use_tqdm_on_load: bool, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" - enable_tqdm = not torch.distributed.is_initialized( - ) or torch.distributed.get_rank() == 0 with SafetensorsStreamer() as streamer: for st_file in tqdm( hf_weights_files, desc="Loading safetensors using Runai Model Streamer", - disable=not enable_tqdm, + disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): streamer.stream_file(st_file) @@ -449,15 +453,14 @@ def runai_safetensors_weights_iterator( def pt_weights_iterator( - hf_weights_files: List[str] + hf_weights_files: List[str], + use_tqdm_on_load: bool, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" - enable_tqdm = not torch.distributed.is_initialized( - ) or torch.distributed.get_rank() == 0 for bin_file in tqdm( hf_weights_files, desc="Loading pt checkpoint shards", - disable=not enable_tqdm, + disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): state = torch.load(bin_file, map_location="cpu", weights_only=True)