Skip to content
Closed
10 changes: 2 additions & 8 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
yield self.__build_full_text_response(text)
text = ''
yield llm_response
if (
message.server_content.input_transcription
and message.server_content.input_transcription.text
):
if message.server_content.input_transcription:
llm_response = LlmResponse(
input_transcription=message.server_content.input_transcription,
)
yield llm_response
if (
message.server_content.output_transcription
and message.server_content.output_transcription.text
):
if message.server_content.output_transcription:
llm_response = LlmResponse(
output_transcription=message.server_content.output_transcription
)
Expand Down
40 changes: 40 additions & 0 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,46 @@ async def test_close(gemini_connection, mock_gemini_session):


@pytest.mark.asyncio
@pytest.mark.parametrize('tx_direction', ['input', 'output'])
async def test_receive_transcript_finished(
gemini_connection, mock_gemini_session, tx_direction
):
"""Test receive_transcript_finished for input and output transcription."""

finished_tx = types.Transcription(finished=True)

msg = mock.Mock()
msg.tool_call = None
msg.usage_metadata = None
msg.session_resumption_update = None
msg.server_content.model_turn = None
msg.server_content.interrupted = False
msg.server_content.turn_complete = False
msg.server_content.input_transcription = (
finished_tx if tx_direction == 'input' else None
)
msg.server_content.output_transcription = (
finished_tx if tx_direction == 'output' else None
)

async def gen():
yield msg

mock_gemini_session.receive = mock.Mock(return_value=gen())

responses = []
async for r in gemini_connection.receive():
responses.append(r)

attr_name = f'{tx_direction}_transcription'
tx_resps = [r for r in responses if getattr(r, attr_name)]
assert tx_resps, f'Expected {tx_direction} transcription response'

transcription = getattr(tx_resps[0], attr_name)
assert transcription.finished is True
assert not transcription.text


async def test_receive_usage_metadata_and_server_content(
gemini_connection, mock_gemini_session
):
Expand Down