diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..219f61c --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +env +__pycache__ +*.pyc +static/audio/* +.Ds_Store \ No newline at end of file diff --git a/app/data/sentiment_data.py b/app/data/sentiment_data.py index 0910853..a835ec4 100644 --- a/app/data/sentiment_data.py +++ b/app/data/sentiment_data.py @@ -42,6 +42,22 @@ def analyze(self, text: str) -> tuple: print(f"[error] [Data Layer] [SentimentDataLayer] [analyze] An error occurred during sentiment analysis: {str(e)}") return {'error': f'An unexpected error occurred while processing the request.'} # Generic error message + def analyze_batch(self, texts: list) -> list: + """ + Perform sentiment analysis on a list of texts. + :param texts: List of input texts. + :return: List of dictionaries each with predicted label and confidence. + """ + try: + # Call the batch_forward method on the underlying model + results = self.model.batch_forward(texts) + return results + + except Exception as e: + print(f"[error] [Data Layer] [SentimentDataLayer] [analyze_batch] An error occurred: {str(e)}") + # Return an error for each text in case of failure + return [{"error": "An unexpected error occurred while processing batch request."} for _ in texts] + # if __name__ == "__main__": # config = { diff --git a/app/models/bertweet_model.py b/app/models/bertweet_model.py index 3466394..48025d5 100644 --- a/app/models/bertweet_model.py +++ b/app/models/bertweet_model.py @@ -61,6 +61,41 @@ def forward(self,text)->tuple: predicted_label = self.class_labels[predicted_class] return outputs, probabilities, predicted_label, probabilities[0][predicted_class].item() + + def batch_forward(self, texts: list) -> list: + """ + Perform sentiment analysis on a list of texts in batch. + + Args: + texts (list): List of input texts for sentiment analysis. + + Returns: + list: A list of dictionaries with 'label' and 'confidence' for each text. + """ + # batch_size get it from the configuration + batch_size = self.config.get("batch_size", len(texts)) + + results = [] + # If the number of texts exceeds the batch_size, split them + if len(texts) > batch_size: + for i in range(0, len(texts), batch_size): + sub_texts = texts[i : i + batch_size] + results.extend(self.batch_forward(sub_texts)) + return results + + # Otherwise, process the batch at once. + inputs = self.tokenizer(texts, return_tensors="pt", truncation=True, padding=True).to(self.device) + outputs = self.model(**inputs) + probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) + for i in range(probabilities.size(0)): + predicted_class = torch.argmax(probabilities[i]).item() + predicted_label = self.class_labels[predicted_class] + confidence = probabilities[i][predicted_class].item() + results.append({ + "label": predicted_label, + "confidence": confidence + }) + return results if __name__ == "__main__": diff --git a/app/routes/audio_transcript_sentiment_routes.py b/app/routes/audio_transcript_sentiment_routes.py index 856dd65..ba70acd 100644 --- a/app/routes/audio_transcript_sentiment_routes.py +++ b/app/routes/audio_transcript_sentiment_routes.py @@ -100,7 +100,9 @@ def post(self): # Call the service to perform sentiment analysis on the audio transcript - result = service.process(url = url, start_time_ms = start_time_ms, end_time_ms = end_time_ms) + # result = service.process(url = url, start_time_ms = start_time_ms, end_time_ms = end_time_ms) + result = service.process_batch(url = url, start_time_ms = start_time_ms, end_time_ms = end_time_ms) + if 'error' in result: return { diff --git a/app/services/audio_transcription_sentiment_pipeline.py b/app/services/audio_transcription_sentiment_pipeline.py index 58d3a5e..567fe66 100644 --- a/app/services/audio_transcription_sentiment_pipeline.py +++ b/app/services/audio_transcription_sentiment_pipeline.py @@ -122,7 +122,70 @@ def process(self, url: str, start_time_ms: int, end_time_ms: int = None, user_id except Exception as e: print(f"[error] [Service Layer] [AudioTranscriptionSentimentPipeline] [process] An error occurred during processing: {str(e)}") return {'error': 'An unexpected error occurred while processing the request.'} # Generic error message + + + def process_batch(self, url: str, start_time_ms: int, end_time_ms: int = None, user_id: str = None) -> dict: + """ + Process the Video/Audio file by extracting a segment, transcribing it, and performing sentiment analysis. + :param url: URL or local file path to the audio file. + :param start_time_ms: Start time of the segment to extract (in milliseconds). + :param end_time_ms: End time of the segment to extract (in milliseconds). + :param user_id: (Optional) User ID for creating user-specific subdirectories + :return: Transcription, sentiment analysis, and audio segment details + """ + try: + # Step(1) Extract the audio segment + audio_result = self.audio_service.extract_audio(url, start_time_ms, end_time_ms, user_id) + if isinstance(audio_result, dict) and 'error' in audio_result: + return {'error': audio_result["error"]} + + if self.debug: + print("[debug] [Service Layer] [AudioTranscriptionSentimentPipeline] [process] [audio_result]", audio_result) + # Parse the audio segment details + audio_path = audio_result['audio_path'] + start_time_ms = audio_result['start_time_ms'] + end_time_ms = audio_result['end_time_ms'] + + # Step(2) Transcribe the audio segment + transcription_result = self.transcript_service.transcribe(audio_path) + if isinstance(transcription_result, dict) and 'error' in transcription_result: + return {'error': transcription_result['error']} + + if self.debug: + print("[debug] [Service Layer] [AudioTranscriptionSentimentPipeline] [process] [transcription_result]", transcription_result) + + # Parse the transcription details + transcription = transcription_result['transcription'] + chunks = transcription_result['chunks'] # Each chunk: {'timestamp': (,), 'text': ...} + + # Remove the audio file after processing if needed + if self.remove_audio: + print(f"[debug] [Service Layer] [AudioTranscriptionSentimentPipeline] [process] Removing audio file: {audio_path}") + os.remove(audio_path) + + # Step(3) Batch Sentiment Analysis + texts = [chunk['text'] for chunk in chunks] + batch_results = self.sentiment_service.analyze_batch(texts) + # Map batch results back to each chunk + for i, result in enumerate(batch_results): + if isinstance(result, dict) and 'error' in result: + chunks[i]['error'] = result['error'] + else: + chunks[i]['label'] = result['label'] + chunks[i]['confidence'] = result['confidence'] + + # Return the transcription, sentiment analysis, and audio segment details + return { + 'audio_path': audio_path, + 'start_time_ms': start_time_ms, + 'end_time_ms': end_time_ms, + 'transcription': transcription, + 'utterances_sentiment': chunks, + } + except Exception as e: + print(f"[error] [Service Layer] [AudioTranscriptionSentimentPipeline] [process] An error occurred during processing: {str(e)}") + return {'error': 'An unexpected error occurred while processing the request.'} # if __name__ == "__main__": diff --git a/app/services/sentiment_service.py b/app/services/sentiment_service.py index 3e0a1e4..203dd21 100644 --- a/app/services/sentiment_service.py +++ b/app/services/sentiment_service.py @@ -37,6 +37,21 @@ def analyze(self, text: str) -> tuple: except Exception as e: print(f"[error] [Service Layer] [SentimentService] [analyze] An error occurred during sentiment analysis: {str(e)}") return {'error': f'An unexpected error occurred while processing the request.'} # Generic error message + + def analyze_batch(self, texts: list) -> list: + """ + Perform sentiment analysis on a list of texts. + :param texts: List of input texts. + :return: List of dictionaries each with predicted label and confidence + """ + try: + results = self.sentiment_data_layer.analyze_batch(texts) + return results + + except Exception as e: + print(f"[error] [Service Layer] [SentimentService] [analyze_batch] An error occurred: {str(e)}") + return [{"error": "An unexpected error occurred while processing batch request."} for _ in texts] + # if __name__ == "__main__": diff --git a/config.yaml b/config.yaml index 52c24f4..72aa86e 100644 --- a/config.yaml +++ b/config.yaml @@ -28,6 +28,7 @@ sentiment_analysis: bertweet: # Vader-specific configuration model_name: "finiteautomata/bertweet-base-sentiment-analysis" device: 'cpu' # `cpu` for CPU, or `cuda` GPU device + batch_size: 8 # another_model: # Placeholder for another sentiment analysis model's configuration # api_key: "your_api_key" # endpoint: "https://api.example.com/sentiment"