diff --git a/backend/app.py b/backend/app.py index d439d73..b7beaab 100644 --- a/backend/app.py +++ b/backend/app.py @@ -7,7 +7,7 @@ from backend.AniZenithExchange import AniZenithRequest, AniZenithResponse from prometheus.prometheus_middleware import PrometheusMiddleware, prometheus_router from backend.validation_utils import validate_anizenith_request -from backend.backend_utils import chat_with_llm +from backend.inference_manager import InferenceManager from starlette.middleware.sessions import SessionMiddleware import logging @@ -29,6 +29,9 @@ app.include_router(prometheus_router) #app.include_router(auth_router) +# Initialize Backend Inference Manager +inference_manager = InferenceManager() + # ┌───────────────────────────────────────────────┐ # │ BACKEND API ENDPOINTS │ # └───────────────────────────────────────────────┘ @@ -47,8 +50,9 @@ async def handle_chat_request(request: AniZenithRequest): # Chat with LLM using the messages in the request assistant_message = "" - for streamed_response in chat_with_llm(request.messages, request.use_local): - assistant_message = streamed_response + # TODO: Replace null user with real user / session ID + for streamed_response in inference_manager.chat(request.messages, "null-user"): + assistant_message += streamed_response # Construct an AniZenithResponse based on Assistant Message # Copy the old set of messages diff --git a/backend/backend_utils.py b/backend/backend_utils.py deleted file mode 100644 index 1be8d38..0000000 --- a/backend/backend_utils.py +++ /dev/null @@ -1,136 +0,0 @@ -from typing import List, Dict -from huggingface_hub import InferenceClient -from transformers import pipeline -from backend.constants import * -from backend.prometheus_utils import * -from dotenv import load_dotenv -import os -import json -from backend.mongo.AniZenithMongoClient import AniZenithMongoClient -from backend.mongo.AniZenithVectorSearchResult import AniZenithVectorSearchResult - -# Load all Environment Variables -load_dotenv() -HF_TOKEN = os.getenv('HF_TOKEN') - -# Init AniZenithMongoClient -CONN_STRING = os.getenv("ATLAS_URI") -DB_CLIENT = AniZenithMongoClient(CONN_STRING) - - -# Load the Local Pipeline Model at App Startup -PIPELINE_LOCAL_MODEL = pipeline(task='text-generation', - model='Qwen/Qwen3-0.6B', - max_new_tokens=MAX_NEW_TOKENS, - temperature=TEMPERATURE, - do_sample=False, - top_p=TOP_P) - - -# TODO: Make this Method Async Later -def chat_with_llm(messages: List[Dict[str, str]], use_local_model: bool): - - # TODO: Replace with actual IDs from a config - model = "Qwen/Qwen3-0.6B" if use_local_model else "openai/gpt-oss-20b" - with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=model, stage="full_pipeline").time(): - - # Use the last message as the user message (it should always be a user message) - user_query = messages[-1]['content'] - - # Retrieve relevant results from DB using vector search - with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=model, stage="db_retrieval").time(): - - # Perform Vector Search using user query - # This method returns a List[AniZenithVectorSearchResult] - recommendations: List[AniZenithVectorSearchResult] = DB_CLIENT.perform_vector_search(user_query) - - # Convert the list of vector search objects into a list of dicts - # model_dump() is a special Pydantic method to generate a dict representation of any Pydantic object - recommendations_dict = [recommendation.model_dump() for recommendation in recommendations] - - # Serialize to a JSON string - recommendations_string = json.dumps(recommendations_dict, indent = 4) - - # Query the model - with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=model, stage="model_generation").time(): - for result in query_model(messages, use_local_model, recommendations_string): - yield result - - -# TODO: Make this method Async later -def query_model(messages: List[Dict[str, str]], use_local_model: bool, recommendations_string: str): - # --- Determine System Prompt --- - # Start with the Fixed System Prompt - system_prompt = SYSTEM_PROMPT - - # Append recommendation string data to the System Prompt if it exists - if recommendations_string: - system_prompt += "\nRECOMMENDATION JSON:" + f"\n{recommendations_string}" - - # Add the System Prompt to the Input Messages to the LLM - input_messages = [{"role": "system", "content": system_prompt}] - # Add the rest of the messages - input_messages.extend(messages) - - # --- Determine which model to use (local or external) --- - # Constants for logging - response = "" - input_token_count = 0 - output_token_count = 0 - - # --- Local Model --- - if use_local_model: - # Uses pipeline from transformers library - response = PIPELINE_LOCAL_MODEL(input_messages) - - # Get the response from the local model, parse it, and yield - generated_text = response[0]['generated_text'][-1]['content'].split('')[-1].strip() - yield generated_text - - # Log token counts (there is no clean way to do this besides re-tokenizing) - tokenizer = PIPELINE_LOCAL_MODEL.tokenizer - formatted_input = tokenizer.apply_chat_template( - input_messages, - tokenize=False, - add_generation_prompt=True - ) - input_tokens = tokenizer.encode(formatted_input) - generated_tokens = tokenizer.encode(generated_text) - input_token_count = len(input_tokens) - output_token_count = len(generated_tokens) - - # --- Non-local Model (Use InferenceClient) --- - else: - client = InferenceClient( - token=HF_TOKEN, - model="openai/gpt-oss-20b", - ) - - # Stream inference client output and yield the text chunk - usage = None - for chunk in client.chat_completion( - messages=input_messages, - max_tokens=MAX_NEW_TOKENS, - stream=True, - temperature=TEMPERATURE, - top_p=TOP_P, - ): - if chunk.choices and chunk.choices[0].delta.content: - token = chunk.choices[0].delta.content - response += token - yield response - # Add usage data for logging - if hasattr(chunk, 'usage') and chunk.usage: - usage = chunk.usage - - if usage: - input_token_count = usage.prompt_tokens - output_token_count = usage.completion_tokens - - # Log the model usage output - # TODO: Record the specific model ID once the backend is refactored - model = "Qwen/Qwen3-0.6B" if use_local_model else "openai/gpt-oss-20b" - # TODO: Use real Inference Manager / session ID - observe_user_message(user_id="0", user_message=messages[-1]['content'], token_count=input_token_count, model=model) - observe_bot_message(user_id="0", bot_message=response, token_count=output_token_count, model=model) - diff --git a/backend/constants.py b/backend/constants.py index 7978a06..ec608fd 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -10,6 +10,10 @@ 3. Ask the user to provide their favorite genre(s) for Anime Recommendations """ +RECOMMENDED_DOCS_PREAMBLE = f""" +Here are the recommendation system's top shows:\n +""" + MAX_NEW_TOKENS = 2048 TEMPERATURE = 0.7 TOP_P = 0.7 \ No newline at end of file diff --git a/backend/inference_manager.py b/backend/inference_manager.py new file mode 100644 index 0000000..5f1086a --- /dev/null +++ b/backend/inference_manager.py @@ -0,0 +1,124 @@ +import json +import os +import time +from typing import List, Dict + +from dotenv import load_dotenv + +from backend.mongo.AniZenithMongoClient import AniZenithMongoClient +from backend.mongo.AniZenithVectorSearchResult import AniZenithVectorSearchResult +from backend.prometheus_utils import * +from backend.models import HFInferenceClientModel, HFLocalModel, Model +from backend.reranker import AniZenithReranker +from constants import * + +load_dotenv(".env") +# TODO: Move these to config management system +MONGO_CONN_STRING = os.getenv("ATLAS_URI") +VECTOR_SEARCH_LIMIT = 20 +RERANK_LIMIT = 5 +MODEL_DOWNTIME_SECONDS = 120 + +local_model_id = "Qwen/Qwen3-0.6B" +external_model_id = "openai/gpt-oss-20b" + + +class InferenceManager: + + def __init__(self): + # Initialize Models for use in descending order of importance + self.models: List[Model] = [HFInferenceClientModel(external_model_id), HFLocalModel(local_model_id)] + self.current_model_idx = 0 # Current model idx being used + self.model_available_at = [0.0 for _ in self.models] # Controls fallback timer in case error occurs + + # Load a DB Client Instance + self.db_client = AniZenithMongoClient(MONGO_CONN_STRING) + + # Load a Reranker model instance + self.reranker = AniZenithReranker() + + # Gets the current most prioritized model for use + def get_best_model(self) -> Model: + now = time.time() + + for i, model in enumerate(self.models): + if now >= self.model_available_at[i]: + self.current_model_idx = i + return model + + # If all models are cooling down, throw error models not available + # TODO: Make custom exception + raise Exception("No model available") + + + def chat(self, messages: List[Dict[str, str]], user_id: str = None): + """ + Enhanced inference chat function with retrieval and reranking. + Steps: + 1. Retrieve relevant documents from MongoDB + 2. Rerank them with the reranker + 3. Build LLM messages prompt + 4. Stream the model output + """ + # TODO: Make this an agentic framework using LangChain + # TODO: Use user_id in logging + # TODO: Add queue system to make blocking better + # TODO: Replace all print statements with logging + current_model = self.get_best_model() + with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="full_pipeline").time(): + # 1) Retrieve results from DB Client + with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="db_retrieval").time(): + # Use the last message as the user message (it should always be a user message) + user_query = messages[-1]['content'] + retrieved_docs: List[AniZenithVectorSearchResult] = self.db_client.perform_vector_search(user_query, limit=VECTOR_SEARCH_LIMIT) + print(f"Retrieved Docs: ({len(retrieved_docs)}) relevant docs") + + # 2) Rerank results using the reranker based on document info and user query + with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="reranking").time(): + reranked_docs: List[AniZenithVectorSearchResult] = self.reranker.rerank(user_query, retrieved_docs, limit=RERANK_LIMIT) + print(f"Reranked Docs: ({len(reranked_docs)})") + + # 3) Construct system prompt with recommended docs + system_prompt = self._build_system_prompt(reranked_docs) + + # 4) Insert system prompt into messages + messages.insert(0, {"role": "system", "content": system_prompt}) + print("Completed System Prompt Building") + + # 5) Stream output of the model using the stream method + output = "" + with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="model_generation").time(): + try: + for token in current_model.stream(messages): + output += token + yield token + + except Exception as e: + # TODO: Log error + print(f"Model Error: {e}") + # Yield model terminated to user + yield "" + # Sets the model's next available time to current time in seconds + downtime + self.model_available_at[self.current_model_idx] = time.time() + MODEL_DOWNTIME_SECONDS + + # Record Usage Metrics + print(f"Streamed output: {output}") + usage = current_model.get_usage() + observe_user_message(user_id, user_query, usage.input_token_count, current_model.get_name()) + observe_bot_message(user_id, output, usage.output_token_count, current_model.get_name()) + + def _build_system_prompt(self, recommendations: List[AniZenithVectorSearchResult]) -> str: + lines = [] + + # Add base system prompt + lines.append(SYSTEM_PROMPT) + lines.append(RECOMMENDED_DOCS_PREAMBLE) + + # Add recommendation docs + # model_dump() is a special Pydantic method to generate a dict representation of any Pydantic object + # Dumps JSON as string with indent + recommendations = [json.dumps(recommendation.model_dump(), indent=4) for recommendation in recommendations] + recommendation_string = "\n\n".join(recommendations) if recommendations else "No good recommendations found." + lines.append(recommendation_string) + + return "\n".join(lines) diff --git a/backend/models.py b/backend/models.py new file mode 100644 index 0000000..2a54ab2 --- /dev/null +++ b/backend/models.py @@ -0,0 +1,178 @@ +import os +from abc import ABC, abstractmethod +from threading import Thread +from typing import Iterator, Dict, Any, List + +import torch +from huggingface_hub import InferenceClient +from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer +from pydantic import BaseModel + +from backend.constants import MAX_NEW_TOKENS, TEMPERATURE, TOP_P + +# TODO: Move to config management system +HF_TOKEN = os.getenv('HF_TOKEN') + +# Usage statistics class to enforce for logging +class ModelUsageStatistics(BaseModel): + model_name: str + input_token_count: int + output_token_count: int + +class Model(ABC): + """ + Abstract base model. + Enforces streaming + usage stats. + """ + def __init__(self, name: str): + self._usage: Dict[str, Any] = {} + self.name = name + + # Each subclass must implement streaming + @abstractmethod + def stream(self, messages: List[Dict[str, str]]) -> Iterator[str]: + """Streams model response as a string generator""" + pass + + # Each subclass must define usage stats + @abstractmethod + def get_usage(self) -> ModelUsageStatistics: + """Returns a ModelUsageStatistics with usage statistics for the model""" + pass + + def get_name(self): + return self.name + + # Generate runs stream and accumulates, then returns + def generate(self, messages: List[Dict[str, str]]) -> Dict[str, Any]: + """ + Runs stream() internally and accumulates output + Returns final text + usage. + """ + output = [] + + for chunk in self.stream(messages): + output.append(chunk) + + result_text = "".join(output) + + return { + "generated_text": result_text, + "usage": self.get_usage() + } + +class HFLocalModel(Model): + def __init__(self, model_id: str): + super().__init__(model_id) + + # Load local model and tokenizer (use efficient model params) + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + self.model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float16) + self.model.eval() + + self._usage_data = None + + def stream(self, messages: List[Dict[str, str]]): + self._usage_data = None + # Apply chat template & tokenize input + inputs = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + return_tensors="pt", + add_generation_prompt=True + ).to(self.model.device) + + # Initialize streamer + streamer = TextIteratorStreamer( + self.tokenizer, + skip_prompt=True, + skip_special_tokens=True + ) + + generation_err = None + def generate(): + # Ensure no gradients + with torch.no_grad(): + nonlocal generation_err + try: + self.model.generate(input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'], + max_new_tokens=MAX_NEW_TOKENS, + do_sample=True, + temperature=TEMPERATURE, + top_p=TOP_P, + use_cache=True, # Use KV Cache + streamer=streamer # Adds the text streamer to capture output callback + ) + except Exception as e: + # Stop streamer and propagate error + generation_err = e + streamer.end() + + # Start another thread to run generation for streaming + thread = Thread(target=generate) + thread.start() + + # Accumulate usage + input_token_count = inputs['input_ids'].shape[-1] # (B x input_tokens_len) + output_token_count = 0 + + # Streamer obtains values sequentially by injecting into generate function in new thread + # Streamer outputs values as Generator for text here + for text in streamer: + yield text + output_token_count += 1 # Streamer executes every token event received + + # Clean up thread + thread.join() + + # Handle error after joining if it exists + if generation_err is not None: + raise generation_err + + # Add usage data + self._usage_data = {"input_token_count": input_token_count, "output_token_count": output_token_count} + + def get_usage(self): + return ModelUsageStatistics( + model_name=self.get_name(), + input_token_count=self._usage_data["input_token_count"], + output_token_count=self._usage_data["output_token_count"], + ) + + +class HFInferenceClientModel(Model): + def __init__(self, model_id: str): + super().__init__(model_id) + + self.client = InferenceClient( + model=model_id, + token=HF_TOKEN, + ) + + self._usage_data = None + + def stream(self, messages: List[Dict[str, str]]): + self._usage_data = None + + # Use built in InferenceClient chat completion + for chunk in self.client.chat_completion( + messages=messages, + max_tokens=MAX_NEW_TOKENS, + stream=True, + temperature=TEMPERATURE, + top_p=TOP_P, + ): + if chunk.choices and chunk.choices[0].delta.content: + token = chunk.choices[0].delta.content + yield token + # Add usage data for logging + if hasattr(chunk, 'usage') and chunk.usage: + self._usage_data = chunk.usage + + def get_usage(self): + return ModelUsageStatistics( + model_name=self.get_name(), + input_token_count=self._usage_data.prompt_tokens, + output_token_count=self._usage_data.completion_tokens, + ) \ No newline at end of file diff --git a/backend/reranker.py b/backend/reranker.py new file mode 100644 index 0000000..7a8b2df --- /dev/null +++ b/backend/reranker.py @@ -0,0 +1,13 @@ +from typing import List + +from backend.mongo.AniZenithVectorSearchResult import AniZenithVectorSearchResult + + +class AniZenithReranker: + + def __init__(self): + pass + + def rerank(self, user_query: str, results: List[AniZenithVectorSearchResult], limit=5): + # TODO: Add to this this framework + return results[:limit] \ No newline at end of file diff --git a/frontend/static/js/chat_utils.js b/frontend/static/js/chat_utils.js index 45e4cea..e01147d 100644 --- a/frontend/static/js/chat_utils.js +++ b/frontend/static/js/chat_utils.js @@ -96,7 +96,7 @@ export async function sendMessagesToBackend() { try { // If using local, detect and add additional timeout - const timeout = payload.use_local ? 180.0 : 5.0; + const timeout = payload.use_local ? 180.0 : 25.0; const response = await fetch("/proxy/anizenith/chat", { method: "POST", headers: { diff --git a/tests/test_chat_models.py b/tests/test_chat_models.py index 8eef7fa..565fa9c 100644 --- a/tests/test_chat_models.py +++ b/tests/test_chat_models.py @@ -1,12 +1,15 @@ import os import pytest -from backend.backend_utils import chat_with_llm -import backend.backend_utils as backend_utils +from backend.inference_manager import InferenceManager +# TODO: Need help re-integrating monkeypatch, leave until after config management refactor TEST_SYSTEM_MESSAGE = "You are a friendly chatbot." TEST_USER_MESSAGE = "Hello!" HF_TOKEN = os.getenv("HF_TOKEN") +@pytest.fixture(scope="module") +def get_manager(): + return InferenceManager() def test_HF_token_exists(): token = os.getenv("HF_TOKEN") @@ -14,23 +17,17 @@ def test_HF_token_exists(): assert len(token) > 1 -def test_local_model_runs(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(backend_utils, "SYSTEM_PROMPT", TEST_SYSTEM_MESSAGE) - use_local_model = True +def test_local_model_runs(get_manager): collected_result = "" - for result in chat_with_llm(messages=[{"role": "user", "content": TEST_USER_MESSAGE}], - use_local_model=use_local_model): - collected_result = result + for result in get_manager.chat(messages=[{"role":"user","content": TEST_USER_MESSAGE}]): + collected_result += result assert len(collected_result) > 0 -def test_external_model_runs(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(backend_utils, "SYSTEM_PROMPT", TEST_SYSTEM_MESSAGE) - use_local_model = False +def test_external_model_runs(get_manager): collected_result = "" - for result in chat_with_llm(messages=[{"role": "user", "content": TEST_USER_MESSAGE}], - use_local_model=use_local_model): + for result in get_manager.chat(messages=[{"role": "user", "content": TEST_USER_MESSAGE}]): collected_result = result assert len(collected_result) > 0