diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 6a9cb6f066ea..3c8022f20f4a 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -308,6 +308,10 @@ def load_model(self) -> None: self.model = self.lora_manager.create_lora_manager(self.model) def _use_graphs(self, batch_size, seq_len, is_prompt): + if is_prompt: + if os.environ.get('VLLM_LIMIT_HPU_GRAPH', 'false').lower() == 'true': + return False + if self.enforce_eager: return False return (batch_size, seq_len, is_prompt) in self.graphed_buckets @@ -999,7 +1003,9 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: decode_available_memory = free_mem - prompt_available_memory prompt_strategy = 'min_tokens' decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', 'max_bs') - self.warmup_graphs(prompt_strategy, self.prompt_buckets, True, kv_caches, prompt_available_memory) + + if os.environ.get('VLLM_LIMIT_HPU_GRAPH', 'false').lower() != 'true': + self.warmup_graphs(prompt_strategy, self.prompt_buckets, True, kv_caches, prompt_available_memory) self.warmup_graphs(decode_strategy, self.decode_buckets, False, kv_caches, decode_available_memory) end_time = time.perf_counter()