Skip to content

Commit 05fd241

Browse files
yewentao256ilmarkov
authored andcommitted
[CI] Fix mypy for vllm/v1/core and vllm/v1/engine (vllm-project#27108)
Signed-off-by: yewentao256 <[email protected]>
1 parent c301d82 commit 05fd241

File tree

12 files changed

+91
-61
lines changed

12 files changed

+91
-61
lines changed

tools/pre_commit/mypy.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@
3636
"vllm/transformers_utils",
3737
"vllm/triton_utils",
3838
"vllm/usage",
39+
"vllm/v1/core",
40+
"vllm/v1/engine",
3941
]
4042

4143
# After fixing errors resulting from changing follow_imports
4244
# from "skip" to "silent", move the following directories to FILES
4345
SEPARATE_GROUPS = [
4446
"tests",
47+
# v0 related
4548
"vllm/attention",
4649
"vllm/compilation",
4750
"vllm/engine",
@@ -50,7 +53,16 @@
5053
"vllm/model_executor",
5154
"vllm/plugins",
5255
"vllm/worker",
53-
"vllm/v1",
56+
# v1 related
57+
"vllm/v1/attention",
58+
"vllm/v1/executor",
59+
"vllm/v1/kv_offload",
60+
"vllm/v1/metrics",
61+
"vllm/v1/pool",
62+
"vllm/v1/sample",
63+
"vllm/v1/spec_decode",
64+
"vllm/v1/structured_output",
65+
"vllm/v1/worker",
5466
]
5567

5668
# TODO(woosuk): Include the code from Megatron and HuggingFace.

vllm/config/vllm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ class VllmConfig:
8484
default_factory=StructuredOutputsConfig
8585
)
8686
"""Structured outputs configuration."""
87-
observability_config: ObservabilityConfig | None = None
87+
observability_config: ObservabilityConfig = Field(
88+
default_factory=ObservabilityConfig
89+
)
8890
"""Observability configuration."""
8991
quant_config: QuantizationConfig | None = None
9092
"""Quantization configuration."""
@@ -170,10 +172,7 @@ def compute_hash(self) -> str:
170172
vllm_factors.append(self.structured_outputs_config.compute_hash())
171173
else:
172174
vllm_factors.append("None")
173-
if self.observability_config:
174-
vllm_factors.append(self.observability_config.compute_hash())
175-
else:
176-
vllm_factors.append("None")
175+
vllm_factors.append(self.observability_config.compute_hash())
177176
if self.quant_config:
178177
pass # should be captured by model_config.quantization
179178
if self.compilation_config:

vllm/engine/protocol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def encode(
7777
lora_request: LoRARequest | None = None,
7878
trace_headers: Mapping[str, str] | None = None,
7979
priority: int = 0,
80+
truncate_prompt_tokens: int | None = None,
8081
tokenization_kwargs: dict[str, Any] | None = None,
8182
) -> AsyncGenerator[PoolingRequestOutput, None]:
8283
"""Generate outputs for a request from a pooling model."""

vllm/v1/core/sched/scheduler.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(
167167
self.kv_cache_manager = KVCacheManager(
168168
kv_cache_config=kv_cache_config,
169169
max_model_len=self.max_model_len,
170-
enable_caching=self.cache_config.enable_prefix_caching,
170+
enable_caching=bool(self.cache_config.enable_prefix_caching),
171171
use_eagle=self.use_eagle,
172172
log_stats=self.log_stats,
173173
enable_kv_cache_events=self.enable_kv_cache_events,
@@ -407,20 +407,22 @@ def schedule(self) -> SchedulerOutput:
407407

408408
# Get externally-cached tokens if using a KVConnector.
409409
if self.connector is not None:
410-
num_external_computed_tokens, load_kv_async = (
410+
ext_tokens, load_kv_async = (
411411
self.connector.get_num_new_matched_tokens(
412412
request, num_new_local_computed_tokens
413413
)
414414
)
415415

416-
if num_external_computed_tokens is None:
416+
if ext_tokens is None:
417417
# The request cannot be scheduled because
418418
# the KVConnector couldn't determine
419419
# the number of matched tokens.
420420
self.waiting.pop_request()
421421
skipped_waiting_requests.prepend_request(request)
422422
continue
423423

424+
num_external_computed_tokens = ext_tokens
425+
424426
# Total computed tokens (local + external).
425427
num_computed_tokens = (
426428
num_new_local_computed_tokens + num_external_computed_tokens
@@ -905,13 +907,13 @@ def update_from_output(
905907

906908
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
907909
spec_decoding_stats: SpecDecodingStats | None = None
908-
kv_connector_stats = (
910+
kv_connector_stats: KVConnectorStats | None = (
909911
kv_connector_output.kv_connector_stats if kv_connector_output else None
910912
)
911913
if kv_connector_stats and self.connector:
912-
stats = self.connector.get_kv_connector_stats()
913-
if stats:
914-
kv_connector_stats = kv_connector_stats.aggregate(stats)
914+
kv_stats = self.connector.get_kv_connector_stats()
915+
if kv_stats:
916+
kv_connector_stats = kv_connector_stats.aggregate(kv_stats)
915917

916918
failed_kv_load_req_ids = None
917919
if kv_connector_output and kv_connector_output.invalid_block_ids:

vllm/v1/engine/async_llm.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77
from collections.abc import AsyncGenerator, Iterable, Mapping
88
from copy import copy
9-
from typing import Any
9+
from typing import Any, cast
1010

1111
import numpy as np
1212
import torch
@@ -131,10 +131,9 @@ def __init__(
131131
self.output_processor = OutputProcessor(
132132
self.tokenizer, log_stats=self.log_stats
133133
)
134-
if self.observability_config.otlp_traces_endpoint is not None:
135-
tracer = init_tracer(
136-
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint
137-
)
134+
endpoint = self.observability_config.otlp_traces_endpoint
135+
if endpoint is not None:
136+
tracer = init_tracer("vllm.llm_engine", endpoint)
138137
self.output_processor.tracer = tracer
139138

140139
# EngineCore (starts the engine in background process).
@@ -266,7 +265,9 @@ def shutdown(self):
266265
if engine_core := getattr(self, "engine_core", None):
267266
engine_core.shutdown()
268267

269-
cancel_task_threadsafe(getattr(self, "output_handler", None))
268+
handler = getattr(self, "output_handler", None)
269+
if handler is not None:
270+
cancel_task_threadsafe(handler)
270271

271272
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
272273
return await self.engine_core.get_supported_tasks_async()
@@ -314,7 +315,10 @@ async def add_request(
314315
priority,
315316
data_parallel_rank,
316317
)
317-
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
318+
if isinstance(prompt, str):
319+
prompt_text = prompt
320+
elif isinstance(prompt, Mapping):
321+
prompt_text = cast(str | None, prompt.get("prompt"))
318322

319323
if is_pooling or params.n == 1:
320324
await self._add_request(request, prompt_text, None, 0, queue)
@@ -436,6 +440,7 @@ async def generate(
436440
# Note: both OutputProcessor and EngineCore handle their
437441
# own request cleanup based on finished.
438442
finished = out.finished
443+
assert isinstance(out, RequestOutput)
439444
yield out
440445

441446
# If the request is disconnected by the client, generate()
@@ -653,7 +658,7 @@ async def get_tokenizer(self) -> AnyTokenizer:
653658
return self.tokenizer
654659

655660
async def is_tracing_enabled(self) -> bool:
656-
return self.observability_config.otlp_traces_endpoint is not None
661+
return self.observability_config.otlp_traces_endpoint is not None # type: ignore
657662

658663
async def do_log_stats(self) -> None:
659664
if self.logger_manager:

vllm/v1/engine/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,7 @@ def _init_data_parallel(self, vllm_config: VllmConfig):
10751075
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
10761076

10771077
assert dp_size > 1
1078+
assert local_dp_rank is not None
10781079
assert 0 <= local_dp_rank <= dp_rank < dp_size
10791080

10801081
if vllm_config.kv_transfer_config is not None:

vllm/v1/engine/core_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -385,10 +385,11 @@ def close_sockets_and_tasks():
385385
with contextlib.suppress(Exception):
386386
task.cancel()
387387

388-
if in_loop(loop):
389-
close_sockets_and_tasks()
390-
elif loop and not loop.is_closed():
391-
loop.call_soon_threadsafe(close_sockets_and_tasks)
388+
if loop is not None:
389+
if in_loop(loop):
390+
close_sockets_and_tasks()
391+
elif not loop.is_closed():
392+
loop.call_soon_threadsafe(close_sockets_and_tasks)
392393
else:
393394
# Loop has been closed, try to clean up directly.
394395
del tasks
@@ -1044,6 +1045,7 @@ def _ensure_stats_update_task(self):
10441045
return
10451046

10461047
assert self.stats_update_address is not None
1048+
stats_addr: str = self.stats_update_address
10471049
assert len(self.engine_ranks_managed) > 0
10481050
# NOTE: running and waiting counts are all global from
10491051
# the Coordinator include all global EngineCores. This
@@ -1054,9 +1056,7 @@ def _ensure_stats_update_task(self):
10541056

10551057
async def run_engine_stats_update_task():
10561058
with (
1057-
make_zmq_socket(
1058-
self.ctx, self.stats_update_address, zmq.XSUB, linger=0
1059-
) as socket,
1059+
make_zmq_socket(self.ctx, stats_addr, zmq.XSUB, linger=0) as socket,
10601060
make_zmq_socket(
10611061
self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0
10621062
) as first_req_rcv_socket,

vllm/v1/engine/detokenizer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,21 @@ def __init__(self, request: EngineCoreRequest):
6969
# Stop strings
7070
params = request.sampling_params
7171
assert params is not None
72-
self.stop = stop = params.stop
72+
stop_list: list[str]
73+
if params.stop is None:
74+
stop_list = []
75+
elif isinstance(params.stop, str):
76+
stop_list = [params.stop]
77+
else:
78+
stop_list = params.stop
79+
self.stop = stop_list
7380
self.min_tokens = params.min_tokens
7481
self.include_stop_str_in_output = params.include_stop_str_in_output
7582

7683
# Number of chars to hold back when stop strings are to be excluded
7784
# from streamed output.
78-
if stop and not self.include_stop_str_in_output:
79-
self.stop_buffer_length = max(len(s) for s in stop) - 1
85+
if self.stop and not self.include_stop_str_in_output:
86+
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
8087
else:
8188
self.stop_buffer_length = 0
8289
self._last_output_text_offset: int = 0

vllm/v1/engine/llm_engine.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from collections.abc import Callable, Mapping
66
from copy import copy
7-
from typing import Any
7+
from typing import Any, cast
88

99
import torch.nn as nn
1010
from typing_extensions import TypeVar
@@ -112,10 +112,9 @@ def __init__(
112112
self.output_processor = OutputProcessor(
113113
self.tokenizer, log_stats=self.log_stats
114114
)
115-
if self.observability_config.otlp_traces_endpoint is not None:
116-
tracer = init_tracer(
117-
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint
118-
)
115+
endpoint = self.observability_config.otlp_traces_endpoint
116+
if endpoint is not None:
117+
tracer = init_tracer("vllm.llm_engine", endpoint)
119118
self.output_processor.tracer = tracer
120119

121120
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
@@ -259,7 +258,10 @@ def add_request(
259258
trace_headers,
260259
priority,
261260
)
262-
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
261+
if isinstance(prompt, str):
262+
prompt_text = prompt
263+
elif isinstance(prompt, Mapping):
264+
prompt_text = cast(str | None, prompt.get("prompt"))
263265

264266
n = params.n if isinstance(params, SamplingParams) else 1
265267

@@ -285,7 +287,7 @@ def add_request(
285287
# Add the request to EngineCore.
286288
self.engine_core.add_request(child_request)
287289

288-
def step(self) -> list[RequestOutput] | list[PoolingRequestOutput]:
290+
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
289291
if self.should_execute_dummy_batch:
290292
self.should_execute_dummy_batch = False
291293
self.engine_core.execute_dummy_batch()

vllm/v1/engine/output_processor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,16 @@ def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
4444
if self.output is None or isinstance(output, Exception):
4545
self.output = output
4646
self.ready.set()
47-
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
47+
elif isinstance(self.output, RequestOutput) and isinstance(
48+
output, RequestOutput
49+
):
4850
# This ensures that request outputs with different request indexes
4951
# (if n > 1) do not override each other.
5052
self.output.add(output, aggregate=self.aggregate)
53+
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
54+
output, PoolingRequestOutput
55+
):
56+
self.output = output
5157

5258
async def get(self) -> RequestOutput | PoolingRequestOutput:
5359
"""Get operation blocks on put event."""
@@ -408,7 +414,7 @@ def process_outputs(
408414
within the loop below.
409415
"""
410416

411-
request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = []
417+
request_outputs: list[RequestOutput | PoolingRequestOutput] = []
412418
reqs_to_abort: list[str] = []
413419
for engine_core_output in engine_core_outputs:
414420
req_id = engine_core_output.request_id

0 commit comments

Comments
 (0)