diff --git a/tests/agents/core/test_history_processor.py b/tests/agents/core/test_history_processor.py index 22182b3..ebbdebb 100644 --- a/tests/agents/core/test_history_processor.py +++ b/tests/agents/core/test_history_processor.py @@ -116,6 +116,18 @@ def test_invocation_mode_filters_by_id(self, invocation_context): assert len(events) == 1 assert events[0].content.parts[0].text == "current" + def test_invocation_mode_includes_summary_events(self, invocation_context): + proc = HistoryProcessor(timeline_filter_mode=TimelineFilterMode.INVOCATION) + summary_event = _make_event("system", "Previous conversation summary", invocation_id="summary") + summary_event.set_summary_event(True) + current_event = _make_event("user", "current", invocation_id="inv-1") + + events = proc.filter_events(invocation_context, [summary_event, current_event]) + + assert len(events) == 2 + assert events[0].is_summary_event() + assert events[1].content.parts[0].text == "current" + # --------------------------------------------------------------------------- # HistoryProcessor.filter_events - Branch filtering diff --git a/tests/sessions/test_base_session_service.py b/tests/sessions/test_base_session_service.py index ef898f8..f15d373 100644 --- a/tests/sessions/test_base_session_service.py +++ b/tests/sessions/test_base_session_service.py @@ -170,7 +170,8 @@ def test_filter_by_num_recent_events(self): svc = ConcreteSessionService(session_config=config) session = _make_session() for i in range(10): - session.events.append(_make_event(text=f"msg{i}")) + author = "user" if i == 7 else "agent" + session.events.append(_make_event(author=author, text=f"msg{i}")) svc.filter_events(session) assert len(session.events) == 10 visible_events = [event for event in session.events if event.is_model_visible()] @@ -185,7 +186,7 @@ def test_filter_by_event_ttl(self): old_event.timestamp = time.time() - 100 session.events.append(old_event) - new_event = _make_event(text="new") + new_event = _make_event(author="user", text="new") new_event.timestamp = time.time() session.events.append(new_event) @@ -215,6 +216,23 @@ def test_filter_ttl_removes_all_old(self): assert len(session.events) == 5 assert all(not event.is_model_visible() for event in session.events) + def test_filter_by_num_recent_events_preserves_summary_anchor(self): + config = SessionServiceConfig(num_recent_events=3) + svc = ConcreteSessionService(session_config=config) + session = _make_session() + + summary_event = _make_event(author="system", text="summary") + summary_event.set_summary_event(True) + session.events.append(summary_event) + for i in range(5): + session.events.append(_make_event(text=f"agent{i}")) + + svc.filter_events(session) + + visible_events = [event for event in session.events if event.is_model_visible()] + assert len(visible_events) == 1 + assert visible_events[0].is_summary_event() + class TestBaseSessionServiceSetSummarizerManager: """Test set_summarizer_manager method.""" diff --git a/tests/sessions/test_session.py b/tests/sessions/test_session.py index 39c83bd..d0411bf 100644 --- a/tests/sessions/test_session.py +++ b/tests/sessions/test_session.py @@ -10,6 +10,7 @@ from google.genai.types import Part from trpc_agent_sdk.events import Event from trpc_agent_sdk.sessions import Session +from trpc_agent_sdk.sessions import is_summary_anchor class TestSession: @@ -35,15 +36,16 @@ def test_add_event(self): assert session.events[0].author == "user" assert session.last_update_time == event.timestamp - def test_is_user_message(self): - """Test checking if an event is a user message.""" - session = Session(id="test-session", app_name="test-app", user_id="test-user", save_key="test-key") - + def test_is_anchor_message(self): + """Test checking if an event can anchor visible conversation history.""" user_event = Event(author="user", content=Content(parts=[Part.from_text(text="Hello")])) agent_event = Event(author="agent-1", content=Content(parts=[Part.from_text(text="Hi")])) + summary_event = Event(author="system", content=Content(parts=[Part.from_text(text="Summary")])) + summary_event.set_summary_event(True) - assert session._is_user_message(user_event) is True - assert session._is_user_message(agent_event) is False + assert is_summary_anchor(user_event) is True + assert is_summary_anchor(agent_event) is False + assert is_summary_anchor(summary_event) is True def test_apply_event_filtering_no_config(self): """Test event filtering with no configuration.""" diff --git a/tests/sessions/test_session_summarizer.py b/tests/sessions/test_session_summarizer.py index 5cfc98b..e8fc460 100644 --- a/tests/sessions/test_session_summarizer.py +++ b/tests/sessions/test_session_summarizer.py @@ -344,8 +344,136 @@ async def mock_generate(request, stream=False, ctx=None): visible_events = [event for event in result_events if event.is_model_visible()] assert len(visible_events) == 4 # 1 summary + 3 recent assert any(event.is_summary_event() for event in result_events) + summary_event = next(event for event in result_events if event.is_summary_event()) + assert summary_event.author == "system" + assert summary_event.content.role == "user" - async def test_summary_without_keep_recent(self): + async def test_summary_traces_back_to_invisible_user_before_first_visible_event(self): + model = _make_model_mock() + llm_response = MagicMock() + llm_response.content = Content(parts=[Part.from_text(text="summary text")]) + captured_prompts = [] + + async def mock_generate(request, stream=False, ctx=None): + captured_prompts.append(request.contents[0].parts[0].text) + yield llm_response + + model.generate_async = mock_generate + summarizer = SessionSummarizer(model=model, start_by_user_turn=True) + hidden_user = _make_event(author="user", text="hidden question") + hidden_user.set_model_visible(False) + old_answer = _make_event(author="agent", text="visible answer") + recent_user = _make_event(author="user", text="recent question") + system_preamble = _make_event(author="system", text="system preamble") + system_preamble.set_model_visible(False) + events = [ + system_preamble, + hidden_user, + old_answer, + recent_user, + ] + + summary_text, result_events = await summarizer.create_session_summary_by_events( + events, "s1", keep_recent_count=1) + + assert summary_text == "summary text" + assert result_events is events + assert captured_prompts + assert "hidden question" in captured_prompts[0] + assert "visible answer" in captured_prompts[0] + assert "system preamble" not in captured_prompts[0] + assert "recent question" not in captured_prompts[0] + assert old_answer.is_model_visible() is False + assert recent_user.is_model_visible() is True + assert any(event.is_summary_event() for event in result_events) + + async def test_summary_can_start_from_existing_summary_event(self): + model = _make_model_mock() + llm_response = MagicMock() + llm_response.content = Content(parts=[Part.from_text(text="summary text")]) + captured_prompts = [] + + async def mock_generate(request, stream=False, ctx=None): + captured_prompts.append(request.contents[0].parts[0].text) + yield llm_response + + model.generate_async = mock_generate + summarizer = SessionSummarizer(model=model, start_by_user_turn=True) + existing_summary = _make_event(author="system", text="previous summary") + existing_summary.set_summary_event(True) + system_preamble = _make_event(author="system", text="system preamble") + system_preamble.set_model_visible(False) + events = [ + system_preamble, + existing_summary, + _make_event(author="agent", text="old answer"), + _make_event(author="user", text="recent question"), + ] + + summary_text, result_events = await summarizer.create_session_summary_by_events( + events, "s1", keep_recent_count=1) + + assert summary_text == "summary text" + assert "previous summary" in captured_prompts[0] + assert "old answer" in captured_prompts[0] + assert "system preamble" not in captured_prompts[0] + assert result_events[3].is_summary_event() + + async def test_summary_falls_back_to_first_visible_event_and_ignores_large_keep_recent(self): + model = _make_model_mock() + llm_response = MagicMock() + llm_response.content = Content(parts=[Part.from_text(text="summary text")]) + captured_prompts = [] + + async def mock_generate(request, stream=False, ctx=None): + captured_prompts.append(request.contents[0].parts[0].text) + yield llm_response + + model.generate_async = mock_generate + summarizer = SessionSummarizer(model=model, start_by_user_turn=True) + events = [ + _make_event(author="agent", text="agent message 1"), + _make_event(author="agent", text="agent message 2"), + ] + + summary_text, result_events = await summarizer.create_session_summary_by_events( + events, "s1", keep_recent_count=10) + + assert summary_text == "summary text" + assert "agent message 1" in captured_prompts[0] + assert "agent message 2" in captured_prompts[0] + visible_events = [event for event in result_events if event.is_model_visible()] + assert len(visible_events) == 1 + assert visible_events[0].is_summary_event() + + async def test_summary_inserted_before_recent_user_turn_and_hides_prior_events(self): + model = _make_model_mock() + llm_response = MagicMock() + llm_response.content = Content(parts=[Part.from_text(text="summary text")]) + captured_prompts = [] + + async def mock_generate(request, stream=False, ctx=None): + captured_prompts.append(request.contents[0].parts[0].text) + yield llm_response + + model.generate_async = mock_generate + summarizer = SessionSummarizer(model=model, start_by_user_turn=True) + events = [_make_event(author="user" if idx in (8, 80, 92) else "agent", text=f"msg {idx}") for idx in range(100)] + for idx, event in enumerate(events): + event.set_model_visible(10 <= idx < 99) + + summary_text, result_events = await summarizer.create_session_summary_by_events( + events, "s1", keep_recent_count=10) + + assert summary_text == "summary text" + assert "msg 8" in captured_prompts[0] + assert "msg 91" in captured_prompts[0] + assert "msg 92" not in captured_prompts[0] + assert result_events[92].is_summary_event() + assert all(not event.is_model_visible() for event in result_events[:92]) + assert result_events[93].is_model_visible() + + async def test_summary_with_zero_keep_recent(self): model = _make_model_mock() llm_response = MagicMock() llm_response.content = Content(parts=[Part.from_text(text="summary text")]) @@ -357,11 +485,13 @@ async def mock_generate(request, stream=False, ctx=None): summarizer = SessionSummarizer(model=model) events = [_make_event(text=f"msg{i}") for i in range(5)] summary_text, result_events = await summarizer.create_session_summary_by_events( - events, "s1", keep_recent_count=None) + events, "s1", keep_recent_count=0) assert summary_text is not None assert len(result_events) == 6 # preserve all original events + 1 summary visible_events = [event for event in result_events if event.is_model_visible()] assert len(visible_events) == 1 # only summary event remains model-visible + assert visible_events[0].is_summary_event() + assert visible_events[0].content.role == "user" async def test_summary_no_events(self): model = _make_model_mock() @@ -407,6 +537,67 @@ async def mock_generate(request, stream=False, ctx=None): assert len(visible_events) == 3 # 1 summary + 2 recent assert any(event.is_summary_event() for event in session.events) + async def test_summary_traces_back_to_invisible_user_before_visible_events(self): + model = _make_model_mock() + llm_response = MagicMock() + llm_response.content = Content(parts=[Part.from_text(text="session summary")]) + captured_prompts = [] + + async def mock_generate(request, stream=False, ctx=None): + captured_prompts.append(request.contents[0].parts[0].text) + yield llm_response + + model.generate_async = mock_generate + summarizer = SessionSummarizer(model=model, keep_recent_count=1, start_by_user_turn=True) + hidden_user = _make_event(author="user", text="hidden question") + hidden_user.set_model_visible(False) + old_answer = _make_event(author="agent", text="visible answer") + recent_user = _make_event(author="user", text="recent question") + system_preamble = _make_event(author="system", text="system preamble") + system_preamble.set_model_visible(False) + session = _make_session(events=[ + system_preamble, + hidden_user, + old_answer, + recent_user, + ]) + + result = await summarizer.create_session_summary(session) + + assert result == "session summary" + assert captured_prompts + assert "hidden question" in captured_prompts[0] + assert "visible answer" in captured_prompts[0] + assert "system preamble" not in captured_prompts[0] + assert "recent question" not in captured_prompts[0] + assert old_answer.is_model_visible() is False + assert recent_user.is_model_visible() is True + assert any(event.is_summary_event() for event in session.events) + + async def test_summary_without_visible_user_falls_back_to_first_visible_event(self): + model = _make_model_mock() + llm_response = MagicMock() + llm_response.content = Content(parts=[Part.from_text(text="session summary")]) + + async def mock_generate(request, stream=False, ctx=None): + yield llm_response + + model.generate_async = mock_generate + summarizer = SessionSummarizer(model=model, keep_recent_count=10, start_by_user_turn=True) + events = [ + _make_event(author="system", text="system preamble"), + _make_event(author="agent", text="agent answer"), + ] + session = _make_session(events=events) + + result = await summarizer.create_session_summary(session) + + assert result == "session summary" + assert len(session.events) == 3 + visible_events = [event for event in session.events if event.is_model_visible()] + assert len(visible_events) == 1 + assert visible_events[0].is_summary_event() + async def test_summary_no_update_on_failure(self): model = _make_model_mock() diff --git a/tests/sessions/test_utils.py b/tests/sessions/test_utils.py index 44b2a15..bda07ba 100644 --- a/tests/sessions/test_utils.py +++ b/tests/sessions/test_utils.py @@ -18,17 +18,25 @@ import pytest +from trpc_agent_sdk.events import Event from trpc_agent_sdk.sessions._utils import ( StateStorageEntry, app_state_key, extract_state_delta, + find_events_for_summary, merge_state, session_key, user_state_key, ) +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part from trpc_agent_sdk.types import State +def _make_event(author: str = "agent", text: str = "hello") -> Event: + return Event(author=author, content=Content(parts=[Part.from_text(text=text)])) + + class TestStateStorageEntry: """Test StateStorageEntry dataclass.""" @@ -160,6 +168,117 @@ def test_need_copy_false(self): assert "new_key" in original_session_state +class TestFindEventsForSummary: + """Test event selection for summary generation.""" + + def test_defaults_start_from_user_and_keep_recent(self): + events = [ + _make_event(author="system", text="system preamble"), + _make_event(author="user", text="question"), + _make_event(author="agent", text="answer"), + _make_event(author="user", text="recent question"), + ] + + selected_events, insert_index = find_events_for_summary(events, keep_recent_count=1) + + assert selected_events == events[:3] + assert insert_index == 3 + + def test_first_visible_summary_event_starts_summary_window(self): + summary_event = _make_event(author="system", text="previous summary") + summary_event.set_summary_event(True) + events = [ + summary_event, + _make_event(author="agent", text="answer"), + _make_event(author="user", text="recent question"), + ] + + selected_events, insert_index = find_events_for_summary(events, keep_recent_count=1) + + assert selected_events == events[:2] + assert insert_index == 2 + + def test_fallback_to_first_visible_event(self): + events = [ + _make_event(author="agent", text="answer 1"), + _make_event(author="agent", text="answer 2"), + ] + + selected_events, insert_index = find_events_for_summary(events, keep_recent_count=1) + + assert selected_events == events[:1] + assert insert_index == 1 + + def test_traces_back_to_invisible_user_before_first_visible_event(self): + hidden_user = _make_event(author="user", text="hidden") + hidden_user.set_model_visible(False) + events = [ + _make_event(author="system", text="hidden preamble"), + hidden_user, + _make_event(author="agent", text="visible answer"), + _make_event(author="agent", text="visible answer"), + ] + events[0].set_model_visible(False) + + selected_events, insert_index = find_events_for_summary(events, keep_recent_count=10) + + assert selected_events == events[1:] + assert insert_index == len(events) + + def test_aligns_recent_window_to_next_user_turn(self): + events = [_make_event(author="user" if idx in (8, 80, 92) else "agent", text=f"msg {idx}") for idx in range(100)] + for idx, event in enumerate(events): + event.set_model_visible(10 <= idx < 99) + + selected_events, insert_index = find_events_for_summary(events, keep_recent_count=10) + + assert selected_events == events[8:92] + assert insert_index == 92 + + def test_keep_recent_count_uses_model_visible_events(self): + events = [_make_event(author="agent", text=f"msg {idx}") for idx in range(20)] + for idx, event in enumerate(events): + event.set_model_visible(idx in (0, 1, 2, 10, 15, 19)) + + selected_events, insert_index = find_events_for_summary( + events, keep_recent_count=3, start_by_user_turn=False) + + assert selected_events == events[:10] + assert insert_index == 10 + + def test_ignores_keep_recent_when_it_would_empty_selection(self): + events = [ + _make_event(author="user", text="question"), + _make_event(author="agent", text="answer"), + ] + + selected_events, insert_index = find_events_for_summary(events, keep_recent_count=10) + + assert selected_events == events + assert insert_index == len(events) + + def test_zero_keep_recent_selects_all_matching_events(self): + events = [ + _make_event(author="system", text="system preamble"), + _make_event(author="user", text="question"), + _make_event(author="agent", text="answer"), + ] + + selected_events, insert_index = find_events_for_summary(events, keep_recent_count=0) + + assert selected_events == events + assert insert_index == len(events) + + def test_no_visible_events(self): + event = _make_event(author="user", text="hidden") + event.set_model_visible(False) + + selected_events, insert_index = find_events_for_summary([event]) + + assert selected_events == [] + assert insert_index == -1 + + class TestKeyFunctions: """Test key generation functions.""" diff --git a/trpc_agent_sdk/agents/_llm_agent.py b/trpc_agent_sdk/agents/_llm_agent.py index cb8b84d..a4dff94 100644 --- a/trpc_agent_sdk/agents/_llm_agent.py +++ b/trpc_agent_sdk/agents/_llm_agent.py @@ -517,20 +517,31 @@ def accumulate_content(event: Event) -> None: if event.is_final_response(): self._save_output_to_state(ctx, event) - # Process code execution responses if code executor is configured + # Process code execution responses if code executor is configured. + # We collect code execution events first (this mutates event.content in place, + # stripping executable_code parts but keeping text/function_call), then yield + # the main event BEFORE the code execution events so the causal order in + # session is preserved: assistant declaration → code execution → result. + pending_code_events: list[Event] = [] if self.code_executor and event.content: async for code_event in CodeExecutionResponseProcessor.run_async(ctx, event): - # Check if this is a code execution result event if code_event.content and code_event.content.parts: for part in code_event.content.parts: if part.code_execution_result or part.executable_code: code_was_executed = True break - yield code_event - - # Yield LLM response events directly - yield event - accumulate_content(event) + pending_code_events.append(code_event) + + # Yield the main LLM response event first (now stripped of executable_code + # but still carrying text and function_call parts). + # Skip empty events (content became None after all parts were consumed). + if event.content is not None: + yield event + accumulate_content(event) + + # Then yield code execution events in order. + for code_event in pending_code_events: + yield code_event else: # Yield other events directly yield event @@ -628,7 +639,7 @@ def accumulate_content(event: Event) -> None: logger.debug("Tool execution completed, continuing conversation") continue - except RunCancelledException as ex: + except RunCancelledException: # raise to runner to handle raise @@ -652,7 +663,7 @@ def accumulate_content(event: Event) -> None: continue running = False - except RunCancelledException as ex: + except RunCancelledException: # raise to runner to handle raise except Exception as ex: # pylint: disable=broad-except diff --git a/trpc_agent_sdk/agents/core/_code_execution_processor.py b/trpc_agent_sdk/agents/core/_code_execution_processor.py index c643214..2287694 100644 --- a/trpc_agent_sdk/agents/core/_code_execution_processor.py +++ b/trpc_agent_sdk/agents/core/_code_execution_processor.py @@ -249,10 +249,6 @@ async def _run_post_processor( return code_executor_context = CodeExecutorContext(invocation_context.session.state) - if (code_executor.execute_once_per_invocation - and code_executor_context.has_executed_in_invocation(invocation_context.invocation_id)): - return - # Skip if the error count exceeds the max retry attempts. if code_executor_context.get_error_count(invocation_context.invocation_id) >= code_executor.error_retry_attempts: return @@ -285,7 +281,6 @@ async def _run_post_processor( code_blocks, code_execution_result, ) - code_executor_context.mark_executed_in_invocation(invocation_context.invocation_id) # Generate events for code execution results # Event 1: Code execution event diff --git a/trpc_agent_sdk/agents/core/_history_processor.py b/trpc_agent_sdk/agents/core/_history_processor.py index 3b45641..bd38ccd 100644 --- a/trpc_agent_sdk/agents/core/_history_processor.py +++ b/trpc_agent_sdk/agents/core/_history_processor.py @@ -175,6 +175,9 @@ def _should_include_event_by_timeline( if timeline_filter_mode == TimelineFilterMode.ALL: return True + if event.is_summary_event(): + return True + # INVOCATION mode: Filter by invocation_id (which represents a single runner.run_async() call) if timeline_filter_mode == TimelineFilterMode.INVOCATION: if ctx and event.invocation_id: diff --git a/trpc_agent_sdk/code_executors/_base_code_executor.py b/trpc_agent_sdk/code_executors/_base_code_executor.py index caa3a2c..887e36e 100644 --- a/trpc_agent_sdk/code_executors/_base_code_executor.py +++ b/trpc_agent_sdk/code_executors/_base_code_executor.py @@ -59,14 +59,6 @@ class BaseCodeExecutor(BaseModel): error_retry_attempts: int = 2 """The number of attempts to retry on consecutive code execution errors. Default to 2.""" - execute_once_per_invocation: bool = False - """Whether to execute model-extracted code at most once per invocation. - - When enabled, post-processing code execution runs only for the first - detected code block in a single ``invocation_id`` and skips subsequent - auto-execution attempts for that invocation. - """ - code_block_delimiters: list[CodeBlockDelimiter] = [ CodeBlockDelimiter(start="```tool_code\n", end="\n```"), CodeBlockDelimiter(start="```python\n", end="\n```"), diff --git a/trpc_agent_sdk/code_executors/_code_executor_context.py b/trpc_agent_sdk/code_executors/_code_executor_context.py index 11b6cca..1ceb6b8 100644 --- a/trpc_agent_sdk/code_executors/_code_executor_context.py +++ b/trpc_agent_sdk/code_executors/_code_executor_context.py @@ -39,7 +39,6 @@ def _ensure_code_execution_state(self) -> None: "execution_id": None, "error_counts": {}, "code_execution_results": {}, - "executed_invocations": {}, } def get_input_files(self) -> List[CodeFile]: @@ -139,14 +138,6 @@ def update_code_execution_result(self, invocation_id: str, code_blocks: List[Cod code_execution_result.model_dump(), }) - def has_executed_in_invocation(self, invocation_id: str) -> bool: - """Whether code has already been executed in a given invocation.""" - return bool(self.session_state["code_execution"]["executed_invocations"].get(invocation_id, False)) - - def mark_executed_in_invocation(self, invocation_id: str) -> None: - """Mark that code execution has happened in a given invocation.""" - self.session_state["code_execution"]["executed_invocations"][invocation_id] = True - def get_state_delta(self) -> Dict: """Get state delta for the current execution. diff --git a/trpc_agent_sdk/events/_event.py b/trpc_agent_sdk/events/_event.py index ed5537d..88f5b10 100644 --- a/trpc_agent_sdk/events/_event.py +++ b/trpc_agent_sdk/events/_event.py @@ -193,7 +193,12 @@ def is_summary_event(self) -> bool: return bool(self.model_flags & _EVENT_FLAG_SUMMARY) def set_model_visible(self, model_visible: bool) -> None: - """Set whether this event can be seen by model history builders.""" + """Set whether this event can be seen by model history builders. + + This is intended for session summarization/compression internals. User + code should not call it directly because it can break the model-visible + history window maintained by the session summarizer. + """ if model_visible: self.model_flags |= _EVENT_FLAG_MODEL_VISIBLE else: diff --git a/trpc_agent_sdk/sessions/__init__.py b/trpc_agent_sdk/sessions/__init__.py index f1a7acc..33f41d5 100644 --- a/trpc_agent_sdk/sessions/__init__.py +++ b/trpc_agent_sdk/sessions/__init__.py @@ -42,6 +42,8 @@ from ._utils import StateStorageEntry from ._utils import app_state_key from ._utils import extract_state_delta +from ._utils import find_events_for_summary +from ._utils import is_summary_anchor from ._utils import merge_state from ._utils import session_key from ._utils import user_state_key @@ -77,7 +79,9 @@ "StateStorageEntry", "app_state_key", "extract_state_delta", + "find_events_for_summary", "merge_state", + "is_summary_anchor", "session_key", "user_state_key", ] diff --git a/trpc_agent_sdk/sessions/_base_session_service.py b/trpc_agent_sdk/sessions/_base_session_service.py index cba2d8f..e962cda 100644 --- a/trpc_agent_sdk/sessions/_base_session_service.py +++ b/trpc_agent_sdk/sessions/_base_session_service.py @@ -24,8 +24,6 @@ """Base session service interface.""" from __future__ import annotations - -import time from typing import Optional from typing_extensions import override @@ -183,17 +181,10 @@ async def get_session_summary(self, session: Session) -> Optional[str]: def filter_events(self, session: Session) -> None: """Filter events based on the session config.""" - visible_events = [event for event in session.events if event.is_model_visible()] - if self._session_config.num_recent_events > 0: - if len(visible_events) > self._session_config.num_recent_events: - hide_count = len(visible_events) - self._session_config.num_recent_events - for event in visible_events[:hide_count]: - event.set_model_visible(False) - if self._session_config.event_ttl_seconds > 0: - cutoff_timestamp = time.time() - self._session_config.event_ttl_seconds - for event in visible_events: - if event.timestamp <= cutoff_timestamp: - event.set_model_visible(False) + session.apply_event_filtering( + event_ttl_seconds=self._session_config.event_ttl_seconds, + max_events=self._session_config.num_recent_events, + ) @override async def close(self) -> None: diff --git a/trpc_agent_sdk/sessions/_session.py b/trpc_agent_sdk/sessions/_session.py index d227309..73a057f 100644 --- a/trpc_agent_sdk/sessions/_session.py +++ b/trpc_agent_sdk/sessions/_session.py @@ -14,6 +14,8 @@ from trpc_agent_sdk.abc import SessionABC from trpc_agent_sdk.events import Event +from ._utils import is_summary_anchor + class Session(SessionABC): """Represents a series of interactions between a user and agents. @@ -85,7 +87,7 @@ def apply_event_filtering(self, event_ttl_seconds: float = 0.0, max_events: int retained_events = retained_events[-max_events:] for i, event in enumerate(retained_events): - if self._is_user_message(event): + if is_summary_anchor(event): retained_events = retained_events[i:] break else: @@ -94,7 +96,7 @@ def apply_event_filtering(self, event_ttl_seconds: float = 0.0, max_events: int # re-inserted, but only from the already-visible subset. retained_events = [] for event in reversed(visible_events): - if self._is_user_message(event): + if is_summary_anchor(event): retained_events.insert(0, event) break @@ -117,14 +119,3 @@ def insert_events(self, events: List[Event], idx: Optional[int] = None) -> None: if idx is None: idx = self.get_first_visible_event_idx() self.events[idx:idx] = events - - def _is_user_message(self, event: Event) -> bool: - """Check if an event is a user message. - - Args: - event: The event to check. - - Returns: - True if the event is from a user, False otherwise. - """ - return event.author.lower() == "user" diff --git a/trpc_agent_sdk/sessions/_session_summarizer.py b/trpc_agent_sdk/sessions/_session_summarizer.py index c1e7054..5f1dacf 100644 --- a/trpc_agent_sdk/sessions/_session_summarizer.py +++ b/trpc_agent_sdk/sessions/_session_summarizer.py @@ -50,6 +50,7 @@ from ._session import Session from ._summarizer_checker import CheckSummarizerFunction from ._summarizer_checker import set_summarizer_conversation_threshold +from ._utils import find_events_for_summary DEFAULT_SUMMARIZER_PROMPT = dedent("""\ Please summarize the following conversation, focusing on: @@ -124,12 +125,13 @@ class SessionSummarizer: """ def __init__( - self, - model: LLMModel, - summarizer_prompt: str = DEFAULT_SUMMARIZER_PROMPT, - check_summarizer_functions: Optional[List[CheckSummarizerFunction]] = None, - max_summary_length: int = 1000, - keep_recent_count: int = 10, + self, + model: LLMModel, + summarizer_prompt: str = DEFAULT_SUMMARIZER_PROMPT, + check_summarizer_functions: Optional[List[CheckSummarizerFunction]] = None, + max_summary_length: int = 1000, + keep_recent_count: int = 10, + start_by_user_turn: bool = True, # Whether to start summarization by user turn, default is True ): """Initialize the session summarizer. @@ -138,12 +140,13 @@ def __init__( check_summarizer_functions: List of check summarizer functions max_summary_length: Maximum length of generated summary keep_recent_count: Number of recent events to keep after compression + start_by_user_turn: Whether to start summarization by user turn, default is True """ self._summarizer_prompt = summarizer_prompt self.check_summarizer_functions = check_summarizer_functions or [set_summarizer_conversation_threshold()] self.max_summary_length = max_summary_length self.__keep_recent_count = keep_recent_count - + self.__start_by_user_turn = start_by_user_turn # Initialize LLM model for summarization self._model = model @@ -194,7 +197,7 @@ def _has_important_content(self, events: List[Event]) -> bool: async def _compress_session_to_summary(self, events: List[Event], session_id: str, - ctx: InvocationContext = None) -> Optional[str]: + ctx: InvocationContext | None = None) -> Optional[str]: """Generate a summary for a session. Args: @@ -247,8 +250,6 @@ def _extract_conversation_text(self, events: List[Event]) -> str: current_text = "" for event in events: - if not event.is_model_visible(): - continue if not event.content or not event.content.parts: continue @@ -312,7 +313,7 @@ def _extract_conversation_text(self, events: List[Event]) -> str: return "\n".join(conversation_parts) - async def _generate_summary(self, conversation_text: str, ctx: InvocationContext = None) -> str: + async def _generate_summary(self, conversation_text: str, ctx: InvocationContext | None = None) -> str: """Generate a summary using the LLM model. Args: @@ -361,8 +362,8 @@ def _create_summarization_prompt(self, conversation_text: str) -> str: async def create_session_summary_by_events(self, events: List[Event], session_id: str, - keep_recent_count: int | None = None, - ctx: InvocationContext = None) -> Optional[str]: + keep_recent_count: int = 10, + ctx: InvocationContext | None = None) -> Optional[str]: """Compress a session by summarizing old events. Args: @@ -375,13 +376,10 @@ async def create_session_summary_by_events(self, Summary text if successful, None otherwise Events after compression """ - if keep_recent_count is None: - old_events = events - else: - old_events = events[:-keep_recent_count] try: original_count = sum(1 for event in events if event.is_model_visible()) - old_visible_events = [event for event in old_events if event.is_model_visible()] + old_visible_events, insert_index = find_events_for_summary(events, keep_recent_count, + self.__start_by_user_turn) if not old_visible_events: return None, events @@ -394,17 +392,16 @@ async def create_session_summary_by_events(self, author="system", content=Content( parts=[Part.from_text(text=f"Previous conversation summary: {summary_text}")], - role="system"), + role="user"), timestamp=time.time()) summary_event.set_summary_event(True) summary_event.set_model_visible(True) - # Hide old visible events from model history without dropping raw data. - for event in old_visible_events: + # Hide all events before the summary insertion point without dropping raw data. + for event in events[:insert_index]: event.set_model_visible(False) - # Insert summary near the old/recent boundary while preserving all events. - insert_index = len(old_events) + # Insert summary before the recent complete conversation window. events.insert(insert_index, summary_event) compressed_count = sum(1 for event in events if event.is_model_visible()) @@ -416,7 +413,7 @@ async def create_session_summary_by_events(self, logger.error("Failed to compress session %s: %s", session_id, ex, exc_info=True) return None, events - async def create_session_summary(self, session: Session, ctx: InvocationContext = None) -> Optional[str]: + async def create_session_summary(self, session: Session, ctx: InvocationContext | None = None) -> Optional[str]: """Compress a session by summarizing old events. Args: diff --git a/trpc_agent_sdk/sessions/_utils.py b/trpc_agent_sdk/sessions/_utils.py index 0280a4d..6800262 100644 --- a/trpc_agent_sdk/sessions/_utils.py +++ b/trpc_agent_sdk/sessions/_utils.py @@ -11,6 +11,7 @@ from typing import Any from typing import Optional +from trpc_agent_sdk.events import Event from trpc_agent_sdk.types import State @@ -77,6 +78,58 @@ def merge_state(state_delta: StateStorageEntry, need_copy: bool = True) -> dict[ return merged_state +def is_summary_anchor(event: Event) -> bool: + """Return whether the event can anchor a summary window.""" + return (event.author and event.author.lower() == "user") or event.is_summary_event() + + +def find_events_for_summary(events: list[Event], + keep_recent_count: int = 10, + start_by_user_turn: bool = True) -> tuple[list[Event], int]: + """Find events that should be summarized. + + Args: + events: Source events to inspect. + keep_recent_count: Number of recent model-visible events to keep out of the summary window. + start_by_user_turn: Whether to align the summary window to user or summary events. + + Returns: + A tuple of selected events and the summary insertion index in the original events. + Returns ([], -1) when no model-visible events can be selected. + """ + visible_event_indices = [idx for idx, event in enumerate(events) if event.is_model_visible()] + if not visible_event_indices: + return [], -1 + + first_visible_index = visible_event_indices[0] + last_visible_index = visible_event_indices[-1] + start_index = first_visible_index + + if start_by_user_turn and not is_summary_anchor(events[start_index]): + for idx in range(first_visible_index - 1, -1, -1): + if is_summary_anchor(events[idx]): + start_index = idx + break + + window_end_index = last_visible_index + 1 + visible_events_count = len(visible_event_indices) + if keep_recent_count <= 0 or keep_recent_count >= visible_events_count: + insert_index = window_end_index + else: + insert_index = visible_event_indices[-keep_recent_count] + if start_by_user_turn: + for idx in range(insert_index, window_end_index): + if is_summary_anchor(events[idx]): + insert_index = idx + break + + selected_events = events[start_index:insert_index] + if not selected_events: + return [], -1 + + return selected_events, insert_index + + def session_key(app_name: str, user_id: str, session_id: str) -> str: """Generate a key for a session.