|
6 | 6 | import time |
7 | 7 | from collections.abc import AsyncGenerator, Iterable, Mapping |
8 | 8 | from copy import copy |
9 | | -from typing import Any |
| 9 | +from typing import Any, cast |
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 | import torch |
@@ -131,10 +131,9 @@ def __init__( |
131 | 131 | self.output_processor = OutputProcessor( |
132 | 132 | self.tokenizer, log_stats=self.log_stats |
133 | 133 | ) |
134 | | - if self.observability_config.otlp_traces_endpoint is not None: |
135 | | - tracer = init_tracer( |
136 | | - "vllm.llm_engine", self.observability_config.otlp_traces_endpoint |
137 | | - ) |
| 134 | + endpoint = self.observability_config.otlp_traces_endpoint |
| 135 | + if endpoint is not None: |
| 136 | + tracer = init_tracer("vllm.llm_engine", endpoint) |
138 | 137 | self.output_processor.tracer = tracer |
139 | 138 |
|
140 | 139 | # EngineCore (starts the engine in background process). |
@@ -266,7 +265,9 @@ def shutdown(self): |
266 | 265 | if engine_core := getattr(self, "engine_core", None): |
267 | 266 | engine_core.shutdown() |
268 | 267 |
|
269 | | - cancel_task_threadsafe(getattr(self, "output_handler", None)) |
| 268 | + handler = getattr(self, "output_handler", None) |
| 269 | + if handler is not None: |
| 270 | + cancel_task_threadsafe(handler) |
270 | 271 |
|
271 | 272 | async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: |
272 | 273 | return await self.engine_core.get_supported_tasks_async() |
@@ -314,7 +315,10 @@ async def add_request( |
314 | 315 | priority, |
315 | 316 | data_parallel_rank, |
316 | 317 | ) |
317 | | - prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") |
| 318 | + if isinstance(prompt, str): |
| 319 | + prompt_text = prompt |
| 320 | + elif isinstance(prompt, Mapping): |
| 321 | + prompt_text = cast(str | None, prompt.get("prompt")) |
318 | 322 |
|
319 | 323 | if is_pooling or params.n == 1: |
320 | 324 | await self._add_request(request, prompt_text, None, 0, queue) |
@@ -436,6 +440,7 @@ async def generate( |
436 | 440 | # Note: both OutputProcessor and EngineCore handle their |
437 | 441 | # own request cleanup based on finished. |
438 | 442 | finished = out.finished |
| 443 | + assert isinstance(out, RequestOutput) |
439 | 444 | yield out |
440 | 445 |
|
441 | 446 | # If the request is disconnected by the client, generate() |
@@ -653,7 +658,7 @@ async def get_tokenizer(self) -> AnyTokenizer: |
653 | 658 | return self.tokenizer |
654 | 659 |
|
655 | 660 | async def is_tracing_enabled(self) -> bool: |
656 | | - return self.observability_config.otlp_traces_endpoint is not None |
| 661 | + return self.observability_config.otlp_traces_endpoint is not None # type: ignore |
657 | 662 |
|
658 | 663 | async def do_log_stats(self) -> None: |
659 | 664 | if self.logger_manager: |
|
0 commit comments