Skip to content

Commit 5859175

Browse files
authored
[Refactor Followup] Fix type errors in scheduler package (#436)
## Summary and Details This PR fixed type errors in the scheduler package. Many of them are Null errors and some are fixing type annotations on more complex types. One group of errors is switching the type annotations to reflect that the data structures are only storing string IDs instead of the type instances. There is one remaining type error that can be found by running `mypy` on the scheduler folder. I believe that has to do with re-using the typing from the worker group in the singular worker. The problem is that `SchedulerState` is present in the group, but not the singular worker. This PR removes 44 type errors, bringing the current count to 83. ## Test Plan I ran some basic benchmarks. Run benchmarks to ensure the scheduler is still working as intended. --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`)
2 parents dd219f1 + 6a0a468 commit 5859175

File tree

6 files changed

+124
-71
lines changed

6 files changed

+124
-71
lines changed

src/guidellm/scheduler/constraints.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,9 +1005,7 @@ def info(self) -> dict[str, Any]:
10051005
return self.model_dump()
10061006

10071007
def __call__(
1008-
self,
1009-
state: SchedulerState,
1010-
request_info: RequestInfo, # noqa: ARG002
1008+
self, state: SchedulerState, _request: RequestInfo
10111009
) -> SchedulerUpdateAction:
10121010
create_exceeded = state.created_requests >= self.num_requests
10131011
processed_exceeded = state.processed_requests >= self.num_requests

src/guidellm/scheduler/environments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def sync_run_start(self) -> float:
8484
async def update_run_iteration(
8585
self,
8686
response: ResponseT | None,
87-
request: RequestT,
87+
request: RequestT | MultiTurnRequestT[RequestT],
8888
request_info: RequestInfo,
8989
state: SchedulerState,
9090
):
@@ -201,7 +201,7 @@ async def sync_run_start(self) -> float:
201201
async def update_run_iteration(
202202
self,
203203
response: ResponseT | None,
204-
request: RequestT,
204+
request: RequestT | MultiTurnRequestT[RequestT],
205205
request_info: RequestInfo,
206206
state: SchedulerState,
207207
):

src/guidellm/scheduler/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ async def run(
6969
) -> AsyncIterator[
7070
tuple[
7171
ResponseT | None,
72-
RequestT,
72+
RequestT | MultiTurnRequestT[RequestT],
7373
RequestInfo,
7474
SchedulerState,
7575
]

src/guidellm/scheduler/strategies.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def __pydantic_schema_base_type__(cls) -> type[SchedulingStrategy]:
7070
description="Number of worker processes to use for this strategy",
7171
ge=0,
7272
)
73-
max_concurrency: int = Field(
74-
default=0,
73+
max_concurrency: int | None = Field(
74+
default=None,
7575
description="Maximum number of concurrent requests to allow",
7676
ge=0,
7777
)
@@ -122,8 +122,8 @@ def init_processes_timings(
122122
self.startup_duration = startup_duration
123123

124124
self._processes_request_index = Value("i", 0)
125-
self._processes_lock = Lock()
126125
self._processes_start_time = Value("d", -1.0)
126+
self._processes_lock = Lock()
127127

128128
def init_processes_start(self, start_time: float):
129129
"""
@@ -137,6 +137,10 @@ def init_processes_start(self, start_time: float):
137137
"SchedulingStrategy init_processes_start called before "
138138
"init_processes_timings"
139139
)
140+
if self._processes_start_time is None:
141+
raise RuntimeError(
142+
"_processes_lock is not None but _processes_start_time is None"
143+
)
140144

141145
with self._processes_lock:
142146
self._processes_start_time.value = start_time
@@ -153,6 +157,10 @@ async def get_processes_start_time(self) -> float:
153157
"SchedulingStrategy get_processes_start_time called before "
154158
"init_processes_timings"
155159
)
160+
if self._processes_start_time is None:
161+
raise RuntimeError(
162+
"_processes_lock is not None but _processes_start_time is None"
163+
)
156164

157165
while self._cached_processes_start_time is None:
158166
with self._processes_lock:
@@ -175,6 +183,10 @@ def next_request_index(self) -> int:
175183
"SchedulingStrategy next_request_index called before "
176184
"init_processes_timings"
177185
)
186+
if self._processes_request_index is None:
187+
raise RuntimeError(
188+
"_processes_lock is not None but _processes_request_index is None"
189+
)
178190

179191
with self._processes_lock:
180192
self._processes_request_index.value += 1
@@ -369,7 +381,8 @@ async def next_request_time(self, offset: int) -> float:
369381
start_time = await self.get_processes_start_time()
370382

371383
if (
372-
self.startup_duration > 0
384+
self.max_concurrency is not None
385+
and self.startup_duration > 0
373386
and (time.time() - start_time) < self.startup_duration
374387
and (current_index := self.next_request_index()) <= self.max_concurrency
375388
):
@@ -477,6 +490,8 @@ def init_processes_timings(
477490
:param startup_duration: Duration in seconds for request startup ramping
478491
"""
479492
super().init_processes_timings(worker_count, max_concurrency, startup_duration)
493+
if self._processes_lock is None:
494+
raise RuntimeError("_processes_lock is None in init_processes_timings")
480495
with self._processes_lock:
481496
self._offset = Value("d", -1.0)
482497

@@ -487,6 +502,12 @@ def init_processes_start(self, start_time: float):
487502
:param start_time: Unix timestamp when request processing should begin
488503
"""
489504
ThroughputStrategy.init_processes_start(self, start_time)
505+
506+
if self._processes_lock is None:
507+
raise RuntimeError("_processes_lock is None in init_processes_start")
508+
if self._offset is None:
509+
raise RuntimeError("_offset is None in init_processes_start; was "
510+
"init_processes_timings not called?")
490511
with self._processes_lock:
491512
self._offset.value = start_time
492513

@@ -505,6 +526,12 @@ async def next_request_time(self, offset: int) -> float:
505526

506527
next_delay = self._random.expovariate(self.rate)
507528

529+
if self._processes_lock is None:
530+
raise RuntimeError("_processes_lock is None in next_request_time; was "
531+
"init_processes_timings not called?")
532+
if self._offset is None:
533+
raise RuntimeError("_offset is None in next_request_time; was "
534+
"init_processes_timings not called?")
508535
with self._processes_lock:
509536
self._offset.value += next_delay
510537

src/guidellm/scheduler/worker.py

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@
2323
bool, "Flag indicating uvloop availability for event loop optimization"
2424
] = True
2525
except ImportError:
26-
uvloop = None
26+
uvloop = None # type: ignore[assignment] # Optional dependency
2727

28-
HAS_UVLOOP: Annotated[
29-
bool, "Flag indicating uvloop availability for event loop optimization"
30-
] = False
28+
HAS_UVLOOP = False
3129

3230

3331
from guidellm.scheduler.schemas import (
@@ -84,6 +82,10 @@ def __init__(
8482
RequestT | MultiTurnRequestT[RequestT],
8583
RequestInfo,
8684
],
85+
tuple[
86+
RequestT | MultiTurnRequestT[RequestT],
87+
RequestInfo,
88+
],
8789
],
8890
backend: BackendInterface[RequestT, ResponseT],
8991
strategy: SchedulingStrategy,
@@ -201,8 +203,11 @@ async def run_async(self):
201203

202204
async def _stop_monitor(
203205
self,
204-
) -> Literal["error_event", "shutdown_event"]:
205-
"""Monitor shutdown and error events for worker termination."""
206+
) -> None:
207+
"""
208+
Monitor shutdown and error events for worker termination.
209+
:raises RuntimeError if the work process received an error signal.
210+
"""
206211
exit_key = await wait_for_sync_objects(
207212
{
208213
"error_event": self.error_event,
@@ -322,7 +327,7 @@ async def _cancel_requests_loop(self):
322327
"""Cancel all remaining queued requests until worker process terminates."""
323328
while True:
324329
try:
325-
request: RequestT
330+
request: RequestT | MultiTurnRequestT[RequestT]
326331
request_info: RequestInfo
327332
request, request_info = await self.messaging.get(
328333
timeout=self.messaging.poll_interval
@@ -350,31 +355,19 @@ async def _process_next_request(self, target_start: float):
350355

351356
try:
352357
# Pull request from the queue, update state, and send "pending" update
353-
request, request_info = await self.messaging.get()
354-
request_info.timings.dequeued = time.time()
355-
request_info.scheduler_node_id = self.messaging.worker_index or -1
356-
request_info.timings.targeted_start = target_start
357-
self._send_update("pending", response, request, request_info)
358-
359-
if request is None or request_info is None:
360-
raise RuntimeError("Received invalid request or request info")
361-
if isinstance(request, list | tuple):
362-
raise NotImplementedError("Multi-turn requests are not yet supported")
363-
364-
# Schedule the request
365-
current_time = time.time()
366-
request_info.timings.scheduled_at = current_time
367-
if target_start > current_time:
368-
await asyncio.sleep(target_start - current_time)
369-
# Adapt delay so that scheduled at reflects the sleep time
370-
request_info.timings.scheduled_at = target_start
371-
372-
# Process the request with the backend
373-
request_info.timings.resolve_start = time.time()
374-
self._send_update("in_progress", response, request, request_info)
375-
async for resp, info in self.backend.resolve(request, request_info, None):
358+
request, request_info = await self._dequeue_next_request(target_start)
359+
360+
# Schedule the request and send "in_progress" update
361+
await self._schedule_request(request, request_info, target_start)
362+
363+
async for resp, info in self.backend.resolve( # type: ignore[attr-defined]
364+
request, request_info, None
365+
):
366+
376367
response = resp
377368
request_info = info
369+
if request_info is None:
370+
raise RuntimeError("Received invalid request info from backend")
378371

379372
# Complete the request
380373
request_info.timings.resolve_end = time.time()
@@ -397,6 +390,39 @@ async def _process_next_request(self, target_start: float):
397390
if request_info is not None:
398391
self.strategy.request_completed(request_info)
399392

393+
async def _dequeue_next_request(
394+
self, target_start: float
395+
) -> tuple[RequestT, RequestInfo]:
396+
request, request_info = await self.messaging.get()
397+
dequeued_time = time.time() # Ensure accurate dequeue timing
398+
if request is None or request_info is None:
399+
raise RuntimeError("Received invalid request or request info")
400+
if isinstance(request, list | tuple):
401+
raise NotImplementedError("Multi-turn requests are not yet supported")
402+
403+
request_info.timings.dequeued = dequeued_time
404+
request_info.scheduler_node_id = self.messaging.worker_index or -1
405+
request_info.timings.targeted_start = target_start
406+
self._send_update("pending", None, request, request_info)
407+
return request, request_info
408+
409+
async def _schedule_request(
410+
self,
411+
request: RequestT,
412+
request_info: RequestInfo,
413+
target_start: float
414+
):
415+
current_time = time.time()
416+
request_info.timings.scheduled_at = current_time
417+
if target_start > current_time:
418+
await asyncio.sleep(target_start - current_time)
419+
# Adapt delay so that scheduled at reflects the sleep time
420+
request_info.timings.scheduled_at = target_start
421+
422+
# Process the request with the backend
423+
request_info.timings.resolve_start = time.time()
424+
self._send_update("in_progress", None, request, request_info)
425+
400426
def _send_update(
401427
self,
402428
new_status: Literal[

src/guidellm/scheduler/worker_group.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
backend: BackendInterface[RequestT, ResponseT],
8585
strategy: SchedulingStrategy,
8686
startup_duration: float,
87-
**constraints: dict[str, Constraint],
87+
**constraints: Constraint,
8888
):
8989
"""
9090
Initialize a worker process group for distributed request processing.
@@ -232,7 +232,7 @@ async def create_processes(self):
232232
worker_index=rank,
233233
max_buffer_send_size=None,
234234
max_buffer_receive_size=per_proc_max_buffer_size,
235-
),
235+
), # The non-group worker lacks the SchedulerState type. Type err.
236236
backend=self.backend,
237237
strategy=self.strategy,
238238
async_limit=async_limit,
@@ -478,9 +478,9 @@ def __init__(
478478
num_processes=len(processes),
479479
start_time=start_time,
480480
)
481-
self._queued_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set()
482-
self._pending_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set()
483-
self._processing_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set()
481+
self._queued_request_ids: set[str] = set()
482+
self._pending_request_ids: set[str] = set()
483+
self._processing_request_ids: set[str] = set()
484484

485485
def requests_generator(
486486
self, requests: Iterable[RequestT | MultiTurnRequestT[RequestT]]
@@ -517,11 +517,13 @@ def requests_generator(
517517
)
518518
state_update = self._locked_update(request_info)
519519
request_info.timings.queued = time.time()
520+
if self.messaging.buffer_receive_queue is None:
521+
raise RuntimeError("buffer receive queue is None")
520522
self.messaging.buffer_receive_queue.sync_put(
521523
(None, request, request_info, state_update.state)
522524
)
523525

524-
yield (request, request_info)
526+
yield request, request_info
525527

526528
if state_update.stop_queueing:
527529
self.stop_send_requests_event.set()
@@ -530,8 +532,8 @@ def requests_generator(
530532
# Reached the end, inject a RequestsExhaustedConstraint to record
531533
self._locked_update(
532534
info=None,
533-
requests_exhausted={
534-
"requests_exhausted": RequestsExhaustedConstraint(
535+
add_constraints={
536+
"requests_exhausted": RequestsExhaustedConstraint( # type: ignore[dict-item]
535537
num_requests=count
536538
)
537539
},
@@ -610,10 +612,10 @@ def received_callback(
610612
def _locked_update(
611613
self,
612614
info: RequestInfo | None = None,
613-
**add_constraints: dict[str, Constraint],
615+
add_constraints: dict[str, Constraint] | None = None,
614616
) -> _StateUpdate:
615617
with self._update_lock:
616-
if add_constraints:
618+
if add_constraints is not None:
617619
self.constraints.update(add_constraints)
618620

619621
if info is not None:
@@ -631,34 +633,34 @@ def _locked_update(
631633

632634
def _update_state_request_counts(self, info: RequestInfo):
633635
if info.status == "queued":
634-
self._queued_requests.add(info.request_id)
635-
self._state.queued_requests = len(self._queued_requests)
636+
self._queued_request_ids.add(info.request_id)
637+
self._state.queued_requests = len(self._queued_request_ids)
636638
self._state.created_requests += 1
637639
elif info.status == "pending":
638-
self._queued_requests.remove(info.request_id)
639-
self._state.queued_requests = len(self._queued_requests)
640-
self._pending_requests.add(info.request_id)
641-
self._state.pending_requests = len(self._pending_requests)
640+
self._queued_request_ids.remove(info.request_id)
641+
self._state.queued_requests = len(self._queued_request_ids)
642+
self._pending_request_ids.add(info.request_id)
643+
self._state.pending_requests = len(self._pending_request_ids)
642644
elif info.status == "in_progress":
643-
self._pending_requests.remove(info.request_id)
644-
self._state.pending_requests = len(self._pending_requests)
645-
self._processing_requests.add(info.request_id)
646-
self._state.processing_requests = len(self._processing_requests)
645+
self._pending_request_ids.remove(info.request_id)
646+
self._state.pending_requests = len(self._pending_request_ids)
647+
self._processing_request_ids.add(info.request_id)
648+
self._state.processing_requests = len(self._processing_request_ids)
647649
elif info.status == "completed":
648-
self._processing_requests.remove(info.request_id)
649-
self._state.processing_requests = len(self._processing_requests)
650+
self._processing_request_ids.remove(info.request_id)
651+
self._state.processing_requests = len(self._processing_request_ids)
650652
self._state.processed_requests += 1
651653
self._state.successful_requests += 1
652654
elif info.status in ("errored", "cancelled"):
653-
if info.request_id in self._queued_requests:
654-
self._queued_requests.remove(info.request_id)
655-
self._state.queued_requests = len(self._queued_requests)
656-
elif info.request_id in self._pending_requests:
657-
self._pending_requests.remove(info.request_id)
658-
self._state.pending_requests = len(self._pending_requests)
659-
elif info.request_id in self._processing_requests:
660-
self._processing_requests.remove(info.request_id)
661-
self._state.processing_requests = len(self._processing_requests)
655+
if info.request_id in self._queued_request_ids:
656+
self._queued_request_ids.remove(info.request_id)
657+
self._state.queued_requests = len(self._queued_request_ids)
658+
elif info.request_id in self._pending_request_ids:
659+
self._pending_request_ids.remove(info.request_id)
660+
self._state.pending_requests = len(self._pending_request_ids)
661+
elif info.request_id in self._processing_request_ids:
662+
self._processing_request_ids.remove(info.request_id)
663+
self._state.processing_requests = len(self._processing_request_ids)
662664

663665
self._state.processed_requests += 1
664666
self._state.errored_requests += 1 if info.status == "errored" else 0

0 commit comments

Comments
 (0)