From e6ab55123b82a4550f429494c20180a35b550b52 Mon Sep 17 00:00:00 2001 From: ShafathZ Date: Sun, 12 Apr 2026 04:07:07 -0400 Subject: [PATCH 1/5] Constructed new easily scalable and modifiable backend structure --- backend/inference_manager.py | 106 ++++++++++++++++++++++++ backend/models.py | 152 +++++++++++++++++++++++++++++++++++ backend/reranker.py | 13 +++ 3 files changed, 271 insertions(+) create mode 100644 backend/inference_manager.py create mode 100644 backend/models.py create mode 100644 backend/reranker.py diff --git a/backend/inference_manager.py b/backend/inference_manager.py new file mode 100644 index 0000000..d22a216 --- /dev/null +++ b/backend/inference_manager.py @@ -0,0 +1,106 @@ +import os +import time +from typing import List, Dict + +from dotenv import load_dotenv + +from backend.mongo.AniZenithMongoClient import AniZenithMongoClient +from backend.mongo.AniZenithVectorSearchResult import AniZenithVectorSearchResult +from backend.prometheus_utils import * +from backend.models import HFInferenceClientModel, HFLocalModel, Model +from backend.reranker import AniZenithReranker +from constants import * + +load_dotenv() +# TODO: Move to config management system +MONGO_CONN_STRING = os.getenv("ATLAS_URI") + +VECTOR_SEARCH_LIMIT = 20 +RERANK_LIMIT = 5 + +MODEL_DOWNTIME_SECONDS = 120 + + +class InferenceManager: + + def __init__(self): + # Initialize Models for use in descending order of importance + self.models: List[Model] = [HFInferenceClientModel(), HFLocalModel()] + self.current_model_idx = 0 # Current model idx being used + self.model_available_at = [0.0 for _ in self.models] # Controls fallback timer in case error occurs + + # Load a DB Client Instance + self.db_client = AniZenithMongoClient(MONGO_CONN_STRING) + + # Load a Reranker model instance + self.reranker = AniZenithReranker() + + # Gets the current most prioritized model for use + def get_best_model(self) -> Model: + now = time.time() + + for i, model in enumerate(self.models): + if now >= self.model_available_at[i]: + self.current_model_idx = i + return model + + # If all models are cooling down, throw error models not available + # TODO: Make custom exception + raise Exception("No model available") + + + def chat(self, messages: List[Dict[str, str]], user_id: str = None): + """ + Enhanced inference chat function with retrieval and reranking. + Steps: + 1. Retrieve relevant documents from MongoDB + 2. Rerank them with the reranker + 3. Build LLM messages prompt + 4. Stream the model output + """ + # TODO: Make this an agentic framework using LangChain + # TODO: Use user_id in logging + # TODO: Add queue system to make blocking better + current_model = self.get_best_model() + with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="full_pipeline").time(): + # 1) Retrieve results from DB Client + with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="db_retrieval").time(): + # Use the last message as the user message (it should always be a user message) + user_query = messages[-1]['content'] + retrieved_docs: List[AniZenithVectorSearchResult] = self.db_client.perform_vector_search(user_query, limit=VECTOR_SEARCH_LIMIT) + + # 2) Rerank results using the reranker based on document info and user query + with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="reranker").time(): + recommended_docs: List[AniZenithVectorSearchResult] = self.reranker.rerank(user_query, retrieved_docs, limit=RERANK_LIMIT) + + # 3) Construct system prompt with recommended docs + system_prompt = self._build_system_prompt(recommended_docs) + + # 4) Insert system prompt into messages + messages.insert(0, {"role": "system", "content": system_prompt}) + + # 5) Stream output of the model using the stream method + with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="model_generation").time(): + try: + for token in current_model.stream(messages): + yield token + + except Exception as e: + # TODO: Log error + # Sets next available time to current time in seconds + downtime + self.model_available_at[self.current_model_idx] = time.time() + MODEL_DOWNTIME_SECONDS + + def _build_system_prompt(self, recommendations: List[AniZenithVectorSearchResult]) -> str: + lines = [] + + # Add base system prompt + lines.append(SYSTEM_PROMPT) + lines.append("Here are the recommendation system's top shows:\n") + + # Add recommendation docs + # model_dump() is a special Pydantic method to generate a dict representation of any Pydantic object + recommendations = [recommendation.model_dump() for recommendation in recommendations] + recommendation_string = "\n\n".join(recommendations) if recommendations else "No good recommendations found." + lines.append(recommendation_string) + + return "\n".join(lines) diff --git a/backend/models.py b/backend/models.py new file mode 100644 index 0000000..a1ccb0d --- /dev/null +++ b/backend/models.py @@ -0,0 +1,152 @@ +import os +from abc import ABC, abstractmethod +from threading import Thread +from typing import Iterator, Dict, Any, List + +import torch +from huggingface_hub import InferenceClient +from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer + +from backend.constants import MAX_NEW_TOKENS, TEMPERATURE, TOP_P + +# TODO: Move to config management system +HF_TOKEN = os.getenv('HF_TOKEN') +local_model_id = "Qwen/Qwen3-0.6B" +external_model_id = "openai/gpt-oss-20b" + +class Model(ABC): + """ + Abstract base model. + Enforces streaming + usage stats. + """ + def __init__(self): + self._usage: Dict[str, Any] = {} + + # Each subclass must implement streaming + @abstractmethod + def stream(self, messages: List[Dict[str, str]]) -> Iterator[str]: + pass + + # Each subclass must define usage stats + @abstractmethod + def get_usage(self) -> Dict[str, Any]: + pass + + @abstractmethod + def get_name(self) -> str: + pass + + # Generate runs stream and accumulates, then returns + def generate(self, messages: List[Dict[str, str]]) -> Dict[str, Any]: + """ + Runs stream(), accumulates output, returns final text + usage. + """ + output = [] + + for chunk in self.stream(messages): + output.append(chunk) + + result_text = "".join(output) + + return { + "text": result_text, + "usage": self.get_usage() + } + +class HFLocalModel(Model): + def __init__(self): + super().__init__() + + # Load local model and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(local_model_id) + self.model = AutoModelForCausalLM.from_pretrained(local_model_id, device_map="auto", torch_dtype=torch.float16) + + self._usage_data = None + + def stream(self, messages: List[Dict[str, str]]): + self._usage_data = None + # Apply chat template & tokenize input + inputs = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + return_tensors="pt", + add_generation_prompt=True + ).to(self.model.device) + + # Initialize streamer + streamer = TextIteratorStreamer( + self.tokenizer, + skip_prompt=True, + skip_special_tokens=True + ) + + # Make generation config + gen_kwargs = dict( + input_ids=inputs, + streamer=streamer, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=True, + temperature=TEMPERATURE, + top_p=TOP_P + ) + + # Start another thread to run generation for streaming + thread = Thread(target=self.model.generate, kwargs=gen_kwargs) + thread.start() + + # Accumulate usage + input_tokens = inputs["input_ids"].shape[-1] + output_tokens = 0 + for text in streamer: + yield text + output_tokens += 1 + + # Add usage data + self._usage_data = {"input_tokens": input_tokens, "output_tokens": output_tokens} + + def get_usage(self): + return { + "model_name": self.get_name(), + "input_tokens": self._usage_data["input_tokens"], + "output_tokens": self._usage_data["output_tokens"], + } + + def get_name(self): + return local_model_id + + +class HFInferenceClientModel(Model): + def __init__(self): + super().__init__() + + self.client = InferenceClient( + model=external_model_id, + token=HF_TOKEN, + ) + + self._usage_data = None + + def stream(self, messages: List[Dict[str, str]]): + self._usage_data = None + + # Use built in InferenceClient chat completion + for chunk in self.client.chat_completion( + messages=messages, + max_tokens=MAX_NEW_TOKENS, + stream=True, + temperature=TEMPERATURE, + top_p=TOP_P, + ): + yield chunk + if hasattr(chunk, 'usage') and chunk.usage: + self._usage_data = chunk.usage + + def get_name(self): + return external_model_id + + def get_usage(self): + return { + "model_name": self.get_name(), + "input_tokens": self._usage_data.prompt_tokens, + "output_tokens": self._usage_data.completion_tokens, + } \ No newline at end of file diff --git a/backend/reranker.py b/backend/reranker.py new file mode 100644 index 0000000..0209cab --- /dev/null +++ b/backend/reranker.py @@ -0,0 +1,13 @@ +from typing import List + +from backend.mongo.AniZenithVectorSearchResult import AniZenithVectorSearchResult + + +class AniZenithReranker: + + def __init__(self): + pass + + def rerank(self, user_query: str, results: List[AniZenithVectorSearchResult], limit=5): + # TODO: Add to this this framework + return results \ No newline at end of file From 0610e9ec301a4c8ad9acffeb9de4b6515ef7b67f Mon Sep 17 00:00:00 2001 From: ShafathZ Date: Sun, 12 Apr 2026 06:02:14 -0400 Subject: [PATCH 2/5] Fixed backend local model code issues, working but slow on CPU --- backend/app.py | 9 ++-- backend/inference_manager.py | 26 ++++++++-- backend/models.py | 89 ++++++++++++++++++-------------- backend/reranker.py | 2 +- frontend/app.py | 3 ++ frontend/static/js/chat_utils.js | 2 +- 6 files changed, 83 insertions(+), 48 deletions(-) diff --git a/backend/app.py b/backend/app.py index d439d73..e1ff41f 100644 --- a/backend/app.py +++ b/backend/app.py @@ -7,7 +7,7 @@ from backend.AniZenithExchange import AniZenithRequest, AniZenithResponse from prometheus.prometheus_middleware import PrometheusMiddleware, prometheus_router from backend.validation_utils import validate_anizenith_request -from backend.backend_utils import chat_with_llm +from backend.inference_manager import InferenceManager from starlette.middleware.sessions import SessionMiddleware import logging @@ -29,6 +29,9 @@ app.include_router(prometheus_router) #app.include_router(auth_router) +# Initialize Backend Inference Manager +inference_manager = InferenceManager() + # ┌───────────────────────────────────────────────┐ # │ BACKEND API ENDPOINTS │ # └───────────────────────────────────────────────┘ @@ -47,8 +50,8 @@ async def handle_chat_request(request: AniZenithRequest): # Chat with LLM using the messages in the request assistant_message = "" - for streamed_response in chat_with_llm(request.messages, request.use_local): - assistant_message = streamed_response + for streamed_response in inference_manager.chat(request.messages, "null-user"): + assistant_message += streamed_response # Construct an AniZenithResponse based on Assistant Message # Copy the old set of messages diff --git a/backend/inference_manager.py b/backend/inference_manager.py index d22a216..f6e2593 100644 --- a/backend/inference_manager.py +++ b/backend/inference_manager.py @@ -1,3 +1,4 @@ +import json import os import time from typing import List, Dict @@ -11,7 +12,7 @@ from backend.reranker import AniZenithReranker from constants import * -load_dotenv() +load_dotenv(".env") # TODO: Move to config management system MONGO_CONN_STRING = os.getenv("ATLAS_URI") @@ -20,12 +21,15 @@ MODEL_DOWNTIME_SECONDS = 120 +local_model_id = "Qwen/Qwen3-0.6B" +external_model_id = "openai/gpt-oss-20b" + class InferenceManager: def __init__(self): # Initialize Models for use in descending order of importance - self.models: List[Model] = [HFInferenceClientModel(), HFLocalModel()] + self.models: List[Model] = [HFInferenceClientModel(external_model_id), HFLocalModel(local_model_id)] self.current_model_idx = 0 # Current model idx being used self.model_available_at = [0.0 for _ in self.models] # Controls fallback timer in case error occurs @@ -61,6 +65,7 @@ def chat(self, messages: List[Dict[str, str]], user_id: str = None): # TODO: Make this an agentic framework using LangChain # TODO: Use user_id in logging # TODO: Add queue system to make blocking better + # TODO: Replace all print statements with logging current_model = self.get_best_model() with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="full_pipeline").time(): # 1) Retrieve results from DB Client @@ -68,28 +73,42 @@ def chat(self, messages: List[Dict[str, str]], user_id: str = None): # Use the last message as the user message (it should always be a user message) user_query = messages[-1]['content'] retrieved_docs: List[AniZenithVectorSearchResult] = self.db_client.perform_vector_search(user_query, limit=VECTOR_SEARCH_LIMIT) + print(f"Retrieved Docs: ({len(retrieved_docs)})") # 2) Rerank results using the reranker based on document info and user query with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="reranker").time(): recommended_docs: List[AniZenithVectorSearchResult] = self.reranker.rerank(user_query, retrieved_docs, limit=RERANK_LIMIT) + print(f"Reranked Docs: ({len(recommended_docs)})") # 3) Construct system prompt with recommended docs system_prompt = self._build_system_prompt(recommended_docs) # 4) Insert system prompt into messages messages.insert(0, {"role": "system", "content": system_prompt}) + print("Completed System Prompt Building") # 5) Stream output of the model using the stream method + output = "" with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="model_generation").time(): try: for token in current_model.stream(messages): + output += token yield token except Exception as e: # TODO: Log error + print(f"Model Error: {e}") + # Yield model terminated to user + yield "" # Sets next available time to current time in seconds + downtime self.model_available_at[self.current_model_idx] = time.time() + MODEL_DOWNTIME_SECONDS + # Record Usage Metrics + print(f"Streamed output: {output}") + usage = current_model.get_usage() + observe_user_message(user_id, user_query, usage["input_token_count"], current_model.get_name()) + observe_bot_message(user_id, output, usage["output_token_count"], current_model.get_name()) + def _build_system_prompt(self, recommendations: List[AniZenithVectorSearchResult]) -> str: lines = [] @@ -99,7 +118,8 @@ def _build_system_prompt(self, recommendations: List[AniZenithVectorSearchResult # Add recommendation docs # model_dump() is a special Pydantic method to generate a dict representation of any Pydantic object - recommendations = [recommendation.model_dump() for recommendation in recommendations] + # Dumps JSON as string with indent + recommendations = [json.dumps(recommendation.model_dump(), indent=4) for recommendation in recommendations] recommendation_string = "\n\n".join(recommendations) if recommendations else "No good recommendations found." lines.append(recommendation_string) diff --git a/backend/models.py b/backend/models.py index a1ccb0d..95ba358 100644 --- a/backend/models.py +++ b/backend/models.py @@ -11,16 +11,15 @@ # TODO: Move to config management system HF_TOKEN = os.getenv('HF_TOKEN') -local_model_id = "Qwen/Qwen3-0.6B" -external_model_id = "openai/gpt-oss-20b" class Model(ABC): """ Abstract base model. Enforces streaming + usage stats. """ - def __init__(self): + def __init__(self, name: str): self._usage: Dict[str, Any] = {} + self.name = name # Each subclass must implement streaming @abstractmethod @@ -32,9 +31,8 @@ def stream(self, messages: List[Dict[str, str]]) -> Iterator[str]: def get_usage(self) -> Dict[str, Any]: pass - @abstractmethod - def get_name(self) -> str: - pass + def get_name(self): + return self.name # Generate runs stream and accumulates, then returns def generate(self, messages: List[Dict[str, str]]) -> Dict[str, Any]: @@ -54,17 +52,19 @@ def generate(self, messages: List[Dict[str, str]]) -> Dict[str, Any]: } class HFLocalModel(Model): - def __init__(self): - super().__init__() + def __init__(self, model_id: str): + super().__init__(model_id) - # Load local model and tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(local_model_id) - self.model = AutoModelForCausalLM.from_pretrained(local_model_id, device_map="auto", torch_dtype=torch.float16) + # Load local model and tokenizer (use efficient model params) + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + self.model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float16) + self.model.eval() self._usage_data = None def stream(self, messages: List[Dict[str, str]]): self._usage_data = None + print("Starting tokenize") # Apply chat template & tokenize input inputs = self.tokenizer.apply_chat_template( messages, @@ -80,47 +80,56 @@ def stream(self, messages: List[Dict[str, str]]): skip_special_tokens=True ) - # Make generation config - gen_kwargs = dict( - input_ids=inputs, - streamer=streamer, - max_new_tokens=MAX_NEW_TOKENS, - do_sample=True, - temperature=TEMPERATURE, - top_p=TOP_P - ) - + def generate(): + # Ensure no gradients + with torch.inference_mode(): + try: + self.model.generate(input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'], + max_new_tokens=MAX_NEW_TOKENS, + do_sample=True, + temperature=TEMPERATURE, + top_p=TOP_P, + use_cache=True, # Use KV Cache + streamer=streamer # Adds the text streamer to capture output callback + ) + except Exception as e: + # Stop streamer and propagate error + self._thread_error = e + streamer.end() + + print("Starting stream") # Start another thread to run generation for streaming - thread = Thread(target=self.model.generate, kwargs=gen_kwargs) + thread = Thread(target=generate) thread.start() # Accumulate usage - input_tokens = inputs["input_ids"].shape[-1] - output_tokens = 0 + input_token_count = inputs['input_ids'].shape[-1] + output_token_count = 0 for text in streamer: yield text - output_tokens += 1 + output_token_count += 1 + + # Clean up thread + thread.join() # Add usage data - self._usage_data = {"input_tokens": input_tokens, "output_tokens": output_tokens} + self._usage_data = {"input_token_count": input_token_count, "output_token_count": output_token_count} def get_usage(self): return { "model_name": self.get_name(), - "input_tokens": self._usage_data["input_tokens"], - "output_tokens": self._usage_data["output_tokens"], + "input_token_count": self._usage_data["input_token_count"], + "output_token_count": self._usage_data["output_token_count"], } - def get_name(self): - return local_model_id - class HFInferenceClientModel(Model): - def __init__(self): - super().__init__() + def __init__(self, model_id: str): + super().__init__(model_id) self.client = InferenceClient( - model=external_model_id, + model=model_id, token=HF_TOKEN, ) @@ -137,16 +146,16 @@ def stream(self, messages: List[Dict[str, str]]): temperature=TEMPERATURE, top_p=TOP_P, ): - yield chunk + if chunk.choices and chunk.choices[0].delta.content: + token = chunk.choices[0].delta.content + yield token + # Add usage data for logging if hasattr(chunk, 'usage') and chunk.usage: self._usage_data = chunk.usage - def get_name(self): - return external_model_id - def get_usage(self): return { "model_name": self.get_name(), - "input_tokens": self._usage_data.prompt_tokens, - "output_tokens": self._usage_data.completion_tokens, + "input_token_count": self._usage_data.prompt_tokens, + "output_token_count": self._usage_data.completion_tokens, } \ No newline at end of file diff --git a/backend/reranker.py b/backend/reranker.py index 0209cab..7a8b2df 100644 --- a/backend/reranker.py +++ b/backend/reranker.py @@ -10,4 +10,4 @@ def __init__(self): def rerank(self, user_query: str, results: List[AniZenithVectorSearchResult], limit=5): # TODO: Add to this this framework - return results \ No newline at end of file + return results[:limit] \ No newline at end of file diff --git a/frontend/app.py b/frontend/app.py index 9aa8301..8fc40aa 100644 --- a/frontend/app.py +++ b/frontend/app.py @@ -10,6 +10,9 @@ from fastapi.middleware.cors import CORSMiddleware import logging from prometheus.prometheus_middleware import PrometheusMiddleware, prometheus_router +from dotenv import load_dotenv + +load_dotenv(".env") # Configure logging at Startup logging.basicConfig(level = logging.INFO) diff --git a/frontend/static/js/chat_utils.js b/frontend/static/js/chat_utils.js index 45e4cea..e01147d 100644 --- a/frontend/static/js/chat_utils.js +++ b/frontend/static/js/chat_utils.js @@ -96,7 +96,7 @@ export async function sendMessagesToBackend() { try { // If using local, detect and add additional timeout - const timeout = payload.use_local ? 180.0 : 5.0; + const timeout = payload.use_local ? 180.0 : 25.0; const response = await fetch("/proxy/anizenith/chat", { method: "POST", headers: { From 5e9217e226f611715f47cdabe45aef8c579312b3 Mon Sep 17 00:00:00 2001 From: ShafathZ Date: Sun, 12 Apr 2026 06:27:07 -0400 Subject: [PATCH 3/5] Improved scalability defining usage class, improved comments --- backend/backend_utils.py | 136 ----------------------------------- backend/inference_manager.py | 6 +- backend/models.py | 36 ++++++---- tests/test_chat_models.py | 23 +++--- 4 files changed, 36 insertions(+), 165 deletions(-) delete mode 100644 backend/backend_utils.py diff --git a/backend/backend_utils.py b/backend/backend_utils.py deleted file mode 100644 index 1be8d38..0000000 --- a/backend/backend_utils.py +++ /dev/null @@ -1,136 +0,0 @@ -from typing import List, Dict -from huggingface_hub import InferenceClient -from transformers import pipeline -from backend.constants import * -from backend.prometheus_utils import * -from dotenv import load_dotenv -import os -import json -from backend.mongo.AniZenithMongoClient import AniZenithMongoClient -from backend.mongo.AniZenithVectorSearchResult import AniZenithVectorSearchResult - -# Load all Environment Variables -load_dotenv() -HF_TOKEN = os.getenv('HF_TOKEN') - -# Init AniZenithMongoClient -CONN_STRING = os.getenv("ATLAS_URI") -DB_CLIENT = AniZenithMongoClient(CONN_STRING) - - -# Load the Local Pipeline Model at App Startup -PIPELINE_LOCAL_MODEL = pipeline(task='text-generation', - model='Qwen/Qwen3-0.6B', - max_new_tokens=MAX_NEW_TOKENS, - temperature=TEMPERATURE, - do_sample=False, - top_p=TOP_P) - - -# TODO: Make this Method Async Later -def chat_with_llm(messages: List[Dict[str, str]], use_local_model: bool): - - # TODO: Replace with actual IDs from a config - model = "Qwen/Qwen3-0.6B" if use_local_model else "openai/gpt-oss-20b" - with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=model, stage="full_pipeline").time(): - - # Use the last message as the user message (it should always be a user message) - user_query = messages[-1]['content'] - - # Retrieve relevant results from DB using vector search - with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=model, stage="db_retrieval").time(): - - # Perform Vector Search using user query - # This method returns a List[AniZenithVectorSearchResult] - recommendations: List[AniZenithVectorSearchResult] = DB_CLIENT.perform_vector_search(user_query) - - # Convert the list of vector search objects into a list of dicts - # model_dump() is a special Pydantic method to generate a dict representation of any Pydantic object - recommendations_dict = [recommendation.model_dump() for recommendation in recommendations] - - # Serialize to a JSON string - recommendations_string = json.dumps(recommendations_dict, indent = 4) - - # Query the model - with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=model, stage="model_generation").time(): - for result in query_model(messages, use_local_model, recommendations_string): - yield result - - -# TODO: Make this method Async later -def query_model(messages: List[Dict[str, str]], use_local_model: bool, recommendations_string: str): - # --- Determine System Prompt --- - # Start with the Fixed System Prompt - system_prompt = SYSTEM_PROMPT - - # Append recommendation string data to the System Prompt if it exists - if recommendations_string: - system_prompt += "\nRECOMMENDATION JSON:" + f"\n{recommendations_string}" - - # Add the System Prompt to the Input Messages to the LLM - input_messages = [{"role": "system", "content": system_prompt}] - # Add the rest of the messages - input_messages.extend(messages) - - # --- Determine which model to use (local or external) --- - # Constants for logging - response = "" - input_token_count = 0 - output_token_count = 0 - - # --- Local Model --- - if use_local_model: - # Uses pipeline from transformers library - response = PIPELINE_LOCAL_MODEL(input_messages) - - # Get the response from the local model, parse it, and yield - generated_text = response[0]['generated_text'][-1]['content'].split('')[-1].strip() - yield generated_text - - # Log token counts (there is no clean way to do this besides re-tokenizing) - tokenizer = PIPELINE_LOCAL_MODEL.tokenizer - formatted_input = tokenizer.apply_chat_template( - input_messages, - tokenize=False, - add_generation_prompt=True - ) - input_tokens = tokenizer.encode(formatted_input) - generated_tokens = tokenizer.encode(generated_text) - input_token_count = len(input_tokens) - output_token_count = len(generated_tokens) - - # --- Non-local Model (Use InferenceClient) --- - else: - client = InferenceClient( - token=HF_TOKEN, - model="openai/gpt-oss-20b", - ) - - # Stream inference client output and yield the text chunk - usage = None - for chunk in client.chat_completion( - messages=input_messages, - max_tokens=MAX_NEW_TOKENS, - stream=True, - temperature=TEMPERATURE, - top_p=TOP_P, - ): - if chunk.choices and chunk.choices[0].delta.content: - token = chunk.choices[0].delta.content - response += token - yield response - # Add usage data for logging - if hasattr(chunk, 'usage') and chunk.usage: - usage = chunk.usage - - if usage: - input_token_count = usage.prompt_tokens - output_token_count = usage.completion_tokens - - # Log the model usage output - # TODO: Record the specific model ID once the backend is refactored - model = "Qwen/Qwen3-0.6B" if use_local_model else "openai/gpt-oss-20b" - # TODO: Use real Inference Manager / session ID - observe_user_message(user_id="0", user_message=messages[-1]['content'], token_count=input_token_count, model=model) - observe_bot_message(user_id="0", bot_message=response, token_count=output_token_count, model=model) - diff --git a/backend/inference_manager.py b/backend/inference_manager.py index f6e2593..fdb63b0 100644 --- a/backend/inference_manager.py +++ b/backend/inference_manager.py @@ -100,14 +100,14 @@ def chat(self, messages: List[Dict[str, str]], user_id: str = None): print(f"Model Error: {e}") # Yield model terminated to user yield "" - # Sets next available time to current time in seconds + downtime + # Sets the model's next available time to current time in seconds + downtime self.model_available_at[self.current_model_idx] = time.time() + MODEL_DOWNTIME_SECONDS # Record Usage Metrics print(f"Streamed output: {output}") usage = current_model.get_usage() - observe_user_message(user_id, user_query, usage["input_token_count"], current_model.get_name()) - observe_bot_message(user_id, output, usage["output_token_count"], current_model.get_name()) + observe_user_message(user_id, user_query, usage.input_token_count, current_model.get_name()) + observe_bot_message(user_id, output, usage.output_token_count, current_model.get_name()) def _build_system_prompt(self, recommendations: List[AniZenithVectorSearchResult]) -> str: lines = [] diff --git a/backend/models.py b/backend/models.py index 95ba358..ba3603e 100644 --- a/backend/models.py +++ b/backend/models.py @@ -6,12 +6,19 @@ import torch from huggingface_hub import InferenceClient from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer +from pydantic import BaseModel from backend.constants import MAX_NEW_TOKENS, TEMPERATURE, TOP_P # TODO: Move to config management system HF_TOKEN = os.getenv('HF_TOKEN') +# Usage statistics class to enforce for logging +class ModelUsageStatistics(BaseModel): + model_name: str + input_token_count: int + output_token_count: int + class Model(ABC): """ Abstract base model. @@ -24,11 +31,13 @@ def __init__(self, name: str): # Each subclass must implement streaming @abstractmethod def stream(self, messages: List[Dict[str, str]]) -> Iterator[str]: + """Streams model response as a string generator""" pass # Each subclass must define usage stats @abstractmethod - def get_usage(self) -> Dict[str, Any]: + def get_usage(self) -> ModelUsageStatistics: + """Returns a ModelUsageStatistics with usage statistics for the model""" pass def get_name(self): @@ -37,7 +46,8 @@ def get_name(self): # Generate runs stream and accumulates, then returns def generate(self, messages: List[Dict[str, str]]) -> Dict[str, Any]: """ - Runs stream(), accumulates output, returns final text + usage. + Runs stream() internally and accumulates output + Returns final text + usage. """ output = [] @@ -47,7 +57,7 @@ def generate(self, messages: List[Dict[str, str]]) -> Dict[str, Any]: result_text = "".join(output) return { - "text": result_text, + "generated_text": result_text, "usage": self.get_usage() } @@ -117,11 +127,11 @@ def generate(): self._usage_data = {"input_token_count": input_token_count, "output_token_count": output_token_count} def get_usage(self): - return { - "model_name": self.get_name(), - "input_token_count": self._usage_data["input_token_count"], - "output_token_count": self._usage_data["output_token_count"], - } + return ModelUsageStatistics( + model_name=self.get_name(), + input_token_count=self._usage_data["input_token_count"], + output_token_count=self._usage_data["output_token_count"], + ) class HFInferenceClientModel(Model): @@ -154,8 +164,8 @@ def stream(self, messages: List[Dict[str, str]]): self._usage_data = chunk.usage def get_usage(self): - return { - "model_name": self.get_name(), - "input_token_count": self._usage_data.prompt_tokens, - "output_token_count": self._usage_data.completion_tokens, - } \ No newline at end of file + return ModelUsageStatistics( + model_name=self.get_name(), + input_token_count=self._usage_data.prompt_tokens, + output_token_count=self._usage_data.completion_tokens, + ) \ No newline at end of file diff --git a/tests/test_chat_models.py b/tests/test_chat_models.py index 8eef7fa..565fa9c 100644 --- a/tests/test_chat_models.py +++ b/tests/test_chat_models.py @@ -1,12 +1,15 @@ import os import pytest -from backend.backend_utils import chat_with_llm -import backend.backend_utils as backend_utils +from backend.inference_manager import InferenceManager +# TODO: Need help re-integrating monkeypatch, leave until after config management refactor TEST_SYSTEM_MESSAGE = "You are a friendly chatbot." TEST_USER_MESSAGE = "Hello!" HF_TOKEN = os.getenv("HF_TOKEN") +@pytest.fixture(scope="module") +def get_manager(): + return InferenceManager() def test_HF_token_exists(): token = os.getenv("HF_TOKEN") @@ -14,23 +17,17 @@ def test_HF_token_exists(): assert len(token) > 1 -def test_local_model_runs(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(backend_utils, "SYSTEM_PROMPT", TEST_SYSTEM_MESSAGE) - use_local_model = True +def test_local_model_runs(get_manager): collected_result = "" - for result in chat_with_llm(messages=[{"role": "user", "content": TEST_USER_MESSAGE}], - use_local_model=use_local_model): - collected_result = result + for result in get_manager.chat(messages=[{"role":"user","content": TEST_USER_MESSAGE}]): + collected_result += result assert len(collected_result) > 0 -def test_external_model_runs(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(backend_utils, "SYSTEM_PROMPT", TEST_SYSTEM_MESSAGE) - use_local_model = False +def test_external_model_runs(get_manager): collected_result = "" - for result in chat_with_llm(messages=[{"role": "user", "content": TEST_USER_MESSAGE}], - use_local_model=use_local_model): + for result in get_manager.chat(messages=[{"role": "user", "content": TEST_USER_MESSAGE}]): collected_result = result assert len(collected_result) > 0 From ba723d225e53c912d2bd91668fa4c25fd2e13ba0 Mon Sep 17 00:00:00 2001 From: ShafathZ <53407653+ShafathZ@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:52:23 -0400 Subject: [PATCH 4/5] Addressed some PR comments --- backend/constants.py | 4 ++++ backend/inference_manager.py | 16 +++++++--------- backend/models.py | 16 +++++++++++----- frontend/app.py | 3 --- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/backend/constants.py b/backend/constants.py index 7978a06..ec608fd 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -10,6 +10,10 @@ 3. Ask the user to provide their favorite genre(s) for Anime Recommendations """ +RECOMMENDED_DOCS_PREAMBLE = f""" +Here are the recommendation system's top shows:\n +""" + MAX_NEW_TOKENS = 2048 TEMPERATURE = 0.7 TOP_P = 0.7 \ No newline at end of file diff --git a/backend/inference_manager.py b/backend/inference_manager.py index fdb63b0..5f1086a 100644 --- a/backend/inference_manager.py +++ b/backend/inference_manager.py @@ -13,12 +13,10 @@ from constants import * load_dotenv(".env") -# TODO: Move to config management system +# TODO: Move these to config management system MONGO_CONN_STRING = os.getenv("ATLAS_URI") - VECTOR_SEARCH_LIMIT = 20 RERANK_LIMIT = 5 - MODEL_DOWNTIME_SECONDS = 120 local_model_id = "Qwen/Qwen3-0.6B" @@ -73,15 +71,15 @@ def chat(self, messages: List[Dict[str, str]], user_id: str = None): # Use the last message as the user message (it should always be a user message) user_query = messages[-1]['content'] retrieved_docs: List[AniZenithVectorSearchResult] = self.db_client.perform_vector_search(user_query, limit=VECTOR_SEARCH_LIMIT) - print(f"Retrieved Docs: ({len(retrieved_docs)})") + print(f"Retrieved Docs: ({len(retrieved_docs)}) relevant docs") # 2) Rerank results using the reranker based on document info and user query - with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="reranker").time(): - recommended_docs: List[AniZenithVectorSearchResult] = self.reranker.rerank(user_query, retrieved_docs, limit=RERANK_LIMIT) - print(f"Reranked Docs: ({len(recommended_docs)})") + with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="reranking").time(): + reranked_docs: List[AniZenithVectorSearchResult] = self.reranker.rerank(user_query, retrieved_docs, limit=RERANK_LIMIT) + print(f"Reranked Docs: ({len(reranked_docs)})") # 3) Construct system prompt with recommended docs - system_prompt = self._build_system_prompt(recommended_docs) + system_prompt = self._build_system_prompt(reranked_docs) # 4) Insert system prompt into messages messages.insert(0, {"role": "system", "content": system_prompt}) @@ -114,7 +112,7 @@ def _build_system_prompt(self, recommendations: List[AniZenithVectorSearchResult # Add base system prompt lines.append(SYSTEM_PROMPT) - lines.append("Here are the recommendation system's top shows:\n") + lines.append(RECOMMENDED_DOCS_PREAMBLE) # Add recommendation docs # model_dump() is a special Pydantic method to generate a dict representation of any Pydantic object diff --git a/backend/models.py b/backend/models.py index ba3603e..fae3e26 100644 --- a/backend/models.py +++ b/backend/models.py @@ -71,10 +71,10 @@ def __init__(self, model_id: str): self.model.eval() self._usage_data = None + self._thread_error = None def stream(self, messages: List[Dict[str, str]]): self._usage_data = None - print("Starting tokenize") # Apply chat template & tokenize input inputs = self.tokenizer.apply_chat_template( messages, @@ -92,7 +92,7 @@ def stream(self, messages: List[Dict[str, str]]): def generate(): # Ensure no gradients - with torch.inference_mode(): + with torch.no_grad(): try: self.model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], @@ -108,21 +108,27 @@ def generate(): self._thread_error = e streamer.end() - print("Starting stream") # Start another thread to run generation for streaming thread = Thread(target=generate) thread.start() # Accumulate usage - input_token_count = inputs['input_ids'].shape[-1] + input_token_count = inputs['input_ids'].shape[-1] # (B x input_tokens_len) output_token_count = 0 + + # Streamer obtains values sequentially by injecting into generate function in new thread + # Streamer outputs values as Generator for text here for text in streamer: yield text - output_token_count += 1 + output_token_count += 1 # Streamer executes every token event received # Clean up thread thread.join() + # Handle error after joining if it exists + if self._thread_error is not None: + raise self._thread_error + # Add usage data self._usage_data = {"input_token_count": input_token_count, "output_token_count": output_token_count} diff --git a/frontend/app.py b/frontend/app.py index 8fc40aa..9aa8301 100644 --- a/frontend/app.py +++ b/frontend/app.py @@ -10,9 +10,6 @@ from fastapi.middleware.cors import CORSMiddleware import logging from prometheus.prometheus_middleware import PrometheusMiddleware, prometheus_router -from dotenv import load_dotenv - -load_dotenv(".env") # Configure logging at Startup logging.basicConfig(level = logging.INFO) From 015f07148a81895da20bf2256d7ba79f74e5eb52 Mon Sep 17 00:00:00 2001 From: ShafathZ <53407653+ShafathZ@users.noreply.github.com> Date: Tue, 14 Apr 2026 15:10:40 -0400 Subject: [PATCH 5/5] Addressed more PR comments --- backend/app.py | 1 + backend/models.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/app.py b/backend/app.py index e1ff41f..b7beaab 100644 --- a/backend/app.py +++ b/backend/app.py @@ -50,6 +50,7 @@ async def handle_chat_request(request: AniZenithRequest): # Chat with LLM using the messages in the request assistant_message = "" + # TODO: Replace null user with real user / session ID for streamed_response in inference_manager.chat(request.messages, "null-user"): assistant_message += streamed_response diff --git a/backend/models.py b/backend/models.py index fae3e26..2a54ab2 100644 --- a/backend/models.py +++ b/backend/models.py @@ -71,7 +71,6 @@ def __init__(self, model_id: str): self.model.eval() self._usage_data = None - self._thread_error = None def stream(self, messages: List[Dict[str, str]]): self._usage_data = None @@ -90,9 +89,11 @@ def stream(self, messages: List[Dict[str, str]]): skip_special_tokens=True ) + generation_err = None def generate(): # Ensure no gradients with torch.no_grad(): + nonlocal generation_err try: self.model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], @@ -105,7 +106,7 @@ def generate(): ) except Exception as e: # Stop streamer and propagate error - self._thread_error = e + generation_err = e streamer.end() # Start another thread to run generation for streaming @@ -126,8 +127,8 @@ def generate(): thread.join() # Handle error after joining if it exists - if self._thread_error is not None: - raise self._thread_error + if generation_err is not None: + raise generation_err # Add usage data self._usage_data = {"input_token_count": input_token_count, "output_token_count": output_token_count}