diff --git a/multi_llm_chatbot_backend/app/api/routes/chat.py b/multi_llm_chatbot_backend/app/api/routes/chat.py index aaad4e0c..f01eb1f9 100644 --- a/multi_llm_chatbot_backend/app/api/routes/chat.py +++ b/multi_llm_chatbot_backend/app/api/routes/chat.py @@ -13,7 +13,7 @@ from app.api.utils import get_or_create_session_for_request_async from app.core.auth import get_current_active_user from app.config import get_settings -from app.core.bootstrap import chat_orchestrator +from app.core.bootstrap import chat_orchestrator, get_llm_client from app.core.database import get_database from app.core.persona_filter import get_available_persona_ids from app.core.session_manager import get_session_manager @@ -24,6 +24,41 @@ router = APIRouter() session_manager = get_session_manager() + +def resolve_llm_clients(user: User) -> Dict[str, Any]: + """Resolve LLM clients from a user's stored configuration. + + Returns ``{"orchestrator": LLMClient | None, "personas": {id: LLMClient} | None}``. + + - No saved config: both values are ``None``; callers fall back to + orchestrator/persona defaults. + - Uniform mode: the same cached client is returned for the orchestrator + and every persona. + - Hybrid mode: the orchestrator and each persona may receive different + clients based on the user's per-persona mapping. + """ + config = user.llm_config + if config is None: + return {"orchestrator": None, "personas": None} + + if config.mode == "uniform": + client = get_llm_client(config.default_backend) + persona_clients = { + pid: client for pid in chat_orchestrator.personas + } + return {"orchestrator": client, "personas": persona_clients} + + # Hybrid mode + orchestrator_backend = config.orchestrator_backend or config.default_backend + orchestrator_client = get_llm_client(orchestrator_backend) + + persona_clients = {} + for pid in chat_orchestrator.personas: + backend = (config.persona_backends or {}).get(pid, config.default_backend) + persona_clients[pid] = get_llm_client(backend) + + return {"orchestrator": orchestrator_client, "personas": persona_clients} + # Enhanced data models class UserInput(BaseModel): user_input: str @@ -81,6 +116,11 @@ async def chat_stream( async def _event_generator(): try: + # Resolve per-user LLM clients from their stored config + llm_clients = resolve_llm_clients(current_user) + orchestrator_llm = llm_clients["orchestrator"] + persona_llms = llm_clients["personas"] + # Load or create the in-memory session if message.chat_session_id: sid = f"chat_{message.chat_session_id}" @@ -107,7 +147,9 @@ async def _event_generator(): ).to_ndjson() if await chat_orchestrator.needs_clarification_improved(session, message.user_input): - clar = await chat_orchestrator.generate_contextual_clarification(message.user_input) + clar = await chat_orchestrator.generate_contextual_clarification( + message.user_input, llm_client=orchestrator_llm, + ) yield ChatStreamLine( type="clarification", data={ @@ -123,7 +165,9 @@ async def _event_generator(): # If an enabled tool can handle this query, return its response # directly and skip persona generation. - tool_result = await chat_orchestrator.get_tool_response(message.user_input) + tool_result = await chat_orchestrator.get_tool_response( + message.user_input, llm_client=orchestrator_llm, + ) if tool_result.used_tool: # Append user message to in-memory session and persist to MongoDB session.append_message("orchestrator", tool_result.text) @@ -164,6 +208,7 @@ async def _event_generator(): top_personas = await chat_orchestrator.get_top_personas( session_id=sid, allowed_ids=available, + llm_client=orchestrator_llm, ) # Guard against race condition where all selected advisors @@ -210,9 +255,11 @@ async def _run(pid: str) -> None: "document_chunks_used": 0, }) return + persona_llm = (persona_llms or {}).get(pid) result = await chat_orchestrator.generate_single_persona_response( session, persona, message.response_length or "medium", + llm_client=persona_llm, ) session.append_message(pid, result["response"]) await done_queue.put(result) @@ -390,7 +437,10 @@ async def create_new_chat( raise HTTPException(status_code=500, detail="Failed to create new chat") @router.post("/chat/{persona_id}") -async def chat_with_specific_advisor(persona_id: str, input: UserInput, request: Request): +async def chat_with_specific_advisor( + persona_id: str, input: UserInput, request: Request, + current_user: User = Depends(get_current_active_user), +): """Chat with a specific advisor - UPDATED""" try: if persona_id not in chat_orchestrator.personas: @@ -408,11 +458,15 @@ async def chat_with_specific_advisor(persona_id: str, input: UserInput, request: isExpandRequest=True, ), ) + + llm_clients = resolve_llm_clients(current_user) + persona_llm = (llm_clients["personas"] or {}).get(persona_id) result = await chat_orchestrator.chat_with_persona( user_input=input.user_input, persona_id=persona_id, - session_id=session_id + session_id=session_id, + llm_client=persona_llm, ) # Handle response structure @@ -479,7 +533,10 @@ async def chat_with_specific_advisor(persona_id: str, input: UserInput, request: } @router.post("/reply-to-advisor") -async def reply_to_advisor(reply: ReplyToAdvisor, request: Request): +async def reply_to_advisor( + reply: ReplyToAdvisor, request: Request, + current_user: User = Depends(get_current_active_user), +): """Reply to a specific advisor with proper context - UPDATED""" try: if reply.advisor_id not in chat_orchestrator.personas: @@ -520,10 +577,14 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request): if original_message: contextual_input = f"[Replying to your previous message: '{original_message[:100]}...'] {reply.user_input}" + llm_clients = resolve_llm_clients(current_user) + advisor_llm = (llm_clients["personas"] or {}).get(reply.advisor_id) + result = await chat_orchestrator.chat_with_persona( user_input=contextual_input, persona_id=reply.advisor_id, - session_id=session_id + session_id=session_id, + llm_client=advisor_llm, ) # Handle response structure @@ -600,15 +661,22 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request): } @router.post("/ask/") -async def ask_question(query: PersonaQuery, request: Request): +async def ask_question( + query: PersonaQuery, request: Request, + current_user: User = Depends(get_current_active_user), +): """Ask question - UPDATED""" try: session_id = await get_or_create_session_for_request_async(request) + llm_clients = resolve_llm_clients(current_user) + persona_llm = (llm_clients["personas"] or {}).get(query.persona) + result = await chat_orchestrator.chat_with_persona( user_input=query.question, persona_id=query.persona, - session_id=session_id + session_id=session_id, + llm_client=persona_llm, ) if result["type"] == "single_persona_response": diff --git a/multi_llm_chatbot_backend/app/api/routes/provider.py b/multi_llm_chatbot_backend/app/api/routes/provider.py index b7760e74..7d8b45f4 100644 --- a/multi_llm_chatbot_backend/app/api/routes/provider.py +++ b/multi_llm_chatbot_backend/app/api/routes/provider.py @@ -1,108 +1,72 @@ -from fastapi import APIRouter, Body, HTTPException -from app.config import get_settings -from app.llm.improved_gemini_client import ImprovedGeminiClient -from app.llm.improved_ollama_client import ImprovedOllamaClient -from app.llm.improved_vllm_client import ImprovedVllmClient -from app.models.default_personas import get_default_personas -from app.core.bootstrap import chat_orchestrator, llm, current_provider, available_providers -from app.core.brainforge_sync import BRAINFORGE_PERSONA_PREFIX -from pydantic import BaseModel -import os +from fastapi import APIRouter, Depends, HTTPException, status +from app.core.auth import get_current_active_user +from app.core.bootstrap import ( + chat_orchestrator, get_llm_client, AVAILABLE_BACKENDS, _is_backend_enabled, +) +from app.core.database import get_database +from app.models.user import User, UserLLMConfig import logging logger = logging.getLogger(__name__) router = APIRouter() -def create_llm_client(provider: str = None): - global current_provider - if provider is None: - provider = current_provider - - if provider == "gemini": - try: - return ImprovedGeminiClient(model_name=os.getenv("GEMINI_MODEL")) - except ValueError as e: - logger.warning(f"Gemini API key not found, falling back to Ollama: {e}") - return ImprovedOllamaClient(model_name="llama3.2:1b") - elif provider == "ollama": - return ImprovedOllamaClient(model_name="llama3.2:1b") - elif provider == "vllm": - settings = get_settings() - if not settings.llm.vllm.api_url: - raise ValueError("No vLLM endpoint configured. Set llm.vllm.api_url in your config.") - return ImprovedVllmClient( - api_url=settings.llm.vllm.api_url, - api_key=settings.llm.vllm.api_key, - ) - else: - raise ValueError(f"Unknown provider: {provider}") - -# Initialize LLM and personas -llm = create_llm_client(current_provider) -DEFAULT_PERSONAS = get_default_personas(llm) -for persona in DEFAULT_PERSONAS: - chat_orchestrator.register_persona(persona) - -class ProviderSwitch(BaseModel): - provider: str @router.get("/current-provider") -async def get_current_provider(): +async def get_current_provider( + current_user: User = Depends(get_current_active_user), +): + """Return the authenticated user's LLM configuration.""" + config = current_user.llm_config or UserLLMConfig() return { - "current_provider": current_provider, - "available_providers": available_providers, - "model_info": { - "name": llm.model_name if hasattr(llm, 'model_name') else "gemini-2.0-flash", - "provider": current_provider - } + "llm_config": config.model_dump(), + "available_backends": AVAILABLE_BACKENDS, } -@router.post("/switch-provider") -async def switch_provider(provider_data: ProviderSwitch): - global current_provider, llm - - if provider_data.provider not in available_providers: - raise HTTPException(status_code=400, detail=f"Unknown provider: {provider_data.provider}. Available: {available_providers}") - - try: - current_provider = provider_data.provider - new_llm = create_llm_client(current_provider) - llm = new_llm - chat_orchestrator.llm_client = new_llm - - new_personas = get_default_personas(new_llm) - # Clear only non-BrainForge personas; BF advisors have their own LLM clients - non_bf_ids = [pid for pid in chat_orchestrator.personas if not pid.startswith(f"{BRAINFORGE_PERSONA_PREFIX}_")] - for pid in non_bf_ids: - chat_orchestrator.unregister_persona(pid) - for persona in new_personas: - chat_orchestrator.register_persona(persona) - - return { - "message": f"Successfully switched to {current_provider}", - "current_provider": current_provider, - "model_info": { - "name": new_llm.model_name if hasattr(new_llm, 'model_name') else "gemini-2.0-flash", - "provider": current_provider - } - } - - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to switch to {provider_data.provider}: {str(e)}") - -@router.post("/switch-model") -async def switch_model(model_name: str = Body(...)): - if "gemini" in model_name.lower(): - return await switch_provider(ProviderSwitch(provider="gemini")) - else: - return await switch_provider(ProviderSwitch(provider="ollama")) +@router.post("/switch-provider") +async def switch_provider( + llm_config: UserLLMConfig, + current_user: User = Depends(get_current_active_user), +): + """Persist the user's LLM configuration to their profile.""" + if llm_config.mode == "hybrid" and llm_config.persona_backends: + registered = set(chat_orchestrator.personas.keys()) + unknown = set(llm_config.persona_backends.keys()) - registered + if unknown: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unknown persona IDs: {sorted(unknown)}. " + f"Valid IDs: {sorted(registered)}", + ) + + backends_to_check = {llm_config.default_backend} + if llm_config.orchestrator_backend: + backends_to_check.add(llm_config.orchestrator_backend) + if llm_config.persona_backends: + backends_to_check.update(llm_config.persona_backends.values()) + + for backend in backends_to_check: + if not _is_backend_enabled(backend): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Backend {backend!r} is disabled by the administrator.", + ) + try: + get_llm_client(backend) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Backend {backend!r} is not configured: {exc}", + ) + + db = get_database() + await db.users.update_one( + {"_id": current_user.id}, + {"$set": {"llm_config": llm_config.model_dump()}}, + ) -@router.get("/current-model") -async def get_current_model(): - model_name = llm.model_name if hasattr(llm, 'model_name') else "gemini-2.0-flash" return { - "model": model_name, - "provider": current_provider + "message": "LLM configuration updated", + "llm_config": llm_config.model_dump(), } diff --git a/multi_llm_chatbot_backend/app/config.py b/multi_llm_chatbot_backend/app/config.py index 753a1048..c8154837 100644 --- a/multi_llm_chatbot_backend/app/config.py +++ b/multi_llm_chatbot_backend/app/config.py @@ -253,6 +253,7 @@ def _warn_connection_envvar(self): class GeminiConfig(BaseModel): + enabled: bool = True api_key: str = Field(default=os.getenv("GEMINI_API_KEY")) model: str = "gemini-2.5-flash" @@ -272,12 +273,14 @@ def _warn_gemini_envvar(self): class OllamaConfig(BaseModel): + enabled: bool = True model: str = "llama3.2:1b" # TODO: Drop support for `OLLAMA_BASE_URL` envvar handling base_url: str = Field(default=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")) class VllmConfig(BaseModel): + enabled: bool = True api_url: str = "" api_key: str = Field(default=os.getenv("VLLM_API_KEY", "")) @@ -290,10 +293,12 @@ class BrainForgeConfig(BaseModel): class LLMConfig(BaseModel): + default_backend: str = "" gemini: GeminiConfig = GeminiConfig() ollama: OllamaConfig = OllamaConfig() vllm: VllmConfig = VllmConfig() brainforge: BrainForgeConfig = BrainForgeConfig() + health_check_interval_seconds: int = 300 class RAGConfig(BaseModel): diff --git a/multi_llm_chatbot_backend/app/core/bootstrap.py b/multi_llm_chatbot_backend/app/core/bootstrap.py index c08d873e..99283643 100644 --- a/multi_llm_chatbot_backend/app/core/bootstrap.py +++ b/multi_llm_chatbot_backend/app/core/bootstrap.py @@ -1,22 +1,34 @@ # app/core/bootstrap.py +import asyncio +import logging + from app.config import get_settings from app.llm.improved_gemini_client import ImprovedGeminiClient from app.llm.improved_ollama_client import ImprovedOllamaClient from app.llm.improved_vllm_client import ImprovedVllmClient from app.core.improved_orchestrator import ImprovedChatOrchestrator from app.models.default_personas import get_default_personas +from app.models.user import LLM_BACKENDS +from app.llm.llm_client import LLMClient + +logger = logging.getLogger(__name__) settings = get_settings() -current_provider = "gemini" -available_providers = ["ollama", "gemini", "vllm"] +_client_cache = {} -def create_llm_client(provider=None): - if provider is None: - provider = current_provider - if provider == "gemini": + +def create_llm_client(backend: str = None): + """Create an LLM client for the given backend name.""" + if backend is None: + backend = settings.llm.default_backend + if backend not in LLM_BACKENDS: + raise ValueError( + f"Unknown backend {backend!r}. Must be one of {LLM_BACKENDS}" + ) + if backend == "gemini": return ImprovedGeminiClient(model_name=settings.llm.gemini.model) - elif provider == "vllm": + elif backend == "vllm": if not settings.llm.vllm.api_url: raise ValueError("No vLLM endpoint configured. Set llm.vllm.api_url in your config.") return ImprovedVllmClient( @@ -29,7 +41,74 @@ def create_llm_client(provider=None): base_url=settings.llm.ollama.base_url, ) -llm = create_llm_client() + +def get_llm_client(backend: str) -> LLMClient: + """Return a cached LLM client for *backend*, creating it on first access.""" + if backend not in _client_cache: + _client_cache[backend] = create_llm_client(backend) + return _client_cache[backend] + + +def _is_backend_enabled(backend: str) -> bool: + """Check whether *backend* is enabled in the admin config.""" + backend_config = getattr(settings.llm, backend, None) + return getattr(backend_config, "enabled", True) + + +def get_available_backends() -> list: + """Return backends that are enabled and properly configured (sync, used at startup).""" + available = [] + for backend in LLM_BACKENDS: + if not _is_backend_enabled(backend): + continue + try: + get_llm_client(backend) + available.append(backend) + except Exception: + pass + return available + + +async def refresh_available_backends(): + """Re-check which backends are enabled, configured, and reachable.""" + available = [] + for backend in LLM_BACKENDS: + if not _is_backend_enabled(backend): + continue + try: + client = get_llm_client(backend) + if await client.health_check(): + available.append(backend) + except Exception: + pass + AVAILABLE_BACKENDS[:] = available + + +async def _backend_health_loop(): + """Background task that periodically refreshes AVAILABLE_BACKENDS.""" + interval = settings.llm.health_check_interval_seconds + while True: + await refresh_available_backends() + await asyncio.sleep(interval) + + +# Resolve the default backend: prefer the configured default, fall back to the first available. +AVAILABLE_BACKENDS = get_available_backends() +if settings.llm.default_backend in AVAILABLE_BACKENDS: + DEFAULT_BACKEND = settings.llm.default_backend +elif AVAILABLE_BACKENDS: + DEFAULT_BACKEND = AVAILABLE_BACKENDS[0] + logger.warning( + "Configured default_backend %r is not available; falling back to %r", + settings.llm.default_backend, DEFAULT_BACKEND, + ) +else: + raise RuntimeError( + "No LLM backends are available. Check your config.yaml — " + "at least one backend must be enabled and properly configured." + ) +llm = create_llm_client(DEFAULT_BACKEND) +_client_cache[DEFAULT_BACKEND] = llm chat_orchestrator = ImprovedChatOrchestrator(llm_client=llm) DEFAULT_PERSONAS = get_default_personas(llm) diff --git a/multi_llm_chatbot_backend/app/core/improved_orchestrator.py b/multi_llm_chatbot_backend/app/core/improved_orchestrator.py index 12f63a78..774c8687 100644 --- a/multi_llm_chatbot_backend/app/core/improved_orchestrator.py +++ b/multi_llm_chatbot_backend/app/core/improved_orchestrator.py @@ -45,7 +45,8 @@ def list_personas(self) -> List[str]: """List all available persona IDs""" return list(self.personas.keys()) - async def get_tool_response(self, user_message: str) -> ToolCallResult: + async def get_tool_response(self, user_message: str, + llm_client: LLMClient = None) -> ToolCallResult: """Check whether a tool can handle *user_message*. If tools are disabled in config, no LLM client is available, or the @@ -53,7 +54,8 @@ async def get_tool_response(self, user_message: str) -> ToolCallResult: ``ToolCallResult(used_tool=False)``. Otherwise executes the tool and returns the grounded response with ``used_tool=True``. """ - if self.llm_client is None: + effective_llm = llm_client or self.llm_client + if effective_llm is None: return ToolCallResult(text="", used_tool=False) settings = get_settings() @@ -80,7 +82,7 @@ async def get_tool_response(self, user_message: str) -> ToolCallResult: "to present structured data like course listings or professor ratings." ) - return await self.llm_client.generate_with_tools( + return await effective_llm.generate_with_tools( system_prompt=system_prompt, user_message=user_message, tool_definitions=tool_definitions, @@ -325,7 +327,8 @@ async def needs_clarification_improved(self, session: ConversationContext, user_ logger.warning("Falling back to rule-based clarification check") return self.needs_clarification(session, user_input) - async def generate_contextual_clarification(self, user_input: str) -> Dict[str, Any]: + async def generate_contextual_clarification(self, user_input: str, + llm_client: LLMClient = None) -> Dict[str, Any]: """ Use the LLM to produce a clarification question and clickable suggestions that are tailored to what the user actually typed. @@ -355,10 +358,8 @@ async def generate_contextual_clarification(self, user_input: str) -> Dict[str, ) try: - # Use the orchestrator's own LLM rather than a persona's — BrainForge - # persona LLMs may not support the prompt format used here. - llm = self.llm_client - raw = await llm.generate( + effective_llm = llm_client or self.llm_client + raw = await effective_llm.generate( system_prompt=system_prompt, context=[{"role": "user", "content": user_prompt}], temperature=0.4, @@ -391,9 +392,14 @@ async def generate_contextual_clarification(self, user_input: str) -> Dict[str, "suggestions": fallback_suggestions, } - async def generate_persona_responses(self, session: ConversationContext, response_length: str = "medium"): + async def generate_persona_responses(self, session: ConversationContext, + response_length: str = "medium", + llm_clients: Dict[str, LLMClient] = None): """ - Generate responses from all personas with enhanced RAG integration + Generate responses from all personas with enhanced RAG integration. + + *llm_clients* maps persona IDs to the LLM client each should use. + Personas not present in the dict fall back to their default client. """ responses = [] @@ -401,7 +407,10 @@ async def generate_persona_responses(self, session: ConversationContext, respons logger.info(f"Generating response for {persona_id} with enhanced RAG") # Generate persona response with enhanced RAG - response_data = await self.generate_single_persona_response(session, persona, response_length) + persona_llm = (llm_clients or {}).get(persona_id) + response_data = await self.generate_single_persona_response( + session, persona, response_length, llm_client=persona_llm, + ) # Add persona response to session context session.append_message(persona_id, response_data["response"]) @@ -410,9 +419,14 @@ async def generate_persona_responses(self, session: ConversationContext, respons return responses - async def generate_single_persona_response(self, session, persona, response_length: str = "medium"): + async def generate_single_persona_response(self, session, persona, + response_length: str = "medium", + llm_client: LLMClient = None): """ - Enhanced version - Generate response from a single persona with enhanced RAG integration + Enhanced version - Generate response from a single persona with enhanced RAG integration. + + *llm_client* is forwarded to ``persona.respond()``; when ``None`` the + persona uses its default (system-default) client. """ try: # Get the user's latest message for document retrieval @@ -441,7 +455,7 @@ async def generate_single_persona_response(self, session, persona, response_leng ) # Generate response with enhanced context - response = await persona.respond(enhanced_context, response_length) + response = await persona.respond(enhanced_context, response_length, llm=llm_client) # Validate and improve response quality if not self._is_valid_response(response, persona.id): @@ -813,9 +827,13 @@ def _get_persona_context_keywords(self, persona_id: str) -> str: """ return self._get_enhanced_persona_context_keywords(persona_id) - async def chat_with_persona(self, user_input: str, persona_id: str, session_id: str, response_length: str = "medium") -> Dict[str, Any]: + async def chat_with_persona(self, user_input: str, persona_id: str, + session_id: str, response_length: str = "medium", + llm_client: LLMClient = None) -> Dict[str, Any]: """ - Chat with a specific persona directly - FIXED for consistent document access + Chat with a specific persona directly - FIXED for consistent document access. + + *llm_client* is forwarded to the persona's response generation. """ try: persona = self.get_persona(persona_id) @@ -838,7 +856,9 @@ async def chat_with_persona(self, user_input: str, persona_id: str, session_id: logger.info(f"Generating response for {persona_id} with session {session_id}") # Generate response from single persona using consistent session ID - response_data = await self.generate_single_persona_response(session, persona, response_length) + response_data = await self.generate_single_persona_response( + session, persona, response_length, llm_client=llm_client, + ) # Add response to session session.append_message(persona_id, response_data["response"]) @@ -884,7 +904,8 @@ async def chat_with_persona(self, user_input: str, persona_id: str, session_id: async def get_top_personas(self, session_id: str, k: int = 3, - allowed_ids: Optional[List[str]] = None) -> List[str]: + allowed_ids: Optional[List[str]] = None, + llm_client: LLMClient = None) -> List[str]: """ Use the LLM to rank personas based on current session context. Falls back to default persona order if LLM fails or returns invalid data. @@ -902,9 +923,7 @@ async def get_top_personas(self, session_id: str, k: int = 3, logger.warning("No personas available after filtering.") return [] - # Use the orchestrator's own LLM rather than a persona's — BrainForge - # persona LLMs may not support the prompt format used here. - llm = self.llm_client + effective_llm = llm_client or self.llm_client # Use recent conversation context (last 5 messages) recent_context = "\n".join( @@ -935,7 +954,7 @@ async def get_top_personas(self, session_id: str, k: int = 3, {persona_descriptions} """.strip() - llm_response = await llm.generate( + llm_response = await effective_llm.generate( system_prompt=f"You are an assistant that selects the best advisors for a user of {app_title}.", context=[{"role": "user", "content": prompt}], temperature=0.4, diff --git a/multi_llm_chatbot_backend/app/llm/improved_ollama_client.py b/multi_llm_chatbot_backend/app/llm/improved_ollama_client.py index 63074523..ccf0d322 100644 --- a/multi_llm_chatbot_backend/app/llm/improved_ollama_client.py +++ b/multi_llm_chatbot_backend/app/llm/improved_ollama_client.py @@ -110,4 +110,12 @@ def _is_poor_quality(self, response: str) -> bool: len(response.split()) > 150, # Too verbose response.count("?") > 3, # Too many questions ] - return any(poor_indicators) \ No newline at end of file + return any(poor_indicators) + + async def health_check(self) -> bool: + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(f"{self.base_url}/api/tags") + return resp.is_success + except Exception: + return False \ No newline at end of file diff --git a/multi_llm_chatbot_backend/app/llm/improved_vllm_client.py b/multi_llm_chatbot_backend/app/llm/improved_vllm_client.py index de9a927b..5b568358 100644 --- a/multi_llm_chatbot_backend/app/llm/improved_vllm_client.py +++ b/multi_llm_chatbot_backend/app/llm/improved_vllm_client.py @@ -187,4 +187,14 @@ async def generate_with_tools( used_tool=False, ) + async def health_check(self) -> bool: + import httpx + try: + headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(f"{self.api_url}/v1/models", headers=headers) + return resp.is_success + except Exception: + return False + diff --git a/multi_llm_chatbot_backend/app/llm/llm_client.py b/multi_llm_chatbot_backend/app/llm/llm_client.py index 995500da..74ee8f86 100644 --- a/multi_llm_chatbot_backend/app/llm/llm_client.py +++ b/multi_llm_chatbot_backend/app/llm/llm_client.py @@ -67,6 +67,15 @@ async def generate_with_tools( ) return ToolCallResult(text=text, used_tool=False) + async def health_check(self) -> bool: + """Check if the backend service is reachable. + + Subclasses for self-hosted services should override this with a + lightweight network probe. The default returns True (suitable for + cloud APIs where the config check is sufficient). + """ + return True + def _clean_response(self, response: str) -> str: """Clean up response text, preserving Markdown formatting.""" response = response.replace("\r\n", "\n").replace("\r", "\n") diff --git a/multi_llm_chatbot_backend/app/main.py b/multi_llm_chatbot_backend/app/main.py index 677df8f2..d8214fcb 100644 --- a/multi_llm_chatbot_backend/app/main.py +++ b/multi_llm_chatbot_backend/app/main.py @@ -11,6 +11,8 @@ from fastapi.staticfiles import StaticFiles from contextlib import asynccontextmanager +import asyncio + # Load configuration FIRST so every module can use it from app.config import load_settings from app.version import __version__ @@ -18,6 +20,7 @@ # Import the new database functions from app.core.database import connect_to_mongo, close_mongo_connection +from app.core.bootstrap import _backend_health_loop # Import all route modules from app.api.routes import router as main_router @@ -41,9 +44,11 @@ async def lifespan(app: FastAPI): from app.core.brainforge_sync import async_sync_brainforge_personas, periodic_sync_loop await async_sync_brainforge_personas(chat_orchestrator) sync_task = asyncio.create_task(periodic_sync_loop(chat_orchestrator)) + health_task = asyncio.create_task(_backend_health_loop()) yield # Shutdown sync_task.cancel() + health_task.cancel() await close_mongo_connection() app = FastAPI( @@ -119,6 +124,8 @@ def get_public_config(): "dark_color": colors["dark_color"], "dark_bg_color": colors["dark_bg_color"], "image": "icon://Brain", + "backend_locked": True, + "default_backend": "brainforge", }) return config diff --git a/multi_llm_chatbot_backend/app/models/persona.py b/multi_llm_chatbot_backend/app/models/persona.py index ed34caf9..2a19034a 100644 --- a/multi_llm_chatbot_backend/app/models/persona.py +++ b/multi_llm_chatbot_backend/app/models/persona.py @@ -245,10 +245,14 @@ def __init__(self, id: str, name: str, system_prompt: str, llm: LLMClient, tempe self.llm = llm self.temperature = temperature - async def respond(self, context: List[Dict], response_length: str = "medium") -> str: + async def respond(self, context: List[Dict], response_length: str = "medium", + llm: LLMClient = None) -> str: """Generate a compact, well-formed Markdown response suitable for the UI. - Returns the compact Markdown string (backward compatible with previous callers). + + *llm* overrides the default client for this call (used for per-user + backend selection). Falls back to ``self.llm`` when not provided. """ + effective_llm = llm or self.llm max_tokens = MAX_TOKENS_MAP.get(response_length, 500) structure_hint = STRUCTURE_HINTS.get(response_length, STRUCTURE_HINTS["medium"]) temp_scaled = round(self.temperature / 10, 2) @@ -259,7 +263,7 @@ async def respond(self, context: List[Dict], response_length: str = "medium") -> f"{structure_hint}" ) - raw_text = await self.llm.generate( + raw_text = await effective_llm.generate( system_prompt=full_prompt, context=context, temperature=temp_scaled, diff --git a/multi_llm_chatbot_backend/app/models/user.py b/multi_llm_chatbot_backend/app/models/user.py index 24b08957..32ecd31f 100644 --- a/multi_llm_chatbot_backend/app/models/user.py +++ b/multi_llm_chatbot_backend/app/models/user.py @@ -1,8 +1,36 @@ from pydantic import BaseModel, EmailStr, Field, ConfigDict, model_validator -from typing import Literal, Optional, List, Any +from typing import Dict, Literal, Optional, List, Any, get_args from datetime import datetime from bson import ObjectId +BackendName = Literal["gemini", "ollama", "vllm"] +LLM_BACKENDS = get_args(BackendName) + + +class UserLLMConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + """Per-user LLM provider configuration. + + Uniform mode: all advisors and the orchestrator use ``default_backend``. + Hybrid mode: each advisor can use a different backend; ``default_backend`` + is the fallback for any persona not explicitly mapped. + """ + mode: Literal["uniform", "hybrid"] = "uniform" + default_backend: BackendName = "gemini" + orchestrator_backend: Optional[BackendName] = None + persona_backends: Optional[Dict[str, BackendName]] = None + + @model_validator(mode="after") + def _validate_hybrid_fields(self): + if self.mode == "hybrid": + if not self.orchestrator_backend and not self.persona_backends: + self.orchestrator_backend = self.default_backend + else: + self.orchestrator_backend = None + self.persona_backends = None + return self + class PyObjectId(ObjectId): @classmethod def __get_validators__(cls): @@ -48,6 +76,7 @@ class User(BaseModel): academicStage: Optional[str] = None researchArea: Optional[str] = None disabled_advisors: Optional[List[str]] = None + llm_config: Optional[UserLLMConfig] = None created_at: datetime = Field(default_factory=datetime.utcnow) last_login: Optional[datetime] = None is_active: bool = True diff --git a/multi_llm_chatbot_backend/app/tests/unit/conftest.py b/multi_llm_chatbot_backend/app/tests/unit/conftest.py index 1f539bd3..09b73bad 100644 --- a/multi_llm_chatbot_backend/app/tests/unit/conftest.py +++ b/multi_llm_chatbot_backend/app/tests/unit/conftest.py @@ -1,37 +1,38 @@ -"""Session-wide stubs for modules that do heavy work at import time. - -``app.api.routes.__init__`` eagerly imports every sibling route, and -``app.api.routes.provider`` instantiates real LLM clients at module -load. ``app.core.bootstrap`` and ``app.core.rag_manager`` likewise -start NLTK, ChromaDB, and the LLM stack the moment they are imported. - -We install harmless ``MagicMock`` substitutes for those modules once, -before any test file in this directory is collected, so every test -gets a consistent, importable view of ``app.api.routes`` without -having to reproduce the same stubbing recipe in every test module. - -Tests that want to exercise the real version of a specific route -module (for example, ``test_version.py`` wanting the real -``app.api.routes.root``) can still pop their target out of -``sys.modules`` in their own setup -- they no longer have to -coordinate cleanup with peer test modules. +"""Session-wide stubs for heavy-import modules. + +``app.core.rag_manager`` starts NLTK / ChromaDB the moment it is +imported, so we replace it with a ``MagicMock`` before any test is +collected. + +``app.core.bootstrap`` (and the route modules that import it) can load +normally because we pre-set ``GEMINI_API_KEY`` and ``CONFIG_PATH`` +before any import occurs. This lets ``get_settings()``, the LLM-client +constructors, and the orchestrator initialise without real credentials +or config files. + +Route modules that are *not* under direct test (documents, sessions, +debug, phd_canvas) are still replaced with lightweight stubs so their +dependency trees are never pulled in. """ +import os import sys from unittest.mock import MagicMock from fastapi import APIRouter +os.environ.setdefault("GEMINI_API_KEY", "fake-test-key") +os.environ.setdefault("CONFIG_PATH", "") -for _name in ("app.core.bootstrap", "app.core.rag_manager"): - sys.modules.setdefault(_name, MagicMock()) +# rag_manager triggers NLTK / ChromaDB on import — always stub it. +sys.modules.setdefault("app.core.rag_manager", MagicMock()) +# Stub route modules that are NOT under direct test to avoid pulling +# in their full dependency trees when app.api.routes.__init__ runs. _stub_router_module = MagicMock(router=APIRouter()) for _name in ( - "app.api.routes.chat", "app.api.routes.documents", "app.api.routes.sessions", - "app.api.routes.provider", "app.api.routes.debug", "app.api.routes.phd_canvas", ): diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_available_backends.py b/multi_llm_chatbot_backend/app/tests/unit/test_available_backends.py new file mode 100644 index 00000000..c9bb362b --- /dev/null +++ b/multi_llm_chatbot_backend/app/tests/unit/test_available_backends.py @@ -0,0 +1,114 @@ +import asyncio +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.core.bootstrap import ( + get_available_backends, + refresh_available_backends, + AVAILABLE_BACKENDS, + LLM_BACKENDS, +) + + +class TestGetAvailableBackends(unittest.TestCase): + """Sync config-only check used at startup.""" + + @patch("app.core.bootstrap.get_llm_client") + def test_all_configured(self, mock_get_client): + mock_get_client.return_value = MagicMock() + result = get_available_backends() + self.assertEqual(result, ["gemini", "ollama", "vllm"]) + + @patch("app.core.bootstrap.get_llm_client") + def test_vllm_not_configured(self, mock_get_client): + def side_effect(backend): + if backend == "vllm": + raise ValueError("No vLLM endpoint configured.") + return MagicMock() + mock_get_client.side_effect = side_effect + result = get_available_backends() + self.assertEqual(result, ["gemini", "ollama"]) + + @patch("app.core.bootstrap.get_llm_client") + def test_gemini_not_configured(self, mock_get_client): + def side_effect(backend): + if backend == "gemini": + raise ValueError("No Gemini endpoint configured.") + return MagicMock() + mock_get_client.side_effect = side_effect + result = get_available_backends() + self.assertEqual(result, ["ollama", "vllm"]) + + @patch("app.core.bootstrap.get_llm_client") + def test_none_configured(self, mock_get_client): + mock_get_client.side_effect = ValueError("not configured") + result = get_available_backends() + self.assertEqual(result, []) + + +class TestRefreshAvailableBackends(unittest.TestCase): + """Async health-check refresh.""" + + @patch("app.core.bootstrap.get_llm_client") + def test_all_healthy(self, mock_get_client): + mock_client = MagicMock() + mock_client.health_check = AsyncMock(return_value=True) + mock_get_client.return_value = mock_client + + asyncio.run(refresh_available_backends()) + self.assertEqual(list(AVAILABLE_BACKENDS), ["gemini", "ollama", "vllm"]) + + @patch("app.core.bootstrap.get_llm_client") + def test_vllm_unhealthy(self, mock_get_client): + def side_effect(backend): + client = MagicMock() + client.health_check = AsyncMock(return_value=(backend != "vllm")) + return client + mock_get_client.side_effect = side_effect + + asyncio.run(refresh_available_backends()) + self.assertEqual(list(AVAILABLE_BACKENDS), ["gemini", "ollama"]) + + @patch("app.core.bootstrap.get_llm_client") + def test_health_check_exception(self, mock_get_client): + def side_effect(backend): + client = MagicMock() + if backend == "ollama": + client.health_check = AsyncMock(side_effect=Exception("timeout")) + else: + client.health_check = AsyncMock(return_value=True) + return client + mock_get_client.side_effect = side_effect + + asyncio.run(refresh_available_backends()) + self.assertEqual(list(AVAILABLE_BACKENDS), ["gemini", "vllm"]) + + @patch("app.core.bootstrap.get_llm_client") + def test_unconfigured_backend_excluded(self, mock_get_client): + def side_effect(backend): + if backend == "vllm": + raise ValueError("No vLLM endpoint configured.") + client = MagicMock() + client.health_check = AsyncMock(return_value=True) + return client + mock_get_client.side_effect = side_effect + + asyncio.run(refresh_available_backends()) + self.assertEqual(list(AVAILABLE_BACKENDS), ["gemini", "ollama"]) + + @patch("app.core.bootstrap.get_llm_client") + def test_gemini_unconfigured_excluded(self, mock_get_client): + def side_effect(backend): + if backend == "gemini": + raise ValueError("No Gemini endpoint configured.") + client = MagicMock() + client.health_check = AsyncMock(return_value=True) + return client + mock_get_client.side_effect = side_effect + + asyncio.run(refresh_available_backends()) + self.assertEqual(list(AVAILABLE_BACKENDS), ["ollama", "vllm"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_llm_provider_config.py b/multi_llm_chatbot_backend/app/tests/unit/test_llm_provider_config.py new file mode 100644 index 00000000..bcac69d9 --- /dev/null +++ b/multi_llm_chatbot_backend/app/tests/unit/test_llm_provider_config.py @@ -0,0 +1,364 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from bson import ObjectId +from fastapi import HTTPException +from pydantic import ValidationError + +from app.api.routes.chat import resolve_llm_clients +from app.api.routes.provider import switch_provider +from app.models.user import User, UserLLMConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_user(llm_config=None): + return User( + _id=ObjectId(), + firstName="Test", + lastName="User", + email="test@example.com", + hashed_password="fakehash", + llm_config=llm_config, + ) + + +# =================================================================== +# 1. UserLLMConfig — Pydantic model validation +# =================================================================== + +class TestUserLLMConfig(unittest.TestCase): + """Validate the UserLLMConfig Pydantic model and its _validate_hybrid_fields + model validator.""" + + def test_defaults(self): + cfg = UserLLMConfig() + self.assertEqual(cfg.mode, "uniform") + self.assertEqual(cfg.default_backend, "gemini") + self.assertIsNone(cfg.orchestrator_backend) + self.assertIsNone(cfg.persona_backends) + + def test_uniform_strips_hybrid_fields(self): + cfg = UserLLMConfig( + mode="uniform", + default_backend="gemini", + orchestrator_backend="ollama", + persona_backends={"x": "vllm"}, + ) + self.assertIsNone(cfg.orchestrator_backend) + self.assertIsNone(cfg.persona_backends) + + def test_hybrid_without_overrides_falls_back_to_default(self): + cfg = UserLLMConfig(mode="hybrid", default_backend="gemini") + self.assertEqual(cfg.orchestrator_backend, "gemini") + self.assertIsNone(cfg.persona_backends) + + def test_hybrid_with_orchestrator_only(self): + cfg = UserLLMConfig( + mode="hybrid", + default_backend="gemini", + orchestrator_backend="ollama", + ) + self.assertEqual(cfg.orchestrator_backend, "ollama") + self.assertIsNone(cfg.persona_backends) + + def test_hybrid_with_persona_backends_only(self): + cfg = UserLLMConfig( + mode="hybrid", + default_backend="gemini", + persona_backends={"advisor_1": "vllm"}, + ) + self.assertIsNone(cfg.orchestrator_backend) + self.assertEqual(cfg.persona_backends, {"advisor_1": "vllm"}) + + def test_hybrid_with_both(self): + cfg = UserLLMConfig( + mode="hybrid", + default_backend="gemini", + orchestrator_backend="ollama", + persona_backends={"advisor_1": "vllm"}, + ) + self.assertEqual(cfg.orchestrator_backend, "ollama") + self.assertEqual(cfg.persona_backends, {"advisor_1": "vllm"}) + + def test_rejects_unknown_backend_name(self): + with self.assertRaises(ValidationError): + UserLLMConfig(default_backend="claude") + + def test_extra_fields_forbidden(self): + with self.assertRaises(ValidationError): + UserLLMConfig(default_backend="gemini", surprise="boom") + + def test_model_dump_roundtrip(self): + original = UserLLMConfig( + mode="hybrid", + default_backend="ollama", + orchestrator_backend="gemini", + persona_backends={"a": "vllm", "b": "gemini"}, + ) + restored = UserLLMConfig(**original.model_dump()) + self.assertEqual(original.model_dump(), restored.model_dump()) + + +# =================================================================== +# 2. resolve_llm_clients — chat routing logic +# =================================================================== + +class TestResolveLlmClients(unittest.TestCase): + """Verify that resolve_llm_clients maps a user's stored config to the + correct LLM client instances.""" + + @patch("app.api.routes.chat.chat_orchestrator") + @patch("app.api.routes.chat.get_llm_client") + def test_no_config_returns_nones(self, mock_get, mock_orch): + result = resolve_llm_clients(_make_user(llm_config=None)) + self.assertIsNone(result["orchestrator"]) + self.assertIsNone(result["personas"]) + mock_get.assert_not_called() + + @patch("app.api.routes.chat.chat_orchestrator") + @patch("app.api.routes.chat.get_llm_client") + def test_uniform_same_client_for_all(self, mock_get, mock_orch): + mock_orch.personas = {"a": MagicMock(), "b": MagicMock(), "c": MagicMock()} + sentinel = MagicMock(name="shared_client") + mock_get.return_value = sentinel + + user = _make_user(UserLLMConfig(mode="uniform", default_backend="ollama")) + result = resolve_llm_clients(user) + + mock_get.assert_called_once_with("ollama") + self.assertIs(result["orchestrator"], sentinel) + for pid in ("a", "b", "c"): + self.assertIs(result["personas"][pid], sentinel) + + @patch("app.api.routes.chat.chat_orchestrator") + @patch("app.api.routes.chat.get_llm_client") + def test_hybrid_orchestrator_override(self, mock_get, mock_orch): + mock_orch.personas = {"a": MagicMock()} + clients = {"gemini": MagicMock(), "ollama": MagicMock()} + mock_get.side_effect = lambda b: clients[b] + + user = _make_user(UserLLMConfig( + mode="hybrid", default_backend="gemini", + orchestrator_backend="ollama", + persona_backends={"a": "gemini"}, + )) + result = resolve_llm_clients(user) + self.assertIs(result["orchestrator"], clients["ollama"]) + + @patch("app.api.routes.chat.chat_orchestrator") + @patch("app.api.routes.chat.get_llm_client") + def test_hybrid_orchestrator_falls_back_to_default(self, mock_get, mock_orch): + mock_orch.personas = {"a": MagicMock()} + sentinel = MagicMock() + mock_get.return_value = sentinel + + user = _make_user(UserLLMConfig( + mode="hybrid", default_backend="gemini", + persona_backends={"a": "gemini"}, + )) + result = resolve_llm_clients(user) + self.assertIs(result["orchestrator"], sentinel) + + @patch("app.api.routes.chat.chat_orchestrator") + @patch("app.api.routes.chat.get_llm_client") + def test_hybrid_persona_override(self, mock_get, mock_orch): + mock_orch.personas = {"a": MagicMock(), "b": MagicMock()} + clients = {"gemini": MagicMock(), "vllm": MagicMock()} + mock_get.side_effect = lambda b: clients[b] + + user = _make_user(UserLLMConfig( + mode="hybrid", default_backend="gemini", + persona_backends={"a": "vllm", "b": "gemini"}, + )) + result = resolve_llm_clients(user) + self.assertIs(result["personas"]["a"], clients["vllm"]) + + @patch("app.api.routes.chat.chat_orchestrator") + @patch("app.api.routes.chat.get_llm_client") + def test_hybrid_unmapped_persona_uses_default(self, mock_get, mock_orch): + mock_orch.personas = {"a": MagicMock(), "unmapped": MagicMock()} + clients = {"gemini": MagicMock(), "vllm": MagicMock()} + mock_get.side_effect = lambda b: clients[b] + + user = _make_user(UserLLMConfig( + mode="hybrid", default_backend="gemini", + persona_backends={"a": "vllm"}, + )) + result = resolve_llm_clients(user) + self.assertIs(result["personas"]["unmapped"], clients["gemini"]) + + @patch("app.api.routes.chat.chat_orchestrator") + @patch("app.api.routes.chat.get_llm_client") + def test_hybrid_mixed(self, mock_get, mock_orch): + mock_orch.personas = {"a": MagicMock(), "b": MagicMock()} + clients = {"gemini": MagicMock(), "ollama": MagicMock(), "vllm": MagicMock()} + mock_get.side_effect = lambda b: clients[b] + + user = _make_user(UserLLMConfig( + mode="hybrid", default_backend="gemini", + orchestrator_backend="ollama", + persona_backends={"a": "vllm"}, + )) + result = resolve_llm_clients(user) + self.assertIs(result["orchestrator"], clients["ollama"]) + self.assertIs(result["personas"]["a"], clients["vllm"]) + self.assertIs(result["personas"]["b"], clients["gemini"]) + + +# =================================================================== +# 3. switch_provider — endpoint validation +# =================================================================== + +class TestSwitchProvider(unittest.IsolatedAsyncioTestCase): + """Validate the switch_provider endpoint's guard logic.""" + + @patch("app.api.routes.provider.get_database") + @patch("app.api.routes.provider.get_llm_client") + @patch("app.api.routes.provider.chat_orchestrator") + async def test_rejects_unknown_persona_id(self, mock_orch, mock_get, mock_db): + mock_orch.personas = {"known": MagicMock()} + mock_get.return_value = MagicMock() + + cfg = UserLLMConfig( + mode="hybrid", default_backend="gemini", + persona_backends={"unknown_id": "gemini"}, + ) + with self.assertRaises(HTTPException) as ctx: + await switch_provider(cfg, _make_user()) + self.assertEqual(ctx.exception.status_code, 400) + self.assertIn("Unknown persona IDs", ctx.exception.detail) + + @patch("app.api.routes.provider.get_database") + @patch("app.api.routes.provider.get_llm_client") + @patch("app.api.routes.provider.chat_orchestrator") + async def test_accepts_known_persona_ids(self, mock_orch, mock_get, mock_db): + mock_orch.personas = {"a": MagicMock(), "b": MagicMock()} + mock_get.return_value = MagicMock() + mock_db.return_value.users.update_one = AsyncMock() + + cfg = UserLLMConfig( + mode="hybrid", default_backend="gemini", + persona_backends={"a": "gemini", "b": "gemini"}, + ) + result = await switch_provider(cfg, _make_user()) + self.assertEqual(result["message"], "LLM configuration updated") + + @patch("app.api.routes.provider.get_database") + @patch("app.api.routes.provider.get_llm_client") + @patch("app.api.routes.provider.chat_orchestrator") + async def test_rejects_unconfigured_default_backend(self, mock_orch, mock_get, mock_db): + mock_orch.personas = {} + mock_get.side_effect = ValueError("not configured") + + cfg = UserLLMConfig(mode="uniform", default_backend="ollama") + with self.assertRaises(HTTPException) as ctx: + await switch_provider(cfg, _make_user()) + self.assertEqual(ctx.exception.status_code, 400) + self.assertIn("not configured", ctx.exception.detail) + + @patch("app.api.routes.provider.get_database") + @patch("app.api.routes.provider.get_llm_client") + @patch("app.api.routes.provider.chat_orchestrator") + async def test_rejects_unconfigured_orchestrator_backend(self, mock_orch, mock_get, mock_db): + mock_orch.personas = {} + + def side_effect(backend): + if backend == "vllm": + raise ValueError("no vLLM endpoint") + return MagicMock() + mock_get.side_effect = side_effect + + cfg = UserLLMConfig( + mode="hybrid", default_backend="gemini", + orchestrator_backend="vllm", + ) + with self.assertRaises(HTTPException) as ctx: + await switch_provider(cfg, _make_user()) + self.assertEqual(ctx.exception.status_code, 400) + self.assertIn("not configured", ctx.exception.detail) + + @patch("app.api.routes.provider.get_database") + @patch("app.api.routes.provider.get_llm_client") + @patch("app.api.routes.provider.chat_orchestrator") + async def test_rejects_unconfigured_persona_backend(self, mock_orch, mock_get, mock_db): + mock_orch.personas = {"a": MagicMock()} + + def side_effect(backend): + if backend == "vllm": + raise ValueError("no vLLM endpoint") + return MagicMock() + mock_get.side_effect = side_effect + + cfg = UserLLMConfig( + mode="hybrid", default_backend="gemini", + persona_backends={"a": "vllm"}, + ) + with self.assertRaises(HTTPException) as ctx: + await switch_provider(cfg, _make_user()) + self.assertEqual(ctx.exception.status_code, 400) + self.assertIn("not configured", ctx.exception.detail) + + @patch("app.api.routes.provider.get_database") + @patch("app.api.routes.provider.get_llm_client") + @patch("app.api.routes.provider.chat_orchestrator") + async def test_checks_all_distinct_backends(self, mock_orch, mock_get, mock_db): + mock_orch.personas = {"a": MagicMock(), "b": MagicMock()} + mock_get.return_value = MagicMock() + mock_db.return_value.users.update_one = AsyncMock() + + cfg = UserLLMConfig( + mode="hybrid", default_backend="gemini", + orchestrator_backend="ollama", + persona_backends={"a": "vllm", "b": "gemini"}, + ) + await switch_provider(cfg, _make_user()) + + checked = {call.args[0] for call in mock_get.call_args_list} + self.assertEqual(checked, {"gemini", "ollama", "vllm"}) + + @patch("app.api.routes.provider.get_database") + @patch("app.api.routes.provider.get_llm_client") + @patch("app.api.routes.provider.chat_orchestrator") + async def test_persists_to_database(self, mock_orch, mock_get, mock_db): + mock_orch.personas = {} + mock_get.return_value = MagicMock() + mock_collection = MagicMock() + mock_collection.update_one = AsyncMock() + mock_db.return_value.users = mock_collection + + user = _make_user() + cfg = UserLLMConfig(mode="uniform", default_backend="ollama") + await switch_provider(cfg, user) + + mock_collection.update_one.assert_awaited_once() + call_args = mock_collection.update_one.call_args + self.assertEqual(call_args[0][0], {"_id": user.id}) + self.assertEqual( + call_args[0][1], + {"$set": {"llm_config": cfg.model_dump()}}, + ) + + @patch("app.api.routes.provider.get_database") + @patch("app.api.routes.provider.get_llm_client") + @patch("app.api.routes.provider.chat_orchestrator") + async def test_returns_updated_config(self, mock_orch, mock_get, mock_db): + mock_orch.personas = {"a": MagicMock()} + mock_get.return_value = MagicMock() + mock_db.return_value.users.update_one = AsyncMock() + + cfg = UserLLMConfig( + mode="hybrid", default_backend="gemini", + orchestrator_backend="ollama", + persona_backends={"a": "vllm"}, + ) + result = await switch_provider(cfg, _make_user()) + + self.assertEqual(result["message"], "LLM configuration updated") + self.assertEqual(result["llm_config"], cfg.model_dump()) + + +if __name__ == "__main__": + unittest.main() diff --git a/phd-advisor-frontend/src/components/AdvisorConfigPanel.js b/phd-advisor-frontend/src/components/AdvisorConfigPanel.js new file mode 100644 index 00000000..90f0d153 --- /dev/null +++ b/phd-advisor-frontend/src/components/AdvisorConfigPanel.js @@ -0,0 +1,193 @@ +import React, { useEffect, useMemo, useState } from 'react'; + +// Reusable per-advisor backend configuration panel. +// +// Used in two places: +// - Welcome-state "Advanced" expander on ChatPage +// - "Advisor Config" tab inside SettingsModal (lives on feat/UI-for-User-Account-updates; +// drop this component in once branches merge) +// +// Controlled component. Parent owns the config object and decides when to persist. +// +// Shape of `value`: +// { default_backend, orchestrator_backend, persona_backends: { [personaId]: backend } } +// +// The orchestrator and each advisor can be set to DEFAULT_BACKEND ("Default"), +// meaning "follow default_backend". Call stripDefaultBackends() before persisting: +// it drops those sentinels so the backend falls through to default_backend (and +// moves them automatically when the default changes). + +export const DEFAULT_BACKEND = '__default__'; + +export const stripDefaultBackends = (config) => { + if (!config) return config; + const source = config.persona_backends || {}; + const cleaned = {}; + for (const [id, backend] of Object.entries(source)) { + if (backend && backend !== DEFAULT_BACKEND) cleaned[id] = backend; + } + const orchestrator = + config.orchestrator_backend && config.orchestrator_backend !== DEFAULT_BACKEND + ? config.orchestrator_backend + : null; + return { ...config, orchestrator_backend: orchestrator, persona_backends: cleaned }; +}; + +const rowStyle = { + display: 'grid', gridTemplateColumns: '1fr 180px', alignItems: 'center', + gap: 12, padding: '10px 0', borderBottom: '1px solid var(--border-primary)', +}; + +const selectStyle = { + padding: '8px 10px', borderRadius: 8, border: '1px solid var(--border-primary)', + background: 'var(--bg-secondary)', color: 'var(--text-primary)', fontSize: 13.5, + width: '100%', colorScheme: 'light dark', +}; + +// Native dropdown option list ignores the setDefault(e.target.value)} + > + {availableBackends.map(b => )} + + + )} + + {!hideOrchestrator && ( +
+
+
Orchestrator
+
+ Routes user input across advisors. +
+
+ +
+ )} + + {personaIds.map((id) => { + const advisor = advisors[id]; + const locked = advisor?.backendLocked; + const personaValue = locked && advisor?.defaultBackend + ? advisor.defaultBackend + : (config.persona_backends?.[id] || DEFAULT_BACKEND); + return ( +
+
+
{advisor?.name || id}
+
+ {advisor?.role || id}{locked ? ' · backend locked' : ''} +
+
+ {locked ? ( +
+ {personaValue} +
+ ) : ( + + )} +
+ ); + })} + + ); +}; + +export default AdvisorConfigPanel; diff --git a/phd-advisor-frontend/src/components/AvatarPickerModal 2.js b/phd-advisor-frontend/src/components/AvatarPickerModal 2.js new file mode 100644 index 00000000..a7ac07a4 --- /dev/null +++ b/phd-advisor-frontend/src/components/AvatarPickerModal 2.js @@ -0,0 +1,75 @@ +import React from 'react'; +import ReactDOM from 'react-dom'; +import { X } from 'lucide-react'; +import { useAppConfig } from '../contexts/AppConfigContext'; + +const API = process.env.REACT_APP_API_URL || ''; + +const BUNDLED = [ + 'advisor1.png','advisor2.png','advisor3.png','advisor4.png', + 'advisor5.png','advisor6.png','advisor7.png', +]; + +const overlay = { + position: 'fixed', inset: 0, background: 'rgba(0,0,0,0.5)', + display: 'flex', alignItems: 'center', justifyContent: 'center', zIndex: 1000, +}; + +const modal = { + background: 'var(--bg-primary)', borderRadius: 16, padding: 24, width: 480, + maxWidth: '95vw', maxHeight: '85vh', overflowY: 'auto', + boxShadow: 'var(--shadow-xl)', +}; + +const AvatarPickerModal = ({ advisorId, advisorName, onClose }) => { + const { setAdvisorAvatar } = useAppConfig(); + + const select = (url) => { + setAdvisorAvatar(advisorId, url || ''); + onClose(); + }; + + return ReactDOM.createPortal( +
e.target === e.currentTarget && onClose()} onMouseDown={(e) => e.stopPropagation()}> +
+
+

+ Choose Avatar — {advisorName} +

+ +
+ +

Pre-made Avatars

+
+ {BUNDLED.map((file) => ( + {file} select(`${API}/api/avatars/bundled/${file}`)} + style={{ width: '100%', aspectRatio: '1', borderRadius: '50%', objectFit: 'cover', cursor: 'pointer', border: '2px solid transparent', transition: 'border-color 0.15s' }} + onMouseEnter={e => e.target.style.borderColor = 'var(--accent-primary)'} + onMouseLeave={e => e.target.style.borderColor = 'transparent'} + /> + ))} +
+ + +
+
, + document.body + ); +}; + +export default AvatarPickerModal; diff --git a/phd-advisor-frontend/src/components/HybridConfigModal.js b/phd-advisor-frontend/src/components/HybridConfigModal.js new file mode 100644 index 00000000..9da0629d --- /dev/null +++ b/phd-advisor-frontend/src/components/HybridConfigModal.js @@ -0,0 +1,95 @@ +import React, { useMemo, useState } from 'react'; +import ReactDOM from 'react-dom'; +import { X, Layers } from 'lucide-react'; +import AdvisorConfigPanel, { DEFAULT_BACKEND, stripDefaultBackends } from './AdvisorConfigPanel'; + +const overlay = { + position: 'fixed', inset: 0, background: 'rgba(0,0,0,0.5)', + display: 'flex', alignItems: 'center', justifyContent: 'center', zIndex: 1000, +}; + +const modal = { + background: 'var(--bg-primary)', borderRadius: 16, padding: 24, width: 560, + maxWidth: '95vw', maxHeight: '85vh', overflowY: 'auto', + boxShadow: 'var(--shadow-xl)', color: 'var(--text-primary)', +}; + +const seedConfig = (initialConfig, personaIds, availableBackends) => { + const fallback = initialConfig?.default_backend || availableBackends[0]; + const seedPersonas = initialConfig?.persona_backends || {}; + const personas = {}; + for (const id of personaIds) { + personas[id] = seedPersonas[id] || DEFAULT_BACKEND; + } + return { + default_backend: fallback, + orchestrator_backend: initialConfig?.orchestrator_backend || DEFAULT_BACKEND, + persona_backends: personas, + }; +}; + +const HybridConfigModal = ({ + advisors, + availableBackends, + initialConfig, + isSaving, + onSubmit, + onClose, +}) => { + const personaIds = useMemo(() => Object.keys(advisors || {}), [advisors]); + const [config, setConfig] = useState(() => seedConfig(initialConfig, personaIds, availableBackends)); + + return ReactDOM.createPortal( +
e.target === e.currentTarget && onClose()}> +
+
+
+ +

Hybrid LLM Configuration

+
+ +
+ + + +
+ + +
+
+
, + document.body + ); +}; + +export default HybridConfigModal; diff --git a/phd-advisor-frontend/src/components/ProviderDropdown.js b/phd-advisor-frontend/src/components/ProviderDropdown.js index bbadf8f8..d2a6fcf3 100644 --- a/phd-advisor-frontend/src/components/ProviderDropdown.js +++ b/phd-advisor-frontend/src/components/ProviderDropdown.js @@ -1,9 +1,9 @@ // src/components/ProviderDropdown.js import React, { useState, useRef, useEffect } from 'react'; -import { ChevronDown, Cpu, Cloud, Server, Loader2 } from 'lucide-react'; +import { ChevronDown, Cpu, Cloud, Server, Loader2, Layers, Settings2 } from 'lucide-react'; import { useTheme } from '../contexts/ThemeContext'; -const ProviderDropdown = ({ currentProvider, onProviderChange, isLoading = false }) => { +const ProviderDropdown = ({ currentProvider, onProviderChange, isLoading = false, onConfigureHybrid }) => { const [isOpen, setIsOpen] = useState(false); const dropdownRef = useRef(null); const { isDark } = useTheme(); @@ -29,6 +29,13 @@ const ProviderDropdown = ({ currentProvider, onProviderChange, isLoading = false description: 'vLLM inference endpoint', icon: Server, badge: 'API' + }, + { + id: 'hybrid', + name: 'Hybrid', + description: 'Per-advisor backend selection', + icon: Layers, + badge: 'Mixed' } ]; @@ -49,12 +56,25 @@ const ProviderDropdown = ({ currentProvider, onProviderChange, isLoading = false }, []); const handleProviderSelect = (providerId) => { - if (providerId !== currentProvider && !isLoading) { + if (isLoading) return; + if (providerId === 'hybrid') { + onProviderChange(providerId); + if (onConfigureHybrid) onConfigureHybrid(); + setIsOpen(false); + return; + } + if (providerId !== currentProvider) { onProviderChange(providerId); setIsOpen(false); } }; + const handleConfigureClick = (event) => { + event.stopPropagation(); + if (onConfigureHybrid) onConfigureHybrid(); + setIsOpen(false); + }; + const toggleDropdown = () => { if (!isLoading) { setIsOpen(!isOpen); @@ -111,6 +131,16 @@ const ProviderDropdown = ({ currentProvider, onProviderChange, isLoading = false {isSelected && (
)} + {provider.id === 'hybrid' && isSelected && onConfigureHybrid && ( + + )} ); })} diff --git a/phd-advisor-frontend/src/components/SettingsModal.js b/phd-advisor-frontend/src/components/SettingsModal.js index e05cafc3..c8815a4f 100644 --- a/phd-advisor-frontend/src/components/SettingsModal.js +++ b/phd-advisor-frontend/src/components/SettingsModal.js @@ -1,8 +1,9 @@ -import React, { useState, useRef, useEffect } from 'react'; +import React, { useEffect, useMemo, useRef, useState } from 'react'; import ReactDOM from 'react-dom'; -import { X, User as UserIcon, Lock, Trash2, AlertTriangle, Users } from 'lucide-react'; +import { X, User as UserIcon, Lock, Trash2, AlertTriangle, Users, Layers } from 'lucide-react'; import Toggle from './Toggle'; import { useAppConfig } from '../contexts/AppConfigContext'; +import AdvisorConfigPanel, { DEFAULT_BACKEND, stripDefaultBackends } from './AdvisorConfigPanel'; const overlay = { position: 'fixed', inset: 0, background: 'rgba(0,0,0,0.5)', @@ -10,7 +11,7 @@ const overlay = { }; const modal = { - background: 'var(--bg-primary)', borderRadius: 16, padding: 0, width: 560, + background: 'var(--bg-primary)', borderRadius: 16, padding: 0, width: 640, maxWidth: '95vw', maxHeight: '85vh', overflow: 'hidden', boxShadow: 'var(--shadow-xl)', display: 'flex', flexDirection: 'column', }; @@ -67,27 +68,31 @@ const miniBtn = { fontFamily: 'inherit', }; -const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) => { +const SettingsModal = ({ + user, + authToken, + onUserUpdate, + onSignOut, + onClose, + advisors, + availableBackends, + llmConfig, + isSaving, + onSubmitConfig, +}) => { const [activeTab, setActiveTab] = useState('profile'); const { - advisors, isAdvisorEnabled, setAdvisorEnabled, setAllAdvisorsEnabled, hydrateAdvisorPreferences, } = useAppConfig(); - // Reconcile with the backend whenever the user opens Settings (covers fresh - // logins and changes made on another device). useEffect(() => { hydrateAdvisorPreferences(); // eslint-disable-next-line react-hooks/exhaustive-deps }, []); - // Track where the mouse went DOWN so we don't close the modal when a user - // drags to select text inside an input and the mouseup happens outside the modal. - // (React's onClick fires on the common ancestor of down+up, which can be the - // overlay itself — causing accidental close on text selection.) const mouseDownOnOverlay = useRef(false); const handleOverlayMouseDown = (e) => { mouseDownOnOverlay.current = e.target === e.currentTarget; @@ -110,6 +115,19 @@ const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) => const [message, setMessage] = useState(null); const [isSubmitting, setIsSubmitting] = useState(false); + const personaIds = useMemo(() => Object.keys(advisors || {}), [advisors]); + const [modelDraft, setModelDraft] = useState(() => { + const fallback = llmConfig?.default_backend || availableBackends?.[0]; + const seed = llmConfig?.persona_backends || {}; + const personas = {}; + for (const id of personaIds) personas[id] = seed[id] || DEFAULT_BACKEND; + return { + default_backend: fallback, + orchestrator_backend: llmConfig?.orchestrator_backend || DEFAULT_BACKEND, + persona_backends: personas, + }; + }); + const apiUrl = process.env.REACT_APP_API_URL; const extractError = (data, fallback) => { @@ -230,6 +248,11 @@ const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) => } }; + const handleModelSave = async () => { + if (!onSubmitConfig) return; + await onSubmitConfig(stripDefaultBackends(modelDraft)); + }; + const messageStyle = (type) => ({ padding: '10px 12px', borderRadius: 8, marginBottom: 16, fontSize: 13, background: type === 'error' @@ -255,9 +278,6 @@ const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) => const enabledCount = advisorEntries.filter(([id]) => isAdvisorEnabled(id)).length; const setAll = (enabled) => setAllAdvisorsEnabled(enabled); - // pendingDisable: { type: 'all' } | { type: 'single', id } — set when the - // user is about to leave zero advisors enabled. Confirming runs the action; - // "Go back" leaves state untouched. const [pendingDisable, setPendingDisable] = useState(null); const handleDisableAllClick = () => { @@ -302,6 +322,9 @@ const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) => + @@ -423,6 +446,43 @@ const SettingsModal = ({ user, authToken, onUserUpdate, onSignOut, onClose }) => )} + {activeTab === 'model' && ( + <> + +
+ + +
+ + )} + {activeTab === 'danger' && (
{ const { config } = useAppConfig(); const canvasLabel = config?.app?.title ? `${config.app.title} Canvas` : 'Canvas'; @@ -245,7 +246,10 @@ const Sidebar = ({
+ ); + })} +
+ + + + {advancedOpen && ( +
+ +
+ +
+
+ )} +
+ ); +}; + +export default WelcomeModelPicker; diff --git a/phd-advisor-frontend/src/contexts/AppConfigContext.js b/phd-advisor-frontend/src/contexts/AppConfigContext.js index d205e089..97c4ca6e 100644 --- a/phd-advisor-frontend/src/contexts/AppConfigContext.js +++ b/phd-advisor-frontend/src/contexts/AppConfigContext.js @@ -61,6 +61,8 @@ const buildAdvisors = (personaItems, overrides = {}) => { darkBgColor: p.dark_bg_color || '#374151', icon: resolveIcon(isIcon ? image.replace('icon://', '') : null), avatarUrl, + defaultBackend: p.default_backend || null, + backendLocked: Boolean(p.brainforge || p.backend_locked), }; } return advisors; diff --git a/phd-advisor-frontend/src/pages/ChatPage.js b/phd-advisor-frontend/src/pages/ChatPage.js index 053811f8..6bbb6464 100644 --- a/phd-advisor-frontend/src/pages/ChatPage.js +++ b/phd-advisor-frontend/src/pages/ChatPage.js @@ -7,7 +7,7 @@ import MessageBubble from '../components/MessageBubble'; import ThinkingIndicator from '../components/ThinkingIndicator'; import SuggestionsPanel from '../components/SuggestionsPanel'; import ThemeToggle from '../components/ThemeToggle'; -import ProviderDropdown from '../components/ProviderDropdown'; +import SettingsModal from '../components/SettingsModal'; import ExportButton from '../components/ExportButton'; import Sidebar from '../components/Sidebar'; import { useAppConfig } from '../contexts/AppConfigContext'; @@ -25,8 +25,15 @@ const ChatPage = ({ user, authToken, onNavigateToHome, onNavigateToCanvas, onSig const [thinkingAdvisors, setThinkingAdvisors] = useState([]); const [collectedInfo, setCollectedInfo] = useState({}); const [replyingTo, setReplyingTo] = useState(null); - const [currentProvider, setCurrentProvider] = useState('gemini'); + const [llmConfig, setLlmConfig] = useState({ + mode: 'uniform', + default_backend: null, + orchestrator_backend: null, + persona_backends: null, + }); + const [availableBackends, setAvailableBackends] = useState([]); const [isProviderSwitching, setIsProviderSwitching] = useState(false); + const [isSettingsOpen, setIsSettingsOpen] = useState(false); const [uploadedDocuments, setUploadedDocuments] = useState([]); const messagesEndRef = useRef(null); const { isDark } = useTheme(); @@ -53,16 +60,18 @@ const ChatPage = ({ user, authToken, onNavigateToHome, onNavigateToCanvas, onSig }, [messages, thinkingAdvisors]); useEffect(() => { - fetchCurrentProvider(); - }, []); + if (authToken) fetchCurrentProvider(); + }, [authToken]); const fetchCurrentProvider = async () => { try { - const response = await fetch(`${process.env.REACT_APP_API_URL}/current-provider`); + const response = await fetch(`${process.env.REACT_APP_API_URL}/current-provider`, { + headers: { 'Authorization': `Bearer ${authToken}` }, + }); if (response.ok) { const data = await response.json(); - setCurrentProvider(data.current_provider); - console.log('Loaded provider:', data.current_provider, 'Available:', data.available_providers); + if (data.llm_config) setLlmConfig(data.llm_config); + if (Array.isArray(data.available_backends)) setAvailableBackends(data.available_backends); } } catch (error) { console.error('Error fetching current provider:', error); @@ -71,57 +80,67 @@ const ChatPage = ({ user, authToken, onNavigateToHome, onNavigateToCanvas, onSig - const handleProviderSwitch = async (newProvider) => { - if (newProvider === currentProvider || isProviderSwitching) return; - + const submitProviderConfig = async (payload, label) => { setIsProviderSwitching(true); try { const response = await fetch(`${process.env.REACT_APP_API_URL}/switch-provider`, { method: 'POST', headers: { 'Content-Type': 'application/json', + 'Authorization': `Bearer ${authToken}`, }, - body: JSON.stringify({ - provider: newProvider - }), + body: JSON.stringify(payload), }); if (response.ok) { const data = await response.json(); - setCurrentProvider(newProvider); - - const switchMessage = { + if (data.llm_config) { + setLlmConfig(data.llm_config); + } else { + setLlmConfig(payload); + } + + setMessages(prev => [...prev, { id: generateMessageId(), type: 'system', - content: `✨ Switched to ${newProvider.charAt(0).toUpperCase() + newProvider.slice(1)} provider. Your advisors are now ready with the new AI model.`, + content: `✨ Switched to ${label}. Your advisors are now ready with the new configuration.`, timestamp: new Date() - }; - setMessages(prev => [...prev, switchMessage]); - } else { - const error = await response.json(); - console.error('Failed to switch provider:', error); - const errorMessage = { - id: generateMessageId(), - type: 'error', - content: `Failed to switch to ${newProvider}: ${error.detail || 'Unknown error'}`, - timestamp: new Date() - }; - setMessages(prev => [...prev, errorMessage]); + }]); + return true; } + + const error = await response.json().catch(() => ({})); + console.error('Failed to switch provider:', error); + setMessages(prev => [...prev, { + id: generateMessageId(), + type: 'error', + content: `Failed to switch to ${label}: ${error.detail || 'Unknown error'}`, + timestamp: new Date() + }]); + return false; } catch (error) { console.error('Error switching provider:', error); - const errorMessage = { + setMessages(prev => [...prev, { id: generateMessageId(), type: 'error', - content: `Error switching to ${newProvider}. Please try again.`, + content: `Error switching to ${label}. Please try again.`, timestamp: new Date() - }; - setMessages(prev => [...prev, errorMessage]); + }]); + return false; } finally { setIsProviderSwitching(false); } }; + const handleHybridSubmit = async (hybridConfig) => { + const ok = await submitProviderConfig( + { mode: 'hybrid', ...hybridConfig }, + 'Hybrid configuration' + ); + if (ok) setIsSettingsOpen(false); + return ok; + }; + const generateMessageId = () => { return Date.now().toString() + Math.random().toString(36).substr(2, 9); }; @@ -517,6 +536,7 @@ const handleNewChat = async (sessionId = null) => { method: 'POST', headers: { 'Content-Type': 'application/json', + 'Authorization': `Bearer ${authToken}`, }, body: JSON.stringify({ user_input: inputMessage, @@ -594,6 +614,7 @@ const handleNewChat = async (sessionId = null) => { method: 'POST', headers: { 'Content-Type': 'application/json', + 'Authorization': `Bearer ${authToken}`, }, body: JSON.stringify({ user_input: expandPrompt, @@ -731,6 +752,7 @@ const handleNewChat = async (sessionId = null) => { onMobileToggle={setIsMobileMenuOpen} onNavigateToCanvas={onNavigateToCanvas} refreshTrigger={sidebarRefreshTrigger} + onOpenSettings={() => setIsSettingsOpen(true)} />
@@ -781,13 +803,6 @@ const handleNewChat = async (sessionId = null) => { authToken={authToken} /> - {/* Provider Dropdown */} - - {/* Theme Toggle */} @@ -960,6 +975,21 @@ const handleNewChat = async (sessionId = null) => {
+ + {isSettingsOpen && ( + setIsSettingsOpen(false)} + /> + )} ); diff --git a/phd-advisor-frontend/src/styles/ChatPage.css b/phd-advisor-frontend/src/styles/ChatPage.css index a74bdd2f..4509e691 100644 --- a/phd-advisor-frontend/src/styles/ChatPage.css +++ b/phd-advisor-frontend/src/styles/ChatPage.css @@ -911,6 +911,12 @@ .provider-badge.gemini { color: #4285f4; } .provider-badge.ollama { color: #10b981; } .provider-badge.vllm { color: #ef4444; } +.provider-badge.hybrid { color: #a855f7; } + +[data-theme="dark"] .provider-badge.hybrid, +[data-theme="dark"] .provider-option-badge.hybrid { + color: #ffffff; +} .clarification-message-container { display: flex; @@ -1147,6 +1153,34 @@ color: #ef4444; } +.provider-option-badge.hybrid { + background: rgba(168, 85, 247, 0.15); + color: #a855f7; +} + +.provider-option-configure { + margin-left: 8px; + background: transparent; + border: 1px solid var(--border-color, #d1d5db); + border-radius: 6px; + padding: 4px 6px; + cursor: pointer; + display: inline-flex; + align-items: center; + color: var(--text-secondary, #6b7280); +} + +.provider-option-configure:hover { + background: var(--accent-primary); + color: #fff; + border-color: var(--accent-primary); +} + +[data-theme="dark"] .provider-option-configure { + color: #ffffff; + border-color: rgba(255, 255, 255, 0.25); +} + .provider-option-description { font-size: 11px; color: var(--text-tertiary); diff --git a/phd_config.yaml b/phd_config.yaml index 121605cd..bef68cf0 100644 --- a/phd_config.yaml +++ b/phd_config.yaml @@ -173,11 +173,15 @@ mongodb: database_name: "phd_advisor" llm: + default_backend: gemini gemini: + enabled: true model: "gemini-2.5-flash" ollama: + enabled: false model: "llama3.2:1b" vllm: + enabled: true api_url: https://rtx6000blackwell-1.neonaiservices2.com/vllm0 brainforge: api_url: https://hana.neonaialpha.com