diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 047f699e4f27..be0631024f02 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -13,7 +13,7 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Optional, Set, Tuple +from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop from fastapi import APIRouter, FastAPI, HTTPException, Request @@ -419,6 +419,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): "use the Pooling API (`/pooling`) instead.") res = await fallback_handler.create_pooling(request, raw_request) + + generator: Union[ErrorResponse, EmbeddingResponse] if isinstance(res, PoolingResponse): generator = EmbeddingResponse( id=res.id, @@ -493,7 +495,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) -TASK_HANDLERS = { +TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { "generate": { "messages": (ChatCompletionRequest, create_chat_completion), "default": (CompletionRequest, create_completion), @@ -651,7 +653,7 @@ async def add_request_id(request: Request, call_next): module_path, object_name = middleware.rsplit(".", 1) imported = getattr(importlib.import_module(module_path), object_name) if inspect.isclass(imported): - app.add_middleware(imported) + app.add_middleware(imported) # type: ignore[arg-type] elif inspect.iscoroutinefunction(imported): app.middleware("http")(imported) else: