From 9f5a8bc112bf5453c1be668ea4ae262abc45af52 Mon Sep 17 00:00:00 2001 From: pixelsama Date: Sat, 11 Oct 2025 21:08:12 +0800 Subject: [PATCH 1/2] feat: backend-manage system prompt --- .../src/dialog_engine/chat_service.py | 17 ++++++++++--- .../src/dialog_engine/settings.py | 17 +++++++++++++ .../tests/unit/test_chat_service.py | 25 +++++++++++++++++++ 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/services/dialog-engine/src/dialog_engine/chat_service.py b/services/dialog-engine/src/dialog_engine/chat_service.py index cddaf2d..d12ba90 100644 --- a/services/dialog-engine/src/dialog_engine/chat_service.py +++ b/services/dialog-engine/src/dialog_engine/chat_service.py @@ -32,6 +32,11 @@ def __init__( self._memory_store = memory_store self._ltm_client = ltm_client self._state_store = state_store + prompt = getattr(self._settings, "prompts", None) + if prompt and getattr(prompt, "system_prompt", None): + self._system_prompt = prompt.system_prompt.strip() + else: + self._system_prompt = "" self.last_token_count: int = 0 self.last_ttft_ms: Optional[float] = None @@ -287,10 +292,16 @@ async def _compose_messages( context: List[MemoryTurn], ltm_snippets: List[str], ) -> list[Dict[str, str]]: - system_prompt = meta.get("system_prompt") messages: list[Dict[str, str]] = [] - if system_prompt: - messages.append({"role": "system", "content": str(system_prompt)}) + if "system_prompt" in meta and meta["system_prompt"] != self._system_prompt: + from logging import getLogger + + getLogger(__name__).info( + "chat.system_prompt.meta_override_ignored", + extra={"sessionId": meta.get("session_id"), "override_length": len(str(meta["system_prompt"]))}, + ) + if self._system_prompt: + messages.append({"role": "system", "content": self._system_prompt}) # Inject internal states as context if available if self._state_store: diff --git a/services/dialog-engine/src/dialog_engine/settings.py b/services/dialog-engine/src/dialog_engine/settings.py index 6505240..a0454a1 100644 --- a/services/dialog-engine/src/dialog_engine/settings.py +++ b/services/dialog-engine/src/dialog_engine/settings.py @@ -7,6 +7,11 @@ _BOOL_TRUTHY = {"1", "true", "yes", "on"} +DEFAULT_SYSTEM_PROMPT = ( + "你是一位友好、专业的虚拟主播助手,以亲切的语气与用户互动," + "善于引导对话并提供有趣、实用的信息。" +) + def _env_bool(name: str, default: bool = False) -> bool: value = os.getenv(name) @@ -72,6 +77,11 @@ class LTMInlineSettings: max_snippets: int +@dataclass(frozen=True) +class PromptSettings: + system_prompt: str + + @dataclass(frozen=True) class AsrSettings: enabled: bool @@ -92,6 +102,7 @@ class AsrSettings: class Settings: openai: OpenAISettings llm: LLMSettings + prompts: PromptSettings short_term: ShortTermMemorySettings ltm_inline: LTMInlineSettings asr: AsrSettings @@ -119,6 +130,10 @@ def load_settings() -> Settings: base_url=os.getenv("OPENAI_BASE_URL"), ) + prompt_settings = PromptSettings( + system_prompt=os.getenv("CHAT_SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT).strip(), + ) + short_term_settings = ShortTermMemorySettings( enabled=_env_bool("ENABLE_SHORT_TERM_MEMORY", True), db_path=os.getenv("STM_DB_PATH", "/app/data/dialog_memory.sqlite"), @@ -151,6 +166,7 @@ def load_settings() -> Settings: return Settings( openai=openai_settings, llm=llm_settings, + prompts=prompt_settings, short_term=short_term_settings, ltm_inline=ltm_inline_settings, asr=asr_settings, @@ -165,6 +181,7 @@ def load_settings() -> Settings: "OpenAISettings", "ShortTermMemorySettings", "LTMInlineSettings", + "PromptSettings", "AsrSettings", "settings", "load_settings", diff --git a/services/dialog-engine/tests/unit/test_chat_service.py b/services/dialog-engine/tests/unit/test_chat_service.py index 1400be5..89d3f0f 100644 --- a/services/dialog-engine/tests/unit/test_chat_service.py +++ b/services/dialog-engine/tests/unit/test_chat_service.py @@ -11,6 +11,7 @@ AsrSettings, LLMSettings, LTMInlineSettings, + PromptSettings, OpenAISettings, Settings, ShortTermMemorySettings, @@ -23,6 +24,7 @@ def _make_settings( stm_enabled: bool = False, ltm_enabled: bool = False, base_url: Optional[str] = None, + system_prompt: str = "你是一位后端定义的虚拟主播助手。", ) -> Settings: return Settings( openai=OpenAISettings(api_key=None, organization=None, base_url=None), @@ -38,6 +40,7 @@ def _make_settings( retry_limit=0, retry_backoff_seconds=0.0, ), + prompts=PromptSettings(system_prompt=system_prompt), short_term=ShortTermMemorySettings( enabled=stm_enabled, db_path=":memory:", @@ -240,6 +243,28 @@ async def test_stream_reply_llm_includes_ltm_snippets(): assert any("Relevant memories" in m["content"] for m in system_blocks) +@pytest.mark.asyncio +async def test_stream_reply_llm_uses_backend_system_prompt(): + backend_prompt = "后端维护的系统提示词。" + stub_llm = _StubLLMClient(["Hi"]) + service = ChatService( + settings=_make_settings(enabled=True, system_prompt=backend_prompt), + llm_client_factory=lambda: stub_llm, + ) + + async for _ in service.stream_reply( + "sess-system", + "hello", + meta={"system_prompt": "前端尝试覆盖"}, + ): + pass + + sent_messages = stub_llm.calls[0] + assert sent_messages[0]["role"] == "system" + assert sent_messages[0]["content"] == backend_prompt + assert all(msg.get("content") != "前端尝试覆盖" for msg in sent_messages if isinstance(msg, dict)) + + @pytest.mark.asyncio async def test_stream_reply_llm_logs_context_counts(caplog): stub_llm = _StubLLMClient(["Done"]) From c7da07bd464bb56016fbde8d80e43198452da5da Mon Sep 17 00:00:00 2001 From: pixelsama Date: Sat, 11 Oct 2025 21:50:27 +0800 Subject: [PATCH 2/2] feat: support llm tool call retries --- .../src/dialog_engine/chat_service.py | 114 ++++++++++++------ .../src/dialog_engine/llm_client.py | 30 ++++- .../tests/unit/test_chat_service.py | 79 +++++++++++- 3 files changed, 184 insertions(+), 39 deletions(-) diff --git a/services/dialog-engine/src/dialog_engine/chat_service.py b/services/dialog-engine/src/dialog_engine/chat_service.py index d12ba90..56c0ad6 100644 --- a/services/dialog-engine/src/dialog_engine/chat_service.py +++ b/services/dialog-engine/src/dialog_engine/chat_service.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json import random import time from collections.abc import AsyncGenerator, Callable @@ -65,32 +66,48 @@ async def stream_reply( meta=meta, ) self._log_context_info(len(context_turns), len(ltm_snippets)) - try: - async for delta in self._emit_with_metrics( - self._stream_llm( - session_id=session_id, - user_text=user_text, - meta=meta, - context=context_turns, - ltm_snippets=ltm_snippets, - ), - source="llm", - ): - yield delta - return - except LLMStreamEmptyError as exc: - self.last_error = "llm_empty_stream" - # Process tool calls if present - if exc.tool_calls and self._state_store: - await self._process_tool_calls(exc.tool_calls, session_id) - self._log_llm_fallback(reason=f"empty_stream:{exc.tool_calls}") - except LLMNotConfiguredError as exc: - self.last_error = "llm_not_configured" - self._log_llm_fallback(reason=str(exc)) - except Exception as exc: # pragma: no cover - defensive catch - self.last_error = exc.__class__.__name__ - self._log_llm_fallback(reason=repr(exc)) - + extra_messages: List[Dict[str, Any]] = [] + tool_retry = 0 + while True: + try: + async for delta in self._emit_with_metrics( + self._stream_llm( + session_id=session_id, + user_text=user_text, + meta=meta, + context=context_turns, + ltm_snippets=ltm_snippets, + extra_messages=extra_messages, + ), + source="llm", + ): + yield delta + return + except LLMStreamEmptyError as exc: + self.last_error = "llm_empty_stream" + handled_tool_calls = False + if exc.tool_calls and self._state_store: + tool_messages = await self._process_tool_calls(exc.tool_calls, session_id) + handled_tool_calls = bool(tool_messages) + if handled_tool_calls and tool_retry < 3: + tool_retry += 1 + extra_messages.extend( + [ + {"role": "assistant", "content": None, "tool_calls": exc.tool_calls}, + *tool_messages, + ] + ) + continue + self._log_llm_fallback(reason=f"empty_stream:{exc.tool_calls}") + break + except LLMNotConfiguredError as exc: + self.last_error = "llm_not_configured" + self._log_llm_fallback(reason=str(exc)) + break + except Exception as exc: # pragma: no cover - defensive catch + self.last_error = exc.__class__.__name__ + self._log_llm_fallback(reason=repr(exc)) + break async for delta in self._emit_with_metrics( self._stream_mock(user_text=user_text, meta=meta), source="mock", @@ -173,6 +190,7 @@ async def _stream_llm( meta: Dict[str, Any], context: List[MemoryTurn], ltm_snippets: List[str], + extra_messages: Optional[List[Dict[str, Any]]] = None, ) -> AsyncGenerator[str, None]: client = await self._ensure_llm_client() meta_with_session = dict(meta) @@ -182,6 +200,7 @@ async def _stream_llm( meta=meta_with_session, context=context, ltm_snippets=ltm_snippets, + extra_messages=extra_messages, ) extra_options: Dict[str, Any] = { "extra_headers": {"x-session-id": session_id}, @@ -291,8 +310,9 @@ async def _compose_messages( meta: Dict[str, Any], context: List[MemoryTurn], ltm_snippets: List[str], - ) -> list[Dict[str, str]]: - messages: list[Dict[str, str]] = [] + extra_messages: Optional[List[Dict[str, Any]]] = None, + ) -> list[Dict[str, Any]]: + messages: list[Dict[str, Any]] = [] if "system_prompt" in meta and meta["system_prompt"] != self._system_prompt: from logging import getLogger @@ -322,6 +342,8 @@ async def _compose_messages( if ltm_snippets: messages.append({"role": "system", "content": self._format_ltm_snippets(ltm_snippets)}) messages.append({"role": "user", "content": user_text}) + if extra_messages: + messages.extend(extra_messages) return messages def _estimate_tokens(self, chunk: str) -> int: @@ -333,21 +355,41 @@ def _reset_metrics(self) -> None: self.last_source = "mock" self.last_error = None - async def _process_tool_calls(self, tool_calls: List[Any], session_id: str) -> None: - """Process tool calls from LLM to update internal states.""" + async def _process_tool_calls(self, tool_calls: List[Any], session_id: str) -> List[Dict[str, Any]]: + """Process tool calls from LLM to update internal states and return tool response messages.""" if not self._state_store: - return + return [] + tool_messages: List[Dict[str, Any]] = [] for tool_call in tool_calls: try: - call_info = { - "name": getattr(tool_call, "function", {}).get("name"), - "arguments": getattr(tool_call, "function", {}).get("arguments", "{}") - } - await handle_tool_call(call_info, session_id, self._state_store) + if isinstance(tool_call, dict): + function = tool_call.get("function") or {} + name = function.get("name") + arguments = function.get("arguments", "{}") + tool_call_id = tool_call.get("id") + else: + function = getattr(tool_call, "function", {}) or {} + name = function.get("name") + arguments = function.get("arguments", "{}") + tool_call_id = getattr(tool_call, "id", None) + + call_info = {"name": name, "arguments": arguments} + result = await handle_tool_call(call_info, session_id, self._state_store) + message_payload = json.dumps(result or {"success": True}, ensure_ascii=False) + tool_messages.append( + { + "role": "tool", + "tool_call_id": tool_call_id or name or "tool_call", + "name": name or "tool", + "content": message_payload, + } + ) except Exception as exc: self._log_context_warning("tool_call.error", exc) + return tool_messages + async def get_internal_states(self, session_id: str) -> Dict[str, float]: """Get current internal states for a session.""" if not self._state_store: diff --git a/services/dialog-engine/src/dialog_engine/llm_client.py b/services/dialog-engine/src/dialog_engine/llm_client.py index 57a3c0c..dfb4521 100644 --- a/services/dialog-engine/src/dialog_engine/llm_client.py +++ b/services/dialog-engine/src/dialog_engine/llm_client.py @@ -104,7 +104,8 @@ async def stream_chat( }, ) has_content = False - collected_tool_calls: list[Any] = [] + collected_tool_calls: list[Dict[str, Any]] = [] + tool_call_acc: Dict[int, Dict[str, Any]] = {} async for chunk in stream: for choice in chunk.choices: delta = choice.delta @@ -116,7 +117,30 @@ async def stream_chat( yield text_delta tool_calls = getattr(delta, "tool_calls", None) if tool_calls: - collected_tool_calls.extend(tool_calls) + for tool_call in tool_calls: + index = getattr(tool_call, "index", 0) or 0 + entry = tool_call_acc.setdefault( + index, + { + "id": getattr(tool_call, "id", None), + "type": getattr(tool_call, "type", "function"), + "function": {"name": None, "arguments": ""}, + }, + ) + if getattr(tool_call, "id", None): + entry["id"] = tool_call.id + tool_type = getattr(tool_call, "type", None) + if tool_type: + entry["type"] = tool_type + func = getattr(tool_call, "function", None) + if func: + func_name = getattr(func, "name", None) + if func_name: + entry["function"]["name"] = func_name + func_args = getattr(func, "arguments", None) + if func_args: + entry["function"]["arguments"] += func_args + collected_tool_calls = list(tool_call_acc.values()) if not has_content: logger.warning( "llm.stream.empty", @@ -151,6 +175,8 @@ async def stream_chat( except Exception: # pragma: no cover - best effort cleanup logger.debug("llm.stream.close_failed", exc_info=True) + if isinstance(last_error, LLMStreamEmptyError): + raise last_error raise RuntimeError("LLM streaming failed after retries") from last_error async def generate_vision_reply( diff --git a/services/dialog-engine/tests/unit/test_chat_service.py b/services/dialog-engine/tests/unit/test_chat_service.py index 89d3f0f..5bccdf0 100644 --- a/services/dialog-engine/tests/unit/test_chat_service.py +++ b/services/dialog-engine/tests/unit/test_chat_service.py @@ -1,6 +1,6 @@ import asyncio import base64 -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Dict import pytest @@ -132,6 +132,55 @@ async def retrieve(self, *, session_id: str, user_text: str, meta, limit=None): return list(self.snippets) +class _ToolCallLLMClient: + def __init__(self, responses: Iterable[str]) -> None: + self._responses = list(responses) + self.calls: List[list[Dict[str, object]]] = [] + self.invocations = 0 + + async def stream_chat(self, messages, **kwargs): + self.calls.append(list(messages)) + self.invocations += 1 + if self.invocations == 1: + if False: + yield "" # pragma: no cover + raise LLMStreamEmptyError( + "tool", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": { + "name": "update_internal_state", + "arguments": '{"state_key":"emotion","value":80}', + }, + } + ], + ) + for token in self._responses: + await asyncio.sleep(0) + yield token + + async def generate_vision_reply(self, messages, **kwargs): + if False: + return "" # pragma: no cover + raise RuntimeError("not used") + + +class _StubStateStore: + def __init__(self): + self.states: Dict[str, Dict[str, float]] = {} + + async def update_state(self, session_id: str, state_key: str, new_value: float): + self.states.setdefault(session_id, {})[state_key] = new_value + + async def get_state(self, session_id: str, state_key: str): + return self.states.get(session_id, {}).get(state_key) + + async def list_states(self, session_id: str): + return self.states.get(session_id, {}) + + @pytest.mark.asyncio async def test_stream_reply_mock_path(): service = ChatService(settings=_make_settings(enabled=False)) @@ -243,6 +292,34 @@ async def test_stream_reply_llm_includes_ltm_snippets(): assert any("Relevant memories" in m["content"] for m in system_blocks) +@pytest.mark.asyncio +async def test_stream_reply_llm_handles_tool_call_then_continues(): + stub_llm = _ToolCallLLMClient(["情绪已经同步调整,感谢你的分享!"]) + state_store = _StubStateStore() + service = ChatService( + settings=_make_settings(enabled=True), + llm_client_factory=lambda: stub_llm, + state_store=state_store, + ) + + chunks: List[str] = [] + async for delta in service.stream_reply( + "sess-tool", + "我今天特别开心,你也调整一下情绪吧", + meta={}, + ): + chunks.append(delta) + + reply = "".join(chunks) + assert "情绪" in reply + assert state_store.states["sess-tool"]["emotion"] == 80 + assert stub_llm.invocations == 2 + assert any( + isinstance(msg, dict) and msg.get("tool_calls") + for msg in stub_llm.calls[1] + ) + + @pytest.mark.asyncio async def test_stream_reply_llm_uses_backend_system_prompt(): backend_prompt = "后端维护的系统提示词。"