@@ -67,11 +117,12 @@ const AdvisorStatusDropdown = ({ advisors, thinkingAdvisors, getAdvisorColors, i
const IconComponent = advisor.icon;
const colors = getAdvisorColors(id, isDark);
const isThinking = Array.isArray(thinkingAdvisors) && thinkingAdvisors.includes(id);
-
+ const enabled = isAdvisorEnabled(id);
+
return (
{advisor.name}
-
{advisor.description}
-
-
- {isThinking ? (
-
- ) : (
-
Ready
- )}
+
+ {!enabled
+ ? Off — won't reply
+ : isThinking
+ ? Thinking…
+ : advisor.description}
+
+
handleAdvisorToggle(id, next)}
+ size="sm"
+ label={`Toggle ${advisor.name}`}
+ />
);
})}
@@ -240,6 +290,25 @@ const AdvisorStatusDropdown = ({ advisors, thinkingAdvisors, getAdvisorColors, i
.advisor-item.thinking {
background: var(--advisor-bg);
}
+
+ .advisor-item.disabled .advisor-icon,
+ .advisor-item.disabled .advisor-name {
+ opacity: 0.45;
+ }
+
+ .advisor-item.disabled .advisor-description {
+ opacity: 0.7;
+ }
+
+ .advisor-off-label {
+ color: var(--text-tertiary, #9ca3af);
+ font-style: italic;
+ }
+
+ .advisor-thinking-label {
+ color: var(--advisor-color);
+ font-weight: 500;
+ }
.advisor-icon {
width: 32px;
@@ -317,6 +386,72 @@ const AdvisorStatusDropdown = ({ advisors, thinkingAdvisors, getAdvisorColors, i
0%, 100% { opacity: 1; }
50% { opacity: 0.7; }
}
+
+ .disable-all-overlay {
+ position: fixed;
+ inset: 0;
+ background: rgba(0,0,0,0.5);
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ z-index: 1100;
+ }
+
+ .disable-all-modal {
+ background: var(--bg-primary);
+ border-radius: 16px;
+ width: 420px;
+ max-width: 95vw;
+ box-shadow: var(--shadow-xl, 0 24px 48px rgba(0,0,0,0.25));
+ display: flex;
+ flex-direction: column;
+ overflow: hidden;
+ }
+
+ .disable-all-header {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ padding: 16px 20px;
+ border-bottom: 1px solid var(--border-primary);
+ }
+
+ .disable-all-body {
+ padding: 20px;
+ font-size: 14px;
+ color: var(--text-primary);
+ line-height: 1.5;
+ }
+
+ .disable-all-actions {
+ display: flex;
+ justify-content: flex-end;
+ gap: 8px;
+ padding: 0 20px 20px;
+ }
+
+ .disable-all-secondary {
+ background: transparent;
+ border: 1px solid var(--border-primary);
+ color: var(--text-secondary);
+ font-size: 13px;
+ padding: 8px 14px;
+ border-radius: 8px;
+ cursor: pointer;
+ font-family: inherit;
+ }
+
+ .disable-all-danger {
+ background: #dc2626;
+ color: #fff;
+ border: none;
+ font-size: 13px;
+ padding: 8px 14px;
+ border-radius: 8px;
+ cursor: pointer;
+ font-family: inherit;
+ font-weight: 500;
+ }
/* Responsive Design */
@media (max-width: 768px) {
diff --git a/phd-advisor-frontend/src/components/SettingsModal.js b/phd-advisor-frontend/src/components/SettingsModal.js
index 5d418613..e05cafc3 100644
--- a/phd-advisor-frontend/src/components/SettingsModal.js
+++ b/phd-advisor-frontend/src/components/SettingsModal.js
@@ -1,6 +1,8 @@
-import React, { useState, useRef } from 'react';
+import React, { useState, useRef, useEffect } from 'react';
import ReactDOM from 'react-dom';
-import { X, User as UserIcon, Lock, Trash2, AlertTriangle } from 'lucide-react';
+import { X, User as UserIcon, Lock, Trash2, AlertTriangle, Users } from 'lucide-react';
+import Toggle from './Toggle';
+import { useAppConfig } from '../contexts/AppConfigContext';
const overlay = {
position: 'fixed', inset: 0, background: 'rgba(0,0,0,0.5)',
@@ -54,8 +56,33 @@ const dangerBtn = {
cursor: 'pointer', fontSize: 14, fontWeight: 500,
};
+const miniBtn = {
+ background: 'transparent',
+ border: '1px solid var(--border-primary)',
+ color: 'var(--text-secondary)',
+ fontSize: 12,
+ padding: '5px 10px',
+ borderRadius: 6,
+ cursor: 'pointer',
+ fontFamily: 'inherit',
+};
+
const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) => {
const [activeTab, setActiveTab] = useState('profile');
+ const {
+ advisors,
+ isAdvisorEnabled,
+ setAdvisorEnabled,
+ setAllAdvisorsEnabled,
+ hydrateAdvisorPreferences,
+ } = useAppConfig();
+
+ // Reconcile with the backend whenever the user opens Settings (covers fresh
+ // logins and changes made on another device).
+ useEffect(() => {
+ hydrateAdvisorPreferences();
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, []);
// Track where the mouse went DOWN so we don't close the modal when a user
// drags to select text inside an input and the mouseup happens outside the modal.
@@ -224,12 +251,43 @@ const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) =>
}`,
});
+ const advisorEntries = Object.entries(advisors || {});
+ const enabledCount = advisorEntries.filter(([id]) => isAdvisorEnabled(id)).length;
+ const setAll = (enabled) => setAllAdvisorsEnabled(enabled);
+
+ // pendingDisable: { type: 'all' } | { type: 'single', id } — set when the
+ // user is about to leave zero advisors enabled. Confirming runs the action;
+ // "Go back" leaves state untouched.
+ const [pendingDisable, setPendingDisable] = useState(null);
+
+ const handleDisableAllClick = () => {
+ if (enabledCount === 0) return;
+ setPendingDisable({ type: 'all' });
+ };
+
+ const handleAdvisorToggle = (id, next) => {
+ if (!next && enabledCount === 1 && isAdvisorEnabled(id)) {
+ setPendingDisable({ type: 'single', id });
+ return;
+ }
+ setAdvisorEnabled(id, next);
+ };
+
+ const confirmPendingDisable = () => {
+ if (pendingDisable?.type === 'all') {
+ setAll(false);
+ } else if (pendingDisable?.type === 'single') {
+ setAdvisorEnabled(pendingDisable.id, false);
+ }
+ setPendingDisable(null);
+ };
+
return ReactDOM.createPortal(
-
Account Settings
-
@@ -241,6 +299,9 @@ const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) =>
{ setActiveTab('password'); setMessage(null); }}>
Password
+
{ setActiveTab('advisors'); setMessage(null); }}>
+ Advisors
+
{ setActiveTab('danger'); setMessage(null); }}>
Delete Account
@@ -291,6 +352,77 @@ const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) =>
)}
+ {activeTab === 'advisors' && (
+ <>
+
+
+
+ Active advisors
+
+
+ {enabledCount} of {advisorEntries.length} active · turn an advisor off to keep them out of your conversations
+
+
+
+ setAll(true)} style={miniBtn}>Enable all
+ Disable all
+
+
+
+
+ {advisorEntries.length === 0 && (
+
+ No advisors configured.
+
+ )}
+ {advisorEntries.map(([id, advisor]) => {
+ const IconComponent = advisor.icon;
+ const enabled = isAdvisorEnabled(id);
+ return (
+
+
+ {advisor.avatarUrl
+ ?

+ :
}
+
+
+
+ {advisor.name}
+
+
+ {advisor.description || advisor.role || ''}
+
+
+
handleAdvisorToggle(id, next)}
+ label={`Toggle ${advisor.name}`}
+ />
+
+ );
+ })}
+
+ >
+ )}
+
{activeTab === 'danger' && (
,
document.body
diff --git a/phd-advisor-frontend/src/components/Toggle.js b/phd-advisor-frontend/src/components/Toggle.js
new file mode 100644
index 00000000..f63b3a9b
--- /dev/null
+++ b/phd-advisor-frontend/src/components/Toggle.js
@@ -0,0 +1,65 @@
+import React from 'react';
+
+/**
+ * Clean iOS-style toggle switch. Single reusable component so toggles
+ * across the app (advisor on/off, settings preferences) look identical.
+ *
+ * Props:
+ * checked — boolean
+ * onChange — (next: boolean) => void
+ * size — 'sm' (32×18) | 'md' (38×22, default)
+ * disabled — boolean
+ * label — optional aria-label for screen readers
+ */
+const Toggle = ({ checked, onChange, size = 'md', disabled = false, label }) => {
+ const dims = size === 'sm'
+ ? { w: 32, h: 18, knob: 14, off: 2, on: 16 }
+ : { w: 38, h: 22, knob: 18, off: 2, on: 18 };
+
+ const handleClick = (e) => {
+ e.stopPropagation();
+ if (!disabled) onChange(!checked);
+ };
+
+ return (
+
+
+
+ );
+};
+
+export default Toggle;
diff --git a/phd-advisor-frontend/src/contexts/AppConfigContext.js b/phd-advisor-frontend/src/contexts/AppConfigContext.js
index c68bfa31..d205e089 100644
--- a/phd-advisor-frontend/src/contexts/AppConfigContext.js
+++ b/phd-advisor-frontend/src/contexts/AppConfigContext.js
@@ -3,6 +3,23 @@ import * as LucideIcons from 'lucide-react';
const AppConfigContext = createContext(null);
+const ADVISOR_PREFS_URL = `${process.env.REACT_APP_API_URL}/api/me/advisor-preferences`;
+
+// The frontend tracks disabled advisors as an object keyed by id
+// ({ critic: true }) for fast lookups; the backend speaks a flat string[].
+// These two helpers translate between the shapes. A null/undefined array
+// from the backend means "no preferences set" → nothing disabled.
+const disabledObjToArray = (obj) =>
+ Object.keys(obj || {}).filter((id) => obj[id]);
+const disabledArrayToObj = (arr) =>
+ Array.isArray(arr)
+ ? arr.reduce((acc, id) => { acc[id] = true; return acc; }, {})
+ : {};
+
+const getAuthToken = () => {
+ try { return localStorage.getItem('authToken'); } catch { return null; }
+};
+
/**
* Resolve a Lucide icon name string (e.g. "BookOpen") to the actual React
* component. Falls back to HelpCircle if the name isn't found.
@@ -89,6 +106,14 @@ export const AppConfigProvider = ({ children }) => {
try { return JSON.parse(localStorage.getItem('myCustomAvatars') || '[]'); }
catch { return []; }
});
+ // Per-user enable/disable for each advisor. Missing key = enabled by default
+ // so new advisors light up automatically when added on the backend.
+ const [disabledAdvisors, setDisabledAdvisors] = useState(() => {
+ try { return JSON.parse(localStorage.getItem('disabledAdvisors') || '{}'); }
+ catch { return {}; }
+ });
+ // Advisor ids the backend considers selectable (system-level allow list).
+ const [availableAdvisors, setAvailableAdvisors] = useState([]);
useEffect(() => {
const fetchConfig = async () => {
@@ -125,6 +150,97 @@ export const AppConfigProvider = ({ children }) => {
localStorage.setItem('myCustomAvatars', JSON.stringify(next));
};
+ // Advisor enable/disable. Disabled advisors are filtered out of orchestrator
+ // calls (server-side, per user) and visually dimmed in the UI.
+ const isAdvisorEnabled = (id) => !disabledAdvisors[id];
+
+ // Apply a disabled map locally + cache it. localStorage keeps the last known
+ // state so the UI is correct instantly on reload before the backend answers.
+ const applyDisabled = (obj) => {
+ setDisabledAdvisors(obj);
+ try { localStorage.setItem('disabledAdvisors', JSON.stringify(obj)); }
+ catch { /* storage full / unavailable — non-fatal */ }
+ };
+
+ // Reconcile local state with whatever the backend returns (it is the source
+ // of truth; it also distinguishes "no prefs / null" from an explicit list).
+ const applyServerResponse = (data) => {
+ applyDisabled(disabledArrayToObj(data?.disabled_advisors));
+ if (Array.isArray(data?.available_advisors)) {
+ setAvailableAdvisors(data.available_advisors);
+ }
+ };
+
+ // Pull the authenticated user's preferences from the backend.
+ const hydrateAdvisorPreferences = async () => {
+ const token = getAuthToken();
+ if (!token) return;
+ try {
+ const res = await fetch(ADVISOR_PREFS_URL, {
+ headers: { Authorization: `Bearer ${token}` },
+ });
+ if (!res.ok) {
+ console.error('Failed to load advisor preferences:', res.status);
+ return;
+ }
+ applyServerResponse(await res.json());
+ } catch (err) {
+ // Offline / network error — keep the cached localStorage state.
+ console.error('Failed to load advisor preferences:', err);
+ }
+ };
+
+ // Persist the full disabled set to the backend. We send the whole array
+ // (not a delta) so the PUT is idempotent and the server stays authoritative.
+ const persistAdvisorPreferences = async (obj) => {
+ const token = getAuthToken();
+ if (!token) return;
+ try {
+ const res = await fetch(ADVISOR_PREFS_URL, {
+ method: 'PUT',
+ headers: {
+ Authorization: `Bearer ${token}`,
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({ disabled_advisors: disabledObjToArray(obj) }),
+ });
+ if (!res.ok) {
+ console.error('Failed to save advisor preferences:', res.status);
+ return;
+ }
+ applyServerResponse(await res.json());
+ } catch (err) {
+ // Optimistic local state is already applied; surface the failure only.
+ console.error('Failed to save advisor preferences:', err);
+ }
+ };
+
+ const setAdvisorEnabled = (id, enabled) => {
+ const next = { ...disabledAdvisors };
+ if (enabled) delete next[id];
+ else next[id] = true;
+ applyDisabled(next); // optimistic
+ persistAdvisorPreferences(next); // sync (reconciles on response)
+ };
+
+ // Bulk enable/disable in one shot — a single state update and one PUT,
+ // instead of N racing requests when toggling every advisor.
+ const setAllAdvisorsEnabled = (enabled) => {
+ const next = enabled
+ ? {}
+ : Object.keys(advisors || {}).reduce(
+ (acc, id) => { acc[id] = true; return acc; }, {});
+ applyDisabled(next);
+ persistAdvisorPreferences(next);
+ };
+
+ // Load preferences once on mount when a session token is already present
+ // (returning user). Fresh logins reconcile when the Settings modal opens.
+ useEffect(() => {
+ hydrateAdvisorPreferences();
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, []);
+
// Inject the primary colour as a CSS custom property on so it is
// available everywhere without prop-drilling.
useEffect(() => {
@@ -156,6 +272,12 @@ export const AppConfigProvider = ({ children }) => {
setAdvisorAvatar,
addMyAvatar,
myCustomAvatars,
+ disabledAdvisors,
+ availableAdvisors,
+ isAdvisorEnabled,
+ setAdvisorEnabled,
+ setAllAdvisorsEnabled,
+ hydrateAdvisorPreferences,
};
if (loading) {
diff --git a/phd_config.yaml b/phd_config.yaml
index 4adffb26..92df8162 100644
--- a/phd_config.yaml
+++ b/phd_config.yaml
@@ -115,6 +115,10 @@ personas:
# Individual persona files are loaded from this directory (relative to this file).
personas_dir: "personas/phd_advisors"
+ # Optional whitelist of advisor IDs. When set, only these advisors are
+ # available. Omit or leave unset to allow all enabled personas.
+ allowed_advisors:
+
# ── Orchestrator / Clarification ───────────────────────────────────────────
orchestrator:
From 60b85919dc3f8654685bab81378515926e47cf4a Mon Sep 17 00:00:00 2001
From: NeonDaniel
Date: Thu, 28 May 2026 23:17:14 +0000
Subject: [PATCH 08/31] Increment Version to 2.0.1a4
---
multi_llm_chatbot_backend/app/version.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/multi_llm_chatbot_backend/app/version.py b/multi_llm_chatbot_backend/app/version.py
index 11716805..ad015959 100644
--- a/multi_llm_chatbot_backend/app/version.py
+++ b/multi_llm_chatbot_backend/app/version.py
@@ -1,4 +1,4 @@
-__version__ = "2.0.1a3"
+__version__ = "2.0.1a4"
if __name__ == "__main__":
print(__version__)
From 7988be0ae371bd39794f67e03c1b16ddb2466cc0 Mon Sep 17 00:00:00 2001
From: NeonDaniel
Date: Thu, 28 May 2026 23:17:39 +0000
Subject: [PATCH 09/31] Update Changelog
---
CHANGELOG.md | 13 +++++++++++++
1 file changed, 13 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index fe766171..705186c8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,18 @@
# Changelog
+## [2.0.1a4](https://github.com/NeonGeckoCom/CCAI-Demo/tree/2.0.1a4) (2026-05-28)
+
+[Full Changelog](https://github.com/NeonGeckoCom/CCAI-Demo/compare/2.0.1a3...2.0.1a4)
+
+**Implemented enhancements:**
+
+- \[FEAT\] Disable Specific Personas [\#64](https://github.com/NeonGeckoCom/CCAI-Demo/issues/64)
+- \[FEAT\] User Tutorial [\#29](https://github.com/NeonGeckoCom/CCAI-Demo/issues/29)
+
+**Merged pull requests:**
+
+- Disable Specific User Personas [\#65](https://github.com/NeonGeckoCom/CCAI-Demo/pull/65) ([NeonRyan](https://github.com/NeonRyan))
+
## [2.0.1a3](https://github.com/NeonGeckoCom/CCAI-Demo/tree/2.0.1a3) (2026-05-21)
[Full Changelog](https://github.com/NeonGeckoCom/CCAI-Demo/compare/2.0.1a2...2.0.1a3)
From a7d41c7646f2fe8e762a25be92086d5eea541a08 Mon Sep 17 00:00:00 2001
From: NeonCharlie-24
Date: Fri, 29 May 2026 13:47:07 -0700
Subject: [PATCH 10/31] deprecated old_routes.py (#77)
* deprecated old_routes.py
* fixed the race condition in ChatPage.js which resulted in null chat_session_id being passed to backend.
---
.../app/api/old_routes.py | 1001 -----------------
.../app/api/routes/debug.py | 8 +-
phd-advisor-frontend/src/pages/ChatPage.js | 2 +-
3 files changed, 5 insertions(+), 1006 deletions(-)
delete mode 100644 multi_llm_chatbot_backend/app/api/old_routes.py
diff --git a/multi_llm_chatbot_backend/app/api/old_routes.py b/multi_llm_chatbot_backend/app/api/old_routes.py
deleted file mode 100644
index 001905b1..00000000
--- a/multi_llm_chatbot_backend/app/api/old_routes.py
+++ /dev/null
@@ -1,1001 +0,0 @@
-import os
-from fastapi import APIRouter, Body, HTTPException, Header, UploadFile, File, Request
-from fastapi import Query
-from typing import Optional, List
-import httpx
-from app.llm.llm_client import LLMClient
-from app.llm.improved_gemini_client import ImprovedGeminiClient
-from app.llm.improved_ollama_client import ImprovedOllamaClient
-from app.models.persona import Persona
-from app.core.improved_orchestrator import ImprovedChatOrchestrator
-from app.core.session_manager import get_session_manager
-from app.core.rag_manager import get_rag_manager
-from app.models.default_personas import get_default_personas
-from app.utils.document_extractor import extract_text_from_file
-from app.utils.file_limits import is_within_upload_limit
-from pydantic import BaseModel
-
-from fastapi.responses import StreamingResponse
-from fastapi import Query
-from app.utils.file_export import export_chat_as_file
-
-from app.utils.chat_summary import generate_summary_from_messages, parse_summary_to_blocks
-from app.utils.file_export import prepare_export_response, generate_pdf_file_from_blocks
-from app.version import __version__
-
-import hashlib
-import logging
-
-logger = logging.getLogger(__name__)
-
-router = APIRouter()
-
-# Provider management (same as before)
-current_provider = "gemini"
-available_providers = ["ollama", "gemini"]
-
-def create_llm_client(provider: str = None) -> LLMClient:
- """Create LLM client based on provider"""
- if provider is None:
- provider = current_provider
-
- if provider == "gemini":
- try:
- return ImprovedGeminiClient(model_name=os.getenv("GEMINI_MODEL"))
- except ValueError as e:
- logger.warning(f"Gemini API key not found, falling back to Ollama: {e}")
- return ImprovedOllamaClient(model_name="llama3.2:1b")
- elif provider == "ollama":
- return ImprovedOllamaClient(model_name="llama3.2:1b")
- else:
- raise ValueError(f"Unknown provider: {provider}")
-
-# Initialize with default provider
-llm = create_llm_client()
-chat_orchestrator = ImprovedChatOrchestrator()
-session_manager = get_session_manager()
-
-# Initialize personas
-DEFAULT_PERSONAS = get_default_personas(llm)
-for persona in DEFAULT_PERSONAS:
- chat_orchestrator.register_persona(persona)
-
-# Keep all the same data models as before
-class UserInput(BaseModel):
- user_input: str
-
-class PersonaInput(BaseModel):
- id: str
- name: str
- system_prompt: str
-
-class ChatMessage(BaseModel):
- user_input: str
- session_id: Optional[str] = None
- response_length: Optional[str] = "medium"
-
-class ReplyToAdvisor(BaseModel):
- user_input: str
- advisor_id: str
- original_message_id: Optional[str] = None
-
-class ProviderSwitch(BaseModel):
- provider: str
-
-# ==============================================================
-# SESSION MANAGEMENT COMPATIBILITY LAYER
-# ==============================================================
-
-def get_or_create_session_for_request(request: Request,
- session_id_override: Optional[str] = None) -> str:
- """
- Get or create session for request using multiple strategies:
- 1. Use provided session_id if given
- 2. Use X-Session-ID header if present
- 3. Use client IP as fallback for backward compatibility
- 4. Create new session if nothing available
-
- This allows the old stateless API to work with session management
- """
- # Strategy 1: Explicit session ID (for new clients)
- if session_id_override:
- return session_id_override
-
- # Strategy 2: Check for session header (optional for frontend)
- session_header = request.headers.get("X-Session-ID")
- if session_header:
- return session_header
-
- # Strategy 3: Use client IP for backward compatibility
- # This gives each client IP their own persistent session
- client_ip = request.client.host if request.client else "unknown"
- ip_session_id = f"ip_{client_ip}"
-
- # Get or create session for this IP
- session = session_manager.get_session(ip_session_id)
- return session.session_id
-
-
-
-# Helper functions (same as before)
-def _is_valid_response(response: str, persona_id: str) -> bool:
- """Validate response quality"""
- if len(response) < 2 or len(response) > 5000:
- return False
-
- confusion_indicators = [
- f"Thank you, Dr. {persona_id.title()}",
- "Assistant:",
- f"Dr. {persona_id.title()}",
- "Assistant:",
- f"Dr. {persona_id.title()} Advisor:",
- "excellent discussion, Assistant"
- ]
-
- return not any(indicator in response for indicator in confusion_indicators)
-
-def _get_persona_fallback(persona_id: str) -> str:
- """Get persona-specific fallback responses"""
- fallbacks = {
- "methodologist": "Focus on ensuring your methodology aligns with your research question. What specific method are you considering?",
- "theorist": "Consider the theoretical framework underlying your approach. What assumptions guide your thinking?",
- "pragmatist": "Let's break this down into actionable steps. What's the most important thing you need to decide today?"
- }
- return fallbacks.get(persona_id, "I'd be happy to help. Could you provide more details?")
-
-# Provider management endpoints (EXACTLY THE SAME)
-@router.get("/current-provider")
-async def get_current_provider():
- return {
- "current_provider": current_provider,
- "available_providers": available_providers,
- "model_info": {
- "name": llm.model_name if hasattr(llm, 'model_name') else "gemini-2.0-flash",
- "provider": current_provider
- }
- }
-
-@router.post("/switch-provider")
-async def switch_provider(provider_data: ProviderSwitch):
- global current_provider, llm
-
- if provider_data.provider not in available_providers:
- raise HTTPException(
- status_code=400,
- detail=f"Unknown provider: {provider_data.provider}. Available: {available_providers}"
- )
-
- try:
- current_provider = provider_data.provider
- new_llm = create_llm_client(current_provider)
- llm = new_llm
-
- new_personas = get_default_personas(new_llm)
- chat_orchestrator.personas.clear()
- for persona in new_personas:
- chat_orchestrator.register_persona(persona)
-
- return {
- "message": f"Successfully switched to {current_provider}",
- "current_provider": current_provider,
- "model_info": {
- "name": new_llm.model_name if hasattr(new_llm, 'model_name') else "gemini-2.0-flash",
- "provider": current_provider
- }
- }
-
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"Failed to switch to {provider_data.provider}: {str(e)}"
- )
-
-# Main chat endpoint
-@router.post("/chat-sequential")
-async def chat_sequential_enhanced(message: ChatMessage, request: Request):
- """
- Enhanced sequential chat with intelligent persona ordering.
- Returns responses in the order determined by LLM-based relevance ranking.
- """
- try:
- # Get or create session
- session_id = get_or_create_session_for_request(request, message.session_id)
-
- # Add user message to session first (needed for persona ranking)
- session = session_manager.get_session(session_id)
- session.append_message("user", message.user_input)
-
- # Get intelligently ordered personas based on context
- top_personas = await chat_orchestrator.get_top_personas(
- session_id=session_id,
- k=3 # Get top 3 most relevant personas
- )
-
- logger.info(f"Intelligent persona order for session {session_id}: {top_personas}")
-
- # Generate responses from personas in the intelligent order
- responses = []
-
- for persona_id in top_personas:
- try:
- # Generate response from this persona
- persona_result = await chat_orchestrator.chat_with_persona(
- user_input=message.user_input,
- persona_id=persona_id,
- session_id=session_id,
- response_length=message.response_length or "medium"
- )
-
-
- if "persona_name" in persona_result and "response" in persona_result:
- responses.append({
- "persona": persona_result["persona_name"],
- "persona_id": persona_result["persona_id"],
- "response": persona_result["response"]
- })
- elif persona_result.get("type") == "single_persona_response" and "persona" in persona_result:
- persona_data = persona_result["persona"]
- responses.append({
- "persona": persona_data["persona_name"],
- "persona_id": persona_data["persona_id"],
- "response": persona_data["response"]
- })
- else:
- # Fallback response
- responses.append({
- "persona": chat_orchestrator.personas[persona_id].name,
- "persona_id": persona_id,
- "response": "I'm having trouble processing your question right now. Please try again."
- })
-
- except Exception as e:
- logger.error(f"Error generating response for persona {persona_id}: {str(e)}")
- # Error fallback
- responses.append({
- "persona": chat_orchestrator.personas[persona_id].name,
- "persona_id": persona_id,
- "response": "I encountered an error while processing your question. Please try again."
- })
-
- # response format
- return {
- "type": "sequential_responses",
- "responses": responses
- }
-
- except Exception as e:
- logger.error(f"Error in enhanced sequential chat: {str(e)}")
- return {
- "type": "error",
- "responses": [{
- "persona": "System",
- "response": "I'm having trouble processing your request. Could you please try again?"
- }]
- }
-
-@router.post("/chat/{persona_id}")
-async def chat_with_specific_advisor(persona_id: str, input: UserInput, request: Request):
- """Chat with a specific advisor - SAME INTERFACE"""
- try:
- if persona_id not in chat_orchestrator.personas:
- raise HTTPException(status_code=404, detail=f"Persona '{persona_id}' not found")
-
- # Get session using compatibility layer
- session_id = get_or_create_session_for_request(request)
-
- # Use new orchestrator
- result = await chat_orchestrator.chat_with_persona(
- user_input=input.user_input,
- persona_id=persona_id,
- session_id=session_id
- )
-
- # FIX: Handle the actual response structure from orchestrator
- if result.get("type") == "single_persona_response" and "persona" in result:
- # New expected structure
- persona_data = result["persona"]
- return {
- "persona": persona_data["persona_name"],
- "persona_id": persona_data["persona_id"],
- "response": persona_data["response"]
- }
- elif "persona_id" in result and "response" in result:
- # Current actual structure from orchestrator
- return {
- "persona": result["persona_name"],
- "persona_id": result["persona_id"],
- "response": result["response"]
- }
- elif result.get("type") == "error" or "error" in result:
- # Error handling
- return {
- "persona": "System",
- "response": result.get("error", "I'm having trouble generating a response right now. Please try again.")
- }
- else:
- # Fallback
- return {
- "persona": "System",
- "response": "I'm having trouble generating a response right now. Please try again."
- }
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"Error in chat_with_specific_advisor: {e}")
- return {
- "persona": "System",
- "response": "I'm having trouble generating a response right now. Please try again."
- }
-
-# Reply to advisor endpoint (SAME INTERFACE)
-@router.post("/reply-to-advisor")
-async def reply_to_advisor(reply: ReplyToAdvisor, request: Request):
- """Reply to a specific advisor - SAME INTERFACE"""
- try:
- if reply.advisor_id not in chat_orchestrator.personas:
- raise HTTPException(status_code=404, detail=f"Advisor '{reply.advisor_id}' not found")
-
- # Get session using compatibility layer
- session_id = get_or_create_session_for_request(request)
-
- # Use new orchestrator
- result = await chat_orchestrator.chat_with_persona(
- user_input=reply.user_input,
- persona_id=reply.advisor_id,
- session_id=session_id
- )
-
- if result["type"] == "single_persona_response":
- persona_data = result["persona"]
- return {
- "type": "advisor_reply",
- "persona": persona_data["persona_name"],
- "persona_id": persona_data["persona_id"],
- "response": persona_data["response"],
- "original_message_id": reply.original_message_id
- }
- else:
- return {
- "type": "error",
- "persona": "System",
- "response": result.get("message", "I'm having trouble generating a reply right now. Please try again.")
- }
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"Error in reply_to_advisor: {e}")
- return {
- "type": "error",
- "persona": "System",
- "response": "I'm having trouble generating a reply right now. Please try again."
- }
-
-@router.post("/upload-document")
-async def upload_document(file: UploadFile = File(...), request: Request = None):
- """
- Enhanced document upload with better metadata tracking and user feedback
- """
- try:
- # Get or create session
- session_id = get_or_create_session_for_request(request)
- session = session_manager.get_session(session_id)
-
- # Validate file
- MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
- if file.size and file.size > MAX_FILE_SIZE:
- raise HTTPException(status_code=413, detail="File size exceeds 10MB limit")
-
-
- # Read and validate file content
- file_bytes = await file.read()
- content = extract_text_from_file(file_bytes, file.content_type)
- if not content.strip():
- raise HTTPException(status_code=400, detail="Document is empty or unreadable.")
-
- # Get enhanced RAG manager
- rag_manager = get_rag_manager()
-
- # Determine file type for metadata
- file_type_map = {
- "application/pdf": "pdf",
- "application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx",
- "text/plain": "txt"
- }
- file_type = file_type_map.get(file.content_type, "unknown")
-
- # Add document to enhanced vector database
- rag_result = rag_manager.add_document(
- content=content,
- filename=file.filename,
- session_id=session_id,
- file_type=file_type
- )
-
- if not rag_result["success"]:
- raise HTTPException(
- status_code=500,
- detail=f"Failed to process document: {rag_result.get('error', 'Unknown error')}"
- )
-
- # Update session tracking
- session.uploaded_files.append(file.filename)
- session.total_upload_size += len(file_bytes)
-
- # Add enhanced document reference to session messages
- doc_metadata = rag_result.get("document_metadata", {})
- doc_title = doc_metadata.get("title", file.filename)
-
- session.append_message(
- "system",
- f"Document uploaded: '{doc_title}' ({file.filename}) - "
- f"{rag_result['chunks_created']} sections processed, "
- f"~{rag_result['total_tokens']} tokens analyzed. "
- f"You can now ask questions about this document by referencing it by name."
- )
-
- return {
- "message": f"Document '{file.filename}' uploaded and processed successfully.",
- "filename": file.filename,
- "document_title": doc_title,
- "chunks_created": rag_result['chunks_created'],
- "total_tokens": rag_result['total_tokens'],
- "file_type": file_type,
- "can_reference_by_name": True,
- "suggestions": [
- f"Try asking: 'What methodology does my {file.filename} propose?'",
- f"Or: 'What are the key findings in {doc_title}?'",
- f"Or: 'Compare the approach in my document with current best practices'"
- ]
- }
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"Error processing document upload: {str(e)}")
- raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
-
-
-@router.get("/export-chat")
-async def export_chat(request: Request, format: str = Query(..., regex="^(txt|pdf|docx)$")):
- """
- Export the current chat context in the requested format.
- """
- try:
- session_id = get_or_create_session_for_request(request)
- session = session_manager.get_session(session_id)
-
- if not session.messages:
- return {"error": "No messages in this session."}
-
- return prepare_export_response(session.messages, format)
-
- except Exception as e:
- logger.error(f"Error exporting chat: {str(e)}")
- return {"error": "Failed to export chat.", "detail": str(e)}
-
-
-@router.get("/chat-summary")
-async def chat_summary(
- request: Request,
- format: str = Query("text", regex="^(txt|pdf|docx)$")
-):
- """
- Generate and return a summary of the current session chat.
- Can return as plain txt, PDF, or DOCX.
- """
- try:
- session_id = get_or_create_session_for_request(request)
- session = session_manager.get_session(session_id)
-
- if not session.messages:
- return {"error": "No messages in this session."}
-
- llm = next(iter(chat_orchestrator.personas.values())).llm
- summary_text = await generate_summary_from_messages(session.messages, llm)
-
- if format == "txt":
- return prepare_export_response(summary_text, "txt", filename_prefix="chat_summary")
-
- elif format == "docx":
- return prepare_export_response(summary_text, "docx", filename_prefix="chat_summary")
-
- elif format == "pdf":
- # Parse and render using block formatting
- blocks = [{"type": "heading", "text": "Chat Summary"}] + parse_summary_to_blocks(summary_text)
-
- file_stream = generate_pdf_file_from_blocks(blocks)
- return StreamingResponse(
- file_stream,
- media_type="application/pdf",
- headers={"Content-Disposition": "attachment; filename=chat_summary.pdf"}
- )
-
- except Exception as e:
- logger.error(f"Error in chat-summary endpoint: {str(e)}")
- return {"error": "Summary generation failed", "detail": str(e)}
-
-
-
-
-# Add new endpoint to get document statistics
-@router.get("/document-stats")
-async def get_document_stats(request: Request):
- """Get statistics about uploaded documents in vector database"""
- try:
- session_id = get_or_create_session_for_request(request)
- rag_manager = get_rag_manager()
-
- stats = rag_manager.get_document_stats(session_id)
- return stats
-
- except Exception as e:
- logger.error(f"Error getting document stats: {str(e)}")
- return {"total_chunks": 0, "total_documents": 0, "documents": []}
-
-# Get uploaded files (SAME INTERFACE)
-@router.get("/uploaded-files")
-async def get_uploaded_filenames(request: Request):
- """Get uploaded files - SAME INTERFACE"""
- try:
- session_id = get_or_create_session_for_request(request)
- session = session_manager.get_session(session_id)
- return {"files": session.uploaded_files}
- except Exception as e:
- logger.error(f"Error getting uploaded files: {str(e)}")
- return {"files": []}
-
-# Context endpoint (SAME INTERFACE)
-@router.get("/context")
-async def get_context(request: Request):
- """Get context - ENHANCED with RAG information"""
- try:
- session_id = get_or_create_session_for_request(request)
- session = session_manager.get_session(session_id)
-
- # Get RAG statistics
- rag_stats = session.get_rag_stats()
-
- return {
- "messages": session.messages,
- "rag_info": {
- "total_documents": rag_stats.get("total_documents", 0),
- "total_chunks": rag_stats.get("total_chunks", 0),
- "documents": rag_stats.get("documents", [])
- }
- }
- except Exception as e:
- logger.error(f"Error getting context: {str(e)}")
- return {"messages": [], "rag_info": {"total_documents": 0, "total_chunks": 0}}
-
-@router.post("/reset-session")
-async def reset_session(request: Request):
- """Reset session - ENHANCED with RAG cleanup"""
- try:
- session_id = get_or_create_session_for_request(request)
-
- # Use the enhanced reset that clears both conversation and vector DB
- success = session_manager.reset_session_completely(session_id)
-
- if success:
- return {"status": "reset", "message": "Session and all documents reset successfully"}
- else:
- return {"status": "error", "message": "Failed to reset session"}
- except Exception as e:
- logger.error(f"Error resetting session: {e}")
- return {"status": "error", "message": "Failed to reset session"}
-
-
-# Legacy model endpoints (SAME INTERFACE)
-@router.post("/switch-model")
-async def switch_model(model_name: str = Body(...)):
- """Legacy model switching - SAME INTERFACE"""
- if "gemini" in model_name.lower():
- return await switch_provider(ProviderSwitch(provider="gemini"))
- else:
- return await switch_provider(ProviderSwitch(provider="ollama"))
-
-@router.get("/current-model")
-async def get_current_model():
- """Legacy model info - SAME INTERFACE"""
- model_name = llm.model_name if hasattr(llm, 'model_name') else "gemini-2.0-flash"
- return {
- "model": model_name,
- "provider": current_provider
- }
-
-@router.post("/search-documents")
-async def search_documents(request: Request, query: str = Body(..., embed=True), persona: str = Body("", embed=True)):
- """
- Search uploaded documents using RAG
-
- This endpoint allows direct document search for debugging/testing
- """
- try:
- session_id = get_or_create_session_for_request(request)
- rag_manager = get_rag_manager()
-
- # Get persona context for search enhancement
- persona_contexts = {
- "methodologist": "methodology research design analysis",
- "theorist": "theory theoretical framework conceptual",
- "pragmatist": "practical application implementation"
- }
- persona_context = persona_contexts.get(persona, "")
-
- # Search documents
- results = rag_manager.search_documents(
- query=query,
- session_id=session_id,
- persona_context=persona_context,
- n_results=5
- )
-
- return {
- "query": query,
- "persona_filter": persona,
- "results_count": len(results),
- "results": results
- }
-
- except Exception as e:
- logger.error(f"Error searching documents: {str(e)}")
- return {"query": query, "results_count": 0, "results": [], "error": str(e)}
-
-@router.get("/session-stats")
-async def get_session_stats(request: Request):
- """Get comprehensive session statistics including RAG data"""
- try:
- session_id = get_or_create_session_for_request(request)
- stats = session_manager.get_session_stats(session_id)
- return stats
- except Exception as e:
- logger.error(f"Error getting session stats: {str(e)}")
- return {"error": str(e)}
-
-
-@router.get("/debug/personas")
-async def debug_personas(request: Request):
- """Debug personas - ENHANCED with RAG information"""
- try:
- session_id = get_or_create_session_for_request(request)
- session = session_manager.get_session(session_id)
-
- # Get RAG statistics
- rag_manager = get_rag_manager()
- rag_stats = rag_manager.get_document_stats(session_id)
-
- return {
- "personas": {
- pid: {
- "name": persona.name,
- "prompt": persona.system_prompt[:100] + "...",
- "retrieval_keywords": chat_orchestrator._get_persona_context_keywords(pid)
- } for pid, persona in chat_orchestrator.personas.items()
- },
- "session_info": {
- "context_length": len(session.messages),
- "uploaded_files": session.uploaded_files,
- "rag_stats": rag_stats
- },
- "current_provider": current_provider,
- "rag_enabled": True
- }
- except Exception as e:
- logger.error(f"Error in debug endpoint: {str(e)}")
- return {
- "personas": {},
- "session_info": {"context_length": 0},
- "current_provider": current_provider,
- "rag_enabled": False,
- "error": str(e)
- }
-
-@router.get("/debug/ranked-personas")
-async def get_ranked_personas(request: Request, k: int = Query(3, ge=1, le=10)):
- """
- Debug endpoint: Get top-k ranked personas based on current session context.
- Uses LLM to rank based on latest conversation messages.
- """
- try:
- session_id = get_or_create_session_for_request(request)
-
- # Call the ranking method
- top_personas = await chat_orchestrator.get_top_personas(session_id=session_id, k=k)
-
- # Include some metadata for debug purposes
- return {
- "ranked_personas": top_personas,
- "available_personas": list(chat_orchestrator.personas.keys()),
- "session_id": session_id
- }
- except Exception as e:
- logger.error(f"Error in /debug/ranked-personas: {e}")
- return {
- "ranked_personas": [],
- "error": str(e)
- }
-
-
-@router.post("/chat/{persona_id}")
-async def chat_with_specific_persona(persona_id: str, message: ChatMessage, request: Request):
- """
- Chat with a specific persona - Enhanced with RAG debugging
-
- This endpoint helps debug RAG integration by testing individual personas
- """
- try:
- session_id = get_or_create_session_for_request(request, message.session_id)
-
- # Validate persona exists
- if persona_id not in chat_orchestrator.personas:
- available_personas = list(chat_orchestrator.personas.keys())
- raise HTTPException(
- status_code=400,
- detail=f"Persona '{persona_id}' not found. Available: {available_personas}"
- )
-
- # Use the enhanced orchestrator method
- result = await chat_orchestrator.chat_with_persona(
- user_input=message.user_input,
- persona_id=persona_id,
- session_id=session_id,
- response_length=message.response_length or "medium"
- )
-
- # Fix: Handle the response structure properly
- if result.get("type") == "single_persona_response" and "persona" in result:
- persona_data = result["persona"]
-
- # Add debugging information
- result["debug_info"] = {
- "persona_id": persona_id,
- "session_id": session_id,
- "query_length": len(message.user_input),
- "rag_manager_available": True,
- "used_documents": persona_data.get("used_documents", False),
- "chunks_used": persona_data.get("document_chunks_used", 0)
- }
-
- return result
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"Error in individual persona chat: {str(e)}")
- return {
- "type": "error",
- "message": f"Error chatting with {persona_id}: {str(e)}",
- "persona_id": persona_id
- }
-
-@router.get("/debug/enhanced-personas")
-async def debug_enhanced_personas(request: Request):
- """
- Enhanced debug endpoint with document context information
- """
- try:
- session_id = get_or_create_session_for_request(request)
- session = session_manager.get_session(session_id)
-
- # Get enhanced RAG statistics
- rag_manager = get_rag_manager()
- rag_stats = rag_manager.get_document_stats(session_id)
-
- # Analyze document awareness capabilities
- document_analysis = {}
- if rag_stats.get("documents"):
- for doc in rag_stats["documents"]:
- document_analysis[doc["filename"]] = {
- "chunks_available": doc["chunks"],
- "estimated_tokens": doc["estimated_tokens"],
- "sections_identified": doc["sections"],
- "content_types_detected": {
- "has_methodology": doc.get("has_methodology", False),
- "has_theory": doc.get("has_theory", False),
- "has_references": doc.get("has_references", False)
- }
- }
-
- return {
- "personas": {
- pid: {
- "name": persona.name,
- "expertise_area": persona.name.split(" - ")[1] if " - " in persona.name else "General",
- "prompt_quality": "enhanced" if len(persona.system_prompt) > 500 else "basic",
- "document_handling_enabled": "document awareness" in persona.system_prompt.lower(),
- "retrieval_keywords": chat_orchestrator._get_enhanced_persona_context_keywords(pid)[:100] + "...",
- "temperature": getattr(persona, 'temperature', 5)
- } for pid, persona in chat_orchestrator.personas.items()
- },
- "session_info": {
- "context_length": len(session.messages),
- "uploaded_files": session.uploaded_files,
- "rag_stats": rag_stats,
- "document_analysis": document_analysis
- },
- "system_capabilities": {
- "document_name_recognition": True,
- "cross_document_analysis": True,
- "persona_specialized_retrieval": True,
- "enhanced_attribution": True,
- "query_document_detection": True
- },
- "current_provider": current_provider,
- "rag_enabled": True,
- "enhancement_level": "advanced"
- }
- except Exception as e:
- logger.error(f"Error in enhanced debug endpoint: {str(e)}")
- return {
- "error": str(e),
- "enhancement_level": "error",
- "rag_enabled": False
- }
-
-@router.get("/document-insights/{filename}")
-async def get_document_insights(filename: str, request: Request):
- """
- NEW ENDPOINT: Get insights about a specific uploaded document
- """
- try:
- session_id = get_or_create_session_for_request(request)
- rag_manager = get_rag_manager()
-
- # Get document statistics
- stats = rag_manager.get_document_stats(session_id)
-
- # Find the specific document
- document_info = None
- for doc in stats.get("documents", []):
- if doc["filename"] == filename:
- document_info = doc
- break
-
- if not document_info:
- raise HTTPException(status_code=404, detail=f"Document {filename} not found")
-
- # Get a sample of content from this document
- results = rag_manager.collection.get(
- where={"session_id": session_id, "filename": filename},
- limit=3,
- include=["documents", "metadatas"]
- )
-
- sample_sections = []
- if results["documents"]:
- for doc, metadata in zip(results["documents"], results["metadatas"]):
- sample_sections.append({
- "section": metadata.get("document_section", "unknown"),
- "content_preview": doc[:200] + "..." if len(doc) > 200 else doc,
- "keywords": metadata.get("keywords", "")
- })
-
- return {
- "filename": filename,
- "document_title": document_info.get("title", filename),
- "file_type": document_info.get("file_type", "unknown"),
- "statistics": {
- "total_chunks": document_info["chunks"],
- "estimated_tokens": document_info["estimated_tokens"],
- "sections_identified": document_info["sections"]
- },
- "content_analysis": {
- "has_methodology": document_info.get("has_methodology", False),
- "has_theory": document_info.get("has_theory", False),
- "has_references": document_info.get("has_references", False)
- },
- "sample_sections": sample_sections,
- "suggested_queries": [
- f"What methodology does my {filename} propose?",
- f"What are the key theoretical concepts in {filename}?",
- f"What are the main findings in my {document_info.get('title', filename)}?",
- f"How can I improve the approach described in {filename}?"
- ]
- }
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"Error getting document insights: {str(e)}")
- raise HTTPException(status_code=500, detail=f"Error analyzing document: {str(e)}")
-
-# Also add a debug endpoint to check RAG status:
-
-@router.get("/debug/rag-status")
-async def debug_rag_status(request: Request):
- """
- Debug endpoint to check RAG system status
- """
- try:
- session_id = get_or_create_session_for_request(request)
-
- # Get RAG manager
- rag_manager = get_rag_manager()
-
- # Get session stats
- session_stats = session_manager.get_session_stats(session_id)
-
- # Test a simple search
- test_search = rag_manager.search_documents(
- query="test methodology research",
- session_id=session_id,
- persona_context="",
- n_results=3
- )
-
- return {
- "rag_manager_healthy": True,
- "session_id": session_id,
- "session_stats": session_stats.get("rag_stats", {}),
- "test_search_results": len(test_search),
- "test_search_details": [
- {
- "relevance": chunk.get("relevance_score", 0),
- "distance": chunk.get("distance", "unknown"),
- "text_length": len(chunk.get("text", "")),
- "filename": chunk.get("metadata", {}).get("filename", "unknown")
- }
- for chunk in test_search[:3]
- ],
- "persona_keywords": {
- pid: chat_orchestrator._get_persona_context_keywords(pid)
- for pid in chat_orchestrator.personas.keys()
- }
- }
-
- except Exception as e:
- logger.error(f"Error in RAG debug: {str(e)}")
- return {
- "rag_manager_healthy": False,
- "error": str(e),
- "session_id": session_id if 'session_id' in locals() else "unknown"
- }
-
-# Ask endpoint (SAME INTERFACE)
-class PersonaQuery(BaseModel):
- question: str
- persona: str
-
-@router.post("/ask/")
-async def ask_question(query: PersonaQuery, request: Request):
- """Ask question - SAME INTERFACE"""
- try:
- session_id = get_or_create_session_for_request(request)
-
- # Use the new orchestrator
- result = await chat_orchestrator.chat_with_persona(
- user_input=query.question,
- persona_id=query.persona,
- session_id=session_id
- )
-
- if result["type"] == "single_persona_response":
- response_text = result["persona"]["response"]
- else:
- response_text = result.get("message", "I'm having trouble responding right now.")
-
- return {"response": response_text}
-
- except Exception as e:
- logger.error(f"Error in ask endpoint: {str(e)}")
- return {"response": "I encountered an error. Please try again."}
-
-
-
-# Root endpoint (SAME INTERFACE)
-@router.get("/")
-def root():
- """Root endpoint - SAME INTERFACE with updated info"""
- return {
- "message": "Multi-LLM PhD Advisor Backend is up and running",
- "version": __version__,
- "features": [
- "Improved Session Management",
- "Unified Context Handling",
- "Ollama Support",
- "Gemini API Support",
- "Provider Switching"
- ]
- }
\ No newline at end of file
diff --git a/multi_llm_chatbot_backend/app/api/routes/debug.py b/multi_llm_chatbot_backend/app/api/routes/debug.py
index 844c321f..47c3dec5 100644
--- a/multi_llm_chatbot_backend/app/api/routes/debug.py
+++ b/multi_llm_chatbot_backend/app/api/routes/debug.py
@@ -4,7 +4,7 @@
from app.core.bootstrap import chat_orchestrator
import logging
-from app.api.old_routes import get_or_create_session_for_request
+from app.api.utils import get_or_create_session_for_request_async
logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@
@router.get("/debug/personas")
async def debug_personas(request: Request):
try:
- session_id = get_or_create_session_for_request(request)
+ session_id = await get_or_create_session_for_request_async(request)
session = session_manager.get_session(session_id)
rag_manager = get_rag_manager()
rag_stats = rag_manager.get_document_stats(session_id)
@@ -45,7 +45,7 @@ async def debug_personas(request: Request):
@router.get("/debug/ranked-personas")
async def get_ranked_personas(request: Request, k: int = Query(3, ge=1, le=10)):
try:
- session_id = get_or_create_session_for_request(request)
+ session_id = await get_or_create_session_for_request_async(request)
top_personas = await chat_orchestrator.get_top_personas(session_id=session_id, k=k)
return {
"ranked_personas": top_personas,
@@ -62,7 +62,7 @@ async def get_ranked_personas(request: Request, k: int = Query(3, ge=1, le=10)):
@router.get("/debug/rag-status")
async def debug_rag_status(request: Request):
try:
- session_id = get_or_create_session_for_request(request)
+ session_id = await get_or_create_session_for_request_async(request)
rag_manager = get_rag_manager()
session_stats = session_manager.get_session_stats(session_id)
diff --git a/phd-advisor-frontend/src/pages/ChatPage.js b/phd-advisor-frontend/src/pages/ChatPage.js
index 68251109..063a454b 100644
--- a/phd-advisor-frontend/src/pages/ChatPage.js
+++ b/phd-advisor-frontend/src/pages/ChatPage.js
@@ -417,7 +417,7 @@ const handleNewChat = async (sessionId = null) => {
body: JSON.stringify({
user_input: inputMessage,
response_length: 'medium',
- chat_session_id: currentSessionId // Include current session ID
+ chat_session_id: sessionId
}),
});
From 724cb3c7e4158f9e8567e65a25db9cf95271dcfb Mon Sep 17 00:00:00 2001
From: NeonDaniel
Date: Fri, 29 May 2026 20:47:26 +0000
Subject: [PATCH 11/31] Increment Version to 2.0.1a5
---
multi_llm_chatbot_backend/app/version.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/multi_llm_chatbot_backend/app/version.py b/multi_llm_chatbot_backend/app/version.py
index ad015959..0dabcc30 100644
--- a/multi_llm_chatbot_backend/app/version.py
+++ b/multi_llm_chatbot_backend/app/version.py
@@ -1,4 +1,4 @@
-__version__ = "2.0.1a4"
+__version__ = "2.0.1a5"
if __name__ == "__main__":
print(__version__)
From c38d5eff270862632f408dc56b3f4bf90ed42742 Mon Sep 17 00:00:00 2001
From: NeonDaniel
Date: Fri, 29 May 2026 20:47:52 +0000
Subject: [PATCH 12/31] Update Changelog
---
CHANGELOG.md | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 705186c8..ea07a5ca 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,17 @@
# Changelog
+## [2.0.1a5](https://github.com/NeonGeckoCom/CCAI-Demo/tree/2.0.1a5) (2026-05-29)
+
+[Full Changelog](https://github.com/NeonGeckoCom/CCAI-Demo/compare/2.0.1a4...2.0.1a5)
+
+**Implemented enhancements:**
+
+- \[FEAT\] Evaluate old\_routes.py for deprecation [\#74](https://github.com/NeonGeckoCom/CCAI-Demo/issues/74)
+
+**Merged pull requests:**
+
+- deprecated old\_routes.py [\#77](https://github.com/NeonGeckoCom/CCAI-Demo/pull/77) ([NeonCharlie-24](https://github.com/NeonCharlie-24))
+
## [2.0.1a4](https://github.com/NeonGeckoCom/CCAI-Demo/tree/2.0.1a4) (2026-05-28)
[Full Changelog](https://github.com/NeonGeckoCom/CCAI-Demo/compare/2.0.1a3...2.0.1a4)
From a94166fdb3e92e02ab67b47f2e41722c820574c8 Mon Sep 17 00:00:00 2001
From: NeonCharlie-24
Date: Fri, 29 May 2026 16:32:19 -0700
Subject: [PATCH 13/31] BrainForge LLM Integration (#62)
* added BrainForge configuration schema to LLMConfig.
* added brainforge auth manager with login, refresh and health check.
* added brainforge LLM client using the openai compatible endpoint.
* added brainforge persona sync to register brainforge advisors.
* added brainforge persona sync to merge brainforge personas with the hardcoded yaml config personas on app startup.
* added periodic brainforge persona sync to check for new/stale personas at the brainforge endpoint. checking every 10 minutes.
* made fixes to the periodic sync function downstream effects and cleaned up some unused code.
* fixed orchestrator attempting to use brainforge llm for persona selection and added guard against stale personas causing a crash in chat stream.
* added unit test for brainforge auth, client and persona sync.
* removed BrainForgeConfig api_url envvar fallback and updated sync_interval name to include seconds label.
* Updated brainforge prompt to be bare minimum length which produces best results so far for Neon AI assistant model.
* passed structured_output json as an extra_body parameter and got improved model responses.
* added a response maxLength the brainforge JSON schema to allow responses to be passed though _ensure_compact_shape without truncation.
* doubled max_tokens for brainforge to prevent JSON truncation.
* updated brainforge tests to factor in double max_tokens in payload.
* refactor PERSONA_ID_PREFIX variable name to be more self documenting.
* refactor SKIP_PERSONA_NAMES to be private.
* use orchestrator LLM fo needs_clarification_improved instead of persona LLM so BrainForge is never called to handle clarification requests.
* added admin script to list all avaliable advisors (including BrainForge) so config can easily be updated.
* filter BrainForge advisors through whitelist before showing them on frontend.
---
docker-compose.yml | 2 +
.../app/api/routes/chat.py | 16 +-
.../app/api/routes/provider.py | 6 +-
multi_llm_chatbot_backend/app/config.py | 8 +
.../app/core/brainforge_sync.py | 170 ++++++++++++
.../app/core/context_manager.py | 2 +-
.../app/core/improved_orchestrator.py | 25 +-
.../app/llm/brainforge_auth.py | 106 ++++++++
.../app/llm/improved_brainforge_client.py | 184 +++++++++++++
multi_llm_chatbot_backend/app/main.py | 45 +++-
.../app/tests/unit/test_brainforge_auth.py | 180 +++++++++++++
.../app/tests/unit/test_brainforge_client.py | 248 +++++++++++++++++
.../app/tests/unit/test_brainforge_sync.py | 253 ++++++++++++++++++
.../app/tests/unit/test_clarification.py | 32 +--
phd_config.yaml | 18 ++
scripts/list_available_advisors.py | 221 +++++++++++++++
16 files changed, 1488 insertions(+), 28 deletions(-)
create mode 100644 multi_llm_chatbot_backend/app/core/brainforge_sync.py
create mode 100644 multi_llm_chatbot_backend/app/llm/brainforge_auth.py
create mode 100644 multi_llm_chatbot_backend/app/llm/improved_brainforge_client.py
create mode 100644 multi_llm_chatbot_backend/app/tests/unit/test_brainforge_auth.py
create mode 100644 multi_llm_chatbot_backend/app/tests/unit/test_brainforge_client.py
create mode 100644 multi_llm_chatbot_backend/app/tests/unit/test_brainforge_sync.py
create mode 100755 scripts/list_available_advisors.py
diff --git a/docker-compose.yml b/docker-compose.yml
index 8381546f..3dad8802 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -16,6 +16,8 @@ services:
JWT_SECRET_KEY: ${JWT_SECRET_KEY:-CHANGEME-by-overriding-in-dot-env-file}
GEMINI_API_KEY: ${GEMINI_API_KEY:-?}
VLLM_API_KEY: ${VLLM_API_KEY:-}
+ BRAINFORGE_USERNAME: ${BRAINFORGE_USERNAME:-}
+ BRAINFORGE_PASSWORD: ${BRAINFORGE_PASSWORD:-}
CORS_ORIGINS: ${CORS_ORIGINS:-http://localhost:3000}
GEMINI_MODEL: gemini-2.5-flash
CONFIG_PATH: ${CONFIG_PATH:-/ccai/phd_config.yaml}
diff --git a/multi_llm_chatbot_backend/app/api/routes/chat.py b/multi_llm_chatbot_backend/app/api/routes/chat.py
index 8da03f86..3a600ab9 100644
--- a/multi_llm_chatbot_backend/app/api/routes/chat.py
+++ b/multi_llm_chatbot_backend/app/api/routes/chat.py
@@ -183,7 +183,19 @@ async def _event_generator():
async def _run(pid: str) -> None:
try:
+ # Guard against the persona being removed mid-request — return a
+ # fallback response instead of crashing and hanging the stream.
persona = chat_orchestrator.get_persona(pid)
+ if persona is None:
+ logger.warning("Persona %s was unregistered before response generation", pid)
+ await done_queue.put({
+ "persona_id": pid,
+ "persona_name": pid,
+ "response": "This advisor is temporarily unavailable. Please try again.",
+ "used_documents": False,
+ "document_chunks_used": 0,
+ })
+ return
result = await chat_orchestrator.generate_single_persona_response(
session, persona,
message.response_length or "medium",
@@ -193,8 +205,8 @@ async def _run(pid: str) -> None:
except Exception as e:
logger.exception(f"chat-stream _run failed for {pid}: {e}")
await done_queue.put({
- "persona_id": persona.id,
- "persona_name": persona.name,
+ "persona_id": pid,
+ "persona_name": getattr(persona, "name", pid),
"response": f"I ran into a technical issue. Please try again. ({e!s})",
"used_documents": False,
"document_chunks_used": 0,
diff --git a/multi_llm_chatbot_backend/app/api/routes/provider.py b/multi_llm_chatbot_backend/app/api/routes/provider.py
index 7185f70c..b7760e74 100644
--- a/multi_llm_chatbot_backend/app/api/routes/provider.py
+++ b/multi_llm_chatbot_backend/app/api/routes/provider.py
@@ -5,6 +5,7 @@
from app.llm.improved_vllm_client import ImprovedVllmClient
from app.models.default_personas import get_default_personas
from app.core.bootstrap import chat_orchestrator, llm, current_provider, available_providers
+from app.core.brainforge_sync import BRAINFORGE_PERSONA_PREFIX
from pydantic import BaseModel
import os
import logging
@@ -72,7 +73,10 @@ async def switch_provider(provider_data: ProviderSwitch):
chat_orchestrator.llm_client = new_llm
new_personas = get_default_personas(new_llm)
- chat_orchestrator.personas.clear()
+ # Clear only non-BrainForge personas; BF advisors have their own LLM clients
+ non_bf_ids = [pid for pid in chat_orchestrator.personas if not pid.startswith(f"{BRAINFORGE_PERSONA_PREFIX}_")]
+ for pid in non_bf_ids:
+ chat_orchestrator.unregister_persona(pid)
for persona in new_personas:
chat_orchestrator.register_persona(persona)
diff --git a/multi_llm_chatbot_backend/app/config.py b/multi_llm_chatbot_backend/app/config.py
index 65d276b8..753a1048 100644
--- a/multi_llm_chatbot_backend/app/config.py
+++ b/multi_llm_chatbot_backend/app/config.py
@@ -282,10 +282,18 @@ class VllmConfig(BaseModel):
api_key: str = Field(default=os.getenv("VLLM_API_KEY", ""))
+class BrainForgeConfig(BaseModel):
+ api_url: str = ""
+ username: str = Field(default=os.getenv("BRAINFORGE_USERNAME", ""))
+ password: str = Field(default=os.getenv("BRAINFORGE_PASSWORD", ""))
+ sync_interval_seconds: int = 600
+
+
class LLMConfig(BaseModel):
gemini: GeminiConfig = GeminiConfig()
ollama: OllamaConfig = OllamaConfig()
vllm: VllmConfig = VllmConfig()
+ brainforge: BrainForgeConfig = BrainForgeConfig()
class RAGConfig(BaseModel):
diff --git a/multi_llm_chatbot_backend/app/core/brainforge_sync.py b/multi_llm_chatbot_backend/app/core/brainforge_sync.py
new file mode 100644
index 00000000..ba46f2a2
--- /dev/null
+++ b/multi_llm_chatbot_backend/app/core/brainforge_sync.py
@@ -0,0 +1,170 @@
+"""BrainForge persona sync — fetches models and personas from the BrainForge
+API and registers them as advisors in the orchestrator.
+
+Called at startup and periodically via a background loop.
+"""
+
+import asyncio
+import logging
+import re
+from typing import List
+
+import httpx
+
+from app.config import get_settings
+from app.llm.brainforge_auth import BrainForgeAuthManager
+from app.llm.improved_brainforge_client import ImprovedBrainForgeClient
+from app.models.persona import Persona
+
+logger = logging.getLogger(__name__)
+
+BRAINFORGE_PERSONA_PREFIX = "bf"
+_SKIP_PERSONA_NAMES = {"vanilla"}
+
+
+def _make_persona_id(model_name: str, persona_name: str) -> str:
+ """Generate a stable, unique persona ID like 'bf_neonai_NeonAI'."""
+ short_model = model_name.rsplit("/", 1)[-1].lower()
+ # Sanitize persona name for use in URLs and dict keys
+ safe_name = re.sub(r"[^a-zA-Z0-9]+", "_", persona_name).strip("_")
+ return f"{BRAINFORGE_PERSONA_PREFIX}_{short_model}_{safe_name}"
+
+
+async def fetch_brainforge_models(auth: BrainForgeAuthManager, api_url: str) -> list:
+ """Fetch all models and their personas from BrainForge."""
+ try:
+ token = await auth.get_token()
+ async with httpx.AsyncClient(timeout=15) as client:
+ resp = await client.post(
+ f"{api_url}/brainforge/get_models",
+ headers={"Authorization": f"Bearer {token}"},
+ )
+
+ if resp.status_code != 200:
+ logger.warning("BrainForge get_models returned %s", resp.status_code)
+ return []
+
+ return resp.json().get("models", [])
+
+ except Exception as exc:
+ logger.warning("Failed to fetch BrainForge models: %s", exc)
+ return []
+
+
+def build_brainforge_personas(
+ models: list,
+ auth: BrainForgeAuthManager,
+ api_url: str,
+) -> List[Persona]:
+ """Build Persona objects from BrainForge model/persona data."""
+ personas = []
+
+ for model in models:
+ model_name = model.get("name", "")
+ model_version = model.get("version", "")
+ model_id = f"{model_name}@{model_version}"
+
+ for p in model.get("personas", []):
+ persona_name = p.get("persona_name", "")
+
+ if persona_name.lower() in _SKIP_PERSONA_NAMES:
+ continue
+
+ if not p.get("enabled", True):
+ continue
+
+ system_prompt = p.get("system_prompt") or p.get("description") or ""
+ if not system_prompt:
+ logger.debug(
+ "Skipping BrainForge persona %s (no prompt)", persona_name
+ )
+ continue
+
+ pid = _make_persona_id(model_name, persona_name)
+
+ llm_client = ImprovedBrainForgeClient(
+ api_url=api_url,
+ model_id=model_id,
+ auth_manager=auth,
+ )
+
+ persona = Persona(
+ id=pid,
+ name=persona_name,
+ system_prompt=system_prompt,
+ llm=llm_client,
+ temperature=5,
+ )
+ personas.append(persona)
+
+ return personas
+
+
+async def async_sync_brainforge_personas(orchestrator) -> int:
+ """Fetch BrainForge personas and reconcile with the orchestrator.
+
+ Registers new personas, updates existing ones, and removes stale ones
+ that are no longer advertised by BrainForge. Returns the number of
+ personas currently registered after reconciliation.
+ """
+ settings = get_settings()
+ bf_config = settings.llm.brainforge
+
+ if not bf_config.api_url:
+ logger.debug("BrainForge not configured, skipping persona sync")
+ return 0
+
+ if not bf_config.username or not bf_config.password:
+ logger.warning("BrainForge credentials not set, skipping persona sync")
+ return 0
+
+ api_url = bf_config.api_url.rstrip("/")
+ auth = BrainForgeAuthManager(api_url, bf_config.username, bf_config.password)
+
+ models = await fetch_brainforge_models(auth, api_url)
+ if not models:
+ logger.warning("No BrainForge models available, no personas registered")
+ return 0
+
+ personas = build_brainforge_personas(models, auth, api_url)
+ fresh_ids = {p.id for p in personas}
+
+ stale_ids = [
+ pid for pid in orchestrator.personas
+ if pid.startswith(f"{BRAINFORGE_PERSONA_PREFIX}_") and pid not in fresh_ids
+ ]
+ for pid in stale_ids:
+ orchestrator.unregister_persona(pid)
+
+ added = 0
+ for persona in personas:
+ is_new = persona.id not in orchestrator.personas
+ orchestrator.register_persona(persona)
+ if is_new:
+ added += 1
+
+ if added or stale_ids:
+ logger.info(
+ "BrainForge sync: +%d new, -%d stale, %d total",
+ added, len(stale_ids), len(fresh_ids),
+ )
+
+ return len(fresh_ids)
+
+
+async def periodic_sync_loop(orchestrator) -> None:
+ """Background task that re-syncs BrainForge personas on a timer."""
+ settings = get_settings()
+ interval = settings.llm.brainforge.sync_interval_seconds
+
+ if interval <= 0:
+ logger.info("BrainForge periodic sync disabled (sync_interval_seconds=%d)", interval)
+ return
+
+ logger.info("BrainForge periodic sync started (every %ds)", interval)
+ while True:
+ await asyncio.sleep(interval)
+ try:
+ await async_sync_brainforge_personas(orchestrator)
+ except Exception as exc:
+ logger.warning("BrainForge periodic sync error: %s", exc)
diff --git a/multi_llm_chatbot_backend/app/core/context_manager.py b/multi_llm_chatbot_backend/app/core/context_manager.py
index 64416b92..2b2bef19 100644
--- a/multi_llm_chatbot_backend/app/core/context_manager.py
+++ b/multi_llm_chatbot_backend/app/core/context_manager.py
@@ -134,7 +134,7 @@ def _format_for_provider(self, messages: List[dict], system_prompt: str, provide
return self._format_for_gemini(messages, system_prompt)
elif provider.lower() in ["ollama", "mistral"]:
return self._format_for_ollama(messages, system_prompt)
- elif provider.lower() == "vllm":
+ elif provider.lower() in ["vllm", "brainforge"]:
return self._format_for_vllm(messages, system_prompt)
else:
# Default format
diff --git a/multi_llm_chatbot_backend/app/core/improved_orchestrator.py b/multi_llm_chatbot_backend/app/core/improved_orchestrator.py
index 2191b125..12f63a78 100644
--- a/multi_llm_chatbot_backend/app/core/improved_orchestrator.py
+++ b/multi_llm_chatbot_backend/app/core/improved_orchestrator.py
@@ -25,10 +25,18 @@ def __init__(self, llm_client: LLMClient = None):
self.context_manager = get_context_manager()
def register_persona(self, persona: Persona):
- """Register a persona with the orchestrator"""
+ """Register or update a persona in the orchestrator."""
+ is_new = persona.id not in self.personas
self.personas[persona.id] = persona
- logger.info(f"Registered persona: {persona.id} ({persona.name})")
+ if is_new:
+ logger.info(f"Registered persona: {persona.id} ({persona.name})")
+ def unregister_persona(self, persona_id: str):
+ """Remove a persona from the orchestrator."""
+ removed = self.personas.pop(persona_id, None)
+ if removed:
+ logger.info(f"Unregistered persona: {persona_id} ({removed.name})")
+
def get_persona(self, persona_id: str) -> Optional[Persona]:
"""Get a specific persona"""
return self.personas.get(persona_id)
@@ -282,7 +290,9 @@ async def needs_clarification_improved(self, session: ConversationContext, user_
raw = None
try:
- llm = next(iter(self.personas.values())).llm
+ # Use the orchestrator's own LLM rather than a persona's — BrainForge
+ # persona LLMs may not support the prompt format used here.
+ llm = self.llm_client
raw = await llm.generate(
system_prompt=system_prompt,
context=[{"role": "user", "content": user_prompt}],
@@ -345,7 +355,9 @@ async def generate_contextual_clarification(self, user_input: str) -> Dict[str,
)
try:
- llm = next(iter(self.personas.values())).llm
+ # Use the orchestrator's own LLM rather than a persona's — BrainForge
+ # persona LLMs may not support the prompt format used here.
+ llm = self.llm_client
raw = await llm.generate(
system_prompt=system_prompt,
context=[{"role": "user", "content": user_prompt}],
@@ -890,8 +902,9 @@ async def get_top_personas(self, session_id: str, k: int = 3,
logger.warning("No personas available after filtering.")
return []
- # Use the LLM from one of the existing persona objects
- llm = next(iter(pool.values())).llm
+ # Use the orchestrator's own LLM rather than a persona's — BrainForge
+ # persona LLMs may not support the prompt format used here.
+ llm = self.llm_client
# Use recent conversation context (last 5 messages)
recent_context = "\n".join(
diff --git a/multi_llm_chatbot_backend/app/llm/brainforge_auth.py b/multi_llm_chatbot_backend/app/llm/brainforge_auth.py
new file mode 100644
index 00000000..f273ecc0
--- /dev/null
+++ b/multi_llm_chatbot_backend/app/llm/brainforge_auth.py
@@ -0,0 +1,106 @@
+import asyncio
+import logging
+import time
+import uuid
+
+import httpx
+
+logger = logging.getLogger(__name__)
+
+TOKEN_EXPIRY_BUFFER = 60 # refresh this many seconds before actual expiry
+
+
+class BrainForgeAuthManager:
+ """Manages authentication tokens for the BrainForge (HANA) API.
+
+ Handles login, caching, and automatic refresh so callers can simply
+ await ``get_token()`` to obtain a valid bearer token.
+ """
+
+ def __init__(self, api_url: str, username: str, password: str):
+ self._api_url = api_url.rstrip("/")
+ self._username = username
+ self._password = password
+ self._client_id = f"ccai-backend-{uuid.uuid4().hex[:8]}"
+
+ self._access_token: str | None = None
+ self._refresh_token: str | None = None
+ self._expiration: float = 0.0
+
+ self._lock = asyncio.Lock()
+
+ async def get_token(self) -> str:
+ """Return a valid bearer token, refreshing or re-logging in as needed."""
+ async with self._lock:
+ if self._access_token and time.time() < self._expiration - TOKEN_EXPIRY_BUFFER:
+ return self._access_token
+
+ if self._refresh_token:
+ refreshed = await self._refresh()
+ if refreshed:
+ return self._access_token
+
+ await self._login()
+ return self._access_token
+
+ async def health_check(self) -> bool:
+ """Return True if we can successfully authenticate."""
+ try:
+ await self.get_token()
+ return True
+ except Exception:
+ return False
+
+ async def _login(self) -> None:
+ """Authenticate with username/password and store tokens."""
+ url = f"{self._api_url}/auth/login"
+ payload = {
+ "username": self._username,
+ "password": self._password,
+ "token_name": "ccai-backend",
+ "client_id": self._client_id,
+ }
+
+ async with httpx.AsyncClient(timeout=15) as client:
+ resp = await client.post(url, json=payload)
+
+ if resp.status_code != 200:
+ logger.error(
+ "BrainForge login failed: %s %s", resp.status_code, resp.text[:200]
+ )
+ raise RuntimeError(
+ f"BrainForge authentication failed (HTTP {resp.status_code})"
+ )
+
+ data = resp.json()
+ self._access_token = data["access_token"]
+ self._refresh_token = data["refresh_token"]
+ self._expiration = data["expiration"]
+ logger.info("BrainForge login successful (user=%s)", self._username)
+
+ async def _refresh(self) -> bool:
+ """Attempt to refresh the access token. Returns False on failure."""
+ url = f"{self._api_url}/auth/refresh"
+ payload = {"access_token": self._access_token, "refresh_token": self._refresh_token}
+
+ try:
+ async with httpx.AsyncClient(timeout=15) as client:
+ resp = await client.post(url, json=payload)
+
+ if resp.status_code != 200:
+ logger.warning(
+ "BrainForge token refresh failed (%s), will re-login",
+ resp.status_code,
+ )
+ return False
+
+ data = resp.json()
+ self._access_token = data["access_token"]
+ self._refresh_token = data["refresh_token"]
+ self._expiration = data["expiration"]
+ logger.debug("BrainForge token refreshed successfully")
+ return True
+
+ except Exception as exc:
+ logger.warning("BrainForge token refresh error: %s, will re-login", exc)
+ return False
diff --git a/multi_llm_chatbot_backend/app/llm/improved_brainforge_client.py b/multi_llm_chatbot_backend/app/llm/improved_brainforge_client.py
new file mode 100644
index 00000000..198dea10
--- /dev/null
+++ b/multi_llm_chatbot_backend/app/llm/improved_brainforge_client.py
@@ -0,0 +1,184 @@
+import json
+import logging
+import re
+from typing import Any, Callable, Dict, List, Optional
+
+import httpx
+
+from app.llm.llm_client import LLMClient, ToolCallResult
+from app.llm.brainforge_auth import BrainForgeAuthManager
+from app.core.context_manager import get_context_manager
+
+logger = logging.getLogger(__name__)
+
+_STRUCTURED_OUTPUT_SCHEMA: Dict[str, Any] = {
+ "type": "object",
+ "properties": {
+ "thought": {"type": "string", "maxLength": 225},
+ "what_to_do": {
+ "type": "array",
+ "items": {"type": "string", "maxLength": 160},
+ "minItems": 3,
+ "maxItems": 3,
+ },
+ "next_step": {"type": "string", "maxLength": 225},
+ },
+ "required": ["thought", "what_to_do", "next_step"],
+}
+
+
+class ImprovedBrainForgeClient(LLMClient):
+ """LLM client for BrainForge via its OpenAI-compatible endpoint.
+
+ Uses bearer-token auth (managed by BrainForgeAuthManager) and sends
+ requests to ``/brainforge/openai/chat/completions``.
+ """
+
+ def __init__(
+ self,
+ api_url: str,
+ username: str = "",
+ password: str = "",
+ model_id: Optional[str] = None,
+ auth_manager: Optional[BrainForgeAuthManager] = None,
+ ):
+ self.api_url = api_url.rstrip("/")
+ self.model_id = model_id
+ self._auth = auth_manager or BrainForgeAuthManager(api_url, username, password)
+ self.context_manager = get_context_manager()
+
+ async def refresh_model(self):
+ """Discover the first available model from BrainForge."""
+ token = await self._auth.get_token()
+ async with httpx.AsyncClient(timeout=15) as client:
+ resp = await client.post(
+ f"{self.api_url}/brainforge/get_models",
+ headers={"Authorization": f"Bearer {token}"},
+ )
+
+ if resp.status_code != 200:
+ raise ValueError(f"BrainForge get_models failed: HTTP {resp.status_code}")
+
+ models = resp.json().get("models", [])
+ if not models:
+ raise ValueError("No models available on BrainForge")
+
+ self.model_id = f"{models[0]['name']}@{models[0]['version']}"
+ logger.info("BrainForge auto-selected model: %s", self.model_id)
+
+ async def generate(
+ self,
+ system_prompt: str,
+ context: List[dict],
+ temperature: float,
+ max_tokens: int,
+ response_mime_type: str = None,
+ ) -> str:
+ # JSON structured outputs need more tokens than plain text due to
+ # syntax overhead ({, ", :, [, etc.). Scale up to avoid truncation.
+ max_tokens = int(max_tokens * 2)
+
+ try:
+ context_window = self.context_manager.prepare_context_for_llm(
+ messages=context,
+ system_prompt=system_prompt,
+ llm_provider="brainforge",
+ )
+
+ logger.debug(
+ "BrainForge context prepared: %d messages, ~%d tokens, truncated=%s",
+ len(context_window.messages),
+ context_window.total_tokens,
+ context_window.truncated,
+ )
+
+ if not self.model_id:
+ await self.refresh_model()
+
+ token = await self._auth.get_token()
+
+ payload = {
+ "model": self.model_id,
+ "messages": context_window.messages,
+ "temperature": temperature,
+ "max_tokens": max_tokens,
+ "extra_body": {
+ "structured_outputs": {"json": _STRUCTURED_OUTPUT_SCHEMA},
+ },
+ }
+
+ async with httpx.AsyncClient(timeout=90) as client:
+ resp = await client.post(
+ f"{self.api_url}/brainforge/openai/chat/completions",
+ headers={
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ },
+ json=payload,
+ )
+
+ if resp.status_code != 200:
+ logger.error(
+ "BrainForge API error: %s - %s",
+ resp.status_code,
+ resp.text[:200],
+ )
+ if resp.status_code == 500 and "model not found" in resp.text.lower():
+ logger.info("Model not found, will re-discover on next request")
+ self.model_id = None
+ return "The AI service encountered an error. Please try again."
+
+ data = resp.json()
+ text = data["choices"][0]["message"]["content"].strip()
+
+ try:
+ parsed = json.loads(text)
+ expected_keys = {"thought", "what_to_do", "next_step"}
+ if isinstance(parsed, dict) and expected_keys.issubset(parsed.keys()):
+ # Structured JSON response from vLLM constrained decoding.
+ # Clean up bullet items: strip leading "- " or "1." prefixes
+ # and convert **bold** labels to plain text.
+ bullets = []
+ for item in parsed["what_to_do"]:
+ cleaned = re.sub(r"^-\s*", "", item)
+ cleaned = re.sub(r"^\d+\.\s*", "", cleaned)
+ cleaned = re.sub(r"\*\*(.+?)\*\*:?\s*", r"\1: ", cleaned)
+ bullets.append(cleaned.strip())
+ md = (
+ f"### Thought\n{parsed['thought']}\n\n"
+ f"### What to do\n"
+ + "\n".join(f"- {b}" for b in bullets)
+ + f"\n\n### Next step\n{parsed['next_step']}"
+ )
+ return md
+ except (json.JSONDecodeError, KeyError, TypeError) as exc:
+ logger.warning("BrainForge JSON parse failed (%s): %s", type(exc).__name__, exc)
+
+ # Fallback: plain text response (structured_outputs not active)
+ return self._clean_response(text)
+
+ except httpx.ConnectError:
+ logger.error("Unable to connect to BrainForge at %s", self.api_url)
+ return "I'm unable to connect to the BrainForge service. Please try again later."
+ except httpx.TimeoutException:
+ logger.error("BrainForge request timed out")
+ return "The BrainForge service is taking too long to respond. Please try again."
+ except RuntimeError as e:
+ logger.error("BrainForge auth failure: %s", e)
+ return "Unable to authenticate with BrainForge. Please check credentials."
+ except Exception as e:
+ logger.error("Unexpected error in BrainForge client: %s", e)
+ return "I encountered an unexpected error. Please try again."
+
+ async def health_check(self) -> bool:
+ """Check if BrainForge is reachable and authenticated."""
+ try:
+ token = await self._auth.get_token()
+ async with httpx.AsyncClient(timeout=10) as client:
+ resp = await client.post(
+ f"{self.api_url}/brainforge/get_models",
+ headers={"Authorization": f"Bearer {token}"},
+ )
+ return resp.status_code == 200
+ except Exception:
+ return False
diff --git a/multi_llm_chatbot_backend/app/main.py b/multi_llm_chatbot_backend/app/main.py
index cce60530..677df8f2 100644
--- a/multi_llm_chatbot_backend/app/main.py
+++ b/multi_llm_chatbot_backend/app/main.py
@@ -1,3 +1,4 @@
+import asyncio
import os
from dotenv import load_dotenv
@@ -36,8 +37,13 @@
async def lifespan(app: FastAPI):
# Startup
await connect_to_mongo()
+ from app.core.bootstrap import chat_orchestrator
+ from app.core.brainforge_sync import async_sync_brainforge_personas, periodic_sync_loop
+ await async_sync_brainforge_personas(chat_orchestrator)
+ sync_task = asyncio.create_task(periodic_sync_loop(chat_orchestrator))
yield
# Shutdown
+ sync_task.cancel()
await close_mongo_connection()
app = FastAPI(
@@ -79,8 +85,43 @@ async def lifespan(app: FastAPI):
# ---------------------------------------------------------------------------
@app.get("/api/config")
def get_public_config():
- """Return the public (non-secret) application configuration."""
- return settings.get_frontend_config()
+ """Return the public (non-secret) application configuration.
+
+ Merges statically-configured personas (from YAML) with dynamically
+ discovered BrainForge personas so the frontend sees all advisors in
+ a single response.
+ """
+ from app.core.bootstrap import chat_orchestrator
+ from app.config import generate_persona_colors
+ from app.core.brainforge_sync import BRAINFORGE_PERSONA_PREFIX
+
+ config = settings.get_frontend_config()
+
+ static_ids = {p["id"] for p in config["personas"]["items"]}
+ allowed = settings.personas.allowed_advisors
+
+ for pid, persona in chat_orchestrator.personas.items():
+ if not pid.startswith(f"{BRAINFORGE_PERSONA_PREFIX}_"):
+ continue
+ if pid in static_ids:
+ continue
+ if allowed is not None and pid not in allowed:
+ continue
+
+ colors = generate_persona_colors(persona.name)
+ config["personas"]["items"].append({
+ "id": pid,
+ "name": persona.name,
+ "role": "BrainForge Advisor",
+ "summary": persona.system_prompt[:120] if persona.system_prompt else "",
+ "color": colors["color"],
+ "bg_color": colors["bg_color"],
+ "dark_color": colors["dark_color"],
+ "dark_bg_color": colors["dark_bg_color"],
+ "image": "icon://Brain",
+ })
+
+ return config
@app.get("/")
def root():
diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_brainforge_auth.py b/multi_llm_chatbot_backend/app/tests/unit/test_brainforge_auth.py
new file mode 100644
index 00000000..e18440df
--- /dev/null
+++ b/multi_llm_chatbot_backend/app/tests/unit/test_brainforge_auth.py
@@ -0,0 +1,180 @@
+import asyncio
+import time
+import unittest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from app.llm.brainforge_auth import BrainForgeAuthManager, TOKEN_EXPIRY_BUFFER
+
+FAKE_URL = "https://fake.brainforge.example.com"
+FAKE_USER = "testuser"
+FAKE_PASS = "testpass"
+
+
+def _make_login_response():
+ """Build a mock httpx.Response for a successful login."""
+ resp = MagicMock()
+ resp.status_code = 200
+ resp.json.return_value = {
+ "access_token": "access-abc",
+ "refresh_token": "refresh-xyz",
+ "expiration": time.time() + 3600,
+ }
+ return resp
+
+
+def _make_refresh_response():
+ """Build a mock httpx.Response for a successful token refresh."""
+ resp = MagicMock()
+ resp.status_code = 200
+ resp.json.return_value = {
+ "access_token": "access-refreshed",
+ "refresh_token": "refresh-new",
+ "expiration": time.time() + 3600,
+ }
+ return resp
+
+
+def _setup_http_mock(MockAsyncClient, response):
+ """Wire up MockAsyncClient to return the given response from post()."""
+ mock_client = AsyncMock()
+ mock_client.post.return_value = response
+ MockAsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client)
+ MockAsyncClient.return_value.__aexit__ = AsyncMock(return_value=False)
+ return mock_client
+
+
+@patch("app.llm.brainforge_auth.httpx.AsyncClient")
+class TestBrainForgeAuthManager(unittest.TestCase):
+
+ # ------------------------------------------------------------------
+ # Construction
+ # ------------------------------------------------------------------
+
+ def test_constructor_stores_attributes(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+ self.assertEqual(auth._api_url, FAKE_URL)
+ self.assertEqual(auth._username, FAKE_USER)
+ self.assertEqual(auth._password, FAKE_PASS)
+ self.assertIsNone(auth._access_token)
+ self.assertIsNone(auth._refresh_token)
+
+ def test_constructor_strips_trailing_slash(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(f"{FAKE_URL}/", FAKE_USER, FAKE_PASS)
+ self.assertEqual(auth._api_url, FAKE_URL)
+
+ # ------------------------------------------------------------------
+ # Login
+ # ------------------------------------------------------------------
+
+ def test_login_sends_correct_payload(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+ mock_client = _setup_http_mock(MockAsyncClient, _make_login_response())
+
+ asyncio.run(auth._login())
+
+ call_args = mock_client.post.call_args
+ self.assertEqual(call_args[0][0], f"{FAKE_URL}/auth/login")
+ payload = call_args[1]["json"]
+ self.assertEqual(payload["username"], FAKE_USER)
+ self.assertEqual(payload["password"], FAKE_PASS)
+
+ def test_login_stores_tokens_on_success(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+ _setup_http_mock(MockAsyncClient, _make_login_response())
+
+ asyncio.run(auth._login())
+
+ self.assertEqual(auth._access_token, "access-abc")
+ self.assertEqual(auth._refresh_token, "refresh-xyz")
+ self.assertGreater(auth._expiration, time.time())
+
+ def test_login_raises_on_non_200(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 401
+ mock_resp.text = '{"detail":"Invalid username or password"}'
+ _setup_http_mock(MockAsyncClient, mock_resp)
+
+ with self.assertRaises(RuntimeError):
+ asyncio.run(auth._login())
+
+ # ------------------------------------------------------------------
+ # Token refresh
+ # ------------------------------------------------------------------
+
+ def test_refresh_updates_tokens_on_success(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+ auth._access_token = "old-access"
+ auth._refresh_token = "old-refresh"
+ _setup_http_mock(MockAsyncClient, _make_refresh_response())
+
+ result = asyncio.run(auth._refresh())
+
+ self.assertTrue(result)
+ self.assertEqual(auth._access_token, "access-refreshed")
+ self.assertEqual(auth._refresh_token, "refresh-new")
+
+ def test_refresh_returns_false_on_failure(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+ auth._access_token = "old-access"
+ auth._refresh_token = "old-refresh"
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 401
+ _setup_http_mock(MockAsyncClient, mock_resp)
+
+ result = asyncio.run(auth._refresh())
+ self.assertFalse(result)
+
+ # ------------------------------------------------------------------
+ # get_token flow
+ # ------------------------------------------------------------------
+
+ def test_get_token_returns_cached_when_valid(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+ auth._access_token = "cached-token"
+ auth._expiration = time.time() + 3600
+
+ token = asyncio.run(auth.get_token())
+ self.assertEqual(token, "cached-token")
+
+ def test_get_token_refreshes_when_near_expiry(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+ auth._access_token = "expiring-token"
+ auth._refresh_token = "has-refresh"
+ auth._expiration = time.time() + (TOKEN_EXPIRY_BUFFER - 1)
+ _setup_http_mock(MockAsyncClient, _make_refresh_response())
+
+ token = asyncio.run(auth.get_token())
+ self.assertEqual(token, "access-refreshed")
+
+ def test_get_token_logins_when_no_tokens(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+ _setup_http_mock(MockAsyncClient, _make_login_response())
+
+ token = asyncio.run(auth.get_token())
+ self.assertEqual(token, "access-abc")
+
+ # ------------------------------------------------------------------
+ # Health check
+ # ------------------------------------------------------------------
+
+ def test_health_check_true_on_valid_token(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+ auth._access_token = "valid-token"
+ auth._expiration = time.time() + 3600
+
+ result = asyncio.run(auth.health_check())
+ self.assertTrue(result)
+
+ def test_health_check_false_on_exception(self, MockAsyncClient):
+ auth = BrainForgeAuthManager(FAKE_URL, FAKE_USER, FAKE_PASS)
+
+ mock_client = AsyncMock()
+ mock_client.post.side_effect = Exception("connection refused")
+ MockAsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client)
+ MockAsyncClient.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ result = asyncio.run(auth.health_check())
+ self.assertFalse(result)
diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_brainforge_client.py b/multi_llm_chatbot_backend/app/tests/unit/test_brainforge_client.py
new file mode 100644
index 00000000..50db6b02
--- /dev/null
+++ b/multi_llm_chatbot_backend/app/tests/unit/test_brainforge_client.py
@@ -0,0 +1,248 @@
+import asyncio
+import unittest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import httpx
+
+from app.llm.improved_brainforge_client import ImprovedBrainForgeClient
+
+FAKE_URL = "https://fake.brainforge.example.com"
+FAKE_MODEL = "BrainForge/neonai@2026.01.26"
+
+
+def _make_chat_response(content="Hello from BrainForge"):
+ """Build a mock httpx.Response for a successful chat completion."""
+ resp = MagicMock()
+ resp.status_code = 200
+ resp.json.return_value = {
+ "choices": [{"message": {"content": f" {content} "}}],
+ }
+ return resp
+
+
+def _make_models_response(model_name="BrainForge/neonai", version="2026.01.26"):
+ """Build a mock httpx.Response for get_models."""
+ resp = MagicMock()
+ resp.status_code = 200
+ resp.json.return_value = {
+ "models": [{"name": model_name, "version": version}],
+ }
+ return resp
+
+
+def _make_client(auth_manager=None, model_id=FAKE_MODEL):
+ """Create an ImprovedBrainForgeClient with a mocked auth manager."""
+ mock_auth = auth_manager or AsyncMock()
+ mock_auth.get_token = AsyncMock(return_value="fake-token")
+ return ImprovedBrainForgeClient(
+ api_url=FAKE_URL,
+ model_id=model_id,
+ auth_manager=mock_auth,
+ )
+
+
+def _setup_http_mock(MockHttpClient, response):
+ """Wire up MockHttpClient to return the given response from post()."""
+ mock_http = AsyncMock()
+ mock_http.post.return_value = response
+ MockHttpClient.return_value.__aenter__ = AsyncMock(return_value=mock_http)
+ MockHttpClient.return_value.__aexit__ = AsyncMock(return_value=False)
+ return mock_http
+
+
+@patch("app.llm.improved_brainforge_client.httpx.AsyncClient")
+@patch("app.llm.improved_brainforge_client.get_context_manager")
+class TestImprovedBrainForgeClient(unittest.TestCase):
+
+ # ------------------------------------------------------------------
+ # Construction
+ # ------------------------------------------------------------------
+
+ def test_constructor_stores_attributes(self, mock_ctx, MockHttpClient):
+ client = ImprovedBrainForgeClient(
+ api_url=FAKE_URL, model_id=FAKE_MODEL, auth_manager=AsyncMock(),
+ )
+ self.assertEqual(client.api_url, FAKE_URL)
+ self.assertEqual(client.model_id, FAKE_MODEL)
+
+ def test_constructor_strips_trailing_slash(self, mock_ctx, MockHttpClient):
+ client = ImprovedBrainForgeClient(
+ api_url=f"{FAKE_URL}/", model_id=FAKE_MODEL, auth_manager=AsyncMock(),
+ )
+ self.assertEqual(client.api_url, FAKE_URL)
+
+ def test_constructor_accepts_auth_manager(self, mock_ctx, MockHttpClient):
+ mock_auth = AsyncMock()
+ client = ImprovedBrainForgeClient(
+ api_url=FAKE_URL, auth_manager=mock_auth,
+ )
+ self.assertIs(client._auth, mock_auth)
+
+ def test_constructor_creates_auth_from_credentials(self, mock_ctx, MockHttpClient):
+ client = ImprovedBrainForgeClient(
+ api_url=FAKE_URL, username="user", password="pass",
+ )
+ self.assertIsNotNone(client._auth)
+ self.assertEqual(client._auth._username, "user")
+
+ # ------------------------------------------------------------------
+ # generate — happy path
+ # ------------------------------------------------------------------
+
+ def test_generate_returns_cleaned_response(self, mock_ctx, MockHttpClient):
+ client = _make_client()
+ _setup_http_mock(MockHttpClient, _make_chat_response("Here is my response."))
+
+ result = asyncio.run(client.generate(
+ system_prompt="You are helpful.",
+ context=[{"role": "user", "content": "Hello"}],
+ temperature=0.7,
+ max_tokens=100,
+ ))
+ self.assertEqual(result, "Here is my response.")
+
+ def test_generate_sends_correct_payload(self, mock_ctx, MockHttpClient):
+ client = _make_client()
+ mock_http = _setup_http_mock(MockHttpClient, _make_chat_response())
+
+ asyncio.run(client.generate(
+ system_prompt="Test",
+ context=[{"role": "user", "content": "Hi"}],
+ temperature=0.5,
+ max_tokens=50,
+ ))
+
+ call_args = mock_http.post.call_args
+ self.assertIn("/brainforge/openai/chat/completions", call_args[0][0])
+ payload = call_args[1]["json"]
+ self.assertEqual(payload["model"], FAKE_MODEL)
+ self.assertEqual(payload["temperature"], 0.5)
+ self.assertEqual(payload["max_tokens"], 100) # 50 * 2 (JSON overhead scaling)
+
+ def test_generate_auto_discovers_model_when_none(self, mock_ctx, MockHttpClient):
+ client = _make_client(model_id=None)
+
+ mock_http = AsyncMock()
+ mock_http.post.side_effect = [
+ _make_models_response(),
+ _make_chat_response(),
+ ]
+ MockHttpClient.return_value.__aenter__ = AsyncMock(return_value=mock_http)
+ MockHttpClient.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ asyncio.run(client.generate(
+ system_prompt="Test",
+ context=[{"role": "user", "content": "Hi"}],
+ temperature=0.5,
+ max_tokens=50,
+ ))
+
+ self.assertEqual(client.model_id, "BrainForge/neonai@2026.01.26")
+
+ # ------------------------------------------------------------------
+ # generate — error handling
+ # ------------------------------------------------------------------
+
+ def test_generate_handles_non_200(self, mock_ctx, MockHttpClient):
+ client = _make_client()
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 422
+ mock_resp.text = '{"detail":"validation error"}'
+ _setup_http_mock(MockHttpClient, mock_resp)
+
+ result = asyncio.run(client.generate(
+ system_prompt="Test",
+ context=[{"role": "user", "content": "Hi"}],
+ temperature=0.5,
+ max_tokens=50,
+ ))
+ self.assertIn("error", result.lower())
+
+ def test_generate_clears_model_on_500_model_not_found(self, mock_ctx, MockHttpClient):
+ client = _make_client()
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 500
+ mock_resp.text = '{"detail":"model not found"}'
+ _setup_http_mock(MockHttpClient, mock_resp)
+
+ asyncio.run(client.generate(
+ system_prompt="Test",
+ context=[{"role": "user", "content": "Hi"}],
+ temperature=0.5,
+ max_tokens=50,
+ ))
+ self.assertIsNone(client.model_id)
+
+ def test_generate_handles_connect_error(self, mock_ctx, MockHttpClient):
+ client = _make_client()
+
+ mock_http = AsyncMock()
+ mock_http.post.side_effect = httpx.ConnectError("Connection refused")
+ MockHttpClient.return_value.__aenter__ = AsyncMock(return_value=mock_http)
+ MockHttpClient.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ result = asyncio.run(client.generate(
+ system_prompt="Test",
+ context=[{"role": "user", "content": "Hi"}],
+ temperature=0.5,
+ max_tokens=50,
+ ))
+ self.assertIn("unable to connect", result.lower())
+
+ def test_generate_handles_timeout(self, mock_ctx, MockHttpClient):
+ client = _make_client()
+
+ mock_http = AsyncMock()
+ mock_http.post.side_effect = httpx.TimeoutException("timed out")
+ MockHttpClient.return_value.__aenter__ = AsyncMock(return_value=mock_http)
+ MockHttpClient.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ result = asyncio.run(client.generate(
+ system_prompt="Test",
+ context=[{"role": "user", "content": "Hi"}],
+ temperature=0.5,
+ max_tokens=50,
+ ))
+ self.assertIn("too long", result.lower())
+
+ def test_generate_handles_auth_failure(self, mock_ctx, MockHttpClient):
+ mock_auth = AsyncMock()
+ mock_auth.get_token.side_effect = RuntimeError("auth failed")
+ client = ImprovedBrainForgeClient(
+ api_url=FAKE_URL, model_id=FAKE_MODEL, auth_manager=mock_auth,
+ )
+
+ result = asyncio.run(client.generate(
+ system_prompt="Test",
+ context=[{"role": "user", "content": "Hi"}],
+ temperature=0.5,
+ max_tokens=50,
+ ))
+ self.assertIn("authenticate", result.lower())
+
+ # ------------------------------------------------------------------
+ # health_check
+ # ------------------------------------------------------------------
+
+ def test_health_check_true_on_200(self, mock_ctx, MockHttpClient):
+ client = _make_client()
+
+ mock_resp = MagicMock()
+ mock_resp.status_code = 200
+ _setup_http_mock(MockHttpClient, mock_resp)
+
+ result = asyncio.run(client.health_check())
+ self.assertTrue(result)
+
+ def test_health_check_false_on_exception(self, mock_ctx, MockHttpClient):
+ client = _make_client()
+
+ mock_http = AsyncMock()
+ mock_http.post.side_effect = Exception("boom")
+ MockHttpClient.return_value.__aenter__ = AsyncMock(return_value=mock_http)
+ MockHttpClient.return_value.__aexit__ = AsyncMock(return_value=False)
+
+ result = asyncio.run(client.health_check())
+ self.assertFalse(result)
diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_brainforge_sync.py b/multi_llm_chatbot_backend/app/tests/unit/test_brainforge_sync.py
new file mode 100644
index 00000000..42882b32
--- /dev/null
+++ b/multi_llm_chatbot_backend/app/tests/unit/test_brainforge_sync.py
@@ -0,0 +1,253 @@
+import asyncio
+import unittest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from app.core.brainforge_sync import (
+ BRAINFORGE_PERSONA_PREFIX,
+ _SKIP_PERSONA_NAMES,
+ _make_persona_id,
+ build_brainforge_personas,
+ async_sync_brainforge_personas,
+)
+
+
+FAKE_URL = "https://fake.brainforge.example.com"
+
+SAMPLE_MODELS = [
+ {
+ "name": "BrainForge/neonai",
+ "version": "2026.01.26",
+ "personas": [
+ {
+ "persona_name": "NeonAI",
+ "system_prompt": "You are the Neon AI assistant.",
+ "enabled": True,
+ },
+ {
+ "persona_name": "vanilla",
+ "system_prompt": None,
+ "enabled": True,
+ },
+ ],
+ },
+ {
+ "name": "BrainForge/NucleotidingsLLM",
+ "version": "2026.04.03",
+ "personas": [
+ {
+ "persona_name": "Nucleotidings",
+ "system_prompt": "You are a bioinformatics assistant.",
+ "enabled": True,
+ },
+ {
+ "persona_name": "DisabledBot",
+ "system_prompt": "I am disabled.",
+ "enabled": False,
+ },
+ ],
+ },
+]
+
+
+class TestMakePersonaId(unittest.TestCase):
+ """Tests for the _make_persona_id helper."""
+
+ def test_basic_id_generation(self):
+ pid = _make_persona_id("BrainForge/neonai", "NeonAI")
+ self.assertEqual(pid, "bf_neonai_NeonAI")
+
+ def test_slugifies_special_characters(self):
+ pid = _make_persona_id("BrainForge/test-model", "Dr. Smith's Bot!")
+ self.assertTrue(pid.startswith(f"{BRAINFORGE_PERSONA_PREFIX}_"))
+ self.assertNotIn(" ", pid)
+ self.assertNotIn(".", pid)
+ self.assertNotIn("'", pid)
+ self.assertNotIn("!", pid)
+
+ def test_strips_trailing_underscores(self):
+ pid = _make_persona_id("BrainForge/model", " spaces ")
+ self.assertFalse(pid.endswith("_"))
+
+ def test_uses_last_segment_of_model_name(self):
+ pid = _make_persona_id("Org/SubOrg/deepmodel", "Bot")
+ self.assertIn("deepmodel", pid)
+ self.assertNotIn("Org", pid)
+
+ def test_prefix_is_correct(self):
+ pid = _make_persona_id("BrainForge/neonai", "NeonAI")
+ self.assertTrue(pid.startswith(f"{BRAINFORGE_PERSONA_PREFIX}_"))
+
+
+@patch("app.llm.improved_brainforge_client.get_context_manager")
+class TestBuildBrainforgePersonas(unittest.TestCase):
+ """Tests for build_brainforge_personas."""
+
+ def test_builds_personas_from_model_data(self, mock_ctx):
+ mock_auth = AsyncMock()
+ personas = build_brainforge_personas(SAMPLE_MODELS, mock_auth, FAKE_URL)
+
+ names = {p.name for p in personas}
+ self.assertIn("NeonAI", names)
+ self.assertIn("Nucleotidings", names)
+
+ for p in personas:
+ self.assertTrue(p.id.startswith(f"{BRAINFORGE_PERSONA_PREFIX}_"))
+
+ def test_skips_vanilla_persona(self, mock_ctx):
+ mock_auth = AsyncMock()
+ personas = build_brainforge_personas(SAMPLE_MODELS, mock_auth, FAKE_URL)
+
+ names = {p.name for p in personas}
+ for skip_name in _SKIP_PERSONA_NAMES:
+ self.assertNotIn(skip_name, names)
+
+ def test_skips_disabled_persona(self, mock_ctx):
+ mock_auth = AsyncMock()
+ personas = build_brainforge_personas(SAMPLE_MODELS, mock_auth, FAKE_URL)
+
+ names = {p.name for p in personas}
+ self.assertNotIn("DisabledBot", names)
+
+ def test_skips_persona_without_prompt(self, mock_ctx):
+ models = [{
+ "name": "BrainForge/test",
+ "version": "1.0",
+ "personas": [
+ {"persona_name": "EmptyBot", "system_prompt": "", "enabled": True},
+ {"persona_name": "NullBot", "enabled": True},
+ ],
+ }]
+ mock_auth = AsyncMock()
+ personas = build_brainforge_personas(models, mock_auth, FAKE_URL)
+ self.assertEqual(len(personas), 0)
+
+ def test_persona_has_correct_model_id(self, mock_ctx):
+ mock_auth = AsyncMock()
+ personas = build_brainforge_personas(SAMPLE_MODELS, mock_auth, FAKE_URL)
+
+ neon_persona = next(p for p in personas if p.name == "NeonAI")
+ self.assertEqual(neon_persona.llm.model_id, "BrainForge/neonai@2026.01.26")
+
+ def test_persona_shares_auth_manager(self, mock_ctx):
+ mock_auth = AsyncMock()
+ personas = build_brainforge_personas(SAMPLE_MODELS, mock_auth, FAKE_URL)
+
+ for p in personas:
+ self.assertIs(p.llm._auth, mock_auth)
+
+
+def _make_mock_orchestrator(existing_personas=None):
+ """Create a mock orchestrator with optional pre-registered personas."""
+ orch = MagicMock()
+ orch.personas = dict(existing_personas or {})
+
+ def register_side_effect(persona):
+ orch.personas[persona.id] = persona
+
+ def unregister_side_effect(pid):
+ orch.personas.pop(pid, None)
+
+ orch.register_persona.side_effect = register_side_effect
+ orch.unregister_persona.side_effect = unregister_side_effect
+ return orch
+
+
+def _make_mock_settings(api_url=FAKE_URL, username="user", password="pass"):
+ """Create mock settings with BrainForge config."""
+ settings = MagicMock()
+ settings.llm.brainforge.api_url = api_url
+ settings.llm.brainforge.username = username
+ settings.llm.brainforge.password = password
+ settings.llm.brainforge.sync_interval_seconds = 300
+ return settings
+
+
+@patch("app.llm.improved_brainforge_client.get_context_manager")
+@patch("app.core.brainforge_sync.fetch_brainforge_models", new_callable=AsyncMock)
+@patch("app.core.brainforge_sync.get_settings")
+class TestAsyncSyncBrainforgePersonas(unittest.TestCase):
+ """Tests for async_sync_brainforge_personas."""
+
+ def test_sync_registers_new_personas(self, mock_settings, mock_fetch, mock_ctx):
+ mock_settings.return_value = _make_mock_settings()
+ mock_fetch.return_value = SAMPLE_MODELS
+
+ orch = _make_mock_orchestrator()
+ count = asyncio.run(async_sync_brainforge_personas(orch))
+
+ self.assertEqual(count, 2)
+ self.assertEqual(orch.register_persona.call_count, 2)
+ registered_ids = {call.args[0].id for call in orch.register_persona.call_args_list}
+ self.assertTrue(all(pid.startswith(f"{BRAINFORGE_PERSONA_PREFIX}_") for pid in registered_ids))
+
+ def test_sync_removes_stale_personas(self, mock_settings, mock_fetch, mock_ctx):
+ mock_settings.return_value = _make_mock_settings()
+ mock_fetch.return_value = SAMPLE_MODELS[:1]
+
+ stale_persona = MagicMock()
+ stale_persona.id = "bf_nucleotidingsllm_Nucleotidings"
+ stale_persona.name = "Nucleotidings"
+
+ orch = _make_mock_orchestrator({
+ "bf_nucleotidingsllm_Nucleotidings": stale_persona,
+ })
+ asyncio.run(async_sync_brainforge_personas(orch))
+
+ self.assertNotIn("bf_nucleotidingsllm_Nucleotidings", orch.personas)
+
+ def test_sync_upserts_existing_personas(self, mock_settings, mock_fetch, mock_ctx):
+ mock_settings.return_value = _make_mock_settings()
+ mock_fetch.return_value = SAMPLE_MODELS
+
+ existing = MagicMock()
+ existing.id = "bf_neonai_NeonAI"
+ existing.name = "NeonAI"
+
+ orch = _make_mock_orchestrator({"bf_neonai_NeonAI": existing})
+ asyncio.run(async_sync_brainforge_personas(orch))
+
+ self.assertTrue(orch.register_persona.called)
+ re_registered_ids = {call.args[0].id for call in orch.register_persona.call_args_list}
+ self.assertIn("bf_neonai_NeonAI", re_registered_ids)
+
+ def test_sync_skips_when_no_url(self, mock_settings, mock_fetch, mock_ctx):
+ mock_settings.return_value = _make_mock_settings(api_url="")
+
+ orch = _make_mock_orchestrator()
+ count = asyncio.run(async_sync_brainforge_personas(orch))
+
+ self.assertEqual(count, 0)
+ mock_fetch.assert_not_called()
+
+ def test_sync_skips_when_no_credentials(self, mock_settings, mock_fetch, mock_ctx):
+ mock_settings.return_value = _make_mock_settings(username="", password="")
+
+ orch = _make_mock_orchestrator()
+ count = asyncio.run(async_sync_brainforge_personas(orch))
+
+ self.assertEqual(count, 0)
+ mock_fetch.assert_not_called()
+
+ def test_sync_skips_when_no_models(self, mock_settings, mock_fetch, mock_ctx):
+ mock_settings.return_value = _make_mock_settings()
+ mock_fetch.return_value = []
+
+ orch = _make_mock_orchestrator()
+ count = asyncio.run(async_sync_brainforge_personas(orch))
+
+ self.assertEqual(count, 0)
+
+ def test_sync_preserves_non_bf_personas(self, mock_settings, mock_fetch, mock_ctx):
+ mock_settings.return_value = _make_mock_settings()
+ mock_fetch.return_value = SAMPLE_MODELS
+
+ static_persona = MagicMock()
+ static_persona.id = "critic"
+ static_persona.name = "Constructive Critic"
+
+ orch = _make_mock_orchestrator({"critic": static_persona})
+ asyncio.run(async_sync_brainforge_personas(orch))
+
+ self.assertIn("critic", orch.personas)
+ unregistered_ids = [call.args[0] for call in orch.unregister_persona.call_args_list]
+ self.assertNotIn("critic", unregistered_ids)
diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_clarification.py b/multi_llm_chatbot_backend/app/tests/unit/test_clarification.py
index d8f63645..1f5dac56 100644
--- a/multi_llm_chatbot_backend/app/tests/unit/test_clarification.py
+++ b/multi_llm_chatbot_backend/app/tests/unit/test_clarification.py
@@ -28,10 +28,10 @@ def _make_mock_settings():
return settings
-def _make_orchestrator(persona_llm=None):
+def _make_orchestrator(persona_llm=None, orchestrator_llm=None):
"""Build an orchestrator with mocked dependencies, bypassing __init__."""
orch = ImprovedChatOrchestrator.__new__(ImprovedChatOrchestrator)
- orch.llm_client = None
+ orch.llm_client = orchestrator_llm
orch.session_manager = MagicMock()
orch.context_manager = MagicMock()
@@ -61,7 +61,7 @@ def test_skips_when_session_has_multiple_user_messages(self, mock_settings):
mock_settings.return_value = _make_mock_settings()
llm = MagicMock()
llm.generate = AsyncMock()
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
session = _make_session(user_message_count=3)
result = self._run(
@@ -78,7 +78,7 @@ def test_proceeds_when_session_has_one_user_message(self, mock_settings):
"needs_clarification": False,
"reason": "Clear.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
session = _make_session(user_message_count=1)
self._run(orch.needs_clarification_improved(session, "explain transformers"))
@@ -96,7 +96,7 @@ def test_returns_false_when_llm_says_clear(self, mock_settings):
"needs_clarification": False,
"reason": "The user asked about a specific topic.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
session = _make_session(user_message_count=1)
result = self._run(
@@ -116,7 +116,7 @@ def test_returns_true_when_llm_says_vague(self, mock_settings):
"needs_clarification": True,
"reason": "Single generic word with no topic.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
session = _make_session(user_message_count=1)
result = self._run(
@@ -137,7 +137,7 @@ def test_rejects_string_false_and_falls_back(self, mock_settings):
"needs_clarification": "false",
"reason": "Should have been a boolean.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
orch.needs_clarification = MagicMock(return_value=False)
session = _make_session(user_message_count=1)
@@ -155,7 +155,7 @@ def test_rejects_string_true_and_falls_back(self, mock_settings):
"needs_clarification": "true",
"reason": "Should have been a boolean.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
orch.needs_clarification = MagicMock(return_value=True)
session = _make_session(user_message_count=1)
@@ -173,7 +173,7 @@ def test_rejects_missing_key_and_falls_back(self, mock_settings):
llm.generate = AsyncMock(return_value=json.dumps({
"reason": "Forgot the main field.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
orch.needs_clarification = MagicMock(return_value=True)
session = _make_session(user_message_count=1)
@@ -192,7 +192,7 @@ def test_falls_back_on_malformed_json(self, mock_settings):
mock_settings.return_value = _make_mock_settings()
llm = MagicMock()
llm.generate = AsyncMock(return_value="this is not json at all")
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
orch.needs_clarification = MagicMock(return_value=False)
session = _make_session(user_message_count=1)
@@ -211,7 +211,7 @@ def test_falls_back_on_llm_exception(self, mock_settings):
mock_settings.return_value = _make_mock_settings()
llm = MagicMock()
llm.generate = AsyncMock(side_effect=RuntimeError("connection refused"))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
orch.needs_clarification = MagicMock(return_value=True)
session = _make_session(user_message_count=1)
@@ -250,7 +250,7 @@ def test_llm_called_with_json_mode_and_zero_temp(self, mock_settings):
"needs_clarification": False,
"reason": "Clear.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
session = _make_session(user_message_count=1)
self._run(
@@ -269,7 +269,7 @@ def test_system_prompt_includes_app_context(self, mock_settings):
"needs_clarification": False,
"reason": "Clear.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
session = _make_session(user_message_count=1)
self._run(
@@ -287,7 +287,7 @@ def test_system_prompt_includes_domain_keywords(self, mock_settings):
"needs_clarification": False,
"reason": "Clear.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
session = _make_session(user_message_count=1)
self._run(
@@ -306,7 +306,7 @@ def test_system_prompt_includes_advisor_names(self, mock_settings):
"needs_clarification": False,
"reason": "Clear.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
session = _make_session(user_message_count=1)
self._run(
@@ -323,7 +323,7 @@ def test_user_input_passed_in_user_prompt(self, mock_settings):
"needs_clarification": False,
"reason": "Clear.",
}))
- orch = _make_orchestrator(persona_llm=llm)
+ orch = _make_orchestrator(persona_llm=llm, orchestrator_llm=llm)
session = _make_session(user_message_count=1)
self._run(
diff --git a/phd_config.yaml b/phd_config.yaml
index 92df8162..121605cd 100644
--- a/phd_config.yaml
+++ b/phd_config.yaml
@@ -118,6 +118,22 @@ personas:
# Optional whitelist of advisor IDs. When set, only these advisors are
# available. Omit or leave unset to allow all enabled personas.
allowed_advisors:
+ - "critic"
+ - "empathetic"
+ - "methodologist"
+ - "minimalist"
+ - "motivator"
+ - "pragmatist"
+ - "socratic"
+ - "storyteller"
+ - "theorist"
+ - "visionary"
+ - "bf_nucleotidingsllm_NucleotidingsAI"
+ - "bf_logisticsllm_Logistics_Expert_LLM"
+ - "bf_logisticsllm_Agile_Project_Management_Instructor"
+ - "bf_security_CybersecurityExpert"
+ - "bf_shakespeare_Historian"
+ - "bf_shakespeare_Shakespeare"
# ── Orchestrator / Clarification ───────────────────────────────────────────
@@ -163,6 +179,8 @@ llm:
model: "llama3.2:1b"
vllm:
api_url: https://rtx6000blackwell-1.neonaiservices2.com/vllm0
+ brainforge:
+ api_url: https://hana.neonaialpha.com
rag:
embedding_model: "all-MiniLM-L6-v2"
diff --git a/scripts/list_available_advisors.py b/scripts/list_available_advisors.py
new file mode 100755
index 00000000..5991b5de
--- /dev/null
+++ b/scripts/list_available_advisors.py
@@ -0,0 +1,221 @@
+#!/usr/bin/env python3
+"""Discover all available advisors (static + BrainForge) and print their IDs
+for use in the allowed_advisors whitelist.
+
+Usage:
+ # Using credentials from .env / phd_config.yaml:
+ python3 scripts/list_available_advisors.py
+
+ # Override BrainForge credentials via flags:
+ python3 scripts/list_available_advisors.py \
+ --api-url https://hana.neonaialpha.com \
+ --username admin \
+ --password secret
+
+ # Static advisors only (no BrainForge connection needed):
+ python3 scripts/list_available_advisors.py --skip-brainforge
+"""
+
+import argparse
+import os
+import re
+import sys
+from pathlib import Path
+
+from dotenv import load_dotenv
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+load_dotenv(REPO_ROOT / ".env")
+
+try:
+ import httpx
+except ImportError:
+ print("Error: httpx is required. Install with: pip install httpx", file=sys.stderr)
+ sys.exit(1)
+
+try:
+ import yaml
+except ImportError:
+ yaml = None
+
+CONFIG_PATH = REPO_ROOT / "phd_config.yaml"
+
+BRAINFORGE_PERSONA_PREFIX = "bf"
+SKIP_PERSONA_NAMES = {"vanilla"}
+
+
+def _make_persona_id(model_name: str, persona_name: str) -> str:
+ """Mirror the ID generation logic from brainforge_sync.py."""
+ short_model = model_name.rsplit("/", 1)[-1].lower()
+ safe_name = re.sub(r"[^a-zA-Z0-9]+", "_", persona_name).strip("_")
+ return f"{BRAINFORGE_PERSONA_PREFIX}_{short_model}_{safe_name}"
+
+
+def load_yaml_config() -> dict:
+ """Load phd_config.yaml if available."""
+ if not CONFIG_PATH.exists():
+ return {}
+ if yaml is None:
+ print("Warning: PyYAML not installed, cannot read config file.", file=sys.stderr)
+ return {}
+ with open(CONFIG_PATH, "r") as f:
+ return yaml.safe_load(f) or {}
+
+
+def get_static_personas(config: dict) -> list:
+ """Load static persona IDs/names from the personas directory."""
+ personas_cfg = config.get("personas", {})
+ personas_dir = personas_cfg.get("personas_dir", "")
+
+ if not personas_dir:
+ return []
+
+ dir_path = Path(personas_dir)
+ if not dir_path.is_absolute():
+ dir_path = CONFIG_PATH.parent / dir_path
+
+ if not dir_path.is_dir():
+ return []
+
+ results = []
+ for f in sorted(dir_path.glob("*.yaml")):
+ if yaml is None:
+ pid = f.stem
+ results.append((pid, pid))
+ continue
+ with open(f, "r") as fh:
+ data = yaml.safe_load(fh) or {}
+ pid = data.get("id", f.stem)
+ name = data.get("name", pid)
+ results.append((pid, name))
+
+ return results
+
+
+def get_brainforge_credentials(args, config: dict) -> tuple:
+ """Resolve BrainForge credentials from args > env > config."""
+ bf_cfg = config.get("llm", {}).get("brainforge", {})
+
+ api_url = (args.api_url or bf_cfg.get("api_url") or "").rstrip("/")
+ username = args.username or os.getenv("BRAINFORGE_USERNAME", "") or bf_cfg.get("username", "")
+ password = args.password or os.getenv("BRAINFORGE_PASSWORD", "") or bf_cfg.get("password", "")
+
+ return api_url, username, password
+
+
+def login(api_url: str, username: str, password: str) -> str:
+ """Authenticate and return an access token."""
+ resp = httpx.post(
+ f"{api_url}/auth/login",
+ json={
+ "username": username,
+ "password": password,
+ "token_name": "persona-discovery",
+ "client_id": "list-brainforge-personas-script",
+ },
+ timeout=15,
+ )
+ if resp.status_code != 200:
+ print(f"Error: Login failed (HTTP {resp.status_code}): {resp.text[:200]}", file=sys.stderr)
+ sys.exit(1)
+
+ return resp.json()["access_token"]
+
+
+def fetch_models(api_url: str, token: str) -> list:
+ """Fetch all models from BrainForge."""
+ resp = httpx.post(
+ f"{api_url}/brainforge/get_models",
+ headers={"Authorization": f"Bearer {token}"},
+ timeout=15,
+ )
+ if resp.status_code != 200:
+ print(f"Error: get_models failed (HTTP {resp.status_code}): {resp.text[:200]}", file=sys.stderr)
+ sys.exit(1)
+
+ return resp.json().get("models", [])
+
+
+def get_brainforge_personas(api_url: str, username: str, password: str) -> list:
+ """Authenticate and fetch BrainForge persona IDs."""
+ print(f" Connecting to BrainForge at {api_url} ...")
+ token = login(api_url, username, password)
+ print(" Authenticated successfully.\n")
+
+ models = fetch_models(api_url, token)
+ if not models:
+ return []
+
+ results = []
+ for model in models:
+ model_name = model.get("name", "")
+ for p in model.get("personas", []):
+ persona_name = p.get("persona_name", "")
+ if persona_name.lower() in SKIP_PERSONA_NAMES:
+ continue
+ if not p.get("enabled", True):
+ continue
+ system_prompt = p.get("system_prompt") or p.get("description") or ""
+ if not system_prompt:
+ continue
+ pid = _make_persona_id(model_name, persona_name)
+ results.append((pid, persona_name, model_name))
+
+ return results
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser(
+ description="List all available persona IDs (static + BrainForge) for the allowed_advisors config."
+ )
+ parser.add_argument("--api-url", help="BrainForge API URL")
+ parser.add_argument("--username", help="BrainForge username")
+ parser.add_argument("--password", help="BrainForge password")
+ parser.add_argument("--skip-brainforge", action="store_true", help="Only list static personas")
+ args = parser.parse_args()
+
+ config = load_yaml_config()
+
+ # --- Static personas ---
+ static = get_static_personas(config)
+ print(f"Static personas (from persona YAML files): {len(static)} found\n")
+ for pid, name in static:
+ print(f' - "{pid}" # {name}')
+
+ # --- BrainForge personas ---
+ bf_personas = []
+ if not args.skip_brainforge:
+ api_url, username, password = get_brainforge_credentials(args, config)
+
+ if not api_url:
+ print("\n BrainForge: skipped (no API URL configured)")
+ elif not username or not password:
+ print("\n BrainForge: skipped (no credentials available)")
+ else:
+ print()
+ bf_personas = get_brainforge_personas(api_url, username, password)
+ print(f" BrainForge personas: {len(bf_personas)} found\n")
+ for pid, name, model in bf_personas:
+ print(f' - "{pid}" # {name} ({model})')
+ else:
+ print("\n BrainForge: skipped (--skip-brainforge)")
+
+ # --- Combined YAML output ---
+ all_ids = [pid for pid, _ in static] + [pid for pid, _, _ in bf_personas]
+
+ if not all_ids:
+ print("\nNo personas found.")
+ return 0
+
+ print("\n" + "=" * 60)
+ print("YAML-ready allowed_advisors (copy into phd_config.yaml):")
+ print("=" * 60 + "\n")
+ print(" allowed_advisors:")
+ for pid in all_ids:
+ print(f' - "{pid}"')
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
From e3c0c62caaeca4afda7ce3a6150aa4f19358c02d Mon Sep 17 00:00:00 2001
From: NeonDaniel
Date: Fri, 29 May 2026 23:32:35 +0000
Subject: [PATCH 14/31] Increment Version to 2.0.1a6
---
multi_llm_chatbot_backend/app/version.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/multi_llm_chatbot_backend/app/version.py b/multi_llm_chatbot_backend/app/version.py
index 0dabcc30..7b440d62 100644
--- a/multi_llm_chatbot_backend/app/version.py
+++ b/multi_llm_chatbot_backend/app/version.py
@@ -1,4 +1,4 @@
-__version__ = "2.0.1a5"
+__version__ = "2.0.1a6"
if __name__ == "__main__":
print(__version__)
From 4b8ce599865944e6198015af28b9cc409bcfb298 Mon Sep 17 00:00:00 2001
From: NeonDaniel
Date: Fri, 29 May 2026 23:32:54 +0000
Subject: [PATCH 15/31] Update Changelog
---
CHANGELOG.md | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ea07a5ca..2d24b7e0 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,17 @@
# Changelog
+## [2.0.1a6](https://github.com/NeonGeckoCom/CCAI-Demo/tree/2.0.1a6) (2026-05-29)
+
+[Full Changelog](https://github.com/NeonGeckoCom/CCAI-Demo/compare/2.0.1a5...2.0.1a6)
+
+**Implemented enhancements:**
+
+- \[FEAT\] Support BrainForge LLM Backend [\#49](https://github.com/NeonGeckoCom/CCAI-Demo/issues/49)
+
+**Merged pull requests:**
+
+- BrainForge LLM Integration [\#62](https://github.com/NeonGeckoCom/CCAI-Demo/pull/62) ([NeonCharlie-24](https://github.com/NeonCharlie-24))
+
## [2.0.1a5](https://github.com/NeonGeckoCom/CCAI-Demo/tree/2.0.1a5) (2026-05-29)
[Full Changelog](https://github.com/NeonGeckoCom/CCAI-Demo/compare/2.0.1a4...2.0.1a5)
From c969fb568b9ba554e0a4dc0a209e215e78eb6658 Mon Sep 17 00:00:00 2001
From: NeonCharlie-24
Date: Fri, 29 May 2026 18:04:55 -0700
Subject: [PATCH 16/31] persist advisor responses to db in /chat-stream
endpoint. (#78)
* persist advisor responses to db in /chat-stream endpoint.
* persist reply messages to db in /reply-to-advisor
* persist expanded messages to db in /chat/{persona_id}.
* persist document upload notifications to db in /upload-document.
* deleted saveMessageToSession function and all call site from ChatPage.js to remove all frontend message persistence.
* bug fix to /chat/{persona_id} endpoint to reuse the existing chat session.
* added TODO to deprecate /chat-sessions/{session_id}/messages endpoint once we confirm changes work.
* added PersistMessage Pydantic model to validate chat messages.
* persist errors on the backend where they were previously being persisted on the frontend by saveMessageToSession.
* deprecated the /chat-sessions/{session_id}/messages endpoint.
* fix advisorName mapping on frontend for messages loaded from the db.
* revert back to creating a fresh session for the expand function (same implementation prior to this feature).
* rename PersistMessage persona_name to advisorName to match what the frontend expects.
* added model validation for advisor and clarification and small additional model for replyTo.
* added model validator to verify replyTo exists when isReply is True.
* fixed sidebar counter to refresh on all messages and added replyTo metadata to advisor replies.
* fixed sidebar counter to increment when user sends reply or expands message.
---
.../app/api/routes/chat.py | 144 +++++++++-
.../app/api/routes/chat_sessions.py | 51 +---
.../app/api/routes/documents.py | 9 +-
multi_llm_chatbot_backend/app/models/user.py | 56 +++-
.../app/tests/unit/test_chat_sessions.py | 50 +---
.../unit/test_chat_stream_persistence.py | 259 ++++++++++++++++++
phd-advisor-frontend/src/pages/ChatPage.js | 67 +----
7 files changed, 473 insertions(+), 163 deletions(-)
create mode 100644 multi_llm_chatbot_backend/app/tests/unit/test_chat_stream_persistence.py
diff --git a/multi_llm_chatbot_backend/app/api/routes/chat.py b/multi_llm_chatbot_backend/app/api/routes/chat.py
index 3a600ab9..aaad4e0c 100644
--- a/multi_llm_chatbot_backend/app/api/routes/chat.py
+++ b/multi_llm_chatbot_backend/app/api/routes/chat.py
@@ -17,7 +17,7 @@
from app.core.database import get_database
from app.core.persona_filter import get_available_persona_ids
from app.core.session_manager import get_session_manager
-from app.models.user import User
+from app.models.user import PersistMessage, ReplyToRef, User
logger = logging.getLogger(__name__)
@@ -27,6 +27,7 @@
# Enhanced data models
class UserInput(BaseModel):
user_input: str
+ chat_session_id: Optional[str] = None
class ChatMessage(BaseModel):
user_input: str
@@ -97,11 +98,13 @@ async def _event_generator():
# Append user message to in-memory session and persist to MongoDB
session.append_message("user", message.user_input)
if message.chat_session_id:
- await persist_message(message.chat_session_id, {
- "id": str(ObjectId()),
- "type": "user",
- "content": message.user_input,
- })
+ await persist_message(
+ message.chat_session_id,
+ PersistMessage(type="user", content=message.user_input),
+ )
+ yield ChatStreamLine(
+ type="progress", data={"phase": "received"},
+ ).to_ndjson()
if await chat_orchestrator.needs_clarification_improved(session, message.user_input):
clar = await chat_orchestrator.generate_contextual_clarification(message.user_input)
@@ -122,7 +125,18 @@ async def _event_generator():
# directly and skip persona generation.
tool_result = await chat_orchestrator.get_tool_response(message.user_input)
if tool_result.used_tool:
+ # Append user message to in-memory session and persist to MongoDB
session.append_message("orchestrator", tool_result.text)
+ if message.chat_session_id:
+ await persist_message(
+ message.chat_session_id,
+ PersistMessage(
+ type="advisor",
+ persona_id="orchestrator",
+ advisorName="Orchestrator",
+ content=tool_result.text,
+ ),
+ )
yield ChatStreamLine(
type="advisor",
data={
@@ -216,6 +230,18 @@ async def _run(pid: str) -> None:
for _ in range(len(tasks)):
result = await done_queue.get()
+ if message.chat_session_id:
+ await persist_message(
+ message.chat_session_id,
+ PersistMessage(
+ type="advisor",
+ persona_id=result["persona_id"],
+ advisorName=result["persona_name"],
+ content=result["response"],
+ used_documents=result.get("used_documents", False),
+ document_chunks_used=result.get("document_chunks_used", 0),
+ ),
+ )
line = ChatStreamLine(
type="advisor",
data={
@@ -372,7 +398,17 @@ async def chat_with_specific_advisor(persona_id: str, input: UserInput, request:
# Use async session management
session_id = await get_or_create_session_for_request_async(request)
-
+
+ if input.chat_session_id:
+ await persist_message(
+ input.chat_session_id,
+ PersistMessage(
+ type="user",
+ content=input.user_input,
+ isExpandRequest=True,
+ ),
+ )
+
result = await chat_orchestrator.chat_with_persona(
user_input=input.user_input,
persona_id=persona_id,
@@ -382,30 +418,64 @@ async def chat_with_specific_advisor(persona_id: str, input: UserInput, request:
# Handle response structure
if result.get("type") == "single_persona_response" and "persona" in result:
persona_data = result["persona"]
+ if input.chat_session_id:
+ await persist_message(
+ input.chat_session_id,
+ PersistMessage(
+ type="advisor",
+ persona_id=persona_data["persona_id"],
+ advisorName=persona_data["persona_name"],
+ content=persona_data["response"],
+ isExpansion=True,
+ ),
+ )
return {
"persona": persona_data["persona_name"],
"persona_id": persona_data["persona_id"],
"response": persona_data["response"]
}
elif "persona_id" in result and "response" in result:
+ if input.chat_session_id:
+ await persist_message(
+ input.chat_session_id,
+ PersistMessage(
+ type="advisor",
+ persona_id=result["persona_id"],
+ advisorName=result["persona_name"],
+ content=result["response"],
+ isExpansion=True,
+ ),
+ )
return {
"persona": result["persona_name"],
"persona_id": result["persona_id"],
"response": result["response"]
}
else:
+ error_content = "Sorry, I received an unexpected response format. Please try again."
+ if input.chat_session_id:
+ await persist_message(
+ input.chat_session_id,
+ PersistMessage(type="error", content=error_content),
+ )
return {
"persona": "System",
- "response": "I'm having trouble generating a response right now. Please try again."
+ "response": error_content,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in chat_with_specific_advisor: {e}")
+ error_content = "Sorry, I encountered an error while expanding the message. Please try again."
+ if input.chat_session_id:
+ await persist_message(
+ input.chat_session_id,
+ PersistMessage(type="error", content=error_content),
+ )
return {
"persona": "System",
- "response": "I'm having trouble generating a response right now. Please try again."
+ "response": error_content,
}
@router.post("/reply-to-advisor")
@@ -422,7 +492,21 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request):
session_id = await get_or_create_session_for_request_async(request)
session = session_manager.get_session(session_id)
-
+
+ if reply.chat_session_id:
+ await persist_message(
+ reply.chat_session_id,
+ PersistMessage(
+ type="user",
+ content=reply.user_input,
+ replyTo=ReplyToRef(
+ advisorId=reply.advisor_id,
+ advisorName=chat_orchestrator.get_persona(reply.advisor_id).name,
+ messageId=reply.original_message_id,
+ ),
+ ),
+ )
+
# Find the original message being replied to for context
original_message = None
if reply.original_message_id:
@@ -445,6 +529,22 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request):
# Handle response structure
if result.get("type") == "single_persona_response" and "persona" in result:
persona_data = result["persona"]
+ if reply.chat_session_id:
+ await persist_message(
+ reply.chat_session_id,
+ PersistMessage(
+ type="advisor",
+ persona_id=persona_data["persona_id"],
+ advisorName=persona_data["persona_name"],
+ content=persona_data["response"],
+ isReply=True,
+ replyTo=ReplyToRef(
+ advisorId=reply.advisor_id,
+ advisorName=persona_data["persona_name"],
+ messageId=reply.original_message_id,
+ ),
+ ),
+ )
return {
"type": "advisor_reply",
"persona": persona_data["persona_name"],
@@ -453,6 +553,22 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request):
"original_message_id": reply.original_message_id
}
elif "persona_id" in result and "response" in result:
+ if reply.chat_session_id:
+ await persist_message(
+ reply.chat_session_id,
+ PersistMessage(
+ type="advisor",
+ persona_id=result["persona_id"],
+ advisorName=result["persona_name"],
+ content=result["response"],
+ isReply=True,
+ replyTo=ReplyToRef(
+ advisorId=reply.advisor_id,
+ advisorName=result["persona_name"],
+ messageId=reply.original_message_id,
+ ),
+ ),
+ )
return {
"type": "advisor_reply",
"persona": result["persona_name"],
@@ -471,10 +587,16 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request):
raise
except Exception as e:
logger.error(f"Error in reply_to_advisor: {e}")
+ error_content = "Sorry, I encountered an error with your reply. Please try again."
+ if reply.chat_session_id:
+ await persist_message(
+ reply.chat_session_id,
+ PersistMessage(type="error", content=error_content),
+ )
return {
"type": "error",
"persona": "System",
- "response": "I'm having trouble generating a reply right now. Please try again."
+ "response": error_content,
}
@router.post("/ask/")
diff --git a/multi_llm_chatbot_backend/app/api/routes/chat_sessions.py b/multi_llm_chatbot_backend/app/api/routes/chat_sessions.py
index 42f71790..a00d1922 100644
--- a/multi_llm_chatbot_backend/app/api/routes/chat_sessions.py
+++ b/multi_llm_chatbot_backend/app/api/routes/chat_sessions.py
@@ -2,7 +2,7 @@
from typing import List, Optional
from datetime import datetime
from bson import ObjectId
-from app.models.user import User, ChatSession, ChatSessionResponse
+from app.models.user import User, ChatSession, ChatSessionResponse, PersistMessage
from app.core.auth import get_current_active_user
from app.core.database import get_database
from pydantic import BaseModel
@@ -19,14 +19,11 @@ class UpdateChatSessionRequest(BaseModel):
title: Optional[str] = None
messages: Optional[List[dict]] = None
-class SaveMessageRequest(BaseModel):
- session_id: str
- message: dict
-async def persist_message(session_id: str, message: dict):
+async def persist_message(session_id: str, message: PersistMessage):
"""Write a single message to a MongoDB chat session."""
db = get_database()
- msg = message.copy()
+ msg = message.model_dump(exclude_none=True)
if "timestamp" not in msg:
msg["timestamp"] = datetime.utcnow().isoformat()
await db.chat_sessions.update_one(
@@ -246,48 +243,6 @@ async def update_chat_session(
detail="Could not update chat session"
)
-@router.post("/chat-sessions/{session_id}/messages")
-async def save_message_to_session(
- session_id: str,
- request: SaveMessageRequest,
- current_user: User = Depends(get_current_active_user)
-):
- """
- Add a message to a chat session.
- @param session_id: MongoDB ObjectId of the chat session
- @param request: SaveMessageRequest with the message dict
- @param current_user: Authenticated user from dependency injection
- @return: Dict with a confirmation message
- """
- try:
- db = get_database()
-
- # Verify session belongs to user
- session_data = await db.chat_sessions.find_one({
- "_id": ObjectId(session_id),
- "user_id": current_user.id,
- "is_active": True
- })
-
- if not session_data:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Chat session not found"
- )
-
- await persist_message(session_id, request.message)
-
- return {"message": "Message saved successfully"}
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"Error saving message: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Could not save message"
- )
-
@router.delete("/chat-sessions")
diff --git a/multi_llm_chatbot_backend/app/api/routes/documents.py b/multi_llm_chatbot_backend/app/api/routes/documents.py
index 4756e276..54a240b4 100644
--- a/multi_llm_chatbot_backend/app/api/routes/documents.py
+++ b/multi_llm_chatbot_backend/app/api/routes/documents.py
@@ -9,9 +9,10 @@
from app.utils.file_export import prepare_export_response, generate_pdf_file_from_blocks
from app.core.session_manager import get_session_manager
from app.core.bootstrap import chat_orchestrator
+from app.api.routes.chat_sessions import persist_message
from app.core.auth import get_current_active_user
from app.core.database import get_database
-from app.models.user import User
+from app.models.user import PersistMessage, User
from bson import ObjectId
import logging
import re
@@ -217,6 +218,12 @@ async def upload_document(
f"Document uploaded: '{doc_title}' ({file.filename}) - {rag_result['chunks_created']} sections processed, ~{rag_result['total_tokens']} tokens analyzed. You can now ask questions about this document by referencing it by name."
)
+ if chat_session_id:
+ await persist_message(chat_session_id, PersistMessage(
+ type="document_upload",
+ content=f"Document uploaded: {file.filename} ({rag_result['chunks_created']} sections processed)",
+ ))
+
# Return session info for frontend tracking
return {
"message": f"Document '{file.filename}' uploaded and processed successfully.",
diff --git a/multi_llm_chatbot_backend/app/models/user.py b/multi_llm_chatbot_backend/app/models/user.py
index 2a3c9e16..24b08957 100644
--- a/multi_llm_chatbot_backend/app/models/user.py
+++ b/multi_llm_chatbot_backend/app/models/user.py
@@ -1,5 +1,5 @@
-from pydantic import BaseModel, EmailStr, Field, ConfigDict
-from typing import Optional, List, Any
+from pydantic import BaseModel, EmailStr, Field, ConfigDict, model_validator
+from typing import Literal, Optional, List, Any
from datetime import datetime
from bson import ObjectId
@@ -62,6 +62,56 @@ class UserResponse(BaseModel):
created_at: datetime
last_login: Optional[datetime] = None
+MessageType = Literal[
+ "user", "advisor", "error", "clarification", "document_upload", "system",
+]
+
+
+class ReplyToRef(BaseModel):
+ """Reference to the advisor message being replied to."""
+ advisorId: str
+ advisorName: str
+ messageId: str
+
+
+class PersistMessage(BaseModel):
+ """Schema for a single message stored in a ChatSession's messages array."""
+ id: str = Field(default_factory=lambda: str(ObjectId()))
+ type: MessageType
+ content: str
+ timestamp: Optional[str] = None
+ # Advisor-specific
+ persona_id: Optional[str] = None
+ advisorName: Optional[str] = None
+ used_documents: bool = False
+ document_chunks_used: int = 0
+ # Clarification-specific
+ suggestions: Optional[List[str]] = None
+ # Reply/expand metadata
+ isReply: bool = False
+ isExpansion: bool = False
+ isExpandRequest: bool = False
+ replyTo: Optional[ReplyToRef] = None
+
+ @model_validator(mode='after')
+ def check_type_constraints(self):
+ if self.type == 'advisor':
+ if not self.persona_id:
+ raise ValueError("persona_id is required for advisor messages")
+ if not self.advisorName:
+ raise ValueError("advisorName is required for advisor messages")
+ elif self.type == 'clarification':
+ if not self.suggestions:
+ raise ValueError("a non-empty suggestions list is required for clarification messages")
+ return self
+
+ @model_validator(mode='after')
+ def check_reply_metadata(self):
+ if self.isReply and not self.replyTo:
+ raise ValueError("replyTo is required when isReply is True")
+ return self
+
+
class ChatSession(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
@@ -72,7 +122,7 @@ class ChatSession(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
user_id: PyObjectId
title: str
- messages: List[dict] = []
+ messages: List[PersistMessage] = []
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
is_active: bool = True
diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_chat_sessions.py b/multi_llm_chatbot_backend/app/tests/unit/test_chat_sessions.py
index 138ff084..2500bad1 100644
--- a/multi_llm_chatbot_backend/app/tests/unit/test_chat_sessions.py
+++ b/multi_llm_chatbot_backend/app/tests/unit/test_chat_sessions.py
@@ -12,18 +12,16 @@
from app.api.routes.chat_sessions import ( # noqa: E402
CreateChatSessionRequest,
UpdateChatSessionRequest,
- SaveMessageRequest,
persist_message,
create_chat_session,
get_user_chat_sessions,
get_chat_sessions_count,
get_chat_session,
update_chat_session,
- save_message_to_session,
delete_all_chat_sessions,
delete_chat_session,
)
-from app.models.user import User # noqa: E402
+from app.models.user import PersistMessage, User # noqa: E402
FAKE_USER_ID = ObjectId()
OTHER_USER_ID = ObjectId()
@@ -92,7 +90,7 @@ def test_appends_message_with_auto_timestamp(self, mock_get_db):
db = _mock_db()
mock_get_db.return_value = db
- msg = {"type": "user", "content": "hello"}
+ msg = PersistMessage(type="user", content="hello")
asyncio.run(persist_message(str(FAKE_SESSION_ID), msg))
args = db.chat_sessions.update_one.call_args
@@ -105,7 +103,7 @@ def test_preserves_existing_timestamp(self, mock_get_db):
db = _mock_db()
mock_get_db.return_value = db
- msg = {"type": "user", "content": "hi", "timestamp": "2025-01-01T00:00:00"}
+ msg = PersistMessage(type="user", content="hi", timestamp="2025-01-01T00:00:00")
asyncio.run(persist_message(str(FAKE_SESSION_ID), msg))
pushed = db.chat_sessions.update_one.call_args[0][1]["$push"]["messages"]
@@ -299,48 +297,6 @@ def test_returns_404_for_nonexistent_session(self, mock_get_db):
# ------------------------------------------------------------------
-@patch("app.api.routes.chat_sessions.get_database")
-class TestSaveMessageToSession(unittest.TestCase):
-
- def test_saves_message_to_valid_session(self, mock_get_db):
- db = _mock_db()
- db.chat_sessions.find_one.return_value = _make_session_doc()
- mock_get_db.return_value = db
-
- user = _make_fake_user()
- req = SaveMessageRequest(
- session_id=str(FAKE_SESSION_ID),
- message={"type": "user", "content": "test"},
- )
-
- result = asyncio.run(
- save_message_to_session(
- session_id=str(FAKE_SESSION_ID), request=req, current_user=user
- )
- )
-
- self.assertEqual(result["message"], "Message saved successfully")
-
- def test_returns_404_for_nonexistent_session(self, mock_get_db):
- db = _mock_db()
- db.chat_sessions.find_one.return_value = None
- mock_get_db.return_value = db
-
- user = _make_fake_user()
- req = SaveMessageRequest(
- session_id=str(FAKE_SESSION_ID),
- message={"type": "user", "content": "test"},
- )
-
- with self.assertRaises(HTTPException) as ctx:
- asyncio.run(
- save_message_to_session(
- session_id=str(FAKE_SESSION_ID), request=req, current_user=user
- )
- )
-
- self.assertEqual(ctx.exception.status_code, 404)
-
# ------------------------------------------------------------------
# DELETE /chat-sessions (bulk delete)
diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_chat_stream_persistence.py b/multi_llm_chatbot_backend/app/tests/unit/test_chat_stream_persistence.py
new file mode 100644
index 00000000..608d04d6
--- /dev/null
+++ b/multi_llm_chatbot_backend/app/tests/unit/test_chat_stream_persistence.py
@@ -0,0 +1,259 @@
+import unittest
+
+from bson import ObjectId
+
+from app.models.user import PersistMessage, ReplyToRef
+
+
+# ------------------------------------------------------------------
+# PersistMessage – advisor type
+# ------------------------------------------------------------------
+
+
+ADVISOR_REQUIRED_FIELDS = {"id", "type", "persona_id", "advisorName", "content",
+ "used_documents", "document_chunks_used"}
+
+
+class TestAdvisorPersistMessage(unittest.TestCase):
+
+ def test_includes_all_required_fields(self):
+ msg = PersistMessage(
+ type="advisor",
+ persona_id="advisor_a",
+ advisorName="Advisor A",
+ content="Some advice.",
+ ).model_dump(exclude_none=True)
+ self.assertTrue(ADVISOR_REQUIRED_FIELDS.issubset(msg.keys()),
+ f"Missing fields: {ADVISOR_REQUIRED_FIELDS - msg.keys()}")
+
+ def test_type_is_advisor(self):
+ msg = PersistMessage(
+ type="advisor", persona_id="x", advisorName="X", content="c",
+ )
+ self.assertEqual(msg.type, "advisor")
+
+ def test_advisor_name_stored(self):
+ msg = PersistMessage(
+ type="advisor",
+ persona_id="methodologist",
+ advisorName="Dr. Method",
+ content="content",
+ )
+ self.assertEqual(msg.advisorName, "Dr. Method")
+ self.assertEqual(msg.persona_id, "methodologist")
+
+ def test_defaults_for_document_fields(self):
+ msg = PersistMessage(
+ type="advisor", persona_id="x", advisorName="X", content="c",
+ )
+ self.assertFalse(msg.used_documents)
+ self.assertEqual(msg.document_chunks_used, 0)
+
+ def test_explicit_document_fields(self):
+ msg = PersistMessage(
+ type="advisor",
+ persona_id="x",
+ advisorName="X",
+ content="c",
+ used_documents=True,
+ document_chunks_used=5,
+ )
+ self.assertTrue(msg.used_documents)
+ self.assertEqual(msg.document_chunks_used, 5)
+
+ def test_id_is_valid_objectid_string(self):
+ msg = PersistMessage(
+ type="advisor", persona_id="x", advisorName="X", content="c",
+ )
+ ObjectId(msg.id)
+
+ def test_each_call_generates_unique_id(self):
+ ids = {
+ PersistMessage(
+ type="advisor", persona_id="x", advisorName="X", content="c",
+ ).id
+ for _ in range(10)
+ }
+ self.assertEqual(len(ids), 10)
+
+ def test_reply_flag(self):
+ msg = PersistMessage(
+ type="advisor",
+ persona_id="x",
+ advisorName="X",
+ content="c",
+ isReply=True,
+ replyTo=ReplyToRef(
+ advisorId="y", advisorName="Y", messageId="msg_1",
+ ),
+ )
+ self.assertTrue(msg.isReply)
+ self.assertIsNotNone(msg.replyTo)
+
+ def test_orchestrator_message_shape(self):
+ msg = PersistMessage(
+ type="advisor",
+ persona_id="orchestrator",
+ advisorName="Orchestrator",
+ content="Tool output here",
+ )
+ self.assertEqual(msg.persona_id, "orchestrator")
+ self.assertEqual(msg.advisorName, "Orchestrator")
+ self.assertEqual(msg.content, "Tool output here")
+ self.assertEqual(msg.type, "advisor")
+
+ def test_expansion_flag(self):
+ msg = PersistMessage(
+ type="advisor",
+ persona_id="theorist",
+ advisorName="Dr. Theory",
+ content="Here is a deeper explanation...",
+ isExpansion=True,
+ )
+ self.assertEqual(msg.type, "advisor")
+ self.assertTrue(msg.isExpansion)
+ self.assertEqual(msg.persona_id, "theorist")
+
+
+# ------------------------------------------------------------------
+# PersistMessage – user type
+# ------------------------------------------------------------------
+
+
+USER_REQUIRED_FIELDS = {"id", "type", "content"}
+
+
+class TestUserPersistMessage(unittest.TestCase):
+
+ def test_includes_required_fields(self):
+ msg = PersistMessage(type="user", content="hello").model_dump(exclude_none=True)
+ self.assertTrue(USER_REQUIRED_FIELDS.issubset(msg.keys()),
+ f"Missing fields: {USER_REQUIRED_FIELDS - msg.keys()}")
+
+ def test_type_is_user(self):
+ msg = PersistMessage(type="user", content="hello")
+ self.assertEqual(msg.type, "user")
+
+ def test_content_preserved(self):
+ msg = PersistMessage(type="user", content="Tell me more")
+ self.assertEqual(msg.content, "Tell me more")
+
+ def test_id_is_valid_objectid_string(self):
+ msg = PersistMessage(type="user", content="hello")
+ ObjectId(msg.id)
+
+ def test_each_call_generates_unique_id(self):
+ ids = {
+ PersistMessage(type="user", content="hello").id
+ for _ in range(10)
+ }
+ self.assertEqual(len(ids), 10)
+
+ def test_reply_to_metadata(self):
+ msg = PersistMessage(
+ type="user",
+ content="I disagree",
+ replyTo=ReplyToRef(
+ advisorId="methodologist",
+ advisorName="Dr. Method",
+ messageId="msg_123",
+ ),
+ )
+ self.assertEqual(msg.replyTo.advisorId, "methodologist")
+ self.assertEqual(msg.replyTo.advisorName, "Dr. Method")
+ self.assertEqual(msg.replyTo.messageId, "msg_123")
+
+ def test_plain_message_has_no_replyTo(self):
+ msg = PersistMessage(type="user", content="hello").model_dump(exclude_none=True)
+ self.assertNotIn("replyTo", msg)
+
+ def test_expand_request_shape(self):
+ msg = PersistMessage(
+ type="user",
+ content="Please expand on your previous response...",
+ isExpandRequest=True,
+ )
+ self.assertEqual(msg.type, "user")
+ self.assertTrue(msg.isExpandRequest)
+
+
+# ------------------------------------------------------------------
+# PersistMessage – error type
+# ------------------------------------------------------------------
+
+
+class TestErrorPersistMessage(unittest.TestCase):
+
+ def test_type_is_error(self):
+ msg = PersistMessage(type="error", content="Something went wrong")
+ self.assertEqual(msg.type, "error")
+
+
+# ------------------------------------------------------------------
+# PersistMessage – type validation
+# ------------------------------------------------------------------
+
+
+class TestPersistMessageTypeValidation(unittest.TestCase):
+
+ def test_rejects_invalid_type(self):
+ from pydantic import ValidationError
+ with self.assertRaises(ValidationError):
+ PersistMessage(type="bogus", content="hello")
+
+ def test_all_valid_types_accepted(self):
+ valid = {
+ "user": {},
+ "advisor": {"persona_id": "x", "advisorName": "X"},
+ "error": {},
+ "clarification": {"suggestions": ["Try this"]},
+ "document_upload": {},
+ "system": {},
+ }
+ for t, kwargs in valid.items():
+ msg = PersistMessage(type=t, content="test", **kwargs)
+ self.assertEqual(msg.type, t)
+
+
+# ------------------------------------------------------------------
+# PersistMessage – model validators
+# ------------------------------------------------------------------
+
+
+class TestPersistMessageValidators(unittest.TestCase):
+
+ def test_advisor_without_persona_id_rejected(self):
+ from pydantic import ValidationError
+ with self.assertRaises(ValidationError):
+ PersistMessage(type="advisor", advisorName="X", content="c")
+
+ def test_advisor_without_advisor_name_rejected(self):
+ from pydantic import ValidationError
+ with self.assertRaises(ValidationError):
+ PersistMessage(type="advisor", persona_id="x", content="c")
+
+ def test_advisor_without_both_rejected(self):
+ from pydantic import ValidationError
+ with self.assertRaises(ValidationError):
+ PersistMessage(type="advisor", content="c")
+
+ def test_clarification_without_suggestions_rejected(self):
+ from pydantic import ValidationError
+ with self.assertRaises(ValidationError):
+ PersistMessage(type="clarification", content="Need more info")
+
+ def test_clarification_with_empty_suggestions_rejected(self):
+ from pydantic import ValidationError
+ with self.assertRaises(ValidationError):
+ PersistMessage(type="clarification", content="Need more info", suggestions=[])
+
+ def test_reply_without_reply_to_rejected(self):
+ from pydantic import ValidationError
+ with self.assertRaises(ValidationError):
+ PersistMessage(
+ type="advisor",
+ persona_id="x",
+ advisorName="X",
+ content="c",
+ isReply=True,
+ )
diff --git a/phd-advisor-frontend/src/pages/ChatPage.js b/phd-advisor-frontend/src/pages/ChatPage.js
index 063a454b..053811f8 100644
--- a/phd-advisor-frontend/src/pages/ChatPage.js
+++ b/phd-advisor-frontend/src/pages/ChatPage.js
@@ -215,30 +215,6 @@ const loadChatSession = async (sessionId) => {
}
};
-// Save a message to the current session
-const saveMessageToSession = async (message) => {
- if (!currentSessionId || !authToken) return;
-
- try {
- await fetch(`${process.env.REACT_APP_API_URL}/api/chat-sessions/${currentSessionId}/messages`, {
- method: 'POST',
- headers: {
- 'Authorization': `Bearer ${authToken}`,
- 'Content-Type': 'application/json'
- },
- body: JSON.stringify({
- session_id: currentSessionId,
- message: {
- ...message,
- timestamp: message.timestamp.toISOString()
- }
- })
- });
- } catch (error) {
- console.error('Error saving message to session:', error);
- }
-};
-
// Update session title based on first message
const updateSessionTitle = async (sessionId, newTitle) => {
if (!sessionId || !authToken) return;
@@ -364,10 +340,6 @@ const handleNewChat = async (sessionId = null) => {
current_session_id: currentSessionId
});
- // Save document upload message to database if we have a current session
- if (currentSessionId) {
- await saveMessageToSession(documentMessage);
- }
};
@@ -428,6 +400,7 @@ const handleNewChat = async (sessionId = null) => {
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
+ let refreshedForUserMessage = false;
while (true) {
const { done, value } = await reader.read();
@@ -441,6 +414,11 @@ const handleNewChat = async (sessionId = null) => {
if (!line.trim()) continue;
const payload = JSON.parse(line);
+ if (!refreshedForUserMessage) {
+ setSidebarRefreshTrigger(prev => prev + 1);
+ refreshedForUserMessage = true;
+ }
+
const d = payload.data || {};
switch (payload.type) {
@@ -457,7 +435,6 @@ const handleNewChat = async (sessionId = null) => {
};
setMessages(prev => [...prev, msg]);
setThinkingAdvisors(prev => prev.filter(a => a !== d.persona_id));
- await saveMessageToSession(msg);
break;
}
case 'clarification':
@@ -530,10 +507,8 @@ const handleNewChat = async (sessionId = null) => {
};
setMessages(prev => [...prev, replyMessage]);
-
- // Save reply message to database with explicit session ID
- await saveMessageToSession(replyMessage, sessionId);
-
+ setSidebarRefreshTrigger(prev => prev + 1);
+
setIsLoading(true);
setThinkingAdvisors([replyContext.persona_id]);
@@ -568,9 +543,6 @@ const handleNewChat = async (sessionId = null) => {
timestamp: new Date()
};
setMessages(prev => [...prev, replyResponseMessage]);
-
- // Save advisor reply to database
- await saveMessageToSession(replyResponseMessage, sessionId);
}
} catch (error) {
@@ -582,13 +554,11 @@ const handleNewChat = async (sessionId = null) => {
timestamp: new Date()
};
setMessages(prev => [...prev, errorMessage]);
-
- // Save error message to database
- await saveMessageToSession(errorMessage, sessionId);
}
setIsLoading(false);
setThinkingAdvisors([]);
+ setSidebarRefreshTrigger(prev => prev + 1);
};
const handleCopyMessage = (messageId, content) => {
@@ -614,10 +584,8 @@ const handleNewChat = async (sessionId = null) => {
expandsMessageId: messageId
};
setMessages(prev => [...prev, expandMessage]);
-
- // Save expand request to database
- await saveMessageToSession(expandMessage);
-
+ setSidebarRefreshTrigger(prev => prev + 1);
+
setIsLoading(true);
setThinkingAdvisors([advisorId]);
@@ -629,7 +597,8 @@ const handleNewChat = async (sessionId = null) => {
},
body: JSON.stringify({
user_input: expandPrompt,
- response_length: 'long'
+ response_length: 'long',
+ chat_session_id: currentSessionId
}),
});
@@ -651,9 +620,6 @@ const handleNewChat = async (sessionId = null) => {
timestamp: new Date()
};
setMessages(prev => [...prev, expandedMessage]);
-
- // Save expanded response to database
- await saveMessageToSession(expandedMessage);
} else {
const errorMessage = {
id: generateMessageId(),
@@ -662,9 +628,6 @@ const handleNewChat = async (sessionId = null) => {
timestamp: new Date()
};
setMessages(prev => [...prev, errorMessage]);
-
- // Save error message to database
- await saveMessageToSession(errorMessage);
}
} catch (error) {
@@ -676,13 +639,11 @@ const handleNewChat = async (sessionId = null) => {
timestamp: new Date()
};
setMessages(prev => [...prev, errorMessage]);
-
- // Save error message to database
- await saveMessageToSession(errorMessage);
}
setIsLoading(false);
setThinkingAdvisors([]);
+ setSidebarRefreshTrigger(prev => prev + 1);
};
const handleReplyToMessage = (message) => {
From af809cb47981b32a77382990e170fd5ee034a359 Mon Sep 17 00:00:00 2001
From: NeonDaniel
Date: Sat, 30 May 2026 01:05:19 +0000
Subject: [PATCH 17/31] Increment Version to 2.0.1a7
---
multi_llm_chatbot_backend/app/version.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/multi_llm_chatbot_backend/app/version.py b/multi_llm_chatbot_backend/app/version.py
index 7b440d62..733f3665 100644
--- a/multi_llm_chatbot_backend/app/version.py
+++ b/multi_llm_chatbot_backend/app/version.py
@@ -1,4 +1,4 @@
-__version__ = "2.0.1a6"
+__version__ = "2.0.1a7"
if __name__ == "__main__":
print(__version__)
From ef41a67a281c291c7060ba3b4ced492963aed81f Mon Sep 17 00:00:00 2001
From: NeonDaniel
Date: Sat, 30 May 2026 01:05:39 +0000
Subject: [PATCH 18/31] Update Changelog
---
CHANGELOG.md | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2d24b7e0..b639327d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,17 @@
# Changelog
+## [2.0.1a7](https://github.com/NeonGeckoCom/CCAI-Demo/tree/2.0.1a7) (2026-05-30)
+
+[Full Changelog](https://github.com/NeonGeckoCom/CCAI-Demo/compare/2.0.1a6...2.0.1a7)
+
+**Implemented enhancements:**
+
+- \[FEAT\] Move all message persistence from frontend to backend [\#46](https://github.com/NeonGeckoCom/CCAI-Demo/issues/46)
+
+**Merged pull requests:**
+
+- persist advisor responses to db in /chat-stream endpoint. [\#78](https://github.com/NeonGeckoCom/CCAI-Demo/pull/78) ([NeonCharlie-24](https://github.com/NeonCharlie-24))
+
## [2.0.1a6](https://github.com/NeonGeckoCom/CCAI-Demo/tree/2.0.1a6) (2026-05-29)
[Full Changelog](https://github.com/NeonGeckoCom/CCAI-Demo/compare/2.0.1a5...2.0.1a6)
From 300ed29e02b85a208b5a332273874853e4ae8494 Mon Sep 17 00:00:00 2001
From: NeonCharlie-24
Date: Fri, 5 Jun 2026 12:45:13 -0700
Subject: [PATCH 19/31] Feat/support hybrid model selection rebase (#80)
* feat(hybrid): add Hybrid provider option to ProviderDropdown (UI scaffold)
Adds a 4th "Hybrid" option to the provider dropdown with a purple
"Mixed" badge and dark-mode fix, plus a Settings2 configure button
on the selected hybrid row. Passes new `defaultBackend` and
`backendLocked` (BrainForge) fields from `/api/config` personas
through AppConfigContext for the upcoming mapping UI.
Backend needs:
- `default_backend` + `brainforge` flags on persona config
- accept "hybrid" on POST /switch-provider
- GET/POST /hybrid-config for { orchestrator, personas{} } map
- per-persona inference routing; BrainForge hard-pinned
* added per-user llm backend selection with uniform (all same provider) and hybird (different llm per persona/orchestrator) modes.
* Enhance ChatPage with hybrid LLM configuration support
- Added HybridConfigModal for configuring hybrid LLM settings.
- Refactored state management to support both uniform and hybrid modes.
-Each adviosr can now have its own model
* added filter for avaliable backends with health status check set to ping every 300 seconds.
* added unit tests for the available backends filter and health check.
* Pending changes to build backend the menus will be moved to the welcome screens and settigns pages once merges are completed.
* fixed bootstrap.py import causing test_available_backends failure.
* added conftest.py to simplify mock module imports and unit tests for LLM provider config.
* restored needs_clarification_improved function lost during rebase.
* Enhance SettingsModal with user profile and account management features
- Added functionality for updating user profile information (first name, last name).
- Implemented password change and account deletion processes with confirmation.
- Improved modal behavior to prevent accidental closure during text selection.
- Updated ChatPage to integrate new SettingsModal features, including user update and sign-out callbacks.
* fix stubbing issue in conftest.py from rebase.
* fix backend config values to lock brainforge models on frontend and prevent their underlying models from being changed.
* added admin-level enabled toggle to each provider and set default backend dynamically.
* replaced frontend hardcoded gemini fallback with dynamic defaults.
* fallback to default_backend when hybrid mode has no overrides.
* add test case for gemini missing.
* added configurable default_backend parameter to config.yaml.
* Fixed the black on black text and added the default option
---------
Co-authored-by: Neon:ryan
---
.../app/api/routes/chat.py | 86 ++++-
.../app/api/routes/provider.py | 150 +++-----
multi_llm_chatbot_backend/app/config.py | 5 +
.../app/core/bootstrap.py | 95 ++++-
.../app/core/improved_orchestrator.py | 63 +--
.../app/llm/improved_ollama_client.py | 10 +-
.../app/llm/improved_vllm_client.py | 10 +
.../app/llm/llm_client.py | 9 +
multi_llm_chatbot_backend/app/main.py | 7 +
.../app/models/persona.py | 10 +-
multi_llm_chatbot_backend/app/models/user.py | 31 +-
.../app/tests/unit/conftest.py | 43 ++-
.../app/tests/unit/test_available_backends.py | 114 ++++++
.../tests/unit/test_llm_provider_config.py | 364 ++++++++++++++++++
.../src/components/AdvisorConfigPanel.js | 193 ++++++++++
.../src/components/AvatarPickerModal 2.js | 75 ++++
.../src/components/HybridConfigModal.js | 95 +++++
.../src/components/ProviderDropdown.js | 36 +-
.../src/components/SettingsModal.js | 88 ++++-
.../src/components/Sidebar.js | 8 +-
.../src/components/WelcomeModelPicker.js | 162 ++++++++
.../src/contexts/AppConfigContext.js | 2 +
phd-advisor-frontend/src/pages/ChatPage.js | 110 ++++--
phd-advisor-frontend/src/styles/ChatPage.css | 34 ++
phd_config.yaml | 4 +
25 files changed, 1587 insertions(+), 217 deletions(-)
create mode 100644 multi_llm_chatbot_backend/app/tests/unit/test_available_backends.py
create mode 100644 multi_llm_chatbot_backend/app/tests/unit/test_llm_provider_config.py
create mode 100644 phd-advisor-frontend/src/components/AdvisorConfigPanel.js
create mode 100644 phd-advisor-frontend/src/components/AvatarPickerModal 2.js
create mode 100644 phd-advisor-frontend/src/components/HybridConfigModal.js
create mode 100644 phd-advisor-frontend/src/components/WelcomeModelPicker.js
diff --git a/multi_llm_chatbot_backend/app/api/routes/chat.py b/multi_llm_chatbot_backend/app/api/routes/chat.py
index aaad4e0c..f01eb1f9 100644
--- a/multi_llm_chatbot_backend/app/api/routes/chat.py
+++ b/multi_llm_chatbot_backend/app/api/routes/chat.py
@@ -13,7 +13,7 @@
from app.api.utils import get_or_create_session_for_request_async
from app.core.auth import get_current_active_user
from app.config import get_settings
-from app.core.bootstrap import chat_orchestrator
+from app.core.bootstrap import chat_orchestrator, get_llm_client
from app.core.database import get_database
from app.core.persona_filter import get_available_persona_ids
from app.core.session_manager import get_session_manager
@@ -24,6 +24,41 @@
router = APIRouter()
session_manager = get_session_manager()
+
+def resolve_llm_clients(user: User) -> Dict[str, Any]:
+ """Resolve LLM clients from a user's stored configuration.
+
+ Returns ``{"orchestrator": LLMClient | None, "personas": {id: LLMClient} | None}``.
+
+ - No saved config: both values are ``None``; callers fall back to
+ orchestrator/persona defaults.
+ - Uniform mode: the same cached client is returned for the orchestrator
+ and every persona.
+ - Hybrid mode: the orchestrator and each persona may receive different
+ clients based on the user's per-persona mapping.
+ """
+ config = user.llm_config
+ if config is None:
+ return {"orchestrator": None, "personas": None}
+
+ if config.mode == "uniform":
+ client = get_llm_client(config.default_backend)
+ persona_clients = {
+ pid: client for pid in chat_orchestrator.personas
+ }
+ return {"orchestrator": client, "personas": persona_clients}
+
+ # Hybrid mode
+ orchestrator_backend = config.orchestrator_backend or config.default_backend
+ orchestrator_client = get_llm_client(orchestrator_backend)
+
+ persona_clients = {}
+ for pid in chat_orchestrator.personas:
+ backend = (config.persona_backends or {}).get(pid, config.default_backend)
+ persona_clients[pid] = get_llm_client(backend)
+
+ return {"orchestrator": orchestrator_client, "personas": persona_clients}
+
# Enhanced data models
class UserInput(BaseModel):
user_input: str
@@ -81,6 +116,11 @@ async def chat_stream(
async def _event_generator():
try:
+ # Resolve per-user LLM clients from their stored config
+ llm_clients = resolve_llm_clients(current_user)
+ orchestrator_llm = llm_clients["orchestrator"]
+ persona_llms = llm_clients["personas"]
+
# Load or create the in-memory session
if message.chat_session_id:
sid = f"chat_{message.chat_session_id}"
@@ -107,7 +147,9 @@ async def _event_generator():
).to_ndjson()
if await chat_orchestrator.needs_clarification_improved(session, message.user_input):
- clar = await chat_orchestrator.generate_contextual_clarification(message.user_input)
+ clar = await chat_orchestrator.generate_contextual_clarification(
+ message.user_input, llm_client=orchestrator_llm,
+ )
yield ChatStreamLine(
type="clarification",
data={
@@ -123,7 +165,9 @@ async def _event_generator():
# If an enabled tool can handle this query, return its response
# directly and skip persona generation.
- tool_result = await chat_orchestrator.get_tool_response(message.user_input)
+ tool_result = await chat_orchestrator.get_tool_response(
+ message.user_input, llm_client=orchestrator_llm,
+ )
if tool_result.used_tool:
# Append user message to in-memory session and persist to MongoDB
session.append_message("orchestrator", tool_result.text)
@@ -164,6 +208,7 @@ async def _event_generator():
top_personas = await chat_orchestrator.get_top_personas(
session_id=sid,
allowed_ids=available,
+ llm_client=orchestrator_llm,
)
# Guard against race condition where all selected advisors
@@ -210,9 +255,11 @@ async def _run(pid: str) -> None:
"document_chunks_used": 0,
})
return
+ persona_llm = (persona_llms or {}).get(pid)
result = await chat_orchestrator.generate_single_persona_response(
session, persona,
message.response_length or "medium",
+ llm_client=persona_llm,
)
session.append_message(pid, result["response"])
await done_queue.put(result)
@@ -390,7 +437,10 @@ async def create_new_chat(
raise HTTPException(status_code=500, detail="Failed to create new chat")
@router.post("/chat/{persona_id}")
-async def chat_with_specific_advisor(persona_id: str, input: UserInput, request: Request):
+async def chat_with_specific_advisor(
+ persona_id: str, input: UserInput, request: Request,
+ current_user: User = Depends(get_current_active_user),
+):
"""Chat with a specific advisor - UPDATED"""
try:
if persona_id not in chat_orchestrator.personas:
@@ -408,11 +458,15 @@ async def chat_with_specific_advisor(persona_id: str, input: UserInput, request:
isExpandRequest=True,
),
)
+
+ llm_clients = resolve_llm_clients(current_user)
+ persona_llm = (llm_clients["personas"] or {}).get(persona_id)
result = await chat_orchestrator.chat_with_persona(
user_input=input.user_input,
persona_id=persona_id,
- session_id=session_id
+ session_id=session_id,
+ llm_client=persona_llm,
)
# Handle response structure
@@ -479,7 +533,10 @@ async def chat_with_specific_advisor(persona_id: str, input: UserInput, request:
}
@router.post("/reply-to-advisor")
-async def reply_to_advisor(reply: ReplyToAdvisor, request: Request):
+async def reply_to_advisor(
+ reply: ReplyToAdvisor, request: Request,
+ current_user: User = Depends(get_current_active_user),
+):
"""Reply to a specific advisor with proper context - UPDATED"""
try:
if reply.advisor_id not in chat_orchestrator.personas:
@@ -520,10 +577,14 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request):
if original_message:
contextual_input = f"[Replying to your previous message: '{original_message[:100]}...'] {reply.user_input}"
+ llm_clients = resolve_llm_clients(current_user)
+ advisor_llm = (llm_clients["personas"] or {}).get(reply.advisor_id)
+
result = await chat_orchestrator.chat_with_persona(
user_input=contextual_input,
persona_id=reply.advisor_id,
- session_id=session_id
+ session_id=session_id,
+ llm_client=advisor_llm,
)
# Handle response structure
@@ -600,15 +661,22 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request):
}
@router.post("/ask/")
-async def ask_question(query: PersonaQuery, request: Request):
+async def ask_question(
+ query: PersonaQuery, request: Request,
+ current_user: User = Depends(get_current_active_user),
+):
"""Ask question - UPDATED"""
try:
session_id = await get_or_create_session_for_request_async(request)
+ llm_clients = resolve_llm_clients(current_user)
+ persona_llm = (llm_clients["personas"] or {}).get(query.persona)
+
result = await chat_orchestrator.chat_with_persona(
user_input=query.question,
persona_id=query.persona,
- session_id=session_id
+ session_id=session_id,
+ llm_client=persona_llm,
)
if result["type"] == "single_persona_response":
diff --git a/multi_llm_chatbot_backend/app/api/routes/provider.py b/multi_llm_chatbot_backend/app/api/routes/provider.py
index b7760e74..7d8b45f4 100644
--- a/multi_llm_chatbot_backend/app/api/routes/provider.py
+++ b/multi_llm_chatbot_backend/app/api/routes/provider.py
@@ -1,108 +1,72 @@
-from fastapi import APIRouter, Body, HTTPException
-from app.config import get_settings
-from app.llm.improved_gemini_client import ImprovedGeminiClient
-from app.llm.improved_ollama_client import ImprovedOllamaClient
-from app.llm.improved_vllm_client import ImprovedVllmClient
-from app.models.default_personas import get_default_personas
-from app.core.bootstrap import chat_orchestrator, llm, current_provider, available_providers
-from app.core.brainforge_sync import BRAINFORGE_PERSONA_PREFIX
-from pydantic import BaseModel
-import os
+from fastapi import APIRouter, Depends, HTTPException, status
+from app.core.auth import get_current_active_user
+from app.core.bootstrap import (
+ chat_orchestrator, get_llm_client, AVAILABLE_BACKENDS, _is_backend_enabled,
+)
+from app.core.database import get_database
+from app.models.user import User, UserLLMConfig
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
-def create_llm_client(provider: str = None):
- global current_provider
- if provider is None:
- provider = current_provider
-
- if provider == "gemini":
- try:
- return ImprovedGeminiClient(model_name=os.getenv("GEMINI_MODEL"))
- except ValueError as e:
- logger.warning(f"Gemini API key not found, falling back to Ollama: {e}")
- return ImprovedOllamaClient(model_name="llama3.2:1b")
- elif provider == "ollama":
- return ImprovedOllamaClient(model_name="llama3.2:1b")
- elif provider == "vllm":
- settings = get_settings()
- if not settings.llm.vllm.api_url:
- raise ValueError("No vLLM endpoint configured. Set llm.vllm.api_url in your config.")
- return ImprovedVllmClient(
- api_url=settings.llm.vllm.api_url,
- api_key=settings.llm.vllm.api_key,
- )
- else:
- raise ValueError(f"Unknown provider: {provider}")
-
-# Initialize LLM and personas
-llm = create_llm_client(current_provider)
-DEFAULT_PERSONAS = get_default_personas(llm)
-for persona in DEFAULT_PERSONAS:
- chat_orchestrator.register_persona(persona)
-
-class ProviderSwitch(BaseModel):
- provider: str
@router.get("/current-provider")
-async def get_current_provider():
+async def get_current_provider(
+ current_user: User = Depends(get_current_active_user),
+):
+ """Return the authenticated user's LLM configuration."""
+ config = current_user.llm_config or UserLLMConfig()
return {
- "current_provider": current_provider,
- "available_providers": available_providers,
- "model_info": {
- "name": llm.model_name if hasattr(llm, 'model_name') else "gemini-2.0-flash",
- "provider": current_provider
- }
+ "llm_config": config.model_dump(),
+ "available_backends": AVAILABLE_BACKENDS,
}
-@router.post("/switch-provider")
-async def switch_provider(provider_data: ProviderSwitch):
- global current_provider, llm
-
- if provider_data.provider not in available_providers:
- raise HTTPException(status_code=400, detail=f"Unknown provider: {provider_data.provider}. Available: {available_providers}")
-
- try:
- current_provider = provider_data.provider
- new_llm = create_llm_client(current_provider)
- llm = new_llm
- chat_orchestrator.llm_client = new_llm
-
- new_personas = get_default_personas(new_llm)
- # Clear only non-BrainForge personas; BF advisors have their own LLM clients
- non_bf_ids = [pid for pid in chat_orchestrator.personas if not pid.startswith(f"{BRAINFORGE_PERSONA_PREFIX}_")]
- for pid in non_bf_ids:
- chat_orchestrator.unregister_persona(pid)
- for persona in new_personas:
- chat_orchestrator.register_persona(persona)
-
- return {
- "message": f"Successfully switched to {current_provider}",
- "current_provider": current_provider,
- "model_info": {
- "name": new_llm.model_name if hasattr(new_llm, 'model_name') else "gemini-2.0-flash",
- "provider": current_provider
- }
- }
-
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to switch to {provider_data.provider}: {str(e)}")
-
-@router.post("/switch-model")
-async def switch_model(model_name: str = Body(...)):
- if "gemini" in model_name.lower():
- return await switch_provider(ProviderSwitch(provider="gemini"))
- else:
- return await switch_provider(ProviderSwitch(provider="ollama"))
+@router.post("/switch-provider")
+async def switch_provider(
+ llm_config: UserLLMConfig,
+ current_user: User = Depends(get_current_active_user),
+):
+ """Persist the user's LLM configuration to their profile."""
+ if llm_config.mode == "hybrid" and llm_config.persona_backends:
+ registered = set(chat_orchestrator.personas.keys())
+ unknown = set(llm_config.persona_backends.keys()) - registered
+ if unknown:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Unknown persona IDs: {sorted(unknown)}. "
+ f"Valid IDs: {sorted(registered)}",
+ )
+
+ backends_to_check = {llm_config.default_backend}
+ if llm_config.orchestrator_backend:
+ backends_to_check.add(llm_config.orchestrator_backend)
+ if llm_config.persona_backends:
+ backends_to_check.update(llm_config.persona_backends.values())
+
+ for backend in backends_to_check:
+ if not _is_backend_enabled(backend):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Backend {backend!r} is disabled by the administrator.",
+ )
+ try:
+ get_llm_client(backend)
+ except Exception as exc:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Backend {backend!r} is not configured: {exc}",
+ )
+
+ db = get_database()
+ await db.users.update_one(
+ {"_id": current_user.id},
+ {"$set": {"llm_config": llm_config.model_dump()}},
+ )
-@router.get("/current-model")
-async def get_current_model():
- model_name = llm.model_name if hasattr(llm, 'model_name') else "gemini-2.0-flash"
return {
- "model": model_name,
- "provider": current_provider
+ "message": "LLM configuration updated",
+ "llm_config": llm_config.model_dump(),
}
diff --git a/multi_llm_chatbot_backend/app/config.py b/multi_llm_chatbot_backend/app/config.py
index 753a1048..c8154837 100644
--- a/multi_llm_chatbot_backend/app/config.py
+++ b/multi_llm_chatbot_backend/app/config.py
@@ -253,6 +253,7 @@ def _warn_connection_envvar(self):
class GeminiConfig(BaseModel):
+ enabled: bool = True
api_key: str = Field(default=os.getenv("GEMINI_API_KEY"))
model: str = "gemini-2.5-flash"
@@ -272,12 +273,14 @@ def _warn_gemini_envvar(self):
class OllamaConfig(BaseModel):
+ enabled: bool = True
model: str = "llama3.2:1b"
# TODO: Drop support for `OLLAMA_BASE_URL` envvar handling
base_url: str = Field(default=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"))
class VllmConfig(BaseModel):
+ enabled: bool = True
api_url: str = ""
api_key: str = Field(default=os.getenv("VLLM_API_KEY", ""))
@@ -290,10 +293,12 @@ class BrainForgeConfig(BaseModel):
class LLMConfig(BaseModel):
+ default_backend: str = ""
gemini: GeminiConfig = GeminiConfig()
ollama: OllamaConfig = OllamaConfig()
vllm: VllmConfig = VllmConfig()
brainforge: BrainForgeConfig = BrainForgeConfig()
+ health_check_interval_seconds: int = 300
class RAGConfig(BaseModel):
diff --git a/multi_llm_chatbot_backend/app/core/bootstrap.py b/multi_llm_chatbot_backend/app/core/bootstrap.py
index c08d873e..99283643 100644
--- a/multi_llm_chatbot_backend/app/core/bootstrap.py
+++ b/multi_llm_chatbot_backend/app/core/bootstrap.py
@@ -1,22 +1,34 @@
# app/core/bootstrap.py
+import asyncio
+import logging
+
from app.config import get_settings
from app.llm.improved_gemini_client import ImprovedGeminiClient
from app.llm.improved_ollama_client import ImprovedOllamaClient
from app.llm.improved_vllm_client import ImprovedVllmClient
from app.core.improved_orchestrator import ImprovedChatOrchestrator
from app.models.default_personas import get_default_personas
+from app.models.user import LLM_BACKENDS
+from app.llm.llm_client import LLMClient
+
+logger = logging.getLogger(__name__)
settings = get_settings()
-current_provider = "gemini"
-available_providers = ["ollama", "gemini", "vllm"]
+_client_cache = {}
-def create_llm_client(provider=None):
- if provider is None:
- provider = current_provider
- if provider == "gemini":
+
+def create_llm_client(backend: str = None):
+ """Create an LLM client for the given backend name."""
+ if backend is None:
+ backend = settings.llm.default_backend
+ if backend not in LLM_BACKENDS:
+ raise ValueError(
+ f"Unknown backend {backend!r}. Must be one of {LLM_BACKENDS}"
+ )
+ if backend == "gemini":
return ImprovedGeminiClient(model_name=settings.llm.gemini.model)
- elif provider == "vllm":
+ elif backend == "vllm":
if not settings.llm.vllm.api_url:
raise ValueError("No vLLM endpoint configured. Set llm.vllm.api_url in your config.")
return ImprovedVllmClient(
@@ -29,7 +41,74 @@ def create_llm_client(provider=None):
base_url=settings.llm.ollama.base_url,
)
-llm = create_llm_client()
+
+def get_llm_client(backend: str) -> LLMClient:
+ """Return a cached LLM client for *backend*, creating it on first access."""
+ if backend not in _client_cache:
+ _client_cache[backend] = create_llm_client(backend)
+ return _client_cache[backend]
+
+
+def _is_backend_enabled(backend: str) -> bool:
+ """Check whether *backend* is enabled in the admin config."""
+ backend_config = getattr(settings.llm, backend, None)
+ return getattr(backend_config, "enabled", True)
+
+
+def get_available_backends() -> list:
+ """Return backends that are enabled and properly configured (sync, used at startup)."""
+ available = []
+ for backend in LLM_BACKENDS:
+ if not _is_backend_enabled(backend):
+ continue
+ try:
+ get_llm_client(backend)
+ available.append(backend)
+ except Exception:
+ pass
+ return available
+
+
+async def refresh_available_backends():
+ """Re-check which backends are enabled, configured, and reachable."""
+ available = []
+ for backend in LLM_BACKENDS:
+ if not _is_backend_enabled(backend):
+ continue
+ try:
+ client = get_llm_client(backend)
+ if await client.health_check():
+ available.append(backend)
+ except Exception:
+ pass
+ AVAILABLE_BACKENDS[:] = available
+
+
+async def _backend_health_loop():
+ """Background task that periodically refreshes AVAILABLE_BACKENDS."""
+ interval = settings.llm.health_check_interval_seconds
+ while True:
+ await refresh_available_backends()
+ await asyncio.sleep(interval)
+
+
+# Resolve the default backend: prefer the configured default, fall back to the first available.
+AVAILABLE_BACKENDS = get_available_backends()
+if settings.llm.default_backend in AVAILABLE_BACKENDS:
+ DEFAULT_BACKEND = settings.llm.default_backend
+elif AVAILABLE_BACKENDS:
+ DEFAULT_BACKEND = AVAILABLE_BACKENDS[0]
+ logger.warning(
+ "Configured default_backend %r is not available; falling back to %r",
+ settings.llm.default_backend, DEFAULT_BACKEND,
+ )
+else:
+ raise RuntimeError(
+ "No LLM backends are available. Check your config.yaml — "
+ "at least one backend must be enabled and properly configured."
+ )
+llm = create_llm_client(DEFAULT_BACKEND)
+_client_cache[DEFAULT_BACKEND] = llm
chat_orchestrator = ImprovedChatOrchestrator(llm_client=llm)
DEFAULT_PERSONAS = get_default_personas(llm)
diff --git a/multi_llm_chatbot_backend/app/core/improved_orchestrator.py b/multi_llm_chatbot_backend/app/core/improved_orchestrator.py
index 12f63a78..774c8687 100644
--- a/multi_llm_chatbot_backend/app/core/improved_orchestrator.py
+++ b/multi_llm_chatbot_backend/app/core/improved_orchestrator.py
@@ -45,7 +45,8 @@ def list_personas(self) -> List[str]:
"""List all available persona IDs"""
return list(self.personas.keys())
- async def get_tool_response(self, user_message: str) -> ToolCallResult:
+ async def get_tool_response(self, user_message: str,
+ llm_client: LLMClient = None) -> ToolCallResult:
"""Check whether a tool can handle *user_message*.
If tools are disabled in config, no LLM client is available, or the
@@ -53,7 +54,8 @@ async def get_tool_response(self, user_message: str) -> ToolCallResult:
``ToolCallResult(used_tool=False)``. Otherwise executes the tool and
returns the grounded response with ``used_tool=True``.
"""
- if self.llm_client is None:
+ effective_llm = llm_client or self.llm_client
+ if effective_llm is None:
return ToolCallResult(text="", used_tool=False)
settings = get_settings()
@@ -80,7 +82,7 @@ async def get_tool_response(self, user_message: str) -> ToolCallResult:
"to present structured data like course listings or professor ratings."
)
- return await self.llm_client.generate_with_tools(
+ return await effective_llm.generate_with_tools(
system_prompt=system_prompt,
user_message=user_message,
tool_definitions=tool_definitions,
@@ -325,7 +327,8 @@ async def needs_clarification_improved(self, session: ConversationContext, user_
logger.warning("Falling back to rule-based clarification check")
return self.needs_clarification(session, user_input)
- async def generate_contextual_clarification(self, user_input: str) -> Dict[str, Any]:
+ async def generate_contextual_clarification(self, user_input: str,
+ llm_client: LLMClient = None) -> Dict[str, Any]:
"""
Use the LLM to produce a clarification question and clickable
suggestions that are tailored to what the user actually typed.
@@ -355,10 +358,8 @@ async def generate_contextual_clarification(self, user_input: str) -> Dict[str,
)
try:
- # Use the orchestrator's own LLM rather than a persona's — BrainForge
- # persona LLMs may not support the prompt format used here.
- llm = self.llm_client
- raw = await llm.generate(
+ effective_llm = llm_client or self.llm_client
+ raw = await effective_llm.generate(
system_prompt=system_prompt,
context=[{"role": "user", "content": user_prompt}],
temperature=0.4,
@@ -391,9 +392,14 @@ async def generate_contextual_clarification(self, user_input: str) -> Dict[str,
"suggestions": fallback_suggestions,
}
- async def generate_persona_responses(self, session: ConversationContext, response_length: str = "medium"):
+ async def generate_persona_responses(self, session: ConversationContext,
+ response_length: str = "medium",
+ llm_clients: Dict[str, LLMClient] = None):
"""
- Generate responses from all personas with enhanced RAG integration
+ Generate responses from all personas with enhanced RAG integration.
+
+ *llm_clients* maps persona IDs to the LLM client each should use.
+ Personas not present in the dict fall back to their default client.
"""
responses = []
@@ -401,7 +407,10 @@ async def generate_persona_responses(self, session: ConversationContext, respons
logger.info(f"Generating response for {persona_id} with enhanced RAG")
# Generate persona response with enhanced RAG
- response_data = await self.generate_single_persona_response(session, persona, response_length)
+ persona_llm = (llm_clients or {}).get(persona_id)
+ response_data = await self.generate_single_persona_response(
+ session, persona, response_length, llm_client=persona_llm,
+ )
# Add persona response to session context
session.append_message(persona_id, response_data["response"])
@@ -410,9 +419,14 @@ async def generate_persona_responses(self, session: ConversationContext, respons
return responses
- async def generate_single_persona_response(self, session, persona, response_length: str = "medium"):
+ async def generate_single_persona_response(self, session, persona,
+ response_length: str = "medium",
+ llm_client: LLMClient = None):
"""
- Enhanced version - Generate response from a single persona with enhanced RAG integration
+ Enhanced version - Generate response from a single persona with enhanced RAG integration.
+
+ *llm_client* is forwarded to ``persona.respond()``; when ``None`` the
+ persona uses its default (system-default) client.
"""
try:
# Get the user's latest message for document retrieval
@@ -441,7 +455,7 @@ async def generate_single_persona_response(self, session, persona, response_leng
)
# Generate response with enhanced context
- response = await persona.respond(enhanced_context, response_length)
+ response = await persona.respond(enhanced_context, response_length, llm=llm_client)
# Validate and improve response quality
if not self._is_valid_response(response, persona.id):
@@ -813,9 +827,13 @@ def _get_persona_context_keywords(self, persona_id: str) -> str:
"""
return self._get_enhanced_persona_context_keywords(persona_id)
- async def chat_with_persona(self, user_input: str, persona_id: str, session_id: str, response_length: str = "medium") -> Dict[str, Any]:
+ async def chat_with_persona(self, user_input: str, persona_id: str,
+ session_id: str, response_length: str = "medium",
+ llm_client: LLMClient = None) -> Dict[str, Any]:
"""
- Chat with a specific persona directly - FIXED for consistent document access
+ Chat with a specific persona directly - FIXED for consistent document access.
+
+ *llm_client* is forwarded to the persona's response generation.
"""
try:
persona = self.get_persona(persona_id)
@@ -838,7 +856,9 @@ async def chat_with_persona(self, user_input: str, persona_id: str, session_id:
logger.info(f"Generating response for {persona_id} with session {session_id}")
# Generate response from single persona using consistent session ID
- response_data = await self.generate_single_persona_response(session, persona, response_length)
+ response_data = await self.generate_single_persona_response(
+ session, persona, response_length, llm_client=llm_client,
+ )
# Add response to session
session.append_message(persona_id, response_data["response"])
@@ -884,7 +904,8 @@ async def chat_with_persona(self, user_input: str, persona_id: str, session_id:
async def get_top_personas(self, session_id: str, k: int = 3,
- allowed_ids: Optional[List[str]] = None) -> List[str]:
+ allowed_ids: Optional[List[str]] = None,
+ llm_client: LLMClient = None) -> List[str]:
"""
Use the LLM to rank personas based on current session context.
Falls back to default persona order if LLM fails or returns invalid data.
@@ -902,9 +923,7 @@ async def get_top_personas(self, session_id: str, k: int = 3,
logger.warning("No personas available after filtering.")
return []
- # Use the orchestrator's own LLM rather than a persona's — BrainForge
- # persona LLMs may not support the prompt format used here.
- llm = self.llm_client
+ effective_llm = llm_client or self.llm_client
# Use recent conversation context (last 5 messages)
recent_context = "\n".join(
@@ -935,7 +954,7 @@ async def get_top_personas(self, session_id: str, k: int = 3,
{persona_descriptions}
""".strip()
- llm_response = await llm.generate(
+ llm_response = await effective_llm.generate(
system_prompt=f"You are an assistant that selects the best advisors for a user of {app_title}.",
context=[{"role": "user", "content": prompt}],
temperature=0.4,
diff --git a/multi_llm_chatbot_backend/app/llm/improved_ollama_client.py b/multi_llm_chatbot_backend/app/llm/improved_ollama_client.py
index 63074523..ccf0d322 100644
--- a/multi_llm_chatbot_backend/app/llm/improved_ollama_client.py
+++ b/multi_llm_chatbot_backend/app/llm/improved_ollama_client.py
@@ -110,4 +110,12 @@ def _is_poor_quality(self, response: str) -> bool:
len(response.split()) > 150, # Too verbose
response.count("?") > 3, # Too many questions
]
- return any(poor_indicators)
\ No newline at end of file
+ return any(poor_indicators)
+
+ async def health_check(self) -> bool:
+ try:
+ async with httpx.AsyncClient(timeout=10.0) as client:
+ resp = await client.get(f"{self.base_url}/api/tags")
+ return resp.is_success
+ except Exception:
+ return False
\ No newline at end of file
diff --git a/multi_llm_chatbot_backend/app/llm/improved_vllm_client.py b/multi_llm_chatbot_backend/app/llm/improved_vllm_client.py
index de9a927b..5b568358 100644
--- a/multi_llm_chatbot_backend/app/llm/improved_vllm_client.py
+++ b/multi_llm_chatbot_backend/app/llm/improved_vllm_client.py
@@ -187,4 +187,14 @@ async def generate_with_tools(
used_tool=False,
)
+ async def health_check(self) -> bool:
+ import httpx
+ try:
+ headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
+ async with httpx.AsyncClient(timeout=10.0) as client:
+ resp = await client.get(f"{self.api_url}/v1/models", headers=headers)
+ return resp.is_success
+ except Exception:
+ return False
+
diff --git a/multi_llm_chatbot_backend/app/llm/llm_client.py b/multi_llm_chatbot_backend/app/llm/llm_client.py
index 995500da..74ee8f86 100644
--- a/multi_llm_chatbot_backend/app/llm/llm_client.py
+++ b/multi_llm_chatbot_backend/app/llm/llm_client.py
@@ -67,6 +67,15 @@ async def generate_with_tools(
)
return ToolCallResult(text=text, used_tool=False)
+ async def health_check(self) -> bool:
+ """Check if the backend service is reachable.
+
+ Subclasses for self-hosted services should override this with a
+ lightweight network probe. The default returns True (suitable for
+ cloud APIs where the config check is sufficient).
+ """
+ return True
+
def _clean_response(self, response: str) -> str:
"""Clean up response text, preserving Markdown formatting."""
response = response.replace("\r\n", "\n").replace("\r", "\n")
diff --git a/multi_llm_chatbot_backend/app/main.py b/multi_llm_chatbot_backend/app/main.py
index 677df8f2..d8214fcb 100644
--- a/multi_llm_chatbot_backend/app/main.py
+++ b/multi_llm_chatbot_backend/app/main.py
@@ -11,6 +11,8 @@
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
+import asyncio
+
# Load configuration FIRST so every module can use it
from app.config import load_settings
from app.version import __version__
@@ -18,6 +20,7 @@
# Import the new database functions
from app.core.database import connect_to_mongo, close_mongo_connection
+from app.core.bootstrap import _backend_health_loop
# Import all route modules
from app.api.routes import router as main_router
@@ -41,9 +44,11 @@ async def lifespan(app: FastAPI):
from app.core.brainforge_sync import async_sync_brainforge_personas, periodic_sync_loop
await async_sync_brainforge_personas(chat_orchestrator)
sync_task = asyncio.create_task(periodic_sync_loop(chat_orchestrator))
+ health_task = asyncio.create_task(_backend_health_loop())
yield
# Shutdown
sync_task.cancel()
+ health_task.cancel()
await close_mongo_connection()
app = FastAPI(
@@ -119,6 +124,8 @@ def get_public_config():
"dark_color": colors["dark_color"],
"dark_bg_color": colors["dark_bg_color"],
"image": "icon://Brain",
+ "backend_locked": True,
+ "default_backend": "brainforge",
})
return config
diff --git a/multi_llm_chatbot_backend/app/models/persona.py b/multi_llm_chatbot_backend/app/models/persona.py
index ed34caf9..2a19034a 100644
--- a/multi_llm_chatbot_backend/app/models/persona.py
+++ b/multi_llm_chatbot_backend/app/models/persona.py
@@ -245,10 +245,14 @@ def __init__(self, id: str, name: str, system_prompt: str, llm: LLMClient, tempe
self.llm = llm
self.temperature = temperature
- async def respond(self, context: List[Dict], response_length: str = "medium") -> str:
+ async def respond(self, context: List[Dict], response_length: str = "medium",
+ llm: LLMClient = None) -> str:
"""Generate a compact, well-formed Markdown response suitable for the UI.
- Returns the compact Markdown string (backward compatible with previous callers).
+
+ *llm* overrides the default client for this call (used for per-user
+ backend selection). Falls back to ``self.llm`` when not provided.
"""
+ effective_llm = llm or self.llm
max_tokens = MAX_TOKENS_MAP.get(response_length, 500)
structure_hint = STRUCTURE_HINTS.get(response_length, STRUCTURE_HINTS["medium"])
temp_scaled = round(self.temperature / 10, 2)
@@ -259,7 +263,7 @@ async def respond(self, context: List[Dict], response_length: str = "medium") ->
f"{structure_hint}"
)
- raw_text = await self.llm.generate(
+ raw_text = await effective_llm.generate(
system_prompt=full_prompt,
context=context,
temperature=temp_scaled,
diff --git a/multi_llm_chatbot_backend/app/models/user.py b/multi_llm_chatbot_backend/app/models/user.py
index 24b08957..32ecd31f 100644
--- a/multi_llm_chatbot_backend/app/models/user.py
+++ b/multi_llm_chatbot_backend/app/models/user.py
@@ -1,8 +1,36 @@
from pydantic import BaseModel, EmailStr, Field, ConfigDict, model_validator
-from typing import Literal, Optional, List, Any
+from typing import Dict, Literal, Optional, List, Any, get_args
from datetime import datetime
from bson import ObjectId
+BackendName = Literal["gemini", "ollama", "vllm"]
+LLM_BACKENDS = get_args(BackendName)
+
+
+class UserLLMConfig(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
+ """Per-user LLM provider configuration.
+
+ Uniform mode: all advisors and the orchestrator use ``default_backend``.
+ Hybrid mode: each advisor can use a different backend; ``default_backend``
+ is the fallback for any persona not explicitly mapped.
+ """
+ mode: Literal["uniform", "hybrid"] = "uniform"
+ default_backend: BackendName = "gemini"
+ orchestrator_backend: Optional[BackendName] = None
+ persona_backends: Optional[Dict[str, BackendName]] = None
+
+ @model_validator(mode="after")
+ def _validate_hybrid_fields(self):
+ if self.mode == "hybrid":
+ if not self.orchestrator_backend and not self.persona_backends:
+ self.orchestrator_backend = self.default_backend
+ else:
+ self.orchestrator_backend = None
+ self.persona_backends = None
+ return self
+
class PyObjectId(ObjectId):
@classmethod
def __get_validators__(cls):
@@ -48,6 +76,7 @@ class User(BaseModel):
academicStage: Optional[str] = None
researchArea: Optional[str] = None
disabled_advisors: Optional[List[str]] = None
+ llm_config: Optional[UserLLMConfig] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
last_login: Optional[datetime] = None
is_active: bool = True
diff --git a/multi_llm_chatbot_backend/app/tests/unit/conftest.py b/multi_llm_chatbot_backend/app/tests/unit/conftest.py
index 1f539bd3..09b73bad 100644
--- a/multi_llm_chatbot_backend/app/tests/unit/conftest.py
+++ b/multi_llm_chatbot_backend/app/tests/unit/conftest.py
@@ -1,37 +1,38 @@
-"""Session-wide stubs for modules that do heavy work at import time.
-
-``app.api.routes.__init__`` eagerly imports every sibling route, and
-``app.api.routes.provider`` instantiates real LLM clients at module
-load. ``app.core.bootstrap`` and ``app.core.rag_manager`` likewise
-start NLTK, ChromaDB, and the LLM stack the moment they are imported.
-
-We install harmless ``MagicMock`` substitutes for those modules once,
-before any test file in this directory is collected, so every test
-gets a consistent, importable view of ``app.api.routes`` without
-having to reproduce the same stubbing recipe in every test module.
-
-Tests that want to exercise the real version of a specific route
-module (for example, ``test_version.py`` wanting the real
-``app.api.routes.root``) can still pop their target out of
-``sys.modules`` in their own setup -- they no longer have to
-coordinate cleanup with peer test modules.
+"""Session-wide stubs for heavy-import modules.
+
+``app.core.rag_manager`` starts NLTK / ChromaDB the moment it is
+imported, so we replace it with a ``MagicMock`` before any test is
+collected.
+
+``app.core.bootstrap`` (and the route modules that import it) can load
+normally because we pre-set ``GEMINI_API_KEY`` and ``CONFIG_PATH``
+before any import occurs. This lets ``get_settings()``, the LLM-client
+constructors, and the orchestrator initialise without real credentials
+or config files.
+
+Route modules that are *not* under direct test (documents, sessions,
+debug, phd_canvas) are still replaced with lightweight stubs so their
+dependency trees are never pulled in.
"""
+import os
import sys
from unittest.mock import MagicMock
from fastapi import APIRouter
+os.environ.setdefault("GEMINI_API_KEY", "fake-test-key")
+os.environ.setdefault("CONFIG_PATH", "")
-for _name in ("app.core.bootstrap", "app.core.rag_manager"):
- sys.modules.setdefault(_name, MagicMock())
+# rag_manager triggers NLTK / ChromaDB on import — always stub it.
+sys.modules.setdefault("app.core.rag_manager", MagicMock())
+# Stub route modules that are NOT under direct test to avoid pulling
+# in their full dependency trees when app.api.routes.__init__ runs.
_stub_router_module = MagicMock(router=APIRouter())
for _name in (
- "app.api.routes.chat",
"app.api.routes.documents",
"app.api.routes.sessions",
- "app.api.routes.provider",
"app.api.routes.debug",
"app.api.routes.phd_canvas",
):
diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_available_backends.py b/multi_llm_chatbot_backend/app/tests/unit/test_available_backends.py
new file mode 100644
index 00000000..c9bb362b
--- /dev/null
+++ b/multi_llm_chatbot_backend/app/tests/unit/test_available_backends.py
@@ -0,0 +1,114 @@
+import asyncio
+import unittest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from app.core.bootstrap import (
+ get_available_backends,
+ refresh_available_backends,
+ AVAILABLE_BACKENDS,
+ LLM_BACKENDS,
+)
+
+
+class TestGetAvailableBackends(unittest.TestCase):
+ """Sync config-only check used at startup."""
+
+ @patch("app.core.bootstrap.get_llm_client")
+ def test_all_configured(self, mock_get_client):
+ mock_get_client.return_value = MagicMock()
+ result = get_available_backends()
+ self.assertEqual(result, ["gemini", "ollama", "vllm"])
+
+ @patch("app.core.bootstrap.get_llm_client")
+ def test_vllm_not_configured(self, mock_get_client):
+ def side_effect(backend):
+ if backend == "vllm":
+ raise ValueError("No vLLM endpoint configured.")
+ return MagicMock()
+ mock_get_client.side_effect = side_effect
+ result = get_available_backends()
+ self.assertEqual(result, ["gemini", "ollama"])
+
+ @patch("app.core.bootstrap.get_llm_client")
+ def test_gemini_not_configured(self, mock_get_client):
+ def side_effect(backend):
+ if backend == "gemini":
+ raise ValueError("No Gemini endpoint configured.")
+ return MagicMock()
+ mock_get_client.side_effect = side_effect
+ result = get_available_backends()
+ self.assertEqual(result, ["ollama", "vllm"])
+
+ @patch("app.core.bootstrap.get_llm_client")
+ def test_none_configured(self, mock_get_client):
+ mock_get_client.side_effect = ValueError("not configured")
+ result = get_available_backends()
+ self.assertEqual(result, [])
+
+
+class TestRefreshAvailableBackends(unittest.TestCase):
+ """Async health-check refresh."""
+
+ @patch("app.core.bootstrap.get_llm_client")
+ def test_all_healthy(self, mock_get_client):
+ mock_client = MagicMock()
+ mock_client.health_check = AsyncMock(return_value=True)
+ mock_get_client.return_value = mock_client
+
+ asyncio.run(refresh_available_backends())
+ self.assertEqual(list(AVAILABLE_BACKENDS), ["gemini", "ollama", "vllm"])
+
+ @patch("app.core.bootstrap.get_llm_client")
+ def test_vllm_unhealthy(self, mock_get_client):
+ def side_effect(backend):
+ client = MagicMock()
+ client.health_check = AsyncMock(return_value=(backend != "vllm"))
+ return client
+ mock_get_client.side_effect = side_effect
+
+ asyncio.run(refresh_available_backends())
+ self.assertEqual(list(AVAILABLE_BACKENDS), ["gemini", "ollama"])
+
+ @patch("app.core.bootstrap.get_llm_client")
+ def test_health_check_exception(self, mock_get_client):
+ def side_effect(backend):
+ client = MagicMock()
+ if backend == "ollama":
+ client.health_check = AsyncMock(side_effect=Exception("timeout"))
+ else:
+ client.health_check = AsyncMock(return_value=True)
+ return client
+ mock_get_client.side_effect = side_effect
+
+ asyncio.run(refresh_available_backends())
+ self.assertEqual(list(AVAILABLE_BACKENDS), ["gemini", "vllm"])
+
+ @patch("app.core.bootstrap.get_llm_client")
+ def test_unconfigured_backend_excluded(self, mock_get_client):
+ def side_effect(backend):
+ if backend == "vllm":
+ raise ValueError("No vLLM endpoint configured.")
+ client = MagicMock()
+ client.health_check = AsyncMock(return_value=True)
+ return client
+ mock_get_client.side_effect = side_effect
+
+ asyncio.run(refresh_available_backends())
+ self.assertEqual(list(AVAILABLE_BACKENDS), ["gemini", "ollama"])
+
+ @patch("app.core.bootstrap.get_llm_client")
+ def test_gemini_unconfigured_excluded(self, mock_get_client):
+ def side_effect(backend):
+ if backend == "gemini":
+ raise ValueError("No Gemini endpoint configured.")
+ client = MagicMock()
+ client.health_check = AsyncMock(return_value=True)
+ return client
+ mock_get_client.side_effect = side_effect
+
+ asyncio.run(refresh_available_backends())
+ self.assertEqual(list(AVAILABLE_BACKENDS), ["ollama", "vllm"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_llm_provider_config.py b/multi_llm_chatbot_backend/app/tests/unit/test_llm_provider_config.py
new file mode 100644
index 00000000..bcac69d9
--- /dev/null
+++ b/multi_llm_chatbot_backend/app/tests/unit/test_llm_provider_config.py
@@ -0,0 +1,364 @@
+import unittest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from bson import ObjectId
+from fastapi import HTTPException
+from pydantic import ValidationError
+
+from app.api.routes.chat import resolve_llm_clients
+from app.api.routes.provider import switch_provider
+from app.models.user import User, UserLLMConfig
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_user(llm_config=None):
+ return User(
+ _id=ObjectId(),
+ firstName="Test",
+ lastName="User",
+ email="test@example.com",
+ hashed_password="fakehash",
+ llm_config=llm_config,
+ )
+
+
+# ===================================================================
+# 1. UserLLMConfig — Pydantic model validation
+# ===================================================================
+
+class TestUserLLMConfig(unittest.TestCase):
+ """Validate the UserLLMConfig Pydantic model and its _validate_hybrid_fields
+ model validator."""
+
+ def test_defaults(self):
+ cfg = UserLLMConfig()
+ self.assertEqual(cfg.mode, "uniform")
+ self.assertEqual(cfg.default_backend, "gemini")
+ self.assertIsNone(cfg.orchestrator_backend)
+ self.assertIsNone(cfg.persona_backends)
+
+ def test_uniform_strips_hybrid_fields(self):
+ cfg = UserLLMConfig(
+ mode="uniform",
+ default_backend="gemini",
+ orchestrator_backend="ollama",
+ persona_backends={"x": "vllm"},
+ )
+ self.assertIsNone(cfg.orchestrator_backend)
+ self.assertIsNone(cfg.persona_backends)
+
+ def test_hybrid_without_overrides_falls_back_to_default(self):
+ cfg = UserLLMConfig(mode="hybrid", default_backend="gemini")
+ self.assertEqual(cfg.orchestrator_backend, "gemini")
+ self.assertIsNone(cfg.persona_backends)
+
+ def test_hybrid_with_orchestrator_only(self):
+ cfg = UserLLMConfig(
+ mode="hybrid",
+ default_backend="gemini",
+ orchestrator_backend="ollama",
+ )
+ self.assertEqual(cfg.orchestrator_backend, "ollama")
+ self.assertIsNone(cfg.persona_backends)
+
+ def test_hybrid_with_persona_backends_only(self):
+ cfg = UserLLMConfig(
+ mode="hybrid",
+ default_backend="gemini",
+ persona_backends={"advisor_1": "vllm"},
+ )
+ self.assertIsNone(cfg.orchestrator_backend)
+ self.assertEqual(cfg.persona_backends, {"advisor_1": "vllm"})
+
+ def test_hybrid_with_both(self):
+ cfg = UserLLMConfig(
+ mode="hybrid",
+ default_backend="gemini",
+ orchestrator_backend="ollama",
+ persona_backends={"advisor_1": "vllm"},
+ )
+ self.assertEqual(cfg.orchestrator_backend, "ollama")
+ self.assertEqual(cfg.persona_backends, {"advisor_1": "vllm"})
+
+ def test_rejects_unknown_backend_name(self):
+ with self.assertRaises(ValidationError):
+ UserLLMConfig(default_backend="claude")
+
+ def test_extra_fields_forbidden(self):
+ with self.assertRaises(ValidationError):
+ UserLLMConfig(default_backend="gemini", surprise="boom")
+
+ def test_model_dump_roundtrip(self):
+ original = UserLLMConfig(
+ mode="hybrid",
+ default_backend="ollama",
+ orchestrator_backend="gemini",
+ persona_backends={"a": "vllm", "b": "gemini"},
+ )
+ restored = UserLLMConfig(**original.model_dump())
+ self.assertEqual(original.model_dump(), restored.model_dump())
+
+
+# ===================================================================
+# 2. resolve_llm_clients — chat routing logic
+# ===================================================================
+
+class TestResolveLlmClients(unittest.TestCase):
+ """Verify that resolve_llm_clients maps a user's stored config to the
+ correct LLM client instances."""
+
+ @patch("app.api.routes.chat.chat_orchestrator")
+ @patch("app.api.routes.chat.get_llm_client")
+ def test_no_config_returns_nones(self, mock_get, mock_orch):
+ result = resolve_llm_clients(_make_user(llm_config=None))
+ self.assertIsNone(result["orchestrator"])
+ self.assertIsNone(result["personas"])
+ mock_get.assert_not_called()
+
+ @patch("app.api.routes.chat.chat_orchestrator")
+ @patch("app.api.routes.chat.get_llm_client")
+ def test_uniform_same_client_for_all(self, mock_get, mock_orch):
+ mock_orch.personas = {"a": MagicMock(), "b": MagicMock(), "c": MagicMock()}
+ sentinel = MagicMock(name="shared_client")
+ mock_get.return_value = sentinel
+
+ user = _make_user(UserLLMConfig(mode="uniform", default_backend="ollama"))
+ result = resolve_llm_clients(user)
+
+ mock_get.assert_called_once_with("ollama")
+ self.assertIs(result["orchestrator"], sentinel)
+ for pid in ("a", "b", "c"):
+ self.assertIs(result["personas"][pid], sentinel)
+
+ @patch("app.api.routes.chat.chat_orchestrator")
+ @patch("app.api.routes.chat.get_llm_client")
+ def test_hybrid_orchestrator_override(self, mock_get, mock_orch):
+ mock_orch.personas = {"a": MagicMock()}
+ clients = {"gemini": MagicMock(), "ollama": MagicMock()}
+ mock_get.side_effect = lambda b: clients[b]
+
+ user = _make_user(UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ orchestrator_backend="ollama",
+ persona_backends={"a": "gemini"},
+ ))
+ result = resolve_llm_clients(user)
+ self.assertIs(result["orchestrator"], clients["ollama"])
+
+ @patch("app.api.routes.chat.chat_orchestrator")
+ @patch("app.api.routes.chat.get_llm_client")
+ def test_hybrid_orchestrator_falls_back_to_default(self, mock_get, mock_orch):
+ mock_orch.personas = {"a": MagicMock()}
+ sentinel = MagicMock()
+ mock_get.return_value = sentinel
+
+ user = _make_user(UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ persona_backends={"a": "gemini"},
+ ))
+ result = resolve_llm_clients(user)
+ self.assertIs(result["orchestrator"], sentinel)
+
+ @patch("app.api.routes.chat.chat_orchestrator")
+ @patch("app.api.routes.chat.get_llm_client")
+ def test_hybrid_persona_override(self, mock_get, mock_orch):
+ mock_orch.personas = {"a": MagicMock(), "b": MagicMock()}
+ clients = {"gemini": MagicMock(), "vllm": MagicMock()}
+ mock_get.side_effect = lambda b: clients[b]
+
+ user = _make_user(UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ persona_backends={"a": "vllm", "b": "gemini"},
+ ))
+ result = resolve_llm_clients(user)
+ self.assertIs(result["personas"]["a"], clients["vllm"])
+
+ @patch("app.api.routes.chat.chat_orchestrator")
+ @patch("app.api.routes.chat.get_llm_client")
+ def test_hybrid_unmapped_persona_uses_default(self, mock_get, mock_orch):
+ mock_orch.personas = {"a": MagicMock(), "unmapped": MagicMock()}
+ clients = {"gemini": MagicMock(), "vllm": MagicMock()}
+ mock_get.side_effect = lambda b: clients[b]
+
+ user = _make_user(UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ persona_backends={"a": "vllm"},
+ ))
+ result = resolve_llm_clients(user)
+ self.assertIs(result["personas"]["unmapped"], clients["gemini"])
+
+ @patch("app.api.routes.chat.chat_orchestrator")
+ @patch("app.api.routes.chat.get_llm_client")
+ def test_hybrid_mixed(self, mock_get, mock_orch):
+ mock_orch.personas = {"a": MagicMock(), "b": MagicMock()}
+ clients = {"gemini": MagicMock(), "ollama": MagicMock(), "vllm": MagicMock()}
+ mock_get.side_effect = lambda b: clients[b]
+
+ user = _make_user(UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ orchestrator_backend="ollama",
+ persona_backends={"a": "vllm"},
+ ))
+ result = resolve_llm_clients(user)
+ self.assertIs(result["orchestrator"], clients["ollama"])
+ self.assertIs(result["personas"]["a"], clients["vllm"])
+ self.assertIs(result["personas"]["b"], clients["gemini"])
+
+
+# ===================================================================
+# 3. switch_provider — endpoint validation
+# ===================================================================
+
+class TestSwitchProvider(unittest.IsolatedAsyncioTestCase):
+ """Validate the switch_provider endpoint's guard logic."""
+
+ @patch("app.api.routes.provider.get_database")
+ @patch("app.api.routes.provider.get_llm_client")
+ @patch("app.api.routes.provider.chat_orchestrator")
+ async def test_rejects_unknown_persona_id(self, mock_orch, mock_get, mock_db):
+ mock_orch.personas = {"known": MagicMock()}
+ mock_get.return_value = MagicMock()
+
+ cfg = UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ persona_backends={"unknown_id": "gemini"},
+ )
+ with self.assertRaises(HTTPException) as ctx:
+ await switch_provider(cfg, _make_user())
+ self.assertEqual(ctx.exception.status_code, 400)
+ self.assertIn("Unknown persona IDs", ctx.exception.detail)
+
+ @patch("app.api.routes.provider.get_database")
+ @patch("app.api.routes.provider.get_llm_client")
+ @patch("app.api.routes.provider.chat_orchestrator")
+ async def test_accepts_known_persona_ids(self, mock_orch, mock_get, mock_db):
+ mock_orch.personas = {"a": MagicMock(), "b": MagicMock()}
+ mock_get.return_value = MagicMock()
+ mock_db.return_value.users.update_one = AsyncMock()
+
+ cfg = UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ persona_backends={"a": "gemini", "b": "gemini"},
+ )
+ result = await switch_provider(cfg, _make_user())
+ self.assertEqual(result["message"], "LLM configuration updated")
+
+ @patch("app.api.routes.provider.get_database")
+ @patch("app.api.routes.provider.get_llm_client")
+ @patch("app.api.routes.provider.chat_orchestrator")
+ async def test_rejects_unconfigured_default_backend(self, mock_orch, mock_get, mock_db):
+ mock_orch.personas = {}
+ mock_get.side_effect = ValueError("not configured")
+
+ cfg = UserLLMConfig(mode="uniform", default_backend="ollama")
+ with self.assertRaises(HTTPException) as ctx:
+ await switch_provider(cfg, _make_user())
+ self.assertEqual(ctx.exception.status_code, 400)
+ self.assertIn("not configured", ctx.exception.detail)
+
+ @patch("app.api.routes.provider.get_database")
+ @patch("app.api.routes.provider.get_llm_client")
+ @patch("app.api.routes.provider.chat_orchestrator")
+ async def test_rejects_unconfigured_orchestrator_backend(self, mock_orch, mock_get, mock_db):
+ mock_orch.personas = {}
+
+ def side_effect(backend):
+ if backend == "vllm":
+ raise ValueError("no vLLM endpoint")
+ return MagicMock()
+ mock_get.side_effect = side_effect
+
+ cfg = UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ orchestrator_backend="vllm",
+ )
+ with self.assertRaises(HTTPException) as ctx:
+ await switch_provider(cfg, _make_user())
+ self.assertEqual(ctx.exception.status_code, 400)
+ self.assertIn("not configured", ctx.exception.detail)
+
+ @patch("app.api.routes.provider.get_database")
+ @patch("app.api.routes.provider.get_llm_client")
+ @patch("app.api.routes.provider.chat_orchestrator")
+ async def test_rejects_unconfigured_persona_backend(self, mock_orch, mock_get, mock_db):
+ mock_orch.personas = {"a": MagicMock()}
+
+ def side_effect(backend):
+ if backend == "vllm":
+ raise ValueError("no vLLM endpoint")
+ return MagicMock()
+ mock_get.side_effect = side_effect
+
+ cfg = UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ persona_backends={"a": "vllm"},
+ )
+ with self.assertRaises(HTTPException) as ctx:
+ await switch_provider(cfg, _make_user())
+ self.assertEqual(ctx.exception.status_code, 400)
+ self.assertIn("not configured", ctx.exception.detail)
+
+ @patch("app.api.routes.provider.get_database")
+ @patch("app.api.routes.provider.get_llm_client")
+ @patch("app.api.routes.provider.chat_orchestrator")
+ async def test_checks_all_distinct_backends(self, mock_orch, mock_get, mock_db):
+ mock_orch.personas = {"a": MagicMock(), "b": MagicMock()}
+ mock_get.return_value = MagicMock()
+ mock_db.return_value.users.update_one = AsyncMock()
+
+ cfg = UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ orchestrator_backend="ollama",
+ persona_backends={"a": "vllm", "b": "gemini"},
+ )
+ await switch_provider(cfg, _make_user())
+
+ checked = {call.args[0] for call in mock_get.call_args_list}
+ self.assertEqual(checked, {"gemini", "ollama", "vllm"})
+
+ @patch("app.api.routes.provider.get_database")
+ @patch("app.api.routes.provider.get_llm_client")
+ @patch("app.api.routes.provider.chat_orchestrator")
+ async def test_persists_to_database(self, mock_orch, mock_get, mock_db):
+ mock_orch.personas = {}
+ mock_get.return_value = MagicMock()
+ mock_collection = MagicMock()
+ mock_collection.update_one = AsyncMock()
+ mock_db.return_value.users = mock_collection
+
+ user = _make_user()
+ cfg = UserLLMConfig(mode="uniform", default_backend="ollama")
+ await switch_provider(cfg, user)
+
+ mock_collection.update_one.assert_awaited_once()
+ call_args = mock_collection.update_one.call_args
+ self.assertEqual(call_args[0][0], {"_id": user.id})
+ self.assertEqual(
+ call_args[0][1],
+ {"$set": {"llm_config": cfg.model_dump()}},
+ )
+
+ @patch("app.api.routes.provider.get_database")
+ @patch("app.api.routes.provider.get_llm_client")
+ @patch("app.api.routes.provider.chat_orchestrator")
+ async def test_returns_updated_config(self, mock_orch, mock_get, mock_db):
+ mock_orch.personas = {"a": MagicMock()}
+ mock_get.return_value = MagicMock()
+ mock_db.return_value.users.update_one = AsyncMock()
+
+ cfg = UserLLMConfig(
+ mode="hybrid", default_backend="gemini",
+ orchestrator_backend="ollama",
+ persona_backends={"a": "vllm"},
+ )
+ result = await switch_provider(cfg, _make_user())
+
+ self.assertEqual(result["message"], "LLM configuration updated")
+ self.assertEqual(result["llm_config"], cfg.model_dump())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/phd-advisor-frontend/src/components/AdvisorConfigPanel.js b/phd-advisor-frontend/src/components/AdvisorConfigPanel.js
new file mode 100644
index 00000000..90f0d153
--- /dev/null
+++ b/phd-advisor-frontend/src/components/AdvisorConfigPanel.js
@@ -0,0 +1,193 @@
+import React, { useEffect, useMemo, useState } from 'react';
+
+// Reusable per-advisor backend configuration panel.
+//
+// Used in two places:
+// - Welcome-state "Advanced" expander on ChatPage
+// - "Advisor Config" tab inside SettingsModal (lives on feat/UI-for-User-Account-updates;
+// drop this component in once branches merge)
+//
+// Controlled component. Parent owns the config object and decides when to persist.
+//
+// Shape of `value`:
+// { default_backend, orchestrator_backend, persona_backends: { [personaId]: backend } }
+//
+// The orchestrator and each advisor can be set to DEFAULT_BACKEND ("Default"),
+// meaning "follow default_backend". Call stripDefaultBackends() before persisting:
+// it drops those sentinels so the backend falls through to default_backend (and
+// moves them automatically when the default changes).
+
+export const DEFAULT_BACKEND = '__default__';
+
+export const stripDefaultBackends = (config) => {
+ if (!config) return config;
+ const source = config.persona_backends || {};
+ const cleaned = {};
+ for (const [id, backend] of Object.entries(source)) {
+ if (backend && backend !== DEFAULT_BACKEND) cleaned[id] = backend;
+ }
+ const orchestrator =
+ config.orchestrator_backend && config.orchestrator_backend !== DEFAULT_BACKEND
+ ? config.orchestrator_backend
+ : null;
+ return { ...config, orchestrator_backend: orchestrator, persona_backends: cleaned };
+};
+
+const rowStyle = {
+ display: 'grid', gridTemplateColumns: '1fr 180px', alignItems: 'center',
+ gap: 12, padding: '10px 0', borderBottom: '1px solid var(--border-primary)',
+};
+
+const selectStyle = {
+ padding: '8px 10px', borderRadius: 8, border: '1px solid var(--border-primary)',
+ background: 'var(--bg-secondary)', color: 'var(--text-primary)', fontSize: 13.5,
+ width: '100%', colorScheme: 'light dark',
+};
+
+// Native dropdown option list ignores the
);
diff --git a/phd-advisor-frontend/src/styles/ChatPage.css b/phd-advisor-frontend/src/styles/ChatPage.css
index a74bdd2f..4509e691 100644
--- a/phd-advisor-frontend/src/styles/ChatPage.css
+++ b/phd-advisor-frontend/src/styles/ChatPage.css
@@ -911,6 +911,12 @@
.provider-badge.gemini { color: #4285f4; }
.provider-badge.ollama { color: #10b981; }
.provider-badge.vllm { color: #ef4444; }
+.provider-badge.hybrid { color: #a855f7; }
+
+[data-theme="dark"] .provider-badge.hybrid,
+[data-theme="dark"] .provider-option-badge.hybrid {
+ color: #ffffff;
+}
.clarification-message-container {
display: flex;
@@ -1147,6 +1153,34 @@
color: #ef4444;
}
+.provider-option-badge.hybrid {
+ background: rgba(168, 85, 247, 0.15);
+ color: #a855f7;
+}
+
+.provider-option-configure {
+ margin-left: 8px;
+ background: transparent;
+ border: 1px solid var(--border-color, #d1d5db);
+ border-radius: 6px;
+ padding: 4px 6px;
+ cursor: pointer;
+ display: inline-flex;
+ align-items: center;
+ color: var(--text-secondary, #6b7280);
+}
+
+.provider-option-configure:hover {
+ background: var(--accent-primary);
+ color: #fff;
+ border-color: var(--accent-primary);
+}
+
+[data-theme="dark"] .provider-option-configure {
+ color: #ffffff;
+ border-color: rgba(255, 255, 255, 0.25);
+}
+
.provider-option-description {
font-size: 11px;
color: var(--text-tertiary);
diff --git a/phd_config.yaml b/phd_config.yaml
index 121605cd..bef68cf0 100644
--- a/phd_config.yaml
+++ b/phd_config.yaml
@@ -173,11 +173,15 @@ mongodb:
database_name: "phd_advisor"
llm:
+ default_backend: gemini
gemini:
+ enabled: true
model: "gemini-2.5-flash"
ollama:
+ enabled: false
model: "llama3.2:1b"
vllm:
+ enabled: true
api_url: https://rtx6000blackwell-1.neonaiservices2.com/vllm0
brainforge:
api_url: https://hana.neonaialpha.com
From a3a71bb0f03694efd1d3fb22606b271cf2e2579b Mon Sep 17 00:00:00 2001
From: NeonDaniel