Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 92 additions & 39 deletions services/dialog-engine/src/dialog_engine/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import json
import random
import time
from collections.abc import AsyncGenerator, Callable
Expand Down Expand Up @@ -32,6 +33,11 @@ def __init__(
self._memory_store = memory_store
self._ltm_client = ltm_client
self._state_store = state_store
prompt = getattr(self._settings, "prompts", None)
if prompt and getattr(prompt, "system_prompt", None):
self._system_prompt = prompt.system_prompt.strip()
else:
self._system_prompt = ""

self.last_token_count: int = 0
self.last_ttft_ms: Optional[float] = None
Expand Down Expand Up @@ -60,32 +66,48 @@ async def stream_reply(
meta=meta,
)
self._log_context_info(len(context_turns), len(ltm_snippets))
try:
async for delta in self._emit_with_metrics(
self._stream_llm(
session_id=session_id,
user_text=user_text,
meta=meta,
context=context_turns,
ltm_snippets=ltm_snippets,
),
source="llm",
):
yield delta
return
except LLMStreamEmptyError as exc:
self.last_error = "llm_empty_stream"
# Process tool calls if present
if exc.tool_calls and self._state_store:
await self._process_tool_calls(exc.tool_calls, session_id)
self._log_llm_fallback(reason=f"empty_stream:{exc.tool_calls}")
except LLMNotConfiguredError as exc:
self.last_error = "llm_not_configured"
self._log_llm_fallback(reason=str(exc))
except Exception as exc: # pragma: no cover - defensive catch
self.last_error = exc.__class__.__name__
self._log_llm_fallback(reason=repr(exc))

extra_messages: List[Dict[str, Any]] = []
tool_retry = 0
while True:
try:
async for delta in self._emit_with_metrics(
self._stream_llm(
session_id=session_id,
user_text=user_text,
meta=meta,
context=context_turns,
ltm_snippets=ltm_snippets,
extra_messages=extra_messages,
),
source="llm",
):
yield delta
return
except LLMStreamEmptyError as exc:
self.last_error = "llm_empty_stream"
handled_tool_calls = False
if exc.tool_calls and self._state_store:
tool_messages = await self._process_tool_calls(exc.tool_calls, session_id)
handled_tool_calls = bool(tool_messages)
if handled_tool_calls and tool_retry < 3:
tool_retry += 1
extra_messages.extend(
[
{"role": "assistant", "content": None, "tool_calls": exc.tool_calls},
*tool_messages,
]
)
continue
self._log_llm_fallback(reason=f"empty_stream:{exc.tool_calls}")
break
except LLMNotConfiguredError as exc:
self.last_error = "llm_not_configured"
self._log_llm_fallback(reason=str(exc))
break
except Exception as exc: # pragma: no cover - defensive catch
self.last_error = exc.__class__.__name__
self._log_llm_fallback(reason=repr(exc))
break
async for delta in self._emit_with_metrics(
self._stream_mock(user_text=user_text, meta=meta),
source="mock",
Expand Down Expand Up @@ -168,6 +190,7 @@ async def _stream_llm(
meta: Dict[str, Any],
context: List[MemoryTurn],
ltm_snippets: List[str],
extra_messages: Optional[List[Dict[str, Any]]] = None,
) -> AsyncGenerator[str, None]:
client = await self._ensure_llm_client()
meta_with_session = dict(meta)
Expand All @@ -177,6 +200,7 @@ async def _stream_llm(
meta=meta_with_session,
context=context,
ltm_snippets=ltm_snippets,
extra_messages=extra_messages,
)
extra_options: Dict[str, Any] = {
"extra_headers": {"x-session-id": session_id},
Expand Down Expand Up @@ -286,11 +310,18 @@ async def _compose_messages(
meta: Dict[str, Any],
context: List[MemoryTurn],
ltm_snippets: List[str],
) -> list[Dict[str, str]]:
system_prompt = meta.get("system_prompt")
messages: list[Dict[str, str]] = []
if system_prompt:
messages.append({"role": "system", "content": str(system_prompt)})
extra_messages: Optional[List[Dict[str, Any]]] = None,
) -> list[Dict[str, Any]]:
messages: list[Dict[str, Any]] = []
if "system_prompt" in meta and meta["system_prompt"] != self._system_prompt:
from logging import getLogger

getLogger(__name__).info(
"chat.system_prompt.meta_override_ignored",
extra={"sessionId": meta.get("session_id"), "override_length": len(str(meta["system_prompt"]))},
)
if self._system_prompt:
messages.append({"role": "system", "content": self._system_prompt})

# Inject internal states as context if available
if self._state_store:
Expand All @@ -311,6 +342,8 @@ async def _compose_messages(
if ltm_snippets:
messages.append({"role": "system", "content": self._format_ltm_snippets(ltm_snippets)})
messages.append({"role": "user", "content": user_text})
if extra_messages:
messages.extend(extra_messages)
return messages

def _estimate_tokens(self, chunk: str) -> int:
Expand All @@ -322,21 +355,41 @@ def _reset_metrics(self) -> None:
self.last_source = "mock"
self.last_error = None

async def _process_tool_calls(self, tool_calls: List[Any], session_id: str) -> None:
"""Process tool calls from LLM to update internal states."""
async def _process_tool_calls(self, tool_calls: List[Any], session_id: str) -> List[Dict[str, Any]]:
"""Process tool calls from LLM to update internal states and return tool response messages."""
if not self._state_store:
return
return []

tool_messages: List[Dict[str, Any]] = []
for tool_call in tool_calls:
try:
call_info = {
"name": getattr(tool_call, "function", {}).get("name"),
"arguments": getattr(tool_call, "function", {}).get("arguments", "{}")
}
await handle_tool_call(call_info, session_id, self._state_store)
if isinstance(tool_call, dict):
function = tool_call.get("function") or {}
name = function.get("name")
arguments = function.get("arguments", "{}")
tool_call_id = tool_call.get("id")
else:
function = getattr(tool_call, "function", {}) or {}
name = function.get("name")
arguments = function.get("arguments", "{}")
tool_call_id = getattr(tool_call, "id", None)

call_info = {"name": name, "arguments": arguments}
result = await handle_tool_call(call_info, session_id, self._state_store)
message_payload = json.dumps(result or {"success": True}, ensure_ascii=False)
tool_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id or name or "tool_call",
"name": name or "tool",
"content": message_payload,
}
)
except Exception as exc:
self._log_context_warning("tool_call.error", exc)

return tool_messages

async def get_internal_states(self, session_id: str) -> Dict[str, float]:
"""Get current internal states for a session."""
if not self._state_store:
Expand Down
30 changes: 28 additions & 2 deletions services/dialog-engine/src/dialog_engine/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ async def stream_chat(
},
)
has_content = False
collected_tool_calls: list[Any] = []
collected_tool_calls: list[Dict[str, Any]] = []
tool_call_acc: Dict[int, Dict[str, Any]] = {}
async for chunk in stream:
for choice in chunk.choices:
delta = choice.delta
Expand All @@ -116,7 +117,30 @@ async def stream_chat(
yield text_delta
tool_calls = getattr(delta, "tool_calls", None)
if tool_calls:
collected_tool_calls.extend(tool_calls)
for tool_call in tool_calls:
index = getattr(tool_call, "index", 0) or 0
entry = tool_call_acc.setdefault(
index,
{
"id": getattr(tool_call, "id", None),
"type": getattr(tool_call, "type", "function"),
"function": {"name": None, "arguments": ""},
},
)
if getattr(tool_call, "id", None):
entry["id"] = tool_call.id
tool_type = getattr(tool_call, "type", None)
if tool_type:
entry["type"] = tool_type
func = getattr(tool_call, "function", None)
if func:
func_name = getattr(func, "name", None)
if func_name:
entry["function"]["name"] = func_name
func_args = getattr(func, "arguments", None)
if func_args:
entry["function"]["arguments"] += func_args
collected_tool_calls = list(tool_call_acc.values())
if not has_content:
logger.warning(
"llm.stream.empty",
Expand Down Expand Up @@ -151,6 +175,8 @@ async def stream_chat(
except Exception: # pragma: no cover - best effort cleanup
logger.debug("llm.stream.close_failed", exc_info=True)

if isinstance(last_error, LLMStreamEmptyError):
raise last_error
raise RuntimeError("LLM streaming failed after retries") from last_error

async def generate_vision_reply(
Expand Down
17 changes: 17 additions & 0 deletions services/dialog-engine/src/dialog_engine/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

_BOOL_TRUTHY = {"1", "true", "yes", "on"}

DEFAULT_SYSTEM_PROMPT = (
"你是一位友好、专业的虚拟主播助手,以亲切的语气与用户互动,"
"善于引导对话并提供有趣、实用的信息。"
)
Comment thread
pixelsama marked this conversation as resolved.


def _env_bool(name: str, default: bool = False) -> bool:
value = os.getenv(name)
Expand Down Expand Up @@ -72,6 +77,11 @@ class LTMInlineSettings:
max_snippets: int


@dataclass(frozen=True)
class PromptSettings:
system_prompt: str


@dataclass(frozen=True)
class AsrSettings:
enabled: bool
Expand All @@ -92,6 +102,7 @@ class AsrSettings:
class Settings:
openai: OpenAISettings
llm: LLMSettings
prompts: PromptSettings
short_term: ShortTermMemorySettings
ltm_inline: LTMInlineSettings
asr: AsrSettings
Expand Down Expand Up @@ -119,6 +130,10 @@ def load_settings() -> Settings:
base_url=os.getenv("OPENAI_BASE_URL"),
)

prompt_settings = PromptSettings(
system_prompt=os.getenv("CHAT_SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT).strip(),
)

short_term_settings = ShortTermMemorySettings(
enabled=_env_bool("ENABLE_SHORT_TERM_MEMORY", True),
db_path=os.getenv("STM_DB_PATH", "/app/data/dialog_memory.sqlite"),
Expand Down Expand Up @@ -151,6 +166,7 @@ def load_settings() -> Settings:
return Settings(
openai=openai_settings,
llm=llm_settings,
prompts=prompt_settings,
short_term=short_term_settings,
ltm_inline=ltm_inline_settings,
asr=asr_settings,
Expand All @@ -165,6 +181,7 @@ def load_settings() -> Settings:
"OpenAISettings",
"ShortTermMemorySettings",
"LTMInlineSettings",
"PromptSettings",
"AsrSettings",
"settings",
"load_settings",
Expand Down
Loading
Loading