Skip to content

Commit 1e3f73e

Browse files
committed
fix(live): keep streaming tool yields open
1 parent 59f7bdf commit 1e3f73e

10 files changed

Lines changed: 94 additions & 16 deletions

File tree

src/google/adk/agents/live_request_queue.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class LiveRequest(BaseModel):
5555
close: bool = False
5656
"""If set, close the queue. queue.shutdown() is only supported in Python 3.13+."""
5757

58+
turn_complete: bool = True
59+
"""If set, content messages complete the current model turn."""
60+
5861

5962
class LiveRequestQueue:
6063
"""Queue used to send LiveRequest in a live(bidirectional streaming) way."""
@@ -65,8 +68,10 @@ def __init__(self):
6568
def close(self):
6669
self._queue.put_nowait(LiveRequest(close=True))
6770

68-
def send_content(self, content: types.Content):
69-
self._queue.put_nowait(LiveRequest(content=content))
71+
def send_content(self, content: types.Content, turn_complete: bool = True):
72+
self._queue.put_nowait(
73+
LiveRequest(content=content, turn_complete=turn_complete)
74+
)
7075

7176
def send_realtime(self, blob: types.Blob):
7277
self._queue.put_nowait(LiveRequest(blob=blob))

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,9 @@ async def _send_to_model(
770770
is_function_response = content.parts and any(
771771
part.function_response for part in content.parts
772772
)
773-
if not is_function_response:
774-
if not content.role:
775-
content.role = 'user'
773+
if not is_function_response and not content.role:
774+
content.role = 'user'
775+
if not is_function_response and live_request.turn_complete:
776776
user_content_event = Event(
777777
id=Event.new_id(),
778778
invocation_id=invocation_context.invocation_id,
@@ -783,7 +783,9 @@ async def _send_to_model(
783783
session=invocation_context.session,
784784
event=user_content_event,
785785
)
786-
await llm_connection.send_content(live_request.content)
786+
await llm_connection.send_content(
787+
live_request.content, turn_complete=live_request.turn_complete
788+
)
787789

788790
async def _receive_from_model(
789791
self,

src/google/adk/flows/llm_flows/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,9 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
985985
)
986986
],
987987
)
988-
invocation_context.live_request_queue.send_content(updated_content)
988+
invocation_context.live_request_queue.send_content(
989+
updated_content, turn_complete=False
990+
)
989991
except asyncio.CancelledError:
990992
raise # Re-raise to properly propagate the cancellation
991993

src/google/adk/models/base_llm_connection.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,18 @@ async def send_history(self, history: list[types.Content]):
3939
pass
4040

4141
@abstractmethod
42-
async def send_content(self, content: types.Content):
42+
async def send_content(
43+
self, content: types.Content, turn_complete: bool = True
44+
):
4345
"""Sends a user content to the model.
4446
45-
The model will respond immediately upon receiving the content.
47+
By default, the model will respond upon receiving the content.
4648
If you send function responses, all parts in the content should be function
4749
responses.
4850
4951
Args:
5052
content: The content to send to the model.
53+
turn_complete: Whether this content completes the model turn.
5154
"""
5255
pass
5356

src/google/adk/models/gemini_llm_connection.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,18 @@ async def send_history(self, history: list[types.Content]):
117117
else:
118118
logger.info('no content is sent')
119119

120-
async def send_content(self, content: types.Content):
120+
async def send_content(
121+
self, content: types.Content, turn_complete: bool = True
122+
):
121123
"""Sends a user content to the gemini model.
122124
123-
The model will respond immediately upon receiving the content.
125+
By default, the model will respond upon receiving the content.
124126
If you send function responses, all parts in the content should be function
125127
responses.
126128
127129
Args:
128130
content: The content to send to the model.
131+
turn_complete: Whether this content completes the model turn.
129132
"""
130133
assert content.parts
131134
if content.parts[0].function_response:
@@ -138,7 +141,8 @@ async def send_content(self, content: types.Content):
138141
else:
139142
logger.debug('Sending LLM new content %s', content)
140143
if (
141-
self._is_gemini_3_1_flash_live
144+
turn_complete
145+
and self._is_gemini_3_1_flash_live
142146
and len(content.parts) == 1
143147
and content.parts[0].text
144148
):
@@ -150,7 +154,7 @@ async def send_content(self, content: types.Content):
150154
await self._gemini_session.send(
151155
input=types.LiveClientContent(
152156
turns=[content],
153-
turn_complete=True,
157+
turn_complete=turn_complete,
154158
)
155159
)
156160

tests/unittests/agents/test_live_request_queue.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ def test_send_content():
4040
mock_put_nowait.assert_called_once_with(LiveRequest(content=content))
4141

4242

43+
def test_send_content_sets_turn_complete():
44+
queue = LiveRequestQueue()
45+
content = MagicMock(spec=types.Content)
46+
47+
with patch.object(queue._queue, "put_nowait") as mock_put_nowait:
48+
queue.send_content(content, turn_complete=False)
49+
mock_put_nowait.assert_called_once_with(
50+
LiveRequest(content=content, turn_complete=False)
51+
)
52+
53+
4354
def test_send_realtime():
4455
queue = LiveRequestQueue()
4556
blob = MagicMock(spec=types.Blob)

tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,5 +197,36 @@ async def test_send_to_model_with_text_content(mock_llm_connection):
197197
await flow._send_to_model(mock_llm_connection, invocation_context)
198198

199199
# Verify send_content was called instead of send_realtime
200-
mock_llm_connection.send_content.assert_called_once_with(content)
200+
mock_llm_connection.send_content.assert_called_once_with(
201+
content, turn_complete=True
202+
)
201203
mock_llm_connection.send_realtime.assert_not_called()
204+
205+
206+
@pytest.mark.asyncio
207+
async def test_send_to_model_with_intermediate_text_content(
208+
mock_llm_connection,
209+
):
210+
agent = Agent(name='test_agent', model='mock')
211+
invocation_context = await testing_utils.create_invocation_context(
212+
agent=agent, user_content=''
213+
)
214+
invocation_context.live_request_queue = LiveRequestQueue()
215+
invocation_context.session_service.append_event = mock.AsyncMock()
216+
217+
flow = TestBaseLlmFlow()
218+
219+
content = types.Content(
220+
role='user', parts=[types.Part.from_text(text='progress')]
221+
)
222+
invocation_context.live_request_queue.send(
223+
LiveRequest(content=content, turn_complete=False)
224+
)
225+
invocation_context.live_request_queue.close()
226+
227+
await flow._send_to_model(mock_llm_connection, invocation_context)
228+
229+
mock_llm_connection.send_content.assert_called_once_with(
230+
content, turn_complete=False
231+
)
232+
invocation_context.session_service.append_event.assert_not_called()

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,22 @@ async def test_send_content_text(gemini_connection, mock_gemini_session):
124124
assert call_args['input'].turn_complete is True
125125

126126

127+
@pytest.mark.asyncio
128+
async def test_send_content_text_can_keep_turn_open(
129+
gemini_connection, mock_gemini_session
130+
):
131+
content = types.Content(
132+
role='user', parts=[types.Part.from_text(text='progress')]
133+
)
134+
135+
await gemini_connection.send_content(content, turn_complete=False)
136+
137+
mock_gemini_session.send.assert_called_once()
138+
call_args = mock_gemini_session.send.call_args[1]
139+
assert call_args['input'].turns == [content]
140+
assert call_args['input'].turn_complete is False
141+
142+
127143
@pytest.mark.asyncio
128144
async def test_send_content_function_response(
129145
gemini_connection, mock_gemini_session

tests/unittests/testing_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ def __init__(self, llm_responses: list[LlmResponse]):
420420
async def send_history(self, history: list[types.Content]):
421421
pass
422422

423-
async def send_content(self, content: types.Content):
423+
async def send_content(
424+
self, content: types.Content, turn_complete: bool = True
425+
):
424426
pass
425427

426428
async def send(self, data):

tests/unittests/workflow/testing_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,9 @@ def __init__(
484484
async def send_history(self, history: list[types.Content]):
485485
pass
486486

487-
async def send_content(self, content: types.Content):
487+
async def send_content(
488+
self, content: types.Content, turn_complete: bool = True
489+
):
488490
self.mock_model.live_contents.append(content)
489491
self._input_event.set()
490492

0 commit comments

Comments
 (0)