Skip to content

Commit 8dd5a79

Browse files
ChrisQlastyhangfei
authored andcommitted
fix: Fix transcript finish
Merge #3324 **Problem:** ADK seems to not pass output/input transcription `finished` flag from the Gemini. **Solution:** Relaxation of checking conditions of valid llm_response for input/output transcription. ### Testing Plan Unit test `test_receive_transcript_finished` checks if input/output transcription message with no text but `finished` flag is received. **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. <img width="785" height="373" alt="image" src="https://github.com/user-attachments/assets/6d870e9f-1372-4808-91a9-38578c1b3729" /> **Manual End-to-End (E2E) Tests:** Configure Gemini Agent to produce input & output transcriptions. Observe incoming transcript messages - to see if the finished flag appears at the end agents & users statement. ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [x] Any dependent changes have been merged and published in downstream modules. ### Additional context The mentioned `finished` flag can be obtained in native Gemini APIs like via Websocket but not via ADK. Co-authored-by: Hangfei Lin <[email protected]> COPYBARA_INTEGRATE_REVIEW=#3324 from ChrisQlasty:fix/transcript_finish e7b8e5e PiperOrigin-RevId: 828987996
1 parent f167890 commit 8dd5a79

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

src/google/adk/models/gemini_llm_connection.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
165165
yield self.__build_full_text_response(text)
166166
text = ''
167167
yield llm_response
168-
if (
169-
message.server_content.input_transcription
170-
and message.server_content.input_transcription.text
171-
):
168+
if message.server_content.input_transcription:
172169
llm_response = LlmResponse(
173170
input_transcription=message.server_content.input_transcription,
174171
)
175172
yield llm_response
176-
if (
177-
message.server_content.output_transcription
178-
and message.server_content.output_transcription.text
179-
):
173+
if message.server_content.output_transcription:
180174
llm_response = LlmResponse(
181175
output_transcription=message.server_content.output_transcription
182176
)

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,46 @@ async def test_close(gemini_connection, mock_gemini_session):
112112

113113

114114
@pytest.mark.asyncio
115+
@pytest.mark.parametrize('tx_direction', ['input', 'output'])
116+
async def test_receive_transcript_finished(
117+
gemini_connection, mock_gemini_session, tx_direction
118+
):
119+
"""Test receive_transcript_finished for input and output transcription."""
120+
121+
finished_tx = types.Transcription(finished=True)
122+
123+
msg = mock.Mock()
124+
msg.tool_call = None
125+
msg.usage_metadata = None
126+
msg.session_resumption_update = None
127+
msg.server_content.model_turn = None
128+
msg.server_content.interrupted = False
129+
msg.server_content.turn_complete = False
130+
msg.server_content.input_transcription = (
131+
finished_tx if tx_direction == 'input' else None
132+
)
133+
msg.server_content.output_transcription = (
134+
finished_tx if tx_direction == 'output' else None
135+
)
136+
137+
async def gen():
138+
yield msg
139+
140+
mock_gemini_session.receive = mock.Mock(return_value=gen())
141+
142+
responses = []
143+
async for r in gemini_connection.receive():
144+
responses.append(r)
145+
146+
attr_name = f'{tx_direction}_transcription'
147+
tx_resps = [r for r in responses if getattr(r, attr_name)]
148+
assert tx_resps, f'Expected {tx_direction} transcription response'
149+
150+
transcription = getattr(tx_resps[0], attr_name)
151+
assert transcription.finished is True
152+
assert not transcription.text
153+
154+
115155
async def test_receive_usage_metadata_and_server_content(
116156
gemini_connection, mock_gemini_session
117157
):

0 commit comments

Comments
 (0)