From 09155bea09bab19b8624f3dfda2dd909fdc39636 Mon Sep 17 00:00:00 2001 From: VedaSiddhartha Date: Sat, 28 Mar 2026 22:29:19 +0530 Subject: [PATCH 1/2] refactor(models): standardize outputs to structured dictionary format and add config validation --- .gitignore | 3 ++- app/models/bertweet_model.py | 32 ++++++++++++++++++++++++++++---- app/models/whisper_model.py | 35 ++++++++++++++++++++++++++++++----- 3 files changed, 60 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 564b8d3..4717a04 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ logs/ # Ignore Testing Coverage Results tests/coverage/.coverage -env/ \ No newline at end of file +env/venv/ +venv/ diff --git a/app/models/bertweet_model.py b/app/models/bertweet_model.py index 2342c7c..6058c4e 100644 --- a/app/models/bertweet_model.py +++ b/app/models/bertweet_model.py @@ -5,6 +5,7 @@ import torch.nn as nn from transformers import AutoTokenizer, AutoModelForSequenceClassification +from typing import Dict, Any class BertweetSentiment(nn.Module): def __init__(self,config: dict)->None: @@ -14,7 +15,16 @@ def __init__(self,config: dict)->None: """ self.debug = config.get('debug') - self.config = config.get('sentiment_analysis').get('bertweet') + # ✅ Add null check + sentiment_config = config.get('sentiment_analysis') + if not sentiment_config: + raise ValueError("'sentiment_analysis' not found in config") + + self.config = sentiment_config.get('bertweet') + if not self.config: + raise ValueError("'bertweet' not found in sentiment_analysis config") + + self.model_name = self.config.get('model_name') self.device = self.config.get('device') @@ -35,7 +45,7 @@ def __init__(self,config: dict)->None: else: self.class_labels = None - def forward(self,text)->tuple: + def forward(self,text)-> Dict[str, Any]: """ Perform sentiment analysis on the given text. @@ -43,7 +53,7 @@ def forward(self,text)->tuple: text (str): Input text for sentiment analysis. Returns: - tuple: Model outputs, probabilities, predicted label, and confidence score. + Dict: Model outputs, probabilities, predicted label, and confidence score. """ # Tokenize the input text inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device) @@ -57,10 +67,24 @@ def forward(self,text)->tuple: # Get the predicted sentiment predicted_class = torch.argmax(probabilities, dim=1).item() + # Get the predicted sentiment + # Convert to integer explicitly + predicted_class = int(torch.argmax(probabilities, dim=1).item()) + + # Add null check + if self.class_labels is None: + raise ValueError("Class labels not available") + + # Get the corresponding class label predicted_label = self.class_labels[predicted_class] - return outputs, probabilities, predicted_label, probabilities[0][predicted_class].item() + return { + "logits": outputs.logits.tolist(), + "probabilities": probabilities.tolist(), + "label": predicted_label, + "score": probabilities[0][predicted_class].item() +} # if __name__ == "__main__": diff --git a/app/models/whisper_model.py b/app/models/whisper_model.py index 9217bf2..ca10dd3 100644 --- a/app/models/whisper_model.py +++ b/app/models/whisper_model.py @@ -5,6 +5,7 @@ import torch.nn as nn from transformers import pipeline +from typing import Dict, Any class WhisperTranscript(nn.Module): @@ -14,8 +15,14 @@ def __init__(self, config: dict) -> None: :param config: The configuration object containing model and device info. """ self.debug = config.get('debug') - - self.config = config.get('transcription').get('whisper') + transcription_config = config.get('transcription') + if not transcription_config: + raise ValueError("'transcription' not found in config") + + self.config = transcription_config.get('whisper') + if not self.config: + raise ValueError("'whisper' not found in transcription config") + self.model_size = self.config.get('model_size') self.device = self.config.get('device') self.chunk_length_s = self.config.get('chunk_length_s') @@ -32,7 +39,7 @@ def __init__(self, config: dict) -> None: ) - def forward(self, audio_file: str) -> tuple: + def forward(self, audio_file: str) -> Dict[str, Any]: """ Perform transcription on the given audio file. @@ -40,12 +47,30 @@ def forward(self, audio_file: str) -> tuple: audio_file (str): Path to the audio file. Returns: - tuple: Transcribed text and timestamped chunks. + Dict: Transcribed text and timestamped chunks. """ # Forward pass out = self.pipeline(audio_file, return_timestamps=True) + + # Initialize to avoid "possibly unbound" error + text = "" + chunks = [] + - return out["text"], out["chunks"] + # Extract text and chunks safely + if isinstance(out, dict): + text = out.get("text", "") + chunks = out.get("chunks", []) + else: + # For dict-like objects (not necessarily dict type) + text = getattr(out, "text", "") + chunks = getattr(out, "chunks", []) + + return { + "text": text, + "chunks": chunks +} + # if __name__ == "__main__": # config = { From 7124ad1feb23ab4816651b6d44223d5b87d4c567 Mon Sep 17 00:00:00 2001 From: VedaSiddhartha Date: Sat, 28 Mar 2026 22:42:21 +0530 Subject: [PATCH 2/2] refactor(models): standardize outputs to structured dictionary format and add config validation --- app/models/bertweet_model.py | 20 ++++++++++---------- app/models/whisper_model.py | 3 ++- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/app/models/bertweet_model.py b/app/models/bertweet_model.py index 6058c4e..5d57d73 100644 --- a/app/models/bertweet_model.py +++ b/app/models/bertweet_model.py @@ -32,14 +32,14 @@ def __init__(self,config: dict)->None: # Initialize the Tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - # Initialize the Model + # Initializing the Model self.model= AutoModelForSequenceClassification.from_pretrained(self.model_name) self.model.to(self.device) - # Load the model configuration to get class labels + # Loading the model configuration to get class labels self.model_config = self.model.config - # Get Labels + # Geting the Labels if hasattr(self.model_config, 'id2label'): self.class_labels = [self.model_config.id2label[i] for i in range(len(self.model_config.id2label))] else: @@ -55,28 +55,28 @@ def forward(self,text)-> Dict[str, Any]: Returns: Dict: Model outputs, probabilities, predicted label, and confidence score. """ - # Tokenize the input text + # Tokenizing the input text inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device) # Forward pass outputs = self.model(**inputs) - # Convert logits to probabilities + # Converting logits to probabilities probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) - # Get the predicted sentiment + # to get the predicted sentiment predicted_class = torch.argmax(probabilities, dim=1).item() - # Get the predicted sentiment - # Convert to integer explicitly + + # Converting it to the integer explicitly predicted_class = int(torch.argmax(probabilities, dim=1).item()) - # Add null check + # Adding a null check if self.class_labels is None: raise ValueError("Class labels not available") - # Get the corresponding class label + # Geting the corresponding class label predicted_label = self.class_labels[predicted_class] return { diff --git a/app/models/whisper_model.py b/app/models/whisper_model.py index ca10dd3..c2ea654 100644 --- a/app/models/whisper_model.py +++ b/app/models/whisper_model.py @@ -15,6 +15,7 @@ def __init__(self, config: dict) -> None: :param config: The configuration object containing model and device info. """ self.debug = config.get('debug') + transcription_config = config.get('transcription') if not transcription_config: raise ValueError("'transcription' not found in config") @@ -57,7 +58,7 @@ def forward(self, audio_file: str) -> Dict[str, Any]: chunks = [] - # Extract text and chunks safely + # Extracting the text and chunks safely if isinstance(out, dict): text = out.get("text", "") chunks = out.get("chunks", [])