Skip to content

Commit ce8f674

Browse files
ayam04Raman Mangla
authored andcommitted
fix: Allow LLM request to override the model used in the generate content async method in LiteLLM
Merge #3066 Close #3065 Co-authored-by: Raman Mangla <[email protected]> PiperOrigin-RevId: 825880794
1 parent 3814d8b commit ce8f674

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,8 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
384384

385385

386386
def _schema_to_dict(schema: types.Schema) -> dict:
387-
"""
388-
Recursively converts a types.Schema to a pure-python dict
387+
"""Recursively converts a types.Schema to a pure-python dict
388+
389389
with all enum values written as lower-case strings.
390390
391391
Args:
@@ -631,7 +631,8 @@ def _get_completion_inputs(
631631
llm_request: The LlmRequest to convert.
632632
633633
Returns:
634-
The litellm inputs (message list, tool dictionary, response format and generation params).
634+
The litellm inputs (message list, tool dictionary, response format and
635+
generation params).
635636
"""
636637
# 1. Construct messages
637638
messages: List[Message] = []
@@ -905,7 +906,7 @@ async def generate_content_async(
905906
tools = None
906907

907908
completion_args = {
908-
"model": self.model,
909+
"model": llm_request.model or self.model,
909910
"messages": messages,
910911
"tools": tools,
911912
"response_format": response_format,

tests/unittests/models/test_litellm.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,53 @@ async def test_generate_content_async(mock_acompletion, lite_llm_instance):
549549
)
550550

551551

552+
@pytest.mark.asyncio
553+
async def test_generate_content_async_with_model_override(
554+
mock_acompletion, lite_llm_instance
555+
):
556+
llm_request = LlmRequest(
557+
model="overridden_model",
558+
contents=[
559+
types.Content(
560+
role="user", parts=[types.Part.from_text(text="Test prompt")]
561+
)
562+
],
563+
)
564+
565+
async for response in lite_llm_instance.generate_content_async(llm_request):
566+
assert response.content.role == "model"
567+
assert response.content.parts[0].text == "Test response"
568+
569+
mock_acompletion.assert_called_once()
570+
571+
_, kwargs = mock_acompletion.call_args
572+
assert kwargs["model"] == "overridden_model"
573+
assert kwargs["messages"][0]["role"] == "user"
574+
assert kwargs["messages"][0]["content"] == "Test prompt"
575+
576+
577+
@pytest.mark.asyncio
578+
async def test_generate_content_async_without_model_override(
579+
mock_acompletion, lite_llm_instance
580+
):
581+
llm_request = LlmRequest(
582+
model=None,
583+
contents=[
584+
types.Content(
585+
role="user", parts=[types.Part.from_text(text="Test prompt")]
586+
)
587+
],
588+
)
589+
590+
async for response in lite_llm_instance.generate_content_async(llm_request):
591+
assert response.content.role == "model"
592+
593+
mock_acompletion.assert_called_once()
594+
595+
_, kwargs = mock_acompletion.call_args
596+
assert kwargs["model"] == "test_model"
597+
598+
552599
@pytest.mark.asyncio
553600
async def test_generate_content_async_adds_fallback_user_message(
554601
mock_acompletion, lite_llm_instance

0 commit comments

Comments
 (0)