Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 196 additions & 27 deletions multi_llm_chatbot_backend/app/api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,32 @@ class UserInput(BaseModel):
user_input: str
chat_session_id: Optional[str] = None

ResponseMode = Literal["panel", "aggregated"]


class ChatMessage(BaseModel):
user_input: str
session_id: Optional[str] = None
chat_session_id: Optional[str] = None # MongoDB chat session ID
response_length: str = "medium"
active_advisors: Optional[List[str]] = None
response_mode: ResponseMode = "panel"

class PanelResult(BaseModel):
persona_id: str
persona_name: str
response: str
used_documents: bool = False
document_chunks_used: int = 0


class RequestAggregatedResponse(BaseModel):
user_input: str
panel_results: List[PanelResult] = Field(min_length=1)
chat_session_id: str
response_group_id: str
response_length: Literal["short", "medium", "long"] = "medium"


class ReplyToAdvisor(BaseModel):
user_input: str
Expand Down Expand Up @@ -99,7 +119,7 @@ class ChatStreamLine(BaseModel):
def to_ndjson(self) -> str:
return json.dumps(self.model_dump(mode="json"), ensure_ascii=False) + "\n"


# TODO: Refactor this function into smaller composable helpers so it's more readable and maintainable.
@router.post("/chat-stream")
async def chat_stream(
message: ChatMessage,
Expand Down Expand Up @@ -136,11 +156,16 @@ async def _event_generator():
session = session_manager.get_session(sid)

# Append user message to in-memory session and persist to MongoDB
response_group_id = str(ObjectId())
session.append_message("user", message.user_input)
if message.chat_session_id:
await persist_message(
message.chat_session_id,
PersistMessage(type="user", content=message.user_input),
PersistMessage(
type="user",
content=message.user_input,
response_group_id=response_group_id,
),
)
yield ChatStreamLine(
type="progress", data={"phase": "received"},
Expand Down Expand Up @@ -275,33 +300,129 @@ async def _run(pid: str) -> None:

tasks = [asyncio.create_task(_run(pid)) for pid in top_personas]

for _ in range(len(tasks)):
result = await done_queue.get()
if message.chat_session_id:
await persist_message(
message.chat_session_id,
PersistMessage(
type="advisor",
persona_id=result["persona_id"],
advisorName=result["persona_name"],
content=result["response"],
used_documents=result.get("used_documents", False),
document_chunks_used=result.get("document_chunks_used", 0),
),
)
line = ChatStreamLine(
type="advisor",
data={
"persona_id": result["persona_id"],
"persona_name": result["persona_name"],
"content": result["response"],
"used_documents": result.get("used_documents", False),
"document_chunks_used": result.get("document_chunks_used", 0),
},
if message.response_mode == "panel":
# ---- Panel mode: yield each advisor response as it arrives ----
for _ in range(len(tasks)):
result = await done_queue.get()
if message.chat_session_id:
await persist_message(
message.chat_session_id,
PersistMessage(
type="advisor",
persona_id=result["persona_id"],
advisorName=result["persona_name"],
content=result["response"],
used_documents=result.get("used_documents", False),
document_chunks_used=result.get("document_chunks_used", 0),
response_group_id=response_group_id,
),
)
yield ChatStreamLine(
type="advisor",
data={
"persona_id": result["persona_id"],
"persona_name": result["persona_name"],
"content": result["response"],
"used_documents": result.get("used_documents", False),
"document_chunks_used": result.get("document_chunks_used", 0),
"response_group_id": response_group_id,
},
).to_ndjson()

await asyncio.gather(*tasks, return_exceptions=True)

else:
# ---- Aggregated mode: collect all, synthesize, yield one ----
yield ChatStreamLine(
type="progress",
data={"phase": "generating"},
).to_ndjson()

panel_results = []
for _ in range(len(tasks)):
result = await done_queue.get()
panel_results.append(result)
if message.chat_session_id:
await persist_message(
message.chat_session_id,
PersistMessage(
type="advisor",
persona_id=result["persona_id"],
advisorName=result["persona_name"],
content=result["response"],
used_documents=result.get("used_documents", False),
document_chunks_used=result.get("document_chunks_used", 0),
response_group_id=response_group_id,
),
)

await asyncio.gather(*tasks, return_exceptions=True)

for result in panel_results:
yield ChatStreamLine(
type="advisor",
data={
"persona_id": result["persona_id"],
"persona_name": result["persona_name"],
"content": result["response"],
"used_documents": result.get("used_documents", False),
"document_chunks_used": result.get("document_chunks_used", 0),
"response_group_id": response_group_id,
},
).to_ndjson()

yield ChatStreamLine(
type="progress",
data={"phase": "synthesizing"},
).to_ndjson()

synth_result = await chat_orchestrator.synthesize_aggregated_response(
user_input=message.user_input,
panel_results=panel_results,
llm_client=orchestrator_llm,
response_length=message.response_length or "medium",
)
yield line.to_ndjson()

await asyncio.gather(*tasks, return_exceptions=True)
if synth_result:
if message.chat_session_id:
await persist_message(
message.chat_session_id,
PersistMessage(
type="advisor",
persona_id="aggregated",
advisorName=synth_result["persona_name"],
content=synth_result["response"],
is_aggregated=True,
source_personas=synth_result["source_personas"],
response_group_id=response_group_id,
),
)
yield ChatStreamLine(
type="advisor",
data={
"persona_id": "aggregated",
"persona_name": synth_result["persona_name"],
"content": synth_result["response"],
"is_aggregated": True,
"source_personas": synth_result["source_personas"],
"response_group_id": response_group_id,
},
).to_ndjson()
else:
# Synthesis failed — fall back to yielding panel responses
logger.warning("Aggregated synthesis failed, falling back to panel")
for result in panel_results:
yield ChatStreamLine(
type="advisor",
data={
"persona_id": result["persona_id"],
"persona_name": result["persona_name"],
"content": result["response"],
"used_documents": result.get("used_documents", False),
"document_chunks_used": result.get("document_chunks_used", 0),
"response_group_id": response_group_id,
},
).to_ndjson()

yield ChatStreamLine(
type="progress",
Expand All @@ -326,6 +447,54 @@ async def _run(pid: str) -> None:
)


@router.post("/request-aggregated-response")
async def request_aggregated_response(
request: RequestAggregatedResponse,
current_user: User = Depends(get_current_active_user),
):
"""On-demand synthesis of panel advisor responses into a single aggregated answer.

Called when a user toggles to the 'Generalized' view on a panel-mode
exchange that doesn't yet have an aggregated response.
"""
try:
llm_clients = resolve_llm_clients(current_user)
orchestrator_llm = llm_clients.get("orchestrator")

panel_dicts = [r.model_dump() for r in request.panel_results]

result = await chat_orchestrator.synthesize_aggregated_response(
user_input=request.user_input,
panel_results=panel_dicts,
llm_client=orchestrator_llm,
response_length=request.response_length,
)

if not result:
raise HTTPException(status_code=502, detail="Synthesis produced no usable response")

await persist_message(
request.chat_session_id,
PersistMessage(
type="advisor",
persona_id="aggregated",
advisorName=result["persona_name"],
content=result["response"],
is_aggregated=True,
source_personas=result["source_personas"],
response_group_id=request.response_group_id,
),
)

return result

except HTTPException:
raise
except Exception as e:
logger.error(f"Synthesis endpoint error: {e}")
raise HTTPException(status_code=500, detail="Synthesis failed")


@router.post("/switch-chat")
async def switch_to_chat(
request: SwitchChatRequest,
Expand Down
83 changes: 82 additions & 1 deletion multi_llm_chatbot_backend/app/core/improved_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, List, Optional, Any
from app.models.persona import Persona
from app.models.persona import Persona, COMPACT_MARKDOWN_V1, STRUCTURE_HINTS, _ensure_compact_shape
from app.core.session_manager import ConversationContext, get_session_manager
from app.core.context_manager import get_context_manager
from app.core.rag_manager import get_rag_manager
Expand Down Expand Up @@ -488,6 +488,87 @@ async def generate_single_persona_response(self, session, persona,
"context_quality": "error"
}

async def synthesize_aggregated_response(
self,
user_input: str,
panel_results: List[Dict[str, Any]],
llm_client: LLMClient = None,
response_length: str = "medium",
) -> Optional[Dict[str, Any]]:
"""Merge multiple panel advisor responses into a single unified answer.

Uses the orchestrator LLM (not a persona) to synthesize the strongest
points from each advisor into one cohesive response addressed to the
user. Returns ``None`` when synthesis produces nothing usable so the
caller can fall back to the panel responses.
"""
if not panel_results:
return None

token_limits = {"short": 800, "medium": 1500, "long": 2400}
max_tokens = token_limits.get(response_length, 700)

perspectives = "\n\n".join(
f"### {r['persona_name']} ({r['persona_id']})\n{r['response']}"
for r in panel_results
)

structure_hint = STRUCTURE_HINTS.get(response_length, STRUCTURE_HINTS["medium"])

system_prompt = (
"You are a synthesis assistant. You will receive multiple expert "
"advisor perspectives on a user's question. Your job is to merge "
"them into a single, cohesive answer that integrates the strongest "
"points from each.\n\n"
"Guidelines:\n"
"- Produce ONE unified answer addressed directly to the user.\n"
"- Do NOT list or label the individual perspectives.\n"
"- Resolve contradictions by noting the trade-off briefly.\n"
"- Keep the tone warm, clear, and actionable.\n\n"
f"{COMPACT_MARKDOWN_V1}\n\n"
f"{structure_hint}"
)

user_prompt = (
f"The user asked:\n\"{user_input}\"\n\n"
f"The following {len(panel_results)} advisors responded:\n\n"
f"{perspectives}\n\n"
"Synthesize these into a single best-answer response."
)

try:
effective_llm = llm_client or self.llm_client
raw = await effective_llm.generate(
system_prompt=system_prompt,
context=[{"role": "user", "content": user_prompt}],
temperature=0.4,
max_tokens=max_tokens,
)

stripped = raw.strip() if raw else ""
if not stripped:
logger.warning("Synthesis LLM returned empty response")
return None
content = _ensure_compact_shape(stripped, response_length)

return {
"persona_id": "aggregated",
"persona_name": "Orchestrator",
"response": content,
"is_aggregated": True,
"source_personas": [r["persona_id"] for r in panel_results],
"used_documents": any(r.get("used_documents") for r in panel_results),
"document_chunks_used": sum(
r.get("document_chunks_used", 0) for r in panel_results
),
"response_length": response_length,
"context_quality": "synthesized",
}

except Exception as e:
logger.error(f"Aggregated synthesis failed: {e}")
return None

async def _retrieve_relevant_documents(self, user_input: str, session_id: str, persona_id: str = "") -> str:
"""
Enhanced document retrieval with document awareness and better attribution
Expand Down
17 changes: 17 additions & 0 deletions multi_llm_chatbot_backend/app/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class PersistMessage(BaseModel):
isExpansion: bool = False
isExpandRequest: bool = False
replyTo: Optional[ReplyToRef] = None
# Response grouping — links a user message with its panel + aggregated responses
response_group_id: Optional[str] = None
is_aggregated: Optional[bool] = None
source_personas: Optional[List[str]] = None

@model_validator(mode='after')
def check_type_constraints(self):
Expand All @@ -140,6 +144,19 @@ def check_reply_metadata(self):
raise ValueError("replyTo is required when isReply is True")
return self

@model_validator(mode='after')
def check_aggregation_metadata(self):
if self.is_aggregated:
if self.type != 'advisor':
raise ValueError("is_aggregated can only be True for advisor messages")
if self.persona_id != 'aggregated':
raise ValueError("persona_id must be 'aggregated' when is_aggregated is True")
if not self.source_personas:
raise ValueError("source_personas is required when is_aggregated is True")
if self.source_personas and not self.is_aggregated:
raise ValueError("source_personas should only be set on aggregated messages")
return self


class ChatSession(BaseModel):
model_config = ConfigDict(
Expand Down
Loading
Loading