Skip to content

Feat/dl add emotion detection enhancements#150

Open
uelkerd wants to merge 19 commits into
mainfrom
feat/dl-add-emotion-detection-enhancements
Open

Feat/dl add emotion detection enhancements#150
uelkerd wants to merge 19 commits into
mainfrom
feat/dl-add-emotion-detection-enhancements

Conversation

@uelkerd
Copy link
Copy Markdown
Owner

@uelkerd uelkerd commented Sep 10, 2025

Summary by Sourcery

Add a new SAMO-enhanced BERT-based multi-label emotion detection pipeline including model implementation, label definitions, configuration, and standalone tests

New Features:

  • Introduce SAMOBERTEmotionClassifier with temperature scaling, dropout regularization, layer freezing, and multi-label prediction API
  • Provide WeightedBCELoss and EmotionDataset utilities for training and data handling
  • Add emotion_labels module defining 28 GoEmotions categories with groupings, descriptions, synonyms, and lookup functions
  • Supply a standalone test script to validate model initialization, emotion prediction (including threshold and temperature scaling), and performance across text lengths
  • Include a comprehensive YAML configuration for model, training, evaluation, logging, and optimization parameters

Tests:

  • Add standalone test suite covering model initialization, parameter counts, emotion label utilities, prediction functionality (batch, threshold, and temperature variations), and performance profiling

Summary by CodeRabbit

  • New Features

    • Multi-label BERT emotion detection: two model variants with calibrated temperature scaling, optional layer freezing, pooled/classifier outputs, top-k and thresholded predictions, structured prediction results, batch inference, performance metrics, and save/load support.
    • Comprehensive emotion taxonomy with 28 labels, descriptions, synonyms, and valence/arousal/dominance groupings.
    • YAML-backed, strongly-validated configuration system covering model, training, evaluation, logging, performance, security, and runtime options.
  • Tests

    • End-to-end and enhanced test suites exercising initialization, single/batch predictions, threshold/temperature sweeps, edge cases, error handling, and performance.

- Enhanced BERT-based emotion classifier for journal entries
- Multi-label emotion classification (28 emotions from GoEmotions)
- Temperature scaling for calibrated predictions
- Comprehensive emotion labels and descriptions
- Standalone test script for validation
- Configuration file with SAMO-specific optimizations
- Error handling and logging improvements

Files:
- src/models/emotion_detection/samo_bert_emotion_classifier.py
- src/models/emotion_detection/emotion_labels.py
- configs/samo_emotion_detection_config.yaml
- test_samo_emotion_detection_standalone.py

This completes PR-3 of the surgical breakdown plan.
- Fixed NameError for classifier_dropout_prob and freeze_bert_layers
- Updated constructor to use self.classifier_dropout_prob and self.freeze_bert_layers
- Model now initializes correctly and passes standalone tests
- Maintains 110M parameters with 66M frozen BERT layers

Part of PR-3: Emotion Detection model completion
@uelkerd uelkerd self-assigned this Sep 10, 2025
Copilot AI review requested due to automatic review settings September 10, 2025 10:11
@sourcery-ai
Copy link
Copy Markdown
Contributor

sourcery-ai Bot commented Sep 10, 2025

Reviewer's Guide

This PR introduces a complete emotion detection pipeline featuring a new BERT-based multi-label classifier with calibration and regularization, a support module for emotion labels and groupings, a standalone test harness, and a centralized YAML configuration for all related parameters.

File-Level Changes

Change Details Files
Introduce SAMOBERTEmotionClassifier and related training utilities
  • Define classifier with BERT backbone, dropout, two-layer head, and sigmoid outputs
  • Add temperature scaling, freeze/unfreeze layer support, and parameter counters
  • Implement predict_emotions with threshold and top_k options
  • Add WeightedBCELoss, EmotionDataset classes, factory and evaluation functions
src/models/emotion_detection/samo_bert_emotion_classifier.py
Add emotion_labels module with categorization and helpers
  • Define GOEMOTIONS list and valence/arousal/dominance groupings
  • Provide functions to get index/name, groups, descriptions, and synonyms
  • Add validation and statistics utilities
  • Include main test block
src/models/emotion_detection/emotion_labels.py
Introduce standalone test script
  • Test model initialization, parameter counts, and device assignment
  • Validate label utilities and sample emotion lists
  • Test single and batch predictions, temperature scaling, and threshold sweeps
  • Measure performance on varied text lengths
test_samo_emotion_detection_standalone.py
Add YAML configuration for emotion detection pipeline
  • Centralize model, architecture, training, and evaluation parameters
  • Include logging, performance, and SAMO-specific optimization settings
  • Define error handling, security, and privacy options
configs/samo_emotion_detection_config.yaml

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Sep 10, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a SAMO emotion-detection subsystem: YAML configuration, two config modules/managers, a GoEmotions label utility, two BERT-based multi-label classifiers (SAMO and Enhanced) with training/eval helpers, loss/dataset utilities, save/load and performance hooks, and two end-to-end test scripts.

Changes

Cohort / File(s) Summary of modifications
Top-level YAML config
configs/samo_emotion_detection_config.yaml
New comprehensive YAML configuration covering model, emotion_detection, architecture, training, data, evaluation, logging, model_saving, performance, samo_optimizations, error_handling, security, and development sections with explicit defaults.
Programmatic config & manager
src/models/emotion_detection/enhanced_config.py, src/models/emotion_detection/config.py
New strongly-typed dataclasses and EnhancedConfigManager providing YAML loading/validation/defaults/serialization; plus a simpler singleton-based runtime config API with update/reset and helper getters.
Emotion taxonomy & helpers
src/models/emotion_detection/emotion_labels.py
New GoEmotions-derived label set (28 labels) plus valence/arousal/dominance groupings, intensity levels, descriptions, synonyms, and a case-insensitive public API for lookups, validation, stats, and queries.
SAMO BERT classifier & training utilities
src/models/emotion_detection/samo_bert_emotion_classifier.py
New SAMOBERTEmotionClassifier (BERT backbone, learnable temperature, layer-freezing, pooling options), WeightedBCELoss, EmotionDataset, factory create_samo_bert_emotion_classifier, evaluation routine, and prediction API returning labels/probabilities/metadata.
Enhanced BERT classifier (rich outputs)
src/models/emotion_detection/enhanced_bert_classifier.py
New EnhancedBERTEmotionClassifier with EmotionPrediction dataclass, device/mixed-precision support, inference optimizations, batch/metadata-aware predict_emotions, performance metrics, save/load, and model inspection helpers.
Standalone test suites
test_samo_emotion_detection_standalone.py, test_samo_emotion_detection_enhanced.py
New end-to-end test scripts exercising initialization, parameter reporting, label APIs, single/batch predictions, temperature scaling, threshold sweeps, error handling, performance metrics, and config manager interactions.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as User / App
  participant CFG as Config Manager / YAML
  participant F as Factory
  participant C as Classifier (SAMO/Enhanced)
  participant T as Tokenizer
  participant B as BERT Encoder
  participant H as Classifier Head
  participant P as Post-process

  U->>CFG: load / update config
  CFG-->>F: provide hyperparams, thresholds, device
  F-->>C: init model (bert, head, temp, freeze layers)
  U->>C: predict_emotions(texts, threshold/top_k, batch_size)
  C->>T: tokenize(texts, max_length, truncation/padding)
  T-->>C: input tensors
  C->>B: encode(inputs)
  B-->>C: pooled_output / cls_repr
  C->>H: classifier head -> logits
  H-->>C: logits
  C->>P: apply temperature, sigmoid, threshold/top-k
  P-->>U: labels, probabilities, metadata
  Note over C: optional mixed-precision, caching, performance logging
Loading
sequenceDiagram
  autonumber
  participant Trainer as Training Loop
  participant D as DataLoader / EmotionDataset
  participant M as Model (SAMO/Enhanced)
  participant L as WeightedBCELoss
  participant E as Evaluator

  Trainer->>D: load batches (texts, multi-label targets)
  D-->>Trainer: batched tensors
  Trainer->>M: forward(input tensors)
  M-->>Trainer: logits
  Trainer->>L: compute loss(logits, targets)
  L-->>Trainer: loss scalar
  Trainer->>M: backward() / optimizer step
  Trainer->>E: evaluate_emotion_classifier(model, val_dataloader)
  E-->>Trainer: metrics (precision, recall, f1_micro, f1_macro)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Pre-merge checks (3 passed)

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Feat/dl add emotion detection enhancements" succinctly indicates a feature-level change in the deep-learning area and aligns with the PR’s main intent to add an emotion-detection pipeline (SAMOBERTEmotionClassifier, label taxonomy, configs, and tests). It is concise and topical rather than generic or misleading. The only minor issue is the informal "dl" abbreviation and the slash, which slightly reduce clarity for readers scanning history.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

Poem

A twitch, a hop, I sniff the code,
New labels sprout along the road.
BERT and SAMO hum in tune,
Top-k twinkles like the moon.
Rabbit cheers — predictions bloom! 🐇✨

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/dl-add-emotion-detection-enhancements

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @uelkerd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new, advanced emotion detection capability into the SAMO-DL system. It establishes a foundational framework for analyzing emotional states in text, providing a highly customizable and robust solution for journal entry analysis and similar applications. The changes ensure that the system can accurately identify and categorize a wide range of human emotions.

Highlights

  • New Emotion Detection System: Introduces a complete, enhanced emotion detection system for the SAMO-DL project, leveraging a BERT-based classifier.
  • Comprehensive Emotion Taxonomy: Adds a dedicated module defining 28 emotion categories based on the GoEmotions dataset, including valence, arousal, and dominance groupings, along with descriptions and synonyms.
  • Configurable Model and Training Parameters: Provides a detailed YAML configuration file for the emotion detection model, allowing fine-tuning of BERT architecture, training parameters, data processing, and evaluation metrics.
  • Robust BERT-based Classifier: Implements a multi-label BERT emotion classifier with features like temperature scaling for calibrated predictions, dropout regularization, and configurable BERT layer freezing.
  • Standalone Test Suite: Includes a comprehensive standalone test script to validate the model's initialization, prediction accuracy, batch processing, temperature scaling, and performance across various text lengths.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@deepsource-io
Copy link
Copy Markdown
Contributor

deepsource-io Bot commented Sep 10, 2025

Here's the code health analysis summary for commits 69ec243..ca09ca1. View details on DeepSource ↗.

Analysis Summary

AnalyzerStatusSummaryLink
DeepSource Test coverage LogoTest coverage⚠️ Artifact not reportedTimed out: Artifact was never reportedView Check ↗
DeepSource Python LogoPython❌ Failure
❗ 66 occurences introduced
View Check ↗
DeepSource Terraform LogoTerraform✅ SuccessView Check ↗
DeepSource Secrets LogoSecrets✅ SuccessView Check ↗
DeepSource Shell LogoShell✅ SuccessView Check ↗
DeepSource Docker LogoDocker✅ SuccessView Check ↗

💡 If you’re a repository administrator, you can configure the quality gates from the settings.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a comprehensive emotion detection system for the SAMO-DL project, adding BERT-based emotion classification capabilities with temperature scaling and multi-label support. The system is designed to detect 28 emotion categories (27 GoEmotions + neutral) from journal entries and other textual content.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
test_samo_emotion_detection_standalone.py Standalone test script for comprehensive model validation
src/models/emotion_detection/samo_bert_emotion_classifier.py Core BERT-based emotion classifier implementation
src/models/emotion_detection/emotion_labels.py Emotion labels and categorization system based on GoEmotions
configs/samo_emotion_detection_config.yaml Configuration file for model parameters and training settings

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment thread test_samo_emotion_detection_standalone.py Outdated
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
Copy link
Copy Markdown
Contributor

@sourcery-ai sourcery-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there - I've reviewed your changes and they look great!

Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments

### Comment 1
<location> `src/models/emotion_detection/samo_bert_emotion_classifier.py:177` </location>
<code_context>
+        )
+
+        # Use [CLS] token representation for classification
+        pooled_output = bert_outputs.pooler_output
+
+        # Pass through classification head
</code_context>

<issue_to_address>
Using pooler_output may not be robust for all transformer models.

Some models may not have pooler_output or it could be None. Add a check and consider using the first token's hidden state as a fallback.
</issue_to_address>

### Comment 2
<location> `src/models/emotion_detection/samo_bert_emotion_classifier.py:255` </location>
<code_context>
+                batch_probabilities = probabilities.cpu().numpy()
+
+                # Get emotion names for predictions
+                for pred in batch_predictions:
+                    emotions = [
+                        f"emotion_{i}" for i, p in enumerate(pred) if p > 0
</code_context>

<issue_to_address>
Emotion names are generic and not mapped to actual labels.

Consider mapping prediction indices to actual emotion labels from the emotion_labels module to improve interpretability.

Suggested implementation:

```python
                # Get emotion names for predictions
                from emotion_labels import EMOTION_LABELS  # Make sure this is the correct import path and variable name
                for pred in batch_predictions:
                    emotions = [
                        EMOTION_LABELS[i] for i, p in enumerate(pred) if p > 0
                    ]
                    all_emotions.append(emotions)

```

- Ensure that the `emotion_labels` module is available and contains a list or tuple named `EMOTION_LABELS` that maps indices to emotion names.
- If the import path or variable name is different, adjust `from emotion_labels import EMOTION_LABELS` accordingly.
</issue_to_address>

### Comment 3
<location> `src/models/emotion_detection/samo_bert_emotion_classifier.py:435` </location>
<code_context>
+    model: SAMOBERTEmotionClassifier,
+    dataloader: DataLoader,
+    device: torch.device,
+    threshold: float = 0.2,  # Lowered from 0.5 to capture more predictions
+) -> Dict[str, float]:
+    """
</code_context>

<issue_to_address>
Evaluation threshold is hardcoded and may not match training threshold.

Recommend sourcing the evaluation threshold from configuration or matching it to the training threshold for consistency.

Suggested implementation:

```python
from config import EMOTION_CLASSIFICATION_THRESHOLD

def evaluate_emotion_classifier(
    model: SAMOBERTEmotionClassifier,
    dataloader: DataLoader,
    device: torch.device,
) -> Dict[str, float]:
    """

```

```python
    # ... inside the function ...
    preds = (outputs > EMOTION_CLASSIFICATION_THRESHOLD).float()

```

If you do not already have a configuration module or object, you will need to create one (e.g., `config.py`) and ensure the training code also uses `EMOTION_CLASSIFICATION_THRESHOLD` for consistency. You may also need to update any calls to `evaluate_emotion_classifier` to remove the `threshold` argument.
</issue_to_address>

### Comment 4
<location> `test_samo_emotion_detection_standalone.py:50` </location>
<code_context>
+def test_emotion_predictions(model, all_emotions):
</code_context>

<issue_to_address>
Consider adding tests for edge cases such as empty strings, non-English text, and texts with ambiguous or mixed emotions.

Please include tests for empty input, long texts, ambiguous emotion, non-English text, and mixed emotions to better assess model robustness.
</issue_to_address>

### Comment 5
<location> `test_samo_emotion_detection_standalone.py:87` </location>
<code_context>
+            print(f"     {emotion_name}: {prob:.3f}")
+
+
+def test_batch_predictions(model, test_texts):
+    """Test batch prediction functionality."""
+    print("\n5. Testing batch prediction...")
+    batch_results = model.predict_emotions(test_texts[:3], threshold=0.3)
+    
+    print(f"   Batch size: {len(batch_results['emotions'])}")
+    print(f"   All predictions successful: {len(batch_results['emotions']) == 3}")
+
+
</code_context>

<issue_to_address>
Add assertions to batch prediction tests to automatically verify expected output structure and values.

Please add assertions to verify output structure, prediction count, non-empty emotion lists, and valid probability ranges, rather than relying on print statements.
</issue_to_address>

<suggested_fix>
<<<<<<< SEARCH
def test_batch_predictions(model, test_texts):
    """Test batch prediction functionality."""
    print("\n5. Testing batch prediction...")
    batch_results = model.predict_emotions(test_texts[:3], threshold=0.3)

    print(f"   Batch size: {len(batch_results['emotions'])}")
    print(f"   All predictions successful: {len(batch_results['emotions']) == 3}")
=======
def test_batch_predictions(model, test_texts):
    """Test batch prediction functionality with assertions."""
    batch_results = model.predict_emotions(test_texts[:3], threshold=0.3)

    # Assert output structure
    assert isinstance(batch_results, dict), "Batch results should be a dictionary."
    assert "emotions" in batch_results, "'emotions' key missing in batch results."
    assert isinstance(batch_results["emotions"], list), "'emotions' should be a list."

    # Assert prediction count
    assert len(batch_results["emotions"]) == 3, "Batch size should be 3."

    # Assert non-empty emotion lists and valid probability ranges
    for i, emotion_list in enumerate(batch_results["emotions"]):
        assert isinstance(emotion_list, list), f"Prediction {i} should be a list of emotions."
        assert len(emotion_list) > 0, f"Prediction {i} should not be empty."
        for emotion in emotion_list:
            assert "name" in emotion, f"Emotion dict missing 'name' key in prediction {i}."
            assert "probability" in emotion, f"Emotion dict missing 'probability' key in prediction {i}."
            prob = emotion["probability"]
            assert 0.0 <= prob <= 1.0, f"Probability {prob} out of range in prediction {i}."
>>>>>>> REPLACE

</suggested_fix>

### Comment 6
<location> `test_samo_emotion_detection_standalone.py:113` </location>
<code_context>
+    print(f"   Hot temperature (2.0): {len(results_hot['emotions'][0])} emotions")
+
+
+def test_prediction_thresholds(model):
+    """Test different prediction thresholds."""
+    print("\n7. Testing different prediction thresholds...")
+    test_text = "I feel both happy and sad about this situation."
+    
+    for threshold in [0.1, 0.3, 0.5, 0.7]:
+        results = model.predict_emotions(test_text, threshold=threshold)
+        emotions = results['emotions'][0]
+        print(f"   Threshold {threshold}: {len(emotions)} emotions - {emotions}")
+
+
</code_context>

<issue_to_address>
Add assertions to verify that increasing the threshold reduces the number of detected emotions.

Please add assertions to ensure that increasing the threshold does not result in more detected emotions, confirming the expected behavior.
</issue_to_address>

<suggested_fix>
<<<<<<< SEARCH
def test_prediction_thresholds(model):
    """Test different prediction thresholds."""
    print("\n7. Testing different prediction thresholds...")
    test_text = "I feel both happy and sad about this situation."

    for threshold in [0.1, 0.3, 0.5, 0.7]:
        results = model.predict_emotions(test_text, threshold=threshold)
        emotions = results['emotions'][0]
        print(f"   Threshold {threshold}: {len(emotions)} emotions - {emotions}")
=======
def test_prediction_thresholds(model):
    """Test different prediction thresholds."""
    print("\n7. Testing different prediction thresholds...")
    test_text = "I feel both happy and sad about this situation."

    thresholds = [0.1, 0.3, 0.5, 0.7]
    num_emotions = []
    for threshold in thresholds:
        results = model.predict_emotions(test_text, threshold=threshold)
        emotions = results['emotions'][0]
        num_emotions.append(len(emotions))
        print(f"   Threshold {threshold}: {len(emotions)} emotions - {emotions}")

    # Assert that increasing the threshold does not increase the number of detected emotions
    for i in range(1, len(num_emotions)):
        assert num_emotions[i] <= num_emotions[i-1], (
            f"Number of emotions at threshold {thresholds[i]} ({num_emotions[i]}) "
            f"should not be greater than at threshold {thresholds[i-1]} ({num_emotions[i-1]})"
        )
>>>>>>> REPLACE

</suggested_fix>

### Comment 7
<location> `test_samo_emotion_detection_standalone.py:169` </location>
<code_context>
+def test_performance():
</code_context>

<issue_to_address>
Add assertions or checks for performance test to ensure model handles very long texts and edge cases gracefully.

Add assertions to verify the model does not crash, timeout, or produce invalid results with very long, short, or empty texts.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
batch_probabilities = probabilities.cpu().numpy()

# Get emotion names for predictions
for pred in batch_predictions:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Emotion names are generic and not mapped to actual labels.

Consider mapping prediction indices to actual emotion labels from the emotion_labels module to improve interpretability.

Suggested implementation:

                # Get emotion names for predictions
                from emotion_labels import EMOTION_LABELS  # Make sure this is the correct import path and variable name
                for pred in batch_predictions:
                    emotions = [
                        EMOTION_LABELS[i] for i, p in enumerate(pred) if p > 0
                    ]
                    all_emotions.append(emotions)
  • Ensure that the emotion_labels module is available and contains a list or tuple named EMOTION_LABELS that maps indices to emotion names.
  • If the import path or variable name is different, adjust from emotion_labels import EMOTION_LABELS accordingly.

Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
Comment thread test_samo_emotion_detection_standalone.py
Comment thread test_samo_emotion_detection_standalone.py
Comment thread test_samo_emotion_detection_standalone.py Outdated
Comment thread test_samo_emotion_detection_standalone.py Outdated
Comment thread test_samo_emotion_detection_standalone.py Outdated
Comment thread src/models/emotion_detection/emotion_labels.py Outdated
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive emotion detection feature, including a new BERT-based classifier, configuration, emotion label definitions, and standalone tests. The implementation is solid, with a well-structured configuration and detailed emotion label definitions. My review focuses on improving the model's implementation for better numerical stability, more interpretable outputs, and enhanced maintainability. I've also suggested a fix for the test script to correctly interpret the model's output.

Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
Comment thread test_samo_emotion_detection_standalone.py Outdated
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
configs/samo_emotion_detection_config.yaml (1)

1-187: Add YAML document start marker
Insert --- on the first line of configs/samo_emotion_detection_config.yaml to satisfy the document-start rule in yamllint.

🧹 Nitpick comments (20)
src/models/emotion_detection/emotion_labels.py (7)

1-1: Remove shebang or make file executable.

This is a library module, not a script. Drop the shebang to satisfy EXE001.

-#!/usr/bin/env python3

12-12: Drop unused import.

Tuple is unused.

-from typing import List, Dict, Tuple
+from typing import List, Dict

14-44: Make labels immutable and prepare O(1) lookups.

Use a tuple for immutability and add an index map for constant-time get_emotion_index.

-# GoEmotions emotion categories (27 emotions + neutral = 28 total)
-GOEMOTIONS_EMOTIONS = [
+# GoEmotions emotion categories (27 emotions + neutral = 28 total)
+GOEMOTIONS_EMOTIONS = (
     "admiration",      # 0
     "amusement",       # 1
     "anger",          # 2
     "annoyance",      # 3
     "approval",       # 4
     "caring",         # 5
     "confusion",      # 6
     "curiosity",      # 7
     "desire",         # 8
     "disappointment", # 9
     "disapproval",    # 10
     "disgust",        # 11
     "embarrassment",  # 12
     "excitement",     # 13
     "fear",           # 14
     "gratitude",      # 15
     "grief",          # 16
     "joy",            # 17
     "love",           # 18
     "nervousness",    # 19
     "optimism",       # 20
     "pride",          # 21
     "realization",    # 22
     "relief",         # 23
     "remorse",        # 24
     "sadness",        # 25
     "surprise",       # 26
     "neutral",        # 27
-]
+)
+
+# Name → index map for fast lookup
+EMOTION_TO_INDEX: Dict[str, int] = {name: i for i, name in enumerate(GOEMOTIONS_EMOTIONS)}

199-203: Tighten the IndexError message (TRY003).

Shorter, standard message.

-    else:
-        raise IndexError(f"Index {index} out of range for GoEmotions list")
+    else:
+        raise IndexError("emotion index out of range")

215-216: Make group lookups case-insensitive.

Aligns with case-insensitive emotion lookups elsewhere.

-    return EMOTION_VALENCE_GROUPS.get(valence, [])
+    return EMOTION_VALENCE_GROUPS.get(valence.lower(), [])
-    return EMOTION_AROUSAL_GROUPS.get(arousal, [])
+    return EMOTION_AROUSAL_GROUPS.get(arousal.lower(), [])
-    return EMOTION_DOMINANCE_GROUPS.get(dominance, [])
+    return EMOTION_DOMINANCE_GROUPS.get(dominance.lower(), [])

Also applies to: 228-229, 241-242


337-337: Remove unnecessary f-string.

Fixes F541.

-    print(f"\nEmotion descriptions:")
+    print("\nEmotion descriptions:")

270-277: Return a tuple for get_all_emotions to reflect immutability.

Optional, but consistent with making the constant a tuple.

-def get_all_emotions() -> List[str]:
+def get_all_emotions() -> List[str]:
@@
-    return GOEMOTIONS_EMOTIONS.copy()
+    return list(GOEMOTIONS_EMOTIONS)
configs/samo_emotion_detection_config.yaml (3)

11-22: Thresholds: confirm intent.

You set prediction_threshold to 0.6 (inference) but evaluation.threshold to 0.2. If unintentional, harmonize; if intentional, add a brief comment noting the rationale.

 emotion_detection:
   # Number of emotion categories (27 GoEmotions + neutral)
   num_emotions: 28
-  
   # Prediction threshold for binary classification
-  prediction_threshold: 0.6  # Updated from 0.5 for better calibration
+  prediction_threshold: 0.6  # Inference default. Keep distinct from evaluation.threshold if desired.

39-59: Add training seed for reproducibility.

Helps stabilize experiments.

 training:
@@
   warmup_steps: 100
+  seed: 42

113-127: Optionally enable deterministic cuDNN toggle.

Only if you care about exact reproducibility across runs.

 performance:
@@
   gradient_checkpointing: false  # Can be enabled for memory savings
+  cudnn_deterministic: true      # Optional: reproducibility over speed
test_samo_emotion_detection_standalone.py (7)

1-1: Remove shebang or mark file executable.

For a test script run via python, the shebang isn’t needed (EXE001).

-#!/usr/bin/env python3

13-15: Harden sys.path injection.

Resolve absolute path and avoid duplicate entries.

-# Add src to path
-sys.path.insert(0, str(Path(__file__).parent / "src"))
+# Add src to path
+src_path = str((Path(__file__).resolve().parent / "src"))
+if src_path not in sys.path:
+    sys.path.insert(0, src_path)

72-76: Show human-readable emotion names, not emotion_{i}.

predict_emotions currently returns generic label ids; map to actual names for display.

-        emotions = results['emotions'][0]
+        pred_vec = results['predictions'][0]
+        emotions = [all_emotions[i] for i, p in enumerate(pred_vec) if p > 0]

If you prefer to fix this at the source, update predict_emotions to map indices via emotion_labels.get_emotion_name(i).


188-189: Move time import to top-level.

Minor cleanliness.

-            import time

Add near the other imports:

+import time

155-168: Avoid catching blind Exception (BLE001).

Either let failures surface or narrow the exceptions.

-    try:
+    try:
         model, all_emotions, trainable_params = run_all_tests()
@@
-    except Exception as e:
+    except (RuntimeError, OSError, ValueError) as e:
         print(f"❌ Error testing emotion classifier: {e}")
         import traceback
         traceback.print_exc()
         raise

200-202: Same here: narrow the except clause.

Consistent with above.

-    except Exception as e:
+    except (RuntimeError, OSError, ValueError) as e:
         print(f"❌ Error in performance test: {e}")

50-63: Sanity-check thresholds vs config.

Tests hardcode threshold=0.3 while config default is 0.6. If you want tests to reflect defaults, consider pulling from config or a single source of truth.

src/models/emotion_detection/samo_bert_emotion_classifier.py (3)

86-86: Document the calibrated threshold value

The prediction threshold is hardcoded to 0.6 with a comment mentioning it's "based on calibration", but there's no documentation about how this calibration was performed or why this specific value was chosen.

Consider adding more detailed documentation or making this configurable:

-        self.prediction_threshold = 0.6  # Updated from 0.5 to 0.6 based on calibration
+        # Default prediction threshold calibrated on validation data
+        # Higher threshold (0.6) reduces false positives for emotion detection
+        self.prediction_threshold = config.get("prediction_threshold", 0.6)

244-248: Optimize top-k selection logic

The current implementation creates a zero tensor and then scatters ones, which is inefficient. Consider using a more direct approach.

                 # Get top-k if specified
                 if top_k is not None:
-                    _, top_k_indices = torch.topk(probabilities, top_k, dim=1)
-                    predictions = torch.zeros_like(probabilities)
-                    predictions.scatter_(1, top_k_indices, 1.0)
+                    _, top_k_indices = torch.topk(probabilities, min(top_k, probabilities.size(1)), dim=1)
+                    predictions = torch.zeros_like(probabilities)
+                    predictions.scatter_(1, top_k_indices, 1.0)

1-1: Remove shebang from non-executable module file

The shebang line is present but the file is not marked as executable. Since this is a module file that will be imported rather than executed directly, the shebang should be removed.

-#!/usr/bin/env python3
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 69ec243 and 56a5d15.

📒 Files selected for processing (4)
  • configs/samo_emotion_detection_config.yaml (1 hunks)
  • src/models/emotion_detection/emotion_labels.py (1 hunks)
  • src/models/emotion_detection/samo_bert_emotion_classifier.py (1 hunks)
  • test_samo_emotion_detection_standalone.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
test_samo_emotion_detection_standalone.py (2)
src/models/emotion_detection/samo_bert_emotion_classifier.py (5)
  • create_samo_bert_emotion_classifier (388-428)
  • count_parameters (275-277)
  • count_frozen_parameters (279-281)
  • predict_emotions (187-268)
  • set_temperature (270-273)
src/models/emotion_detection/emotion_labels.py (2)
  • get_all_emotions (270-277)
  • get_emotion_description (244-254)
🪛 Ruff (0.12.2)
src/models/emotion_detection/emotion_labels.py

1-1: Shebang is present but file is not executable

(EXE001)


183-183: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


183-183: Avoid specifying long messages outside the exception class

(TRY003)


202-202: Avoid specifying long messages outside the exception class

(TRY003)


337-337: f-string without any placeholders

Remove extraneous f prefix

(F541)

test_samo_emotion_detection_standalone.py

1-1: Shebang is present but file is not executable

(EXE001)


200-200: Do not catch blind exception: Exception

(BLE001)

src/models/emotion_detection/samo_bert_emotion_classifier.py

1-1: Shebang is present but file is not executable

(EXE001)


190-190: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

🪛 YAMLlint (1.37.1)
configs/samo_emotion_detection_config.yaml

[error] 13-13: trailing spaces

(trailing-spaces)


[error] 16-16: trailing spaces

(trailing-spaces)


[error] 19-19: trailing spaces

(trailing-spaces)


[error] 28-28: trailing spaces

(trailing-spaces)


[error] 31-31: trailing spaces

(trailing-spaces)


[error] 34-34: trailing spaces

(trailing-spaces)


[error] 43-43: trailing spaces

(trailing-spaces)


[error] 47-47: trailing spaces

(trailing-spaces)


[error] 51-51: trailing spaces

(trailing-spaces)


[error] 55-55: trailing spaces

(trailing-spaces)


[error] 66-66: trailing spaces

(trailing-spaces)


[error] 69-69: trailing spaces

(trailing-spaces)


[error] 83-83: trailing spaces

(trailing-spaces)


[error] 86-86: trailing spaces

(trailing-spaces)


[error] 96-96: trailing spaces

(trailing-spaces)


[error] 105-105: trailing spaces

(trailing-spaces)


[error] 108-108: trailing spaces

(trailing-spaces)


[error] 117-117: trailing spaces

(trailing-spaces)


[error] 121-121: trailing spaces

(trailing-spaces)


[error] 124-124: trailing spaces

(trailing-spaces)


[error] 132-132: trailing spaces

(trailing-spaces)


[error] 135-135: trailing spaces

(trailing-spaces)


[error] 138-138: trailing spaces

(trailing-spaces)


[error] 141-141: trailing spaces

(trailing-spaces)


[error] 150-150: trailing spaces

(trailing-spaces)


[error] 154-154: trailing spaces

(trailing-spaces)


[error] 163-163: trailing spaces

(trailing-spaces)


[error] 166-166: trailing spaces

(trailing-spaces)


[error] 169-169: trailing spaces

(trailing-spaces)


[error] 177-177: trailing spaces

(trailing-spaces)


[error] 180-180: trailing spaces

(trailing-spaces)


[error] 183-183: trailing spaces

(trailing-spaces)

🔇 Additional comments (4)
configs/samo_emotion_detection_config.yaml (1)

160-172: Verify feature flags are implemented.

Keys like sanitize_input, filter_sensitive_emotions, and anonymize_predictions must be honored by the pipeline; otherwise drop them from config to avoid drift.

Would you like a quick repo scan script to report where these flags are read?

test_samo_emotion_detection_standalone.py (1)

96-111: Guard temperature usage.

If the model lacks a learnable temperature, these calls will fail. Add a simple guard.

-    original_temp = model.temperature.item()
+    original_temp = getattr(model, "temperature", None)
+    if original_temp is None:
+        print("   Temperature scaling not available; skipping.")
+        return
+    original_temp = model.temperature.item()
src/models/emotion_detection/samo_bert_emotion_classifier.py (2)

489-520: LGTM! Well-structured test code

The main block provides a good testing framework with clear output and proper error handling. The test demonstrates the core functionality effectively.


383-383: Fix potential None value handling for token_type_ids

The current implementation may cause issues if token_type_ids is not present in the encoding. Using torch.zeros_like on the entire encoding dictionary will fail.

-            "token_type_ids": encoding.get("token_type_ids", torch.zeros_like(encoding["input_ids"])).squeeze(0),
+            "token_type_ids": encoding["token_type_ids"].squeeze(0) if "token_type_ids" in encoding else torch.zeros_like(encoding["input_ids"]).squeeze(0),

Likely an incorrect or invalid review comment.

Comment thread src/models/emotion_detection/emotion_labels.py
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
uelkerd and others added 6 commits September 11, 2025 00:01
- Fix pooler_output robustness with fallback to first token hidden state
- Map prediction indices to actual emotion labels from GOEMOTIONS_EMOTIONS
- Create centralized configuration module for threshold management
- Add comprehensive edge case testing (empty strings, non-English, mixed emotions)
- Add assertions to batch prediction tests for output structure validation
- Add threshold behavior assertions to verify decreasing emotion count
- Add performance test assertions for processing time and error handling

All tests passing with enhanced robustness and validation.
Resolved issues in the following files with DeepSource Autofix:
1. src/models/emotion_detection/config.py
2. src/models/emotion_detection/emotion_labels.py
3. src/models/emotion_detection/enhanced_bert_classifier.py
4. src/models/emotion_detection/enhanced_config.py
5. src/models/emotion_detection/samo_bert_emotion_classifier.py
6. test_samo_emotion_detection_enhanced.py
7. test_samo_emotion_detection_standalone.py
- Use F.binary_cross_entropy_with_logits for numerical stability in BCE loss
- Fix return type hint for predict_emotions to reflect batch structure
- Simplify config dictionary to only override necessary parameters
- Remove redundant default values from create_samo_bert_emotion_classifier

All tests passing with improved numerical stability and cleaner code.
…om:uelkerd/SAMO--DL into feat/dl-add-emotion-detection-enhancements
- Extract hardcoded test texts into module-level constant SAMPLE_TEST_TEXTS
- Fix inefficient token_type_ids handling in EmotionDataset to use None instead of zeros
- Remove code duplication and improve maintainability
- BERT models can handle missing token type IDs efficiently

All tests passing with cleaner, more maintainable code.
- Replace unused variable with underscore to indicate intentional non-use
- Fix PYL-W0612 linting error for unused variable
- Add explanatory comment for clarity
- All tests still passing with cleaner code
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

♻️ Duplicate comments (4)
test_samo_emotion_detection_standalone.py (1)

219-234: Deduplicate hardcoded test texts across files.

Extract common sample texts to a module-level constant (or a helper module) and reuse in both standalone and enhanced tests to reduce drift.

src/models/emotion_detection/samo_bert_emotion_classifier.py (3)

328-331: Fix potential device/broadcast mismatch when applying class weights.

Multiplying CPU weights with GPU loss tensors will raise, and unsqueeze may not broadcast as intended.

         # Apply class weights if provided
         if self.class_weights is not None:
-            bce_loss = bce_loss * self.class_weights.unsqueeze(0)
+            if self.class_weights.device != bce_loss.device:
+                self.class_weights = self.class_weights.to(bce_loss.device)
+            bce_loss = bce_loss * self.class_weights.view(1, -1).expand_as(bce_loss)

320-336: Use BCEWithLogits for numerical stability.

Avoid explicit sigmoid before BCE; it can cause instability and gradient issues.

-        # Apply sigmoid to get probabilities
-        probabilities = torch.sigmoid(logits)
-
-        # Compute BCE loss
-        bce_loss = F.binary_cross_entropy(
-            probabilities, targets.float(), reduction="none"
-        )
+        # Compute BCE with logits directly for numerical stability
+        bce_loss = F.binary_cross_entropy_with_logits(
+            logits, targets.float(), reduction="none"
+        )

192-199: Fix type hints: threshold should be Optional and return types are nested lists.

Current hints are inaccurate and may mislead consumers and tooling.

     def predict_emotions(
         self,
         texts: Union[str, List[str]],
-        threshold: float = None,
+        threshold: Optional[float] = None,
         top_k: Optional[int] = None,
         batch_size: int = 32,
-    ) -> Dict[str, Union[List[str], List[float], List[List[int]]]]:
+    ) -> Dict[str, Union[List[List[str]], List[List[float]], List[List[float]]]]:
🧹 Nitpick comments (31)
src/models/emotion_detection/config.py (4)

1-1: Remove non-executable shebang or make the file executable.

The module isn’t meant to be invoked directly; drop the shebang to satisfy linters.

-#!/usr/bin/env python3

71-75: from_dict expects flattened keys only.

Callers will likely pass nested dicts (model., training.). Either document this clearly or support nested input. I recommend minimal nested support.

 @classmethod
 def from_dict(cls, config_dict: Dict[str, Any]) -> 'EmotionDetectionConfig':
-    """Create config from dictionary."""
-    return cls(**config_dict)
+    """Create config from dictionary."""
+    flat = {}
+    # allow nested { "model": {"name": ...}, ... }
+    for k, v in config_dict.items():
+        if isinstance(v, dict):
+            flat.update(v)
+        else:
+            flat[k] = v
+    return cls(**flat)

101-111: update_config replaces state instead of merging with current.

Current behavior discards prior runtime changes. Merge into existing _config instead.

 def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig:
     """Get configuration from dictionary with defaults."""
-    default_config = get_default_config()
+    default_config = get_config()

122-126: Make update_config merge rather than reset.

Preserve current values when partial updates come in.

 def update_config(config_dict: Dict[str, Any]) -> None:
     """Update global configuration."""
     global _config
-    _config = get_config_from_dict(config_dict)
+    _config = get_config_from_dict(config_dict)
src/models/emotion_detection/enhanced_bert_classifier.py (5)

392-403: Report calibrated temperature in metadata.

If you switch to softplus in forward, match here too.

-                "temperature": float(self.temperature.item()),
+                "temperature": float(F.softplus(self.temperature).item()),

429-441: Empty prediction is internally inconsistent.

top_k_emotions shows 1.0 while all scores are 0 and confidence is 0. Return empty top_k or use consistent zeros.

         return EmotionPrediction(
             emotions=emotions,
             primary_emotion="neutral",
             confidence=0.0,
             emotional_intensity="very_low",
-            top_k_emotions=[("neutral", 1.0)],
+            top_k_emotions=[],
             prediction_metadata=metadata
         )

135-136: Use logging.exception() in except blocks and avoid long externalized messages.

Improve traceability and satisfy linters.

-            logger.error(f"Failed to load BERT model: {e}")
+            logger.exception("Failed to load BERT model")
@@
-            logger.exception("Forward pass failed")
+            logger.exception("Forward pass failed")
@@
-                logger.error(f"Failed to load tokenizer: {e}")
+                logger.exception("Failed to load tokenizer")
@@
-            logger.error(f"Failed to save model: {e}")
+            logger.exception("Failed to save model")
@@
-            logger.error(f"Failed to load model: {e}")
+            logger.exception("Failed to load model")

Also applies to: 207-212, 450-452, 517-519, 533-535


459-491: Minor: remove unused imports and members or wire them up.

sklearn metrics, Dataset/DataLoader, warnings aren’t used; class_weights is accepted but unused in loss. Either remove or integrate.


229-246: Consider honoring prediction_threshold in outputs.

You expose prediction_threshold but don’t use it to filter emotions. Either document it as metadata-only or apply it to zero out sub-threshold emotions (post top-k).

src/models/emotion_detection/enhanced_config.py (3)

217-224: Prefer logging.exception and narrower except for YAML load.

Gives stack traces and avoids blind catches.

-        except yaml.YAMLError as e:
-            logger.error(f"YAML parsing error: {e}")
+        except yaml.YAMLError:
+            logger.exception("YAML parsing error")
             logger.warning("Using default configuration due to YAML error")
             return self._create_default_config()
-        except Exception as e:
-            logger.error(f"Configuration loading failed: {e}")
+        except (OSError, ValueError):
+            logger.exception("Configuration loading failed")
             logger.warning("Using default configuration due to loading error")
             return self._create_default_config()

537-556: Save the full configuration, not just a subset.

Serializing only two sections drops user changes. Serialize all dataclasses (asdict).

-    def _config_to_dict(self) -> Dict[str, Any]:
-        """Convert configuration to dictionary."""
-        # This would need proper serialization logic
-        # For now, return a basic structure
-        return {
-            "model": {
-                "name": self.config.model.name,
-                "device": self.config.model.device,
-                "use_mixed_precision": self.config.model.use_mixed_precision,
-                "cache_embeddings": self.config.model.cache_embeddings,
-                "max_sequence_length": self.config.model.max_sequence_length,
-            },
-            "emotion_detection": {
-                "num_emotions": self.config.emotion_detection.num_emotions,
-                "prediction_threshold": self.config.emotion_detection.prediction_threshold,
-                "temperature": self.config.emotion_detection.temperature,
-                "top_k": self.config.emotion_detection.top_k,
-            },
-            # Add other sections as needed
-        }
+    def _config_to_dict(self) -> Dict[str, Any]:
+        """Convert configuration to dictionary."""
+        from dataclasses import asdict
+        return asdict(self.config)

397-404: Security config is defined but unused.

If sanitize_input is expected, integrate with your serving layer (e.g., deployment/cloud-run/secure_api_server.sanitize_input) or drop the option.

test_samo_emotion_detection_enhanced.py (6)

15-17: Make sys.path insertion robust to different test locations.

Resolve the path and avoid relying on the current working directory.

-# Add src to path for standalone testing
-sys.path.insert(0, str(Path(__file__).parent / "src"))
+# Add src to path for standalone testing
+repo_root = Path(__file__).resolve().parent
+sys.path.insert(0, str((repo_root / "src").resolve()))

25-40: Avoid blind except and prefer try/except/else pattern.

Catching Exception hides real failures and triggers linter BLE001/TRY300. Narrow the exceptions and move success prints to else.

 def test_model_initialization():
     """Test enhanced model initialization."""
     print("1. Initializing Enhanced SAMO BERT Emotion Classifier...")
-    try:
-        model = EnhancedBERTEmotionClassifier(
+    try:
+        model = EnhancedBERTEmotionClassifier(
             model_name="bert-base-uncased",
             num_emotions=28,
             use_mixed_precision=True,
             cache_embeddings=True
         )
-        print("✅ Enhanced classifier initialized successfully")
-        return model
-    except Exception as e:
+    except (OSError, ValueError, RuntimeError) as e:
         print(f"❌ Model initialization failed: {e}")
         return None
+    else:
+        print("✅ Enhanced classifier initialized successfully")
+        return model

Apply the same pattern to similar try blocks below (Lines 42-57, 59-75, 118-147, 152-175, 177-203, 205-224, 226-287).


130-147: Strengthen batch prediction test with minimal assertions.

Add basic structural assertions to prevent silent regressions.

         predictions = model.predict_emotions(
             batch_texts, 
             top_k=2, 
             return_metadata=True,
             batch_size=2
         )
         end_time = time.time()
         
         print(f"   Processed {len(batch_texts)} texts in {end_time - start_time:.3f}s")
-        
-        for i, prediction in enumerate(predictions):
+        assert isinstance(predictions, list) and len(predictions) == len(batch_texts), "Batch size mismatch"
+        for i, prediction in enumerate(predictions):
+            assert hasattr(prediction, "top_k_emotions") and len(prediction.top_k_emotions) >= 2, "Missing top-k"
+            assert 0.0 <= prediction.confidence <= 1.0, "Confidence out of range"
             print(f"   Text {i+1}: {prediction.primary_emotion} ({prediction.confidence:.3f})")

152-171: Assert performance metrics invariants.

Validate keys and basic ranges to catch metric accumulation issues early.

         metrics = model.get_performance_metrics()
-        
+        required = {"total_inferences","total_inference_time","average_inference_time","error_count","error_rate","last_error"}
+        assert required.issubset(metrics.keys()), "Missing performance metric keys"
+        assert metrics["total_inferences"] >= 1, "No inferences recorded"
+        assert 0.0 <= metrics["error_rate"] <= 1.0, "Invalid error rate"
         print(f"   Total inferences: {metrics['total_inferences']}")
         print(f"   Average inference time: {metrics['average_inference_time']:.3f}s")
         print(f"   Error count: {metrics['error_count']}")
         print(f"   Error rate: {metrics['error_rate']:.3f}")

177-200: Catch specific exceptions and add light assertions in error-handling test.

Limit exception type and ensure outputs look sane when no error is raised.

     try:
         # Test with None input
         try:
             model.predict_emotions(None)
-        except Exception as e:
+        except (TypeError, ValueError, RuntimeError) as e:
             print(f"   ✅ Handled None input: {type(e).__name__}")
         
         # Test with very long text
         very_long_text = "This is a test. " * 1000
         prediction = model.predict_emotions(very_long_text)
+        assert prediction.primary_emotion and 0.0 <= prediction.confidence <= 1.0
         print(f"   ✅ Handled very long text: {len(very_long_text)} chars")
         
         # Test with special characters
         special_text = "I'm feeling 😊🎉 excited! @#$%^&*()"
         prediction = model.predict_emotions(special_text)
+        assert prediction.primary_emotion and 0.0 <= prediction.confidence <= 1.0
         print(f"   ✅ Handled special characters: {prediction.primary_emotion}")

1-1: Shebang without executable bit.

Either remove the shebang or mark the file executable (git chmod +x) to satisfy EXE001.

test_samo_emotion_detection_standalone.py (5)

13-16: Harden sys.path handling for portability.

Resolve the absolute path to src to avoid surprises when invoked from other CWDs.

-# Add src to path
-sys.path.insert(0, str(Path(__file__).parent / "src"))
+# Add src to path
+repo_root = Path(__file__).resolve().parent
+sys.path.insert(0, str((repo_root / "src").resolve()))

72-87: Empty-text branch: narrow exception type and assert behavior.

Catching Exception masks unrelated failures. Add a minimal assertion to document expected behavior.

-        if not text.strip():
+        if not text.strip():
             print("   (Empty text - testing edge case)")
             try:
                 results = model.predict_emotions(text, threshold=0.3, top_k=3)
                 emotions = results['emotions'][0]
                 probabilities = results['probabilities'][0]
+                assert isinstance(emotions, list), "Emotions should be list for empty input"
                 print(f"   Detected emotions: {emotions}")
                 print("   ✅ Empty text handled gracefully")
-            except Exception as e:
+            except (ValueError, RuntimeError) as e:
                 print(f"   ❌ Empty text caused error: {e}")
             continue

170-185: Temperature scaling: add a simple monotonicity check.

Lower temperature should not increase entropy; check that cold yields ≥ confidence or ≤ count than hot under a fixed threshold.

     model.set_temperature(0.5)  # Lower temperature = more confident
     results_cold = model.predict_emotions("I am very happy!", threshold=0.3)
     
     model.set_temperature(2.0)  # Higher temperature = less confident
     results_hot = model.predict_emotions("I am very happy!", threshold=0.3)
     
     model.set_temperature(original_temp)  # Reset
     
     print(f"   Cold temperature (0.5): {len(results_cold['emotions'][0])} emotions")
     print(f"   Hot temperature (2.0): {len(results_hot['emotions'][0])} emotions")
+    assert len(results_cold['emotions'][0]) <= len(results_hot['emotions'][0]), "Cold temperature should not increase detections"

255-306: Performance threshold may be flaky on CPU.

The fixed 10s limit for all cases can fail on CI/CPU. Consider relaxing for “Very long text” or gating by device.

-                assert processing_time < 10.0, f"Processing time {processing_time:.3f}s too slow for {name}"
+                # Allow more time for very long input or CPU-only environments
+                limit = 15.0 if (name == "Very long text" or not torch.cuda.is_available()) else 10.0
+                assert processing_time < limit, f"Processing time {processing_time:.3f}s too slow for {name}"

1-1: Shebang without executable bit.

Either remove the shebang or mark executable to satisfy EXE001.

src/models/emotion_detection/samo_bert_emotion_classifier.py (8)

30-36: Avoid configuring global logging and warning filters in a library module.

Leave logging configuration to applications; setting basicConfig and global filters has side effects for downstream users.

-# Configure logging
-logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
-
-# Suppress warnings for cleaner output
-warnings.filterwarnings("ignore", category=UserWarning)

74-78: Condense config merge.

Small readability cleanup; keeps behavior identical.

-        if config is None:
-            config = default_config
-        else:
-            config = {**default_config, **config}
+        config = default_config if config is None else {**default_config, **config}

86-86: Make prediction threshold configurable.

Hardcoding 0.6 limits flexibility and can drift from evaluation settings. Read from config with a default.

-        self.prediction_threshold = 0.6  # Updated from 0.5 to 0.6 based on calibration
+        self.prediction_threshold = float(config.get("prediction_threshold", 0.6))

127-151: Layer freeze/unfreeze semantics can surprise.

_unfreeze_bert_layers currently sets embeddings to requires_grad=True as a side effect. Consider explicit freeze/unfreeze helpers to prevent inconsistent states.

-    def _set_bert_layers_grad(self, num_layers: int, requires_grad: bool) -> None:
+    def _set_bert_layers_grad(self, num_layers: int, requires_grad: bool) -> None:
         """Set gradient requirements for BERT layers."""
         if num_layers <= 0:
             return
 
-        # Set embeddings
-        for param in self.bert.embeddings.parameters():
-            param.requires_grad = requires_grad
+        # Set embeddings consistently with encoder 0
+        for p in self.bert.embeddings.parameters():
+            p.requires_grad = requires_grad
 
         # Set encoder layers
         for i in range(min(num_layers, len(self.bert.encoder.layer))):
             for param in self.bert.encoder.layer[i].parameters():
                 param.requires_grad = requires_grad

Optionally add an explicit freeze_all/unfreeze_all if you need broader control.


259-266: Avoid per-call imports; cache emotion labels once.

Importing labels inside the loop adds overhead. Cache labels on init and reference self.emotion_labels.

@@
-        # Classification head
         self.classifier = nn.Sequential(
@@
         )
@@
+        # Cache emotion labels for mapping indices to names
+        try:
+            from .emotion_labels import get_all_emotions
+            self.emotion_labels = get_all_emotions()
+        except Exception:
+            self.emotion_labels = [f"emotion_{i}" for i in range(self.num_emotions)]
@@
-                # Get emotion names for predictions
-                from .emotion_labels import GOEMOTIONS_EMOTIONS
-                for pred in batch_predictions:
-                    emotions = [
-                        GOEMOTIONS_EMOTIONS[i] for i, p in enumerate(pred) if p > 0
-                    ]
-                    all_emotions.append(emotions)
+                # Get emotion names for predictions
+                for pred in batch_predictions:
+                    emotions = [self.emotion_labels[i] for i, p in enumerate(pred) if p > 0]
+                    all_emotions.append(emotions)

Also applies to: 98-106


369-391: Return token_type_ids only when present.

Avoid allocating large zero tensors; downstream already uses .get(...).

     def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
@@
-        return {
-            "input_ids": encoding["input_ids"].squeeze(0),
-            "attention_mask": encoding["attention_mask"].squeeze(0),
-            "token_type_ids": encoding.get("token_type_ids", torch.zeros_like(encoding["input_ids"])).squeeze(0),
-            "labels": label_tensor,
-        }
+        item = {
+            "input_ids": encoding["input_ids"].squeeze(0),
+            "attention_mask": encoding["attention_mask"].squeeze(0),
+            "labels": label_tensor,
+        }
+        if "token_type_ids" in encoding:
+            item["token_type_ids"] = encoding["token_type_ids"].squeeze(0)
+        return item

417-424: Avoid duplicating defaults in factory config.

Let the model’s default_config handle defaults; only override what’s variable here.

-    # Create model with default config
-    config = {
-        "hidden_dropout_prob": 0.3,
-        "classifier_dropout_prob": 0.5,
-        "freeze_bert_layers": freeze_bert_layers,
-        "temperature": 1.0,
-    }
+    # Only override what differs from model defaults
+    config = {"freeze_bert_layers": freeze_bert_layers}

1-1: Shebang without executable bit.

Either remove the shebang or mark executable to satisfy EXE001.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 56a5d15 and 3c73dd8.

⛔ Files ignored due to path filters (2)
  • src/models/emotion_detection/__pycache__/emotion_labels.cpython-38.pyc is excluded by !**/*.pyc
  • src/models/emotion_detection/__pycache__/samo_bert_emotion_classifier.cpython-38.pyc is excluded by !**/*.pyc
📒 Files selected for processing (6)
  • src/models/emotion_detection/config.py (1 hunks)
  • src/models/emotion_detection/enhanced_bert_classifier.py (1 hunks)
  • src/models/emotion_detection/enhanced_config.py (1 hunks)
  • src/models/emotion_detection/samo_bert_emotion_classifier.py (1 hunks)
  • test_samo_emotion_detection_enhanced.py (1 hunks)
  • test_samo_emotion_detection_standalone.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
test_samo_emotion_detection_enhanced.py (3)
src/models/emotion_detection/enhanced_bert_classifier.py (4)
  • EmotionPrediction (32-39)
  • get_model_info (475-491)
  • predict_emotions (229-270)
  • get_performance_metrics (459-473)
src/models/emotion_detection/enhanced_config.py (2)
  • create_enhanced_config_manager (559-568)
  • get_config (510-512)
src/models/emotion_detection/emotion_labels.py (2)
  • get_all_emotions (270-277)
  • get_emotion_description (244-254)
src/models/emotion_detection/enhanced_bert_classifier.py (1)
src/models/emotion_detection/samo_bert_emotion_classifier.py (7)
  • _init_classification_layers (120-125)
  • _freeze_bert_layers (144-146)
  • forward (152-190)
  • forward (309-338)
  • predict_emotions (192-274)
  • count_parameters (281-283)
  • count_frozen_parameters (285-287)
src/models/emotion_detection/enhanced_config.py (1)
deployment/cloud-run/secure_api_server.py (1)
  • sanitize_input (158-172)
src/models/emotion_detection/samo_bert_emotion_classifier.py (1)
src/models/emotion_detection/enhanced_bert_classifier.py (6)
  • _init_classification_layers (169-174)
  • _freeze_bert_layers (176-190)
  • forward (192-211)
  • predict_emotions (229-270)
  • count_parameters (493-495)
  • count_frozen_parameters (497-499)
test_samo_emotion_detection_standalone.py (2)
src/models/emotion_detection/samo_bert_emotion_classifier.py (5)
  • create_samo_bert_emotion_classifier (394-434)
  • count_parameters (281-283)
  • count_frozen_parameters (285-287)
  • predict_emotions (192-274)
  • set_temperature (276-279)
src/models/emotion_detection/emotion_labels.py (2)
  • get_all_emotions (270-277)
  • get_emotion_description (244-254)
🪛 Ruff (0.12.2)
src/models/emotion_detection/config.py

1-1: Shebang is present but file is not executable

(EXE001)

test_samo_emotion_detection_enhanced.py

1-1: Shebang is present but file is not executable

(EXE001)


36-36: Consider moving this statement to an else block

(TRY300)


37-37: Do not catch blind exception: Exception

(BLE001)


53-53: Consider moving this statement to an else block

(TRY300)


54-54: Do not catch blind exception: Exception

(BLE001)


71-71: Consider moving this statement to an else block

(TRY300)


72-72: Do not catch blind exception: Exception

(BLE001)


112-112: Consider moving this statement to an else block

(TRY300)


113-113: Do not catch blind exception: Exception

(BLE001)


146-146: Consider moving this statement to an else block

(TRY300)


147-147: Do not catch blind exception: Exception

(BLE001)


171-171: Consider moving this statement to an else block

(TRY300)


172-172: Do not catch blind exception: Exception

(BLE001)


185-185: Do not catch blind exception: Exception

(BLE001)


199-199: Consider moving this statement to an else block

(TRY300)


200-200: Do not catch blind exception: Exception

(BLE001)


220-220: Consider moving this statement to an else block

(TRY300)


221-221: Do not catch blind exception: Exception

(BLE001)


281-281: Consider moving this statement to an else block

(TRY300)

src/models/emotion_detection/enhanced_bert_classifier.py

118-118: Consider moving this statement to an else block

(TRY300)


119-119: Do not catch blind exception: Exception

(BLE001)


135-135: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


136-136: Avoid specifying long messages outside the exception class

(TRY003)


205-205: Consider moving this statement to an else block

(TRY300)


211-211: Avoid specifying long messages outside the exception class

(TRY003)


263-263: Avoid specifying long messages outside the exception class

(TRY003)


450-450: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


451-451: Avoid specifying long messages outside the exception class

(TRY003)


517-517: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


518-518: Avoid specifying long messages outside the exception class

(TRY003)


531-531: Consider moving this statement to an else block

(TRY300)


533-533: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


534-534: Avoid specifying long messages outside the exception class

(TRY003)

src/models/emotion_detection/enhanced_config.py

218-218: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


221-221: Do not catch blind exception: Exception

(BLE001)


222-222: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


278-278: Do not catch blind exception: Exception

(BLE001)


279-279: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


438-438: Consider moving this statement to an else block

(TRY300)


450-450: Consider moving this statement to an else block

(TRY300)


462-462: Consider moving this statement to an else block

(TRY300)


474-474: Consider moving this statement to an else block

(TRY300)


520-520: Do not catch blind exception: Exception

(BLE001)


521-521: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


534-534: Do not catch blind exception: Exception

(BLE001)


535-535: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

src/models/emotion_detection/samo_bert_emotion_classifier.py

1-1: Shebang is present but file is not executable

(EXE001)


195-195: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

test_samo_emotion_detection_standalone.py

1-1: Shebang is present but file is not executable

(EXE001)


84-84: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (1)
src/models/emotion_detection/samo_bert_emotion_classifier.py (1)

176-186: Good fallback when pooler_output is missing.

Using [CLS] when pooler_output is None improves model compatibility.

Comment thread src/models/emotion_detection/config.py Outdated
Comment thread src/models/emotion_detection/enhanced_bert_classifier.py Outdated
Comment thread src/models/emotion_detection/enhanced_bert_classifier.py
Comment thread src/models/emotion_detection/enhanced_bert_classifier.py Outdated
Comment thread src/models/emotion_detection/enhanced_bert_classifier.py
Comment on lines +213 to +228
@contextmanager
def inference_mode(self):
"""Context manager for inference mode with optimizations."""
was_training = self.training
self.eval()
try:
with torch.no_grad():
if self.use_mixed_precision:
with torch.cuda.amp.autocast():
yield
else:
yield
finally:
if was_training:
self.train()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use device-aware autocast, not torch.cuda.amp on non-CUDA devices.

Avoid CUDA-specific autocast on CPU/MPS.

     @contextmanager
     def inference_mode(self):
         """Context manager for inference mode with optimizations."""
         was_training = self.training
         self.eval()
         try:
             with torch.no_grad():
-                if self.use_mixed_precision:
-                    with torch.cuda.amp.autocast():
-                        yield
-                else:
-                    yield
+                device_type = self.device.type
+                enable_amp = self.use_mixed_precision and device_type in ("cuda", "cpu")
+                if enable_amp:
+                    with torch.autocast(device_type=device_type):
+                        yield
+                else:
+                    yield
         finally:
             if was_training:
                 self.train()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@contextmanager
def inference_mode(self):
"""Context manager for inference mode with optimizations."""
was_training = self.training
self.eval()
try:
with torch.no_grad():
if self.use_mixed_precision:
with torch.cuda.amp.autocast():
yield
else:
yield
finally:
if was_training:
self.train()
@contextmanager
def inference_mode(self):
"""Context manager for inference mode with optimizations."""
was_training = self.training
self.eval()
try:
with torch.no_grad():
device_type = self.device.type
enable_amp = self.use_mixed_precision and device_type in ("cuda", "cpu")
if enable_amp:
with torch.autocast(device_type=device_type):
yield
else:
yield
finally:
if was_training:
self.train()
🤖 Prompt for AI Agents
In src/models/emotion_detection/enhanced_bert_classifier.py around lines 213 to
228, the context manager uses torch.cuda.amp.autocast which is CUDA-only and
will fail or be suboptimal on CPU/MPS; change it to a device-aware autocast by
using torch.autocast(device_type=self.device.type) (or conditionally call the
appropriate autocast based on self.device.type) so mixed-precision is applied
correctly on CUDA and safely skipped or handled on CPU/MPS; keep the same yield
structure and restore training state in finally.

Comment thread src/models/emotion_detection/enhanced_bert_classifier.py
Comment thread src/models/emotion_detection/enhanced_config.py Outdated
Comment thread src/models/emotion_detection/enhanced_config.py Outdated
Comment thread src/models/emotion_detection/samo_bert_emotion_classifier.py Outdated
- Rename 'model' parameter to 'emotion_model' in evaluate_emotion_classifier
- Rename 'model' variable to 'emotion_classifier' in create_samo_bert_emotion_classifier
- Rename 'text' variable to 'sample_text' in EmotionDataset.__getitem__
- Rename 'i' variables to 'layer_idx' and 'batch_idx' in loops
- Rename 'emotion' variable to 'emotion_name' in emotion_labels.py test loop

All 9 PYL-W0621 linting errors resolved while maintaining functionality.
All tests passing with cleaner, more readable code.
- Fix lambda closure issue by using default parameter to capture probabilities value
- Change 'lambda i: probabilities[i]' to 'lambda i, probs=probabilities: probs[i]'
- Prevents all closures from using the same reference to the loop variable
- Eliminates potential bug where all lambdas would use the final value of probabilities

This fixes the classic Python closure problem where variables are captured by reference
rather than by value, ensuring each lambda gets its own copy of the probabilities list.
- Replace 37 f-string logging calls with lazy % formatting for performance
- Fix logging calls in samo_bert_emotion_classifier.py (3 calls)
- Fix logging calls in enhanced_config.py (34 calls)
- Prevents unnecessary string formatting when logging is disabled
- Improves performance by deferring string conversion until needed

All tests passing with optimized logging performance.
…ier.py

- Convert remaining 11 f-string logging calls to lazy % formatting
- Fix device setup, model loading, tokenizer loading, and model saving logging
- Complete logging performance optimization across all emotion detection modules
- All logging calls now use lazy evaluation for better performance

All tests passing with fully optimized logging system.
- Remove loop in test_emotion_descriptions() using list comprehension
- Add explicit raise from previous error in get_emotion_index()
- Improve code quality and maintainability
- Follow Python best practices for error handling and test structure

All tests passing with improved code quality.
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

♻️ Duplicate comments (17)
src/models/emotion_detection/emotion_labels.py (1)

167-184: O(1) lookup for emotion index + cleaner exception chaining.
Use a precomputed dict and raise from None. Also reuse the map in validate_emotion. (Same ask was raised before.)

@@
 GOEMOTIONS_EMOTIONS = [
@@
 ]
+
+# Precomputed mapping for O(1) lookups
+EMOTION_TO_INDEX = {e.lower(): i for i, e in enumerate(GOEMOTIONS_EMOTIONS)}
@@
 def get_emotion_index(emotion: str) -> int:
@@
-    try:
-        return GOEMOTIONS_EMOTIONS.index(emotion.lower())
-    except ValueError as e:
-        raise ValueError(f"Emotion '{emotion}' not found in GoEmotions list") from e
+    key = emotion.lower()
+    try:
+        return EMOTION_TO_INDEX[key]
+    except KeyError:
+        raise ValueError(f"Unknown emotion: {emotion!r}") from None
@@
-    return emotion.lower() in GOEMOTIONS_EMOTIONS
+    return emotion.lower() in EMOTION_TO_INDEX

Also applies to: 299-300, 44-46

src/models/emotion_detection/config.py (1)

38-43: Avoid import-time “constants” for thresholds; expose getters tied to runtime config.
Current constants won’t reflect update_config() changes. Provide getters that read _config.

-# Emotion classification threshold (used in evaluation)
-EMOTION_CLASSIFICATION_THRESHOLD = DEFAULT_CONFIG["evaluation"]["threshold"]
-
-# Prediction threshold (used in inference)
-EMOTION_PREDICTION_THRESHOLD = DEFAULT_CONFIG["prediction"]["threshold"]
+def get_evaluation_threshold() -> float:
+    """Current evaluation threshold."""
+    return _config.evaluation_threshold
+
+def get_prediction_threshold() -> float:
+    """Current prediction threshold."""
+    return _config.prediction_threshold
src/models/emotion_detection/enhanced_config.py (1)

417-425: _validate_string returns “default” but yields ""; accept a default parameter and pass real defaults.
Prevents silently writing empty strings for paths/labels.

@@
-    def _validate_string(value: Any, field_name: str) -> str:
+    def _validate_string(value: Any, field_name: str, default: str = "") -> str:
         """Validate string value."""
         if not isinstance(value, str):
-            logger.warning("Invalid string value for %s: %s, using default", field_name, value)
-            return ""
+            logger.warning("Invalid string value for %s: %s, using default", field_name, value)
+            return default
         return value
@@
-            name=self._validate_string(data.get("name", "bert-base-uncased"), "model.name"),
+            name=self._validate_string(data.get("name", "bert-base-uncased"), "model.name", "bert-base-uncased"),
@@
-            padding=self._validate_string(data.get("padding", "max_length"), "data.padding"),
+            padding=self._validate_string(data.get("padding", "max_length"), "data.padding", "max_length"),
@@
-            level=self._validate_log_level(data.get("level", "INFO"), "logging.level"),
+            level=self._validate_log_level(data.get("level", "INFO"), "logging.level"),
             log_interval=self._validate_positive_int(data.get("log_interval", 100), "logging.log_interval"),
             save_interval=self._validate_positive_int(data.get("save_interval", 1000), "logging.save_interval"),
             enable_tensorboard=self._validate_bool(data.get("enable_tensorboard", True), "logging.enable_tensorboard"),
-            log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir"),
+            log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir", "logs/emotion_detection"),
@@
-            save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir"),
-            save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric"),
+            save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir", "models/emotion_detection"),
+            save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric", "f1_macro"),
@@
-            error_log_file=self._validate_string(data.get("error_log_file", "logs/emotion_detection_errors.log"), "error_handling.error_log_file"),
+            error_log_file=self._validate_string(data.get("error_log_file", "logs/emotion_detection_errors.log"), "error_handling.error_log_file", "logs/emotion_detection_errors.log"),

Also applies to: 287-292, 331-337, 351-356, 361-365, 395-396

test_samo_emotion_detection_standalone.py (3)

56-75: De-duplicate test texts by reusing SAMPLE_TEST_TEXTS.

Avoid divergence of fixtures; extend with edge cases locally.

-    test_texts = [
-        "I am so happy and excited about this amazing opportunity!",
-        "I feel really sad and disappointed about what happened today.",
-        "I'm feeling anxious and worried about the upcoming presentation.",
-        "I love spending time with my family and friends.",
-        "I'm angry and frustrated with this situation.",
-        "I feel grateful and thankful for all the support I've received.",
-        "I'm confused and don't understand what's going on.",
-        "I feel proud of my accomplishments and achievements.",
+    test_texts = [
+        *SAMPLE_TEST_TEXTS,
+        "I love spending time with my family and friends.",
+        "I'm angry and frustrated with this situation.",
+        "I feel grateful and thankful for all the support I've received.",
+        "I'm confused and don't understand what's going on.",
+        "I feel proud of my accomplishments and achievements.",

124-141: Good: batch structure and value checks are asserted.

These assertions address prior feedback on validating outputs.


193-214: Good: threshold monotonicity assertions.

This test precisely captures expected behavior.

src/models/emotion_detection/enhanced_bert_classifier.py (6)

19-20: Broken import: wrong module name.

Importing from .labels will fail; use emotion_labels.

-from .labels import GOEMOTIONS_EMOTIONS
+from .emotion_labels import GOEMOTIONS_EMOTIONS

84-91: Temperature overwritten later; initialize once as learnable with provided value.

Preserve user-supplied value and correct type.

-        self.temperature = temperature
+        # Learnable temperature initialized from argument
+        self.temperature = nn.Parameter(torch.tensor([float(temperature)], dtype=torch.float32))

146-148: Remove redundant temperature re-assignment.

Prevents clobbering provided value.

-        self.temperature = nn.Parameter(torch.ones(1))
         self._init_classification_layers()

190-203: Forward: handle missing pooler_output and enforce positive temperature.

Robust across models and avoids divide-by-zero/negative temps.

-    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
+    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
@@
-            bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
-            pooled_output = bert_outputs.pooler_output
+            bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
+            pooled_output = getattr(bert_outputs, "pooler_output", None)
+            if pooled_output is None:
+                pooled_output = bert_outputs.last_hidden_state[:, 0, :]
@@
-            logits = self.classifier(pooled_output)
+            logits = self.classifier(pooled_output)
@@
-            logits = logits / self.temperature
+            temp = F.softplus(self.temperature) + 1e-6
+            logits = logits / temp

211-223: Use device-aware autocast; current CUDA-only autocast breaks on CPU/MPS.

Switch to torch.autocast with device_type.

         try:
             with torch.no_grad():
-                if self.use_mixed_precision:
-                    with torch.cuda.amp.autocast():
-                        yield
-                else:
-                    yield
+                device_type = self.device.type
+                enable_amp = self.use_mixed_precision and device_type in ("cuda", "cpu", "mps")
+                if enable_amp:
+                    with torch.autocast(device_type=device_type):
+                        yield
+                else:
+                    yield

327-358: Batch: preserve input order and length when filtering empty/whitespace texts.

Current code drops entries; reconstruct full result list.

-        # Filter out empty texts
-        valid_texts = [text for text in texts if text and text.strip()]
-        if not valid_texts:
-            return [self._create_empty_prediction(return_metadata) for _ in texts]
+        # Build index map preserving positions; filter invalid
+        index_map, valid_texts = [], []
+        for idx, t in enumerate(texts):
+            if t and t.strip():
+                index_map.append(idx)
+                valid_texts.append(t)
+        if not valid_texts:
+            return [self._create_empty_prediction(return_metadata) for _ in texts]
@@
-        # Process results
-        results = []
-        for i, prob in enumerate(probabilities):
-            result = self._process_prediction_results(
-                prob, top_k, return_metadata, valid_texts[i]
-            )
-            results.append(result)
-
-        return results
+        # Process results and reassemble in original order
+        results = [self._create_empty_prediction(return_metadata) for _ in texts]
+        for i, prob in enumerate(probabilities):
+            results[index_map[i]] = self._process_prediction_results(
+                prob, top_k, return_metadata, valid_texts[i]
+            )
+        return results
src/models/emotion_detection/samo_bert_emotion_classifier.py (5)

126-151: Layer freeze/unfreeze symmetry

unfreeze_bert_layers() only toggles the first N layers/embeddings; previously frozen later layers remain frozen. Consider a symmetrical approach (unfreeze all, then selectively (re)freeze) or track frozen ranges.


175-183: Good fallback for models without pooler_output

CLS fallback is correctly implemented.


225-243: Token type IDs handling is fine

Passing None when absent avoids unnecessary zero tensors and is supported by HF models.


258-263: Map to labels once; guard for length mismatches

Importing every batch is inefficient and assumes 28 labels. Store labels on init and guard index overflow to avoid IndexError if num_emotions differs.

@@
-                # Get emotion names for predictions
-                from .emotion_labels import GOEMOTIONS_EMOTIONS
-                for pred in batch_predictions:
-                    emotions = [
-                        GOEMOTIONS_EMOTIONS[i] for i, p in enumerate(pred) if p > 0
-                    ]
+                # Get emotion names for predictions
+                for pred in batch_predictions:
+                    labels = getattr(self, "emotion_labels", [f"emotion_{i}" for i in range(self.num_emotions)])
+                    emotions = [labels[i] for i, p in enumerate(pred) if p > 0 and i < len(labels)]
                     all_emotions.append(emotions)

And initialize once:

@@
         self.prediction_threshold = 0.6  # Updated from 0.5 to 0.6 based on calibration
+        # Emotion labels (fallback to placeholders if module unavailable)
+        try:
+            from .emotion_labels import GOEMOTIONS_EMOTIONS
+            self.emotion_labels = list(GOEMOTIONS_EMOTIONS)
+        except Exception:
+            self.emotion_labels = [f"emotion_{i}" for i in range(num_emotions)]

319-333: Device/broadcasting bug when applying class weights

On CUDA, multiplying CPU class_weights with GPU loss will error; unsqueeze(0) also may not broadcast for batch>1. Fix device + shape handling (or register as buffer).

         # Apply class weights if provided
         if self.class_weights is not None:
-            bce_loss = bce_loss * self.class_weights.unsqueeze(0)
+            cw = self.class_weights
+            if cw.device != bce_loss.device:
+                cw = cw.to(bce_loss.device)
+            if cw.ndim == 1:
+                cw = cw.view(1, -1)
+            bce_loss = bce_loss * cw.expand_as(bce_loss)

Optional: make weights track device automatically.

@@ class WeightedBCELoss(nn.Module):
-        self.class_weights = class_weights
+        if class_weights is not None:
+            self.register_buffer("class_weights", class_weights.float())
+        else:
+            self.class_weights = None
🧹 Nitpick comments (24)
src/models/emotion_detection/emotion_labels.py (2)

1-1: Remove non-executable shebang or make file executable.
The shebang triggers Ruff EXE001. Either drop it or chmod +x the file.

-#!/usr/bin/env python3

214-214: Make group lookups case-insensitive.
Lower the key before dict access.

-    return EMOTION_VALENCE_GROUPS.get(valence, [])
+    return EMOTION_VALENCE_GROUPS.get(valence.lower(), [])
@@
-    return EMOTION_AROUSAL_GROUPS.get(arousal, [])
+    return EMOTION_AROUSAL_GROUPS.get(arousal.lower(), [])
@@
-    return EMOTION_DOMINANCE_GROUPS.get(dominance, [])
+    return EMOTION_DOMINANCE_GROUPS.get(dominance.lower(), [])

Also applies to: 227-227, 240-240

src/models/emotion_detection/config.py (2)

1-1: Drop non-executable shebang or make executable.
Matches Ruff EXE001.

-#!/usr/bin/env python3

12-36: Single source of truth for defaults.
DEFAULT_CONFIG isn’t used to build EmotionDetectionConfig; either wire get_default_config() from DEFAULT_CONFIG or remove DEFAULT_CONFIG to prevent drift.

Also applies to: 96-99

src/models/emotion_detection/enhanced_config.py (2)

216-223: Use logging.exception and avoid blind excepts where feasible.
Switch to exception() to capture tracebacks; narrow exceptions where practical.

-        except yaml.YAMLError as e:
-            logger.error("YAML parsing error: %s", e)
+        except yaml.YAMLError:
+            logger.exception("YAML parsing error")
             logger.warning("Using default configuration due to YAML error")
             return self._create_default_config()
-        except Exception as e:
-            logger.error("Configuration loading failed: %s", e)
+        except Exception:
+            logger.exception("Configuration loading failed")
             logger.warning("Using default configuration due to loading error")
             return self._create_default_config()
@@
-        except Exception as e:
-            logger.error("Configuration parsing failed: %s", e)
+        except Exception:
+            logger.exception("Configuration parsing failed")
             logger.warning("Using default configuration due to parsing error")
             return self._create_default_config()
@@
-        except Exception as e:
-            logger.error("Configuration update failed: %s", e)
+        except Exception:
+            logger.exception("Configuration update failed")
@@
-        except Exception as e:
-            logger.error("Failed to save configuration: %s", e)
+        except Exception:
+            logger.exception("Failed to save configuration")

Also applies to: 279-283, 531-533, 545-547


524-533: Convert update_config to an instance method and implement a safe shallow merge.
Allow runtime updates section-by-section.

-    @staticmethod
-    def update_config(updates: Dict[str, Any]) -> None:
+    def update_config(self, updates: Dict[str, Any]) -> None:
         """Update configuration with new values."""
         try:
-            # This would need more sophisticated merging logic
-            # For now, just log the attempt
-            logger.info("Configuration update requested: %s", updates)
+            logger.info("Configuration update requested: %s", updates)
+            for section, vals in updates.items():
+                if hasattr(self.config, section) and isinstance(vals, dict):
+                    section_obj = getattr(self.config, section)
+                    for k, v in vals.items():
+                        if hasattr(section_obj, k):
+                            setattr(section_obj, k, v)
         except Exception:
             logger.exception("Configuration update failed")
test_samo_emotion_detection_enhanced.py (5)

1-1: Remove non-executable shebang or make the script executable.

Either drop the shebang or set the exec bit in git to satisfy linters.

Apply this diff to remove it:

-#!/usr/bin/env python3

15-15: Harden sys.path injection.

Avoid duplicate inserts and resolve path for robustness.

-sys.path.insert(0, str(Path(__file__).parent / "src"))
+src_path = str((Path(__file__).resolve().parent / "src"))
+if src_path not in sys.path:
+    sys.path.insert(0, src_path)

24-39: Narrow exception handling in initialization.

Catching Exception hides actionable failures; return in else for clarity.

 def test_model_initialization():
@@
-    try:
+    try:
         model = EnhancedBERTEmotionClassifier(
@@
-        print("✅ Enhanced classifier initialized successfully")
-        return model
-    except Exception as e:
+        print("✅ Enhanced classifier initialized successfully")
+    except (RuntimeError, OSError, ValueError) as e:
         print(f"❌ Model initialization failed: {e}")
-        return None
+        return None
+    else:
+        return model

58-74: Config access looks correct. Add one sanity assertion.

Guard against missing/invalid config fields to fail fast.

-        print(f"   Model name: {config.model.name}")
+        print(f"   Model name: {config.model.name}")
+        assert config.emotion_detection.num_emotions > 0, "num_emotions must be > 0"

76-115: Add lightweight assertions to validate EmotionPrediction structure.

Current test only prints. Validate types and ranges.

@@
-            prediction = model.predict_emotions(
+            prediction = model.predict_emotions(
                 text, 
                 top_k=3, 
                 return_metadata=True
             )
@@
             print(f"   Top emotions: {prediction.top_k_emotions[:3]}")
@@
             if prediction.prediction_metadata:
                 print(f"   Text length: {prediction.prediction_metadata.get('text_length', 'N/A')}")
+            # Assertions
+            assert isinstance(prediction.primary_emotion, str)
+            assert 0.0 <= prediction.confidence <= 1.0
+            assert isinstance(prediction.top_k_emotions, list) and len(prediction.top_k_emotions) <= 3
test_samo_emotion_detection_standalone.py (3)

1-1: Remove non-executable shebang or mark executable.

Same as the enhanced test script.

-#!/usr/bin/env python3

176-191: Temperature test: consider asserting relative confidence/entropy.

Increase signal by checking that colder temperature increases top-1 confidence.

@@
-    print(f"   Cold temperature (0.5): {len(results_cold['emotions'][0])} emotions")
-    print(f"   Hot temperature (2.0): {len(results_hot['emotions'][0])} emotions")
+    print(f"   Cold temperature (0.5): {len(results_cold['emotions'][0])} emotions")
+    print(f"   Hot temperature (2.0): {len(results_hot['emotions'][0])} emotions")
+    # Optional: assert colder => not fewer detected emotions than hotter for same threshold
+    assert len(results_cold['emotions'][0]) >= len(results_hot['emotions'][0])

263-317: Reduce flakiness in performance test.

First-run HF downloads and CPU can exceed 10s. Warm up and relax/tier SLA.

@@
-        model, _ = create_samo_bert_emotion_classifier()
+        model, _ = create_samo_bert_emotion_classifier()
+        # Warm-up (tokenizer + model graph)
+        _ = model.predict_emotions("warmup", threshold=0.3)
@@
-            try:
-                results = model.predict_emotions(text, threshold=0.3)
+            try:
+                results = model.predict_emotions(text, threshold=0.3)
                 end_time = time.time()
@@
-                assert processing_time < 10.0, f"Processing time {processing_time:.3f}s too slow for {name}"
+                assert processing_time < 15.0, f"Processing time {processing_time:.3f}s too slow for {name}"
-                assert isinstance(emotions, list), f"Emotions should be a list for {name}"
-                assert len(emotions) >= 0, f"Emotions count should be non-negative for {name}"
+                assert isinstance(emotions, list), f"Emotions should be a list for {name}"
+                # Optional: at least one or zero depending on threshold/text; keep non-negative check meaningful
+                assert len(emotions) >= 0
src/models/emotion_detection/enhanced_bert_classifier.py (4)

131-135: Prefer logger.exception in except paths.

Improves tracebacks per Ruff TRY400.

-        except Exception as e:
-            logger.error("Failed to load BERT model: %s", e)
-            raise RuntimeError(f"BERT model initialization failed: {e}") from e
+        except Exception as e:
+            logger.exception("Failed to load BERT model")
+            raise RuntimeError(f"BERT model initialization failed: {e}") from e

446-450: Tokenizer load: use logger.exception for traceback.

Minor logging improvement.

-            except Exception as e:
-                logger.error("Failed to load tokenizer: %s", e)
+            except Exception as e:
+                logger.exception("Failed to load tokenizer")
                 raise RuntimeError(f"Tokenizer loading failed: {e}") from e

499-517: Use logger.exception in save_model failure path.

Consistency with other error logs.

-        except Exception as e:
-            logger.error("Failed to save model: %s", e)
+        except Exception as e:
+            logger.exception("Failed to save model")
             raise RuntimeError(f"Model saving failed: {e}") from e

518-533: Use logger.exception in load_model failure path and else for success log.

Minor but cleaner flow.

-        try:
+        try:
             checkpoint = torch.load(path, map_location=device or 'cpu')
@@
-            logger.info("Model loaded from: %s", path)
-            return model
-        except Exception as e:
-            logger.error("Failed to load model: %s", e)
+            logger.info("Model loaded from: %s", path)
+            return model
+        except Exception as e:
+            logger.exception("Failed to load model")
             raise RuntimeError(f"Model loading failed: {e}") from e
src/models/emotion_detection/samo_bert_emotion_classifier.py (6)

29-35: Avoid global logging/warning configuration in a library module

Remove logging.basicConfig(...) and the global warnings.filterwarnings(...) to prevent side effects on host apps/tests. Keep a module logger only.

-# Configure logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-# Suppress warnings for cleaner output
-warnings.filterwarnings("ignore", category=UserWarning)
+logger = logging.getLogger(__name__)

84-84: Remove unused self.class_weights field

It’s not used by this class (loss handles weights). Drop it to avoid confusion.

-        self.class_weights = None

191-198: Fix implicit Optional type for threshold (RUF013)

Use Optional[float] to satisfy PEP 484 and ruff.

     def predict_emotions(
         self,
         texts: Union[str, List[str]],
-        threshold: float = None,
+        threshold: Optional[float] = None,
         top_k: Optional[int] = None,
         batch_size: int = 32,
     ) -> Dict[str, Union[List[List[str]], List[List[float]], List[List[float]]]]:

370-375: Use tokenizer’s model_max_length instead of hardcoded 512

Aligns with the chosen backbone and avoids silent truncation mismatches.

-                    max_length=512,
+                    max_length=self.tokenizer.model_max_length,
-            max_length=self.max_length,
+            max_length=self.max_length,

(Consider defaulting max_length to tokenizer.model_max_length in the dataset’s constructor.)

Also applies to: 230-231


445-447: Make evaluation threshold robust to missing config

Import may fail if constant is absent; fall back to model’s threshold.

-    from .config import EMOTION_CLASSIFICATION_THRESHOLD
-    threshold = EMOTION_CLASSIFICATION_THRESHOLD
+    try:
+        from .config import EMOTION_CLASSIFICATION_THRESHOLD as threshold  # type: ignore
+    except Exception:
+        threshold = getattr(emotion_model, "prediction_threshold", 0.5)

If you want, I can add a small verification script to confirm the constant exists in the repo.


1-1: Shebang on a library file

Either remove the shebang or mark the file executable. It isn’t needed for an importable module.

-#!/usr/bin/env python3
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3c73dd8 and fa1a1dd.

⛔ Files ignored due to path filters (2)
  • src/models/emotion_detection/__pycache__/emotion_labels.cpython-38.pyc is excluded by !**/*.pyc
  • src/models/emotion_detection/__pycache__/samo_bert_emotion_classifier.cpython-38.pyc is excluded by !**/*.pyc
📒 Files selected for processing (7)
  • src/models/emotion_detection/config.py (1 hunks)
  • src/models/emotion_detection/emotion_labels.py (1 hunks)
  • src/models/emotion_detection/enhanced_bert_classifier.py (1 hunks)
  • src/models/emotion_detection/enhanced_config.py (1 hunks)
  • src/models/emotion_detection/samo_bert_emotion_classifier.py (1 hunks)
  • test_samo_emotion_detection_enhanced.py (1 hunks)
  • test_samo_emotion_detection_standalone.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
src/models/emotion_detection/samo_bert_emotion_classifier.py (1)
src/models/emotion_detection/enhanced_bert_classifier.py (6)
  • _init_classification_layers (167-172)
  • _freeze_bert_layers (174-188)
  • forward (190-209)
  • predict_emotions (227-268)
  • count_parameters (491-493)
  • count_frozen_parameters (495-497)
src/models/emotion_detection/config.py (1)
src/models/emotion_detection/enhanced_config.py (3)
  • EmotionDetectionConfig (46-51)
  • get_config (520-522)
  • update_config (525-532)
test_samo_emotion_detection_standalone.py (2)
src/models/emotion_detection/samo_bert_emotion_classifier.py (5)
  • create_samo_bert_emotion_classifier (389-426)
  • count_parameters (280-282)
  • count_frozen_parameters (284-286)
  • predict_emotions (191-273)
  • set_temperature (275-278)
src/models/emotion_detection/emotion_labels.py (2)
  • get_all_emotions (269-276)
  • get_emotion_description (243-253)
src/models/emotion_detection/enhanced_bert_classifier.py (1)
src/models/emotion_detection/samo_bert_emotion_classifier.py (7)
  • _init_classification_layers (119-124)
  • _freeze_bert_layers (143-145)
  • forward (151-189)
  • forward (308-333)
  • predict_emotions (191-273)
  • count_parameters (280-282)
  • count_frozen_parameters (284-286)
src/models/emotion_detection/enhanced_config.py (2)
src/models/emotion_detection/config.py (3)
  • EmotionDetectionConfig (46-93)
  • get_config (117-119)
  • update_config (122-125)
deployment/cloud-run/secure_api_server.py (1)
  • sanitize_input (158-172)
test_samo_emotion_detection_enhanced.py (3)
src/models/emotion_detection/enhanced_bert_classifier.py (4)
  • EnhancedBERTEmotionClassifier (39-532)
  • get_model_info (473-489)
  • predict_emotions (227-268)
  • get_performance_metrics (457-471)
src/models/emotion_detection/enhanced_config.py (2)
  • create_enhanced_config_manager (570-579)
  • get_config (520-522)
src/models/emotion_detection/emotion_labels.py (2)
  • get_all_emotions (269-276)
  • get_emotion_description (243-253)
🪛 Ruff (0.12.2)
src/models/emotion_detection/samo_bert_emotion_classifier.py

1-1: Shebang is present but file is not executable

(EXE001)


194-194: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

src/models/emotion_detection/config.py

1-1: Shebang is present but file is not executable

(EXE001)

test_samo_emotion_detection_standalone.py

1-1: Shebang is present but file is not executable

(EXE001)


90-90: Do not catch blind exception: Exception

(BLE001)

src/models/emotion_detection/emotion_labels.py

1-1: Shebang is present but file is not executable

(EXE001)


183-183: Avoid specifying long messages outside the exception class

(TRY003)


201-201: Avoid specifying long messages outside the exception class

(TRY003)

src/models/emotion_detection/enhanced_bert_classifier.py

116-116: Consider moving this statement to an else block

(TRY300)


117-117: Do not catch blind exception: Exception

(BLE001)


133-133: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


134-134: Avoid specifying long messages outside the exception class

(TRY003)


203-203: Consider moving this statement to an else block

(TRY300)


209-209: Avoid specifying long messages outside the exception class

(TRY003)


261-261: Avoid specifying long messages outside the exception class

(TRY003)


448-448: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


449-449: Avoid specifying long messages outside the exception class

(TRY003)


515-515: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


516-516: Avoid specifying long messages outside the exception class

(TRY003)


529-529: Consider moving this statement to an else block

(TRY300)


531-531: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


532-532: Avoid specifying long messages outside the exception class

(TRY003)

src/models/emotion_detection/enhanced_config.py

217-217: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


220-220: Do not catch blind exception: Exception

(BLE001)


221-221: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


279-279: Do not catch blind exception: Exception

(BLE001)


280-280: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


442-442: Consider moving this statement to an else block

(TRY300)


455-455: Consider moving this statement to an else block

(TRY300)


468-468: Consider moving this statement to an else block

(TRY300)


481-481: Consider moving this statement to an else block

(TRY300)


531-531: Do not catch blind exception: Exception

(BLE001)


532-532: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


545-545: Do not catch blind exception: Exception

(BLE001)


546-546: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

test_samo_emotion_detection_enhanced.py

1-1: Shebang is present but file is not executable

(EXE001)


35-35: Consider moving this statement to an else block

(TRY300)


36-36: Do not catch blind exception: Exception

(BLE001)


52-52: Consider moving this statement to an else block

(TRY300)


53-53: Do not catch blind exception: Exception

(BLE001)


70-70: Consider moving this statement to an else block

(TRY300)


71-71: Do not catch blind exception: Exception

(BLE001)


111-111: Consider moving this statement to an else block

(TRY300)


112-112: Do not catch blind exception: Exception

(BLE001)


145-145: Consider moving this statement to an else block

(TRY300)


146-146: Do not catch blind exception: Exception

(BLE001)


170-170: Consider moving this statement to an else block

(TRY300)


171-171: Do not catch blind exception: Exception

(BLE001)


184-184: Do not catch blind exception: Exception

(BLE001)


198-198: Consider moving this statement to an else block

(TRY300)


199-199: Do not catch blind exception: Exception

(BLE001)


219-219: Consider moving this statement to an else block

(TRY300)


220-220: Do not catch blind exception: Exception

(BLE001)


280-280: Consider moving this statement to an else block

(TRY300)

🔇 Additional comments (2)
test_samo_emotion_detection_enhanced.py (2)

151-171: Performance metrics won’t update on CPU with current classifier timing.

predict_emotions only tracks CUDA-event timings; on CPU you’ll see zeros. After fixing the classifier (see model review), consider asserting >0 to catch regressions.

Add this assertion post-fix:

@@
-        metrics = model.get_performance_metrics()
+        metrics = model.get_performance_metrics()
         print(f"   Total inferences: {metrics['total_inferences']}")
@@
-        print("✅ Performance metrics working")
+        assert metrics['total_inferences'] >= len(test_texts), "Inference count not tracked"
+        print("✅ Performance metrics working")

176-199: Nice coverage of edge cases.

Positive checks for None, very long text, and special chars. Keep.

Comment on lines +122 to +126
def update_config(config_dict: Dict[str, Any]) -> None:
"""Update global configuration."""
global _config
_config = get_config_from_dict(config_dict)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

update_config discards existing values; perform in-place shallow merge.
Rebuilding from defaults nukes prior updates.

-def update_config(config_dict: Dict[str, Any]) -> None:
-    """Update global configuration."""
-    global _config
-    _config = get_config_from_dict(config_dict)
+def update_config(config_dict: Dict[str, Any]) -> None:
+    """Update global configuration in place (shallow merge)."""
+    global _config
+    for key, value in config_dict.items():
+        if hasattr(_config, key):
+            setattr(_config, key, value)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def update_config(config_dict: Dict[str, Any]) -> None:
"""Update global configuration."""
global _config
_config = get_config_from_dict(config_dict)
def update_config(config_dict: Dict[str, Any]) -> None:
"""Update global configuration in place (shallow merge)."""
global _config
for key, value in config_dict.items():
if hasattr(_config, key):
setattr(_config, key, value)
🤖 Prompt for AI Agents
In src/models/emotion_detection/config.py around lines 122 to 126, update_config
currently replaces the global _config with get_config_from_dict(config_dict),
discarding prior values; change it to perform an in-place shallow merge: compute
updates = get_config_from_dict(config_dict), ensure _config is initialized
(e.g., to get_default_config() or an empty dict if None), then iterate
updates.items() and assign each key into _config ( _config[key] = value ) so
existing keys not present in updates are preserved; keep the function signature
and avoid reassigning the _config reference.

Comment on lines +14 to +18
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoTokenizer

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add functional API import used in temperature handling.

Required by softplus in forward/info.

 import torch.nn as nn
+import torch.nn.functional as F
 from transformers import AutoConfig, AutoModel, AutoTokenizer
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
🤖 Prompt for AI Agents
In src/models/emotion_detection/enhanced_bert_classifier.py around lines 14 to
18, the file uses the functional API (e.g., softplus) in forward/info but does
not import it; add the torch.nn.functional import (for example import
torch.nn.functional as F) near the other imports so softplus (F.softplus) is
available.

Comment on lines +245 to +269
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
if start_time:
start_time.record()

try:
# Handle single text input
if isinstance(texts, str):
return self._predict_single_text(texts, top_k, return_metadata)

# Handle multiple texts
return self._predict_batch_texts(texts, top_k, return_metadata, batch_size)

except Exception as e:
logger.exception("Emotion prediction failed")
self._error_count += 1
self._last_error = str(e)
raise RuntimeError(f"Emotion prediction failed: {e}") from e
finally:
if start_time:
end_time = torch.cuda.Event(enable_timing=True)
end_time.record()
torch.cuda.synchronize()
inference_time = start_time.elapsed_time(end_time) / 1000.0
self._update_performance_metrics(inference_time)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Performance metrics currently update only on CUDA; add CPU/MPS timing fallback.

Ensure inference count/time tracked on all devices.

-        start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
-        if start_time:
-            start_time.record()
+        use_cuda_timing = torch.cuda.is_available() and self.device.type == "cuda"
+        if use_cuda_timing:
+            start_event = torch.cuda.Event(enable_timing=True)
+            end_event = torch.cuda.Event(enable_timing=True)
+            start_event.record()
+        else:
+            import time as _time
+            _t0 = _time.perf_counter()
@@
-        finally:
-            if start_time:
-                end_time = torch.cuda.Event(enable_timing=True)
-                end_time.record()
-                torch.cuda.synchronize()
-                inference_time = start_time.elapsed_time(end_time) / 1000.0
-                self._update_performance_metrics(inference_time)
+        finally:
+            if use_cuda_timing:
+                end_event.record()
+                torch.cuda.synchronize()
+                inference_time = start_event.elapsed_time(end_event) / 1000.0
+            else:
+                inference_time = (_time.perf_counter() - _t0)
+            self._update_performance_metrics(inference_time)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
if start_time:
start_time.record()
try:
# Handle single text input
if isinstance(texts, str):
return self._predict_single_text(texts, top_k, return_metadata)
# Handle multiple texts
return self._predict_batch_texts(texts, top_k, return_metadata, batch_size)
except Exception as e:
logger.exception("Emotion prediction failed")
self._error_count += 1
self._last_error = str(e)
raise RuntimeError(f"Emotion prediction failed: {e}") from e
finally:
if start_time:
end_time = torch.cuda.Event(enable_timing=True)
end_time.record()
torch.cuda.synchronize()
inference_time = start_time.elapsed_time(end_time) / 1000.0
self._update_performance_metrics(inference_time)
# Start timing, with CUDA or CPU/MPS fallback
use_cuda_timing = torch.cuda.is_available() and self.device.type == "cuda"
if use_cuda_timing:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
else:
import time as _time
_t0 = _time.perf_counter()
try:
# Handle single text input
if isinstance(texts, str):
return self._predict_single_text(texts, top_k, return_metadata)
# Handle multiple texts
return self._predict_batch_texts(texts, top_k, return_metadata, batch_size)
except Exception as e:
logger.exception("Emotion prediction failed")
self._error_count += 1
self._last_error = str(e)
raise RuntimeError(f"Emotion prediction failed: {e}") from e
finally:
if use_cuda_timing:
end_event.record()
torch.cuda.synchronize()
inference_time = start_event.elapsed_time(end_event) / 1000.0
else:
inference_time = _time.perf_counter() - _t0
self._update_performance_metrics(inference_time)
🧰 Tools
🪛 Ruff (0.12.2)

261-261: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In src/models/emotion_detection/enhanced_bert_classifier.py around lines 245 to
269, the current timing logic only initializes CUDA events and leaves start_time
None on CPU/MPS, so performance metrics aren't updated for non-CUDA runs; change
to use torch.cuda.Event timing when torch.cuda.is_available(), otherwise record
a CPU/MPS start timestamp with time.perf_counter() (ensure time is imported),
and in the finally block compute elapsed via CUDA event elapsed_time when using
CUDA (with torch.cuda.synchronize()), otherwise compute elapsed as
time.perf_counter() - cpu_start; always call
self._update_performance_metrics(inference_time) with seconds and ensure
inference count/time bookkeeping runs for all devices.

Comment thread src/models/emotion_detection/enhanced_config.py
Comment on lines +83 to +86
self.temperature = nn.Parameter(torch.ones(1) * config["temperature"])
self.class_weights = None
self.prediction_threshold = 0.6 # Updated from 0.5 to 0.6 based on calibration

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enforce positive temperature; prevent divide-by-zero/NaN

Guard set_temperature to reject non-positive or non-finite values; current code could yield inf/NaN logits.

 def set_temperature(self, temperature: float) -> None:
     """Set temperature scaling parameter."""
-    self.temperature.data.fill_(temperature)
+    if temperature <= 0.0 or not np.isfinite(temperature):
+        raise ValueError("temperature must be a positive, finite float")
+    self.temperature.data.fill_(float(temperature))
     logger.info("Set temperature to %s", temperature)

Also applies to: 275-279

🤖 Prompt for AI Agents
In src/models/emotion_detection/samo_bert_emotion_classifier.py around lines
83-86 (and similarly update lines 275-279), the temperature parameter is
accepted without validation which can produce zero, negative or non-finite
values leading to inf/NaN logits; update the set_temperature logic to validate
the incoming value is a finite positive scalar (e.g., > 0 and torch.isfinite or
math.isfinite), reject or raise a ValueError for non-positive or non-finite
inputs, and only assign torch.nn.Parameter(torch.tensor(value)) when valid; also
ensure any public setters or config loaders that set temperature use the same
validation to prevent divide-by-zero issues.

Comment on lines +248 to +253
# Get top-k if specified
if top_k is not None:
_, top_k_indices = torch.topk(probabilities, top_k, dim=1)
predictions = torch.zeros_like(probabilities)
predictions.scatter_(1, top_k_indices, 1.0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Clamp top_k and ensure k≥1

Prevent runtime errors when top_k > num classes or 0; clarify precedence vs threshold.

-                if top_k is not None:
-                    _, top_k_indices = torch.topk(probabilities, top_k, dim=1)
+                if top_k is not None:
+                    k = int(max(1, min(top_k, probabilities.size(1))))
+                    _, top_k_indices = torch.topk(probabilities, k, dim=1)
                     predictions = torch.zeros_like(probabilities)
                     predictions.scatter_(1, top_k_indices, 1.0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Get top-k if specified
if top_k is not None:
_, top_k_indices = torch.topk(probabilities, top_k, dim=1)
predictions = torch.zeros_like(probabilities)
predictions.scatter_(1, top_k_indices, 1.0)
# Get top-k if specified
if top_k is not None:
# Clamp top_k to [1, num_classes] to avoid invalid k values
k = int(max(1, min(top_k, probabilities.size(1))))
_, top_k_indices = torch.topk(probabilities, k, dim=1)
predictions = torch.zeros_like(probabilities)
predictions.scatter_(1, top_k_indices, 1.0)
🤖 Prompt for AI Agents
In src/models/emotion_detection/samo_bert_emotion_classifier.py around lines 248
to 253, the top-k logic can raise errors when top_k <= 0 or top_k > number of
classes; clamp and validate top_k before using it: coerce top_k to an int, set
top_k = max(1, min(top_k, num_classes)) where num_classes =
probabilities.size(1), and then compute top_k_indices with the clamped value;
also explicitly document that when top_k is provided it takes precedence over
any probability-threshold branch so the top-k selection is applied regardless of
threshold.

Comment on lines +364 to +386
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get item at index."""
sample_text = self.texts[idx]
labels = self.labels[idx]

# Tokenize text
encoding = self.tokenizer(
sample_text,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt",
)

# Convert labels to tensor
label_tensor = torch.tensor(labels, dtype=torch.float)

return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"token_type_ids": encoding.get("token_type_ids", None),
"labels": label_tensor,
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix None/shape issues for token_type_ids in Dataset items

Returning token_type_ids=None breaks default collate; not squeezing yields shape (B,1,L). Only include the key when present and squeeze.

-        return {
-            "input_ids": encoding["input_ids"].squeeze(0),
-            "attention_mask": encoding["attention_mask"].squeeze(0),
-            "token_type_ids": encoding.get("token_type_ids", None),
-            "labels": label_tensor,
-        }
+        item = {
+            "input_ids": encoding["input_ids"].squeeze(0),
+            "attention_mask": encoding["attention_mask"].squeeze(0),
+            "labels": label_tensor,
+        }
+        if "token_type_ids" in encoding and encoding["token_type_ids"] is not None:
+            item["token_type_ids"] = encoding["token_type_ids"].squeeze(0)
+        return item
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get item at index."""
sample_text = self.texts[idx]
labels = self.labels[idx]
# Tokenize text
encoding = self.tokenizer(
sample_text,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt",
)
# Convert labels to tensor
label_tensor = torch.tensor(labels, dtype=torch.float)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"token_type_ids": encoding.get("token_type_ids", None),
"labels": label_tensor,
}
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get item at index."""
sample_text = self.texts[idx]
labels = self.labels[idx]
# Tokenize text
encoding = self.tokenizer(
sample_text,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt",
)
# Convert labels to tensor
label_tensor = torch.tensor(labels, dtype=torch.float)
item = {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": label_tensor,
}
if "token_type_ids" in encoding and encoding["token_type_ids"] is not None:
item["token_type_ids"] = encoding["token_type_ids"].squeeze(0)
return item
🤖 Prompt for AI Agents
In src/models/emotion_detection/samo_bert_emotion_classifier.py around lines 364
to 386, the Dataset __getitem__ currently returns token_type_ids=None (breaking
default collate) and leaves input tensors with an extra leading dimension;
change it so you only include the "token_type_ids" key when encoding provides
it, and when present squeeze the 0th dim like the other inputs (i.e., if
"token_type_ids" in encoding: include "token_type_ids":
encoding["token_type_ids"].squeeze(0)); ensure input_ids and attention_mask
continue to be squeezed and labels tensor has correct shape/dtype.

- Fix FLK-E501 line length violations in enhanced_config.py
- Break long lines in configuration parsing methods
- Fix temperature parameter overwrite in enhanced_bert_classifier.py
- Preserve user-provided temperature values during initialization
- Improve code readability and maintainability
- Replace EMOTION_CLASSIFICATION_THRESHOLD and EMOTION_PREDICTION_THRESHOLD constants
- Add get_evaluation_threshold() and get_prediction_threshold() functions
- Functions now read from global config to reflect runtime updates
- Update usage in samo_bert_emotion_classifier.py
- Ensures threshold values stay synchronized with config changes
- Change import from .labels to .emotion_labels in enhanced_bert_classifier.py
- Standardize on emotion_labels.py as the canonical source for GOEMOTIONS_EMOTIONS
- Fixes broken import path issue identified by code review
- Maintains consistency with other modules using emotion_labels.py
- Move class weights tensor creation after model initialization
- Use model.device to ensure tensor is on same device as model
- Prevents device mismatch errors during training
- Maintains compatibility with both CPU and CUDA devices
- Improves robustness of loss function creation
- Move GOEMOTIONS_EMOTIONS import from function to module level
- Improves performance by avoiding repeated imports
- Makes code cleaner and more maintainable
- Emotion mapping already working correctly with proper labels
- Maintains existing functionality while improving code structure
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (15)
src/models/emotion_detection/config.py (1)

126-130: update_config should shallow-merge, not replace the singleton

Reassigning _config discards prior fields not present in the update. Merge in place.

Apply this diff:

 def update_config(config_dict: Dict[str, Any]) -> None:
-    """Update global configuration."""
+    """Update global configuration in place (shallow merge)."""
     global _config
-    _config = get_config_from_dict(config_dict)
+    if _config is None:
+        _config = get_default_config()
+    updates = get_config_from_dict(config_dict)
+    for k, v in updates.__dict__.items():
+        setattr(_config, k, v)
src/models/emotion_detection/enhanced_bert_classifier.py (5)

84-87: Avoid double-assigning temperature; initialize once as nn.Parameter

Set the learnable temperature in init and remove the reassignment in _initialize_classifier.

Apply this diff:

@@
-        self.temperature = temperature
+        self.temperature = nn.Parameter(torch.tensor(float(temperature), dtype=torch.float32))
@@
-        self.temperature = nn.Parameter(torch.ones(1) * self.temperature)
+        # temperature already initialized in __init__

Also applies to: 146-147


211-226: Use device-aware autocast; avoid CUDA-only context on CPU/MPS

Switch to torch.autocast with device_type guard.

Apply this diff:

     @contextmanager
     def inference_mode(self):
         """Context manager for inference mode with optimizations."""
         was_training = self.training
         self.eval()
         try:
             with torch.no_grad():
-                if self.use_mixed_precision:
-                    with torch.cuda.amp.autocast():
-                        yield
-                else:
-                    yield
+                device_type = self.device.type
+                enable_amp = self.use_mixed_precision and device_type in ("cuda", "cpu")
+                if enable_amp:
+                    with torch.autocast(device_type=device_type):
+                        yield
+                else:
+                    yield
         finally:
             if was_training:
                 self.train()

245-269: Performance metrics not tracked on CPU/MPS

Only CUDA path updates timing. Add CPU/MPS fallback.

Apply this diff:

-        start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
-        if start_time:
-            start_time.record()
+        use_cuda_timing = torch.cuda.is_available() and self.device.type == "cuda"
+        if use_cuda_timing:
+            _start = torch.cuda.Event(enable_timing=True)
+            _end = torch.cuda.Event(enable_timing=True)
+            _start.record()
+        else:
+            import time as _time
+            _t0 = _time.perf_counter()
@@
-        finally:
-            if start_time:
-                end_time = torch.cuda.Event(enable_timing=True)
-                end_time.record()
-                torch.cuda.synchronize()
-                inference_time = start_time.elapsed_time(end_time) / 1000.0
-                self._update_performance_metrics(inference_time)
+        finally:
+            if use_cuda_timing:
+                _end.record()
+                torch.cuda.synchronize()
+                inference_time = _start.elapsed_time(_end) / 1000.0
+            else:
+                inference_time = (_time.perf_counter() - _t0)
+            self._update_performance_metrics(inference_time)

190-203: Make forward robust: handle missing pooler_output and enforce positive temperature

Fallback to CLS when pooler_output is None and ensure temperature > 0 via softplus; import F.

Apply this diff:

@@
-import torch.nn as nn
-from transformers import AutoConfig, AutoModel, AutoTokenizer
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import AutoConfig, AutoModel, AutoTokenizer
@@
-    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
+    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
         """Forward pass with error handling and optimizations."""
         try:
             # Get BERT outputs
             bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
-            pooled_output = bert_outputs.pooler_output
+            pooled_output = getattr(bert_outputs, "pooler_output", None)
+            if pooled_output is None:
+                pooled_output = bert_outputs.last_hidden_state[:, 0, :]
 
             # Classification head
             logits = self.classifier(pooled_output)
 
             # Apply temperature scaling
-            logits = logits / self.temperature
+            temp = F.softplus(self.temperature) + 1e-6
+            logits = logits / temp

Also applies to: 14-18


320-359: Batch results lose alignment when inputs contain empty strings

Preserve original order by inserting empty predictions for invalid entries.

Apply this diff:

-        # Filter out empty texts
-        valid_texts = [text for text in texts if text and text.strip()]
-        if not valid_texts:
-            return [self._create_empty_prediction(return_metadata) for _ in texts]
+        # Build index map preserving positions
+        index_map: List[int] = []
+        valid_texts: List[str] = []
+        for idx, t in enumerate(texts):
+            if t and t.strip():
+                index_map.append(idx)
+                valid_texts.append(t)
+        if not valid_texts:
+            return [self._create_empty_prediction(return_metadata) for _ in texts]
@@
-        # Process results
-        results = []
-        for i, prob in enumerate(probabilities):
-            result = self._process_prediction_results(
-                prob, top_k, return_metadata, valid_texts[i]
-            )
-            results.append(result)
-
-        return results
+        # Reconstruct results in original order
+        results: List[EmotionPrediction] = [
+            self._create_empty_prediction(return_metadata) for _ in texts
+        ]
+        for i, prob in enumerate(probabilities):
+            results[index_map[i]] = self._process_prediction_results(
+                prob, top_k, return_metadata, valid_texts[i]
+            )
+        return results
src/models/emotion_detection/samo_bert_emotion_classifier.py (6)

193-201: Type hint uses implicit Optional

Change to Optional[float] for threshold.

Apply this diff:

-        threshold: float = None,
+        threshold: Optional[float] = None,

250-254: Clamp top_k to valid range and ensure k ≥ 1

Prevents runtime errors when top_k ≤ 0 or > num_classes.

Apply this diff:

-                if top_k is not None:
-                    _, top_k_indices = torch.topk(probabilities, top_k, dim=1)
+                if top_k is not None:
+                    k = int(max(1, min(top_k, probabilities.size(1))))
+                    _, top_k_indices = torch.topk(probabilities, k, dim=1)
                     predictions = torch.zeros_like(probabilities)
                     predictions.scatter_(1, top_k_indices, 1.0)

276-279: Validate temperature setter

Reject non-positive or non-finite values.

Apply this diff:

 def set_temperature(self, temperature: float) -> None:
     """Set temperature scaling parameter."""
-    self.temperature.data.fill_(temperature)
+    if temperature <= 0.0 or not np.isfinite(temperature):
+        raise ValueError("temperature must be a positive, finite float")
+    self.temperature.data.fill_(float(temperature))
     logger.info("Set temperature to %s", temperature)

325-331: Class weights may be on wrong device and not broadcast correctly

Ensure device alignment and expand to batch shape.

Apply this diff:

         # Apply class weights if provided
         if self.class_weights is not None:
-            bce_loss = bce_loss * self.class_weights.unsqueeze(0)
+            if self.class_weights.device != bce_loss.device:
+                self.class_weights = self.class_weights.to(bce_loss.device)
+            bce_loss = bce_loss * self.class_weights.unsqueeze(0).expand_as(bce_loss)

382-387: Returning token_type_ids=None breaks default collate

Only include token_type_ids when present; squeeze to match shapes.

Apply this diff:

-        return {
-            "input_ids": encoding["input_ids"].squeeze(0),
-            "attention_mask": encoding["attention_mask"].squeeze(0),
-            "token_type_ids": encoding.get("token_type_ids", None),
-            "labels": label_tensor,
-        }
+        item = {
+            "input_ids": encoding["input_ids"].squeeze(0),
+            "attention_mask": encoding["attention_mask"].squeeze(0),
+            "labels": label_tensor,
+        }
+        ttids = encoding.get("token_type_ids", None)
+        if ttids is not None:
+            item["token_type_ids"] = ttids.squeeze(0)
+        return item

188-191: Enforce positive temperature to avoid divide-by-zero/NaN

Use softplus on the learnable temperature before scaling logits.

Apply this diff:

-        # Apply temperature scaling
-        logits = logits / self.temperature
+        # Apply temperature scaling (ensure strictly positive temperature)
+        temp = F.softplus(self.temperature) + 1e-6
+        logits = logits / temp
src/models/emotion_detection/enhanced_config.py (3)

12-12: Persist all sections when saving; use asdict + safe_dump and ensure the directory exists.

Current save drops most sections and uses yaml.dump. Persist the full dataclass tree, create parent dirs, and use safe_dump.

-from dataclasses import dataclass, field
+from dataclasses import dataclass, field, asdict
@@
     def save_config(self, path: Optional[Union[str, Path]] = None) -> None:
         """Save current configuration to file."""
         if path is None:
             path = self.config_path or "configs/samo_emotion_detection_config.yaml"
 
         try:
-            # Convert config to dictionary and save as YAML
-            config_dict = self._config_to_dict()
-            with open(path, 'w', encoding='utf-8') as f:
-                yaml.dump(config_dict, f, default_flow_style=False, indent=2)
+            # Convert config to dictionary and save as YAML
+            path = Path(path)
+            path.parent.mkdir(parents=True, exist_ok=True)
+            config_dict = self._config_to_dict()
+            with open(path, 'w', encoding='utf-8') as f:
+                yaml.safe_dump(config_dict, f, default_flow_style=False, indent=2, sort_keys=False)
             logger.info("Configuration saved to: %s", path)
-        except Exception as e:
-            logger.error("Failed to save configuration: %s", e)
+        except Exception as e:
+            logger.exception("Failed to save configuration: %s", e)
@@
-    def _config_to_dict(self) -> Dict[str, Any]:
-        """Convert configuration to dictionary."""
-        # This would need proper serialization logic
-        # For now, return a basic structure
-        return {
-            "model": {
-                "name": self.config.model.name,
-                "device": self.config.model.device,
-                "use_mixed_precision": self.config.model.use_mixed_precision,
-                "cache_embeddings": self.config.model.cache_embeddings,
-                "max_sequence_length": self.config.model.max_sequence_length,
-            },
-            "emotion_detection": {
-                "num_emotions": self.config.emotion_detection.num_emotions,
-                "prediction_threshold": (
-                    self.config.emotion_detection.prediction_threshold
-                ),
-                "temperature": self.config.emotion_detection.temperature,
-                "top_k": self.config.emotion_detection.top_k,
-            },
-            # Add other sections as needed
-        }
+    def _config_to_dict(self) -> Dict[str, Any]:
+        """Convert configuration to dictionary (all sections)."""
+        return asdict(self.config)

Also applies to: 649-662, 663-685


506-516: Validators should return true defaults, not placeholders; thread defaults through call sites.

Returning "", False, 1, etc. silently loses intended defaults when types are wrong. Accept a default param and return it; pass explicit defaults at call sites.

-    def _validate_string(value: Any, field_name: str) -> str:
+    def _validate_string(value: Any, field_name: str, default: str = "") -> str:
@@
-            return ""
+            return default
         return value
@@
-    def _validate_bool(value: Any, field_name: str) -> bool:
+    def _validate_bool(value: Any, field_name: str, default: bool = False) -> bool:
@@
-            return False
+            return default
         return value
@@
-    def _validate_positive_int(value: Any, field_name: str) -> int:
+    def _validate_positive_int(value: Any, field_name: str, default: int = 1) -> int:
@@
-                return 1
+                return default
@@
-            return 1
+            return default
@@
-    def _validate_non_negative_int(value: Any, field_name: str) -> int:
+    def _validate_non_negative_int(value: Any, field_name: str, default: int = 0) -> int:
@@
-                return 0
+                return default
@@
-            return 0
+            return default
@@
-    def _validate_positive_float(value: Any, field_name: str) -> float:
+    def _validate_positive_float(value: Any, field_name: str, default: float = 1.0) -> float:
@@
-                return 1.0
+                return default
@@
-            return 1.0
+            return default
@@
-    def _validate_float_range(
-        value: Any, min_val: float, max_val: float, field_name: str
-    ) -> float:
+    def _validate_float_range(
+        value: Any, min_val: float, max_val: float, field_name: str, default: Optional[float] = None
+    ) -> float:
@@
-                return (min_val + max_val) / 2
+                return default if default is not None else (min_val + max_val) / 2
@@
-            return (min_val + max_val) / 2
+            return default if default is not None else (min_val + max_val) / 2
@@
-    def _validate_list(value: Any, field_name: str) -> List:
+    def _validate_list(value: Any, field_name: str, default: Optional[List] = None) -> List:
@@
-            return []
+            return default if default is not None else []
         return value

Pass defaults in representative call sites:

-            name=self._validate_string(
-                data.get("name", "bert-base-uncased"), "model.name"
-            ),
+            name=self._validate_string(
+                data.get("name", "bert-base-uncased"), "model.name", "bert-base-uncased"
+            ),
@@
-            prediction_threshold=self._validate_float_range(
-                data.get("prediction_threshold", 0.6), 0.0, 1.0, 
-                "emotion_detection.prediction_threshold"
-            ),
+            prediction_threshold=self._validate_float_range(
+                data.get("prediction_threshold", 0.6), 0.0, 1.0,
+                "emotion_detection.prediction_threshold", 0.6
+            ),
@@
-            padding=self._validate_string(data.get("padding", "max_length"), "data.padding"),
+            padding=self._validate_string(data.get("padding", "max_length"), "data.padding", "max_length"),
@@
-            metrics=self._validate_list(data.get("metrics", ["precision", "recall", "f1_micro", "f1_macro", "accuracy"]), "evaluation.metrics"),
+            metrics=self._validate_list(
+                data.get("metrics", ["precision", "recall", "f1_micro", "f1_macro", "accuracy"]),
+                "evaluation.metrics",
+                ["precision", "recall", "f1_micro", "f1_macro", "accuracy"]
+            ),
@@
-            top_k_values=self._validate_list(data.get("top_k_values", [1, 3, 5]), "evaluation.top_k_values"),
+            top_k_values=self._validate_list(data.get("top_k_values", [1, 3, 5]), "evaluation.top_k_values", [1, 3, 5]),
@@
-            log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir"),
+            log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir", "logs/emotion_detection"),
@@
-            save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir"),
+            save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir", "models/emotion_detection"),
@@
-            save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric"),
+            save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric", "f1_macro"),
@@
-            error_log_file=self._validate_string(
-                data.get("error_log_file", "logs/emotion_detection_errors.log"),
-                "error_handling.error_log_file"
-            ),
+            error_log_file=self._validate_string(
+                data.get("error_log_file", "logs/emotion_detection_errors.log"),
+                "error_handling.error_log_file",
+                "logs/emotion_detection_errors.log"
+            ),

Also applies to: 517-526, 527-543, 544-560, 561-577, 579-596, 627-634, 324-341, 385-395, 396-404, 405-414, 415-423


639-648: update_config is a no-op and static; make it an instance method with a safe shallow merge.

Implement a minimal, documented, section-wise merge and drop the blind except.

-    @staticmethod
-    def update_config(updates: Dict[str, Any]) -> None:
+    def update_config(self, updates: Dict[str, Any]) -> None:
         """Update configuration with new values."""
-        try:
-            # This would need more sophisticated merging logic
-            # For now, just log the attempt
-            logger.info("Configuration update requested: %s", updates)
-        except Exception as e:
-            logger.error("Configuration update failed: %s", e)
+        logger.info("Configuration update requested: %s", updates)
+        if not isinstance(updates, dict):
+            raise TypeError("updates must be a dict of sections -> dict")
+        for section, vals in updates.items():
+            if not hasattr(self.config, section):
+                logger.warning("Unknown config section: %s (skipped)", section)
+                continue
+            section_obj = getattr(self.config, section)
+            if not isinstance(vals, dict):
+                logger.warning("Section %s update must be a dict (skipped)", section)
+                continue
+            for k, v in vals.items():
+                if hasattr(section_obj, k):
+                    setattr(section_obj, k, v)
+                else:
+                    logger.warning("Unknown field %s.%s (skipped)", section, k)
🧹 Nitpick comments (8)
src/models/emotion_detection/config.py (1)

1-1: Remove shebang or mark file executable

Library modules don’t need a shebang; Ruff EXE001. Remove it to silence the lint.

Apply this diff:

-#!/usr/bin/env python3
src/models/emotion_detection/samo_bert_emotion_classifier.py (3)

1-1: Remove shebang or make file executable

Library module doesn’t need a shebang; fixes Ruff EXE001.

Apply this diff:

-#!/usr/bin/env python3

31-37: Avoid configuring global logging and blanket warning suppression in a library

Remove basicConfig and scope warning filter to transformers only.

Apply this diff:

-# Configure logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-# Suppress warnings for cleaner output
-warnings.filterwarnings("ignore", category=UserWarning)
+# Configure module logger (do not set global handlers in libraries)
+logger = logging.getLogger(__name__)
+# Suppress noisy HF warnings only
+warnings.filterwarnings("ignore", category=UserWarning, module="transformers")

212-214: Source prediction threshold from config when available

Use the runtime getter to keep inference aligned with configuration.

Apply this diff:

-        if threshold is None:
-            threshold = self.prediction_threshold
+        if threshold is None:
+            try:
+                from .config import get_prediction_threshold
+                threshold = get_prediction_threshold()
+            except Exception:
+                threshold = self.prediction_threshold
src/models/emotion_detection/enhanced_config.py (4)

230-237: Prefer logging.exception and narrow exception types in except blocks.

Improves diagnostics and satisfies linter (TRY400/BLE001).

-        except yaml.YAMLError as e:
-            logger.error("YAML parsing error: %s", e)
+        except yaml.YAMLError as e:
+            logger.exception("YAML parsing error: %s", e)
             logger.warning("Using default configuration due to YAML error")
             return self._create_default_config()
-        except Exception as e:
-            logger.error("Configuration loading failed: %s", e)
+        except (OSError, UnicodeDecodeError) as e:
+            logger.exception("Configuration loading failed: %s", e)
             logger.warning("Using default configuration due to loading error")
             return self._create_default_config()
@@
-        except Exception as e:
-            logger.error("Configuration parsing failed: %s", e)
+        except (ValueError, TypeError, KeyError) as e:
+            logger.exception("Configuration parsing failed: %s", e)
             logger.warning("Using default configuration due to parsing error")
             return self._create_default_config()

Also applies to: 319-323


36-43: Device handling: return a concrete default (“auto”) instead of None; accept DeviceType.

None vs “auto” is ambiguous and can leak into downstream logic. Standardize on the string.

-@dataclass
-class ModelConfig:
+@dataclass
+class ModelConfig:
@@
-    device: Optional[str] = None
+    device: str = "auto"
-    def _validate_device(value: Any) -> Optional[str]:
+    def _validate_device(value: Any, default: str = "auto") -> str:
         """Validate device value."""
-        if value is None:
-            return None
-        if not isinstance(value, str):
-            logger.warning("Invalid device value: %s, using auto", value)
-            return None
+        if value is None:
+            return default
+        if isinstance(value, DeviceType):
+            value = value.value
+        if not isinstance(value, str):
+            logger.warning("Invalid device value: %s, using %s", value, default)
+            return default
         valid_devices = ["auto", "cpu", "cuda", "mps"]
         if value.lower() not in valid_devices:
-            logger.warning("Invalid device: %s, using auto", value)
-            return None
+            logger.warning("Invalid device: %s, using %s", value, default)
+            return default
         return value.lower()

Would any caller currently rely on device=None to imply “auto”? If so, I’ll adapt the downstream code accordingly.

Also applies to: 597-610


260-303: Add cross-field consistency checks (top_k bounds, split sums).

Prevents invalid-but-parseable configs from slipping through.

     def _parse_config(self, config_data: Dict[str, Any]) -> EnhancedEmotionDetectionConfig:
         """Parse configuration data into structured format."""
         try:
@@
-            return EnhancedEmotionDetectionConfig(
+            # Cross-field normalization
+            if emotion_config.top_k > emotion_config.num_emotions:
+                logger.warning(
+                    "emotion_detection.top_k (%d) > num_emotions (%d); clamping to %d",
+                    emotion_config.top_k, emotion_config.num_emotions, emotion_config.num_emotions
+                )
+                emotion_config.top_k = emotion_config.num_emotions
+            if any(k > emotion_config.num_emotions for k in evaluation_config.top_k_values):
+                evaluation_config.top_k_values = [
+                    min(k, emotion_config.num_emotions) for k in evaluation_config.top_k_values
+                ]
+            if (data_config.validation_split + data_config.test_split) >= 1.0:
+                logger.warning(
+                    "data.validation_split + data.test_split >= 1.0; resetting to 0.2/0.1"
+                )
+                data_config.validation_split = 0.2
+                data_config.test_split = 0.1
+
+            return EnhancedEmotionDetectionConfig(
                 model=model_config,
                 emotion_detection=emotion_config,
                 architecture=architecture_config,
                 training=training_config,
                 data=data_config,
                 evaluation=evaluation_config,
                 logging=logging_config,
                 model_saving=model_saving_config,
                 performance=performance_config,
                 samo_optimizations=samo_optimizations,
                 error_handling=error_handling,
                 security=security_config,
                 development=development_config,
             )

612-626: Optionally accept LogLevel enums in _validate_log_level (or remove unused LogLevel).

Minor DX polish; keeps API consistent with the defined enum.

     def _validate_log_level(value: Any, field_name: str) -> str:
         """Validate log level value."""
+        if isinstance(value, LogLevel):
+            return value.value
         if not isinstance(value, str):
             logger.warning(
                 "Invalid log level for %s: %s, using default", field_name, value
             )
             return "INFO"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fa1a1dd and 04436cb.

📒 Files selected for processing (4)
  • src/models/emotion_detection/config.py (1 hunks)
  • src/models/emotion_detection/enhanced_bert_classifier.py (1 hunks)
  • src/models/emotion_detection/enhanced_config.py (1 hunks)
  • src/models/emotion_detection/samo_bert_emotion_classifier.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
src/models/emotion_detection/enhanced_bert_classifier.py (1)
src/models/emotion_detection/samo_bert_emotion_classifier.py (4)
  • _freeze_bert_layers (145-147)
  • forward (153-191)
  • forward (309-334)
  • predict_emotions (193-274)
src/models/emotion_detection/samo_bert_emotion_classifier.py (1)
src/models/emotion_detection/config.py (1)
  • get_evaluation_threshold (39-41)
src/models/emotion_detection/config.py (1)
src/models/emotion_detection/enhanced_config.py (3)
  • EmotionDetectionConfig (46-51)
  • get_config (635-637)
  • update_config (640-647)
src/models/emotion_detection/enhanced_config.py (2)
src/models/emotion_detection/config.py (3)
  • EmotionDetectionConfig (50-97)
  • get_config (121-123)
  • update_config (126-129)
deployment/cloud-run/secure_api_server.py (1)
  • sanitize_input (158-172)
🪛 Ruff (0.12.2)
src/models/emotion_detection/enhanced_bert_classifier.py

116-116: Consider moving this statement to an else block

(TRY300)


117-117: Do not catch blind exception: Exception

(BLE001)


133-133: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


134-134: Avoid specifying long messages outside the exception class

(TRY003)


203-203: Consider moving this statement to an else block

(TRY300)


209-209: Avoid specifying long messages outside the exception class

(TRY003)


261-261: Avoid specifying long messages outside the exception class

(TRY003)


448-448: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


449-449: Avoid specifying long messages outside the exception class

(TRY003)


515-515: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


516-516: Avoid specifying long messages outside the exception class

(TRY003)


529-529: Consider moving this statement to an else block

(TRY300)


531-531: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


532-532: Avoid specifying long messages outside the exception class

(TRY003)

src/models/emotion_detection/samo_bert_emotion_classifier.py

1-1: Shebang is present but file is not executable

(EXE001)


196-196: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

src/models/emotion_detection/config.py

1-1: Shebang is present but file is not executable

(EXE001)

src/models/emotion_detection/enhanced_config.py

231-231: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


234-234: Do not catch blind exception: Exception

(BLE001)


235-235: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


319-319: Do not catch blind exception: Exception

(BLE001)


320-320: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


537-537: Consider moving this statement to an else block

(TRY300)


554-554: Consider moving this statement to an else block

(TRY300)


571-571: Consider moving this statement to an else block

(TRY300)


590-590: Consider moving this statement to an else block

(TRY300)


646-646: Do not catch blind exception: Exception

(BLE001)


647-647: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


660-660: Do not catch blind exception: Exception

(BLE001)


661-661: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

🔇 Additional comments (1)
src/models/emotion_detection/config.py (1)

39-46: Good fix: thresholds now reflect runtime config

Switching to getters removes the import-time binding issue and aligns with runtime updates.

Comment on lines +105 to +114
def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig:
"""Get configuration from dictionary with defaults."""
default_config = get_default_config()

# Update with provided values
for key, value in config_dict.items():
if hasattr(default_config, key):
setattr(default_config, key, value)

return default_config
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Config update ignores nested YAML structure; implement nested → flat mapping

Passing a YAML-like dict with sections (model/training/evaluation/prediction) won’t update the dataclass fields. Add flattening so nested keys map to attributes.

Apply this diff:

 def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig:
     """Get configuration from dictionary with defaults."""
-    default_config = get_default_config()
-
-    # Update with provided values
-    for key, value in config_dict.items():
-        if hasattr(default_config, key):
-            setattr(default_config, key, value)
-
-    return default_config
+    cfg = get_default_config()
+
+    # Map nested sections to dataclass fields
+    section_map = {
+        "model": {
+            "name": "model_name",
+            "num_emotions": "num_emotions",
+            "hidden_dropout_prob": "hidden_dropout_prob",
+            "classifier_dropout_prob": "classifier_dropout_prob",
+            "freeze_bert_layers": "freeze_bert_layers",
+            "temperature": "temperature",
+        },
+        "training": {
+            "batch_size": "batch_size",
+            "learning_rate": "learning_rate",
+            "num_epochs": "num_epochs",
+            "weight_decay": "weight_decay",
+        },
+        "evaluation": {
+            "threshold": "evaluation_threshold",
+            "top_k": "top_k",
+        },
+        "prediction": {
+            "threshold": "prediction_threshold",
+            "max_length": "max_length",
+        },
+    }
+
+    # Apply nested updates
+    for section, mapping in section_map.items():
+        data = config_dict.get(section)
+        if isinstance(data, dict):
+            for k, v in data.items():
+                attr = mapping.get(k)
+                if attr and hasattr(cfg, attr):
+                    setattr(cfg, attr, v)
+
+    # Apply any flat overrides
+    for key, value in config_dict.items():
+        if not isinstance(value, dict) and hasattr(cfg, key):
+            setattr(cfg, key, value)
+
+    return cfg
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig:
"""Get configuration from dictionary with defaults."""
default_config = get_default_config()
# Update with provided values
for key, value in config_dict.items():
if hasattr(default_config, key):
setattr(default_config, key, value)
return default_config
def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig:
"""Get configuration from dictionary with defaults."""
cfg = get_default_config()
# Map nested sections to dataclass fields
section_map = {
"model": {
"name": "model_name",
"num_emotions": "num_emotions",
"hidden_dropout_prob": "hidden_dropout_prob",
"classifier_dropout_prob": "classifier_dropout_prob",
"freeze_bert_layers": "freeze_bert_layers",
"temperature": "temperature",
},
"training": {
"batch_size": "batch_size",
"learning_rate": "learning_rate",
"num_epochs": "num_epochs",
"weight_decay": "weight_decay",
},
"evaluation": {
"threshold": "evaluation_threshold",
"top_k": "top_k",
},
"prediction": {
"threshold": "prediction_threshold",
"max_length": "max_length",
},
}
# Apply nested updates
for section, mapping in section_map.items():
data = config_dict.get(section)
if isinstance(data, dict):
for k, v in data.items():
attr = mapping.get(k)
if attr and hasattr(cfg, attr):
setattr(cfg, attr, v)
# Apply any flat overrides
for key, value in config_dict.items():
if not isinstance(value, dict) and hasattr(cfg, key):
setattr(cfg, key, value)
return cfg
🤖 Prompt for AI Agents
In src/models/emotion_detection/config.py around lines 105 to 114, the current
get_config_from_dict only updates top-level keys and ignores nested YAML
sections (model/training/evaluation/prediction). Modify the function to flatten
one-level (or recursive) nested dicts into a single mapping where nested keys
map to dataclass attribute names (e.g., config_dict["model"]["lr"] -> "lr"),
then iterate that flattened mapping and setattr on default_config for keys that
exist; ignore unknown keys. Ensure nested dict values override defaults and
preserve non-dict top-level keys.

- Break long lines in configuration parsing methods
- Fix SAMO optimizations, performance, logging, model saving configs
- Fix data and evaluation config parsing methods
- Improve code readability and maintainability
- All lines now comply with 88-character limit
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (7)
src/models/emotion_detection/enhanced_config.py (7)

565-575: Regression note: matches prior feedback.

This validator behavior was previously flagged. The above validator+call-site fixes align with that guidance.

Run to spot any remaining 2-arg validator calls lacking explicit defaults:

#!/bin/bash
rg -nP 'self\._validate_(string|bool|positive_int|non_negative_int|positive_float|log_level|list)\([^,]+,\s*["'\"'"][^"'\"'"]+["'\"'"]\s*\)' src/models/emotion_detection/enhanced_config.py

708-744: Persist the full config; use asdict, safe_dump, and ensure target dir exists.

Current save drops most sections and uses yaml.dump. This risks config loss and unsafe YAML emission.

Apply:

@@
-    def save_config(self, path: Optional[Union[str, Path]] = None) -> None:
+    def save_config(self, path: Optional[Union[str, Path]] = None) -> None:
         """Save current configuration to file."""
         if path is None:
             path = self.config_path or "configs/samo_emotion_detection_config.yaml"
 
         try:
-            # Convert config to dictionary and save as YAML
-            config_dict = self._config_to_dict()
-            with open(path, 'w', encoding='utf-8') as f:
-                yaml.dump(config_dict, f, default_flow_style=False, indent=2)
+            dest = Path(path)
+            dest.parent.mkdir(parents=True, exist_ok=True)
+            config_dict = self._config_to_dict()
+            with dest.open('w', encoding='utf-8') as f:
+                yaml.safe_dump(config_dict, f, default_flow_style=False, indent=2, sort_keys=False)
             logger.info("Configuration saved to: %s", path)
-        except Exception as e:
-            logger.error("Failed to save configuration: %s", e)
+        except (OSError, TypeError) as e:
+            logger.exception("Failed to save configuration: %s", e)
@@
-    def _config_to_dict(self) -> Dict[str, Any]:
-        """Convert configuration to dictionary."""
-        # This would need proper serialization logic
-        # For now, return a basic structure
-        return {
-            "model": {
-                "name": self.config.model.name,
-                "device": self.config.model.device,
-                "use_mixed_precision": self.config.model.use_mixed_precision,
-                "cache_embeddings": self.config.model.cache_embeddings,
-                "max_sequence_length": self.config.model.max_sequence_length,
-            },
-            "emotion_detection": {
-                "num_emotions": self.config.emotion_detection.num_emotions,
-                "prediction_threshold": (
-                    self.config.emotion_detection.prediction_threshold
-                ),
-                "temperature": self.config.emotion_detection.temperature,
-                "top_k": self.config.emotion_detection.top_k,
-            },
-            # Add other sections as needed
-        }
+    def _config_to_dict(self) -> Dict[str, Any]:
+        """Convert configuration to dictionary (all sections)."""
+        return asdict(self.config)

And add the import:

-from dataclasses import dataclass, field
+from dataclasses import dataclass, field, asdict

566-693: Validators return placeholders; return true defaults via a default= parameter.

Log says “using default” but returns "", 0, 1, etc. Make validators accept a default and return it; propagate at call sites.

-    def _validate_string(value: Any, field_name: str) -> str:
+    def _validate_string(value: Any, field_name: str, default: str = "") -> str:
@@
-            return ""
+            return default
         return value
@@
-    def _validate_bool(value: Any, field_name: str) -> bool:
+    def _validate_bool(value: Any, field_name: str, default: bool = False) -> bool:
@@
-            return False
+            return default
         return value
@@
-    def _validate_positive_int(value: Any, field_name: str) -> int:
+    def _validate_positive_int(value: Any, field_name: str, default: int = 1) -> int:
@@
-                return 1
+                return default
             return int_val
         except (ValueError, TypeError):
@@
-            return 1
+            return default
@@
-    def _validate_non_negative_int(value: Any, field_name: str) -> int:
+    def _validate_non_negative_int(value: Any, field_name: str, default: int = 0) -> int:
@@
-                return 0
+                return default
             return int_val
         except (ValueError, TypeError):
@@
-            return 0
+            return default
@@
-    def _validate_positive_float(value: Any, field_name: str) -> float:
+    def _validate_positive_float(value: Any, field_name: str, default: float = 1.0) -> float:
@@
-                return 1.0
+                return default
             return float_val
         except (ValueError, TypeError):
@@
-            return 1.0
+            return default
@@
-    def _validate_float_range(
-        value: Any, min_val: float, max_val: float, field_name: str
-    ) -> float:
+    def _validate_float_range(
+        value: Any, min_val: float, max_val: float, field_name: str, default: Optional[float] = None
+    ) -> float:
@@
-                return (min_val + max_val) / 2
+                return default if default is not None else (min_val + max_val) / 2
             return float_val
         except (ValueError, TypeError):
@@
-            return (min_val + max_val) / 2
+            return default if default is not None else (min_val + max_val) / 2
@@
-    def _validate_device(value: Any) -> Optional[str]:
+    def _validate_device(value: Any, default: Optional[str] = None) -> Optional[str]:
@@
-            return None
+            return default
         if not isinstance(value, str):
-            logger.warning("Invalid device value: %s, using auto", value)
-            return None
+            logger.warning("Invalid device value: %s, using default", value)
+            return default
         valid_devices = ["auto", "cpu", "cuda", "mps"]
         if value.lower() not in valid_devices:
-            logger.warning("Invalid device: %s, using auto", value)
-            return None
+            logger.warning("Invalid device: %s, using default", value)
+            return default
         return value.lower()
@@
-    def _validate_log_level(value: Any, field_name: str) -> str:
+    def _validate_log_level(value: Any, field_name: str, default: str = "INFO") -> str:
@@
-            return "INFO"
+            return default
         valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
         if value.upper() not in valid_levels:
@@
-            return "INFO"
+            return default
         return value.upper()
@@
-    def _validate_list(value: Any, field_name: str) -> List:
+    def _validate_list(value: Any, field_name: str, default: Optional[List] = None) -> List:
@@
-            return []
+            return default if default is not None else []
         return value

698-707: update_config is a no-op and static; implement a shallow, validated merge.

Make it an instance method that merges top-level sections by re-parsing with validators.

-    @staticmethod
-    def update_config(updates: Dict[str, Any]) -> None:
+    def update_config(self, updates: Dict[str, Any]) -> None:
         """Update configuration with new values."""
         try:
-            # This would need more sophisticated merging logic
-            # For now, just log the attempt
-            logger.info("Configuration update requested: %s", updates)
-        except Exception as e:
-            logger.error("Configuration update failed: %s", e)
+            logger.info("Configuration update requested: %s", updates)
+            parsers = {
+                "model": self._parse_model_config,
+                "emotion_detection": self._parse_emotion_detection_config,
+                "architecture": self._parse_architecture_config,
+                "training": self._parse_training_config,
+                "data": self._parse_data_config,
+                "evaluation": self._parse_evaluation_config,
+                "logging": self._parse_logging_config,
+                "model_saving": self._parse_model_saving_config,
+                "performance": self._parse_performance_config,
+                "samo_optimizations": self._parse_samo_optimizations,
+                "error_handling": self._parse_error_handling,
+                "security": self._parse_security_config,
+                "development": self._parse_development_config,
+            }
+            for section, vals in (updates or {}).items():
+                if hasattr(self.config, section) and isinstance(vals, dict) and section in parsers:
+                    current = getattr(self.config, section)
+                    merged = {**asdict(current), **vals}
+                    setattr(self.config, section, parsers[section](merged))
+        except (TypeError, ValueError, AttributeError) as e:
+            logger.exception("Configuration update failed: %s", e)

324-340: Pass explicit defaults to validators (ModelConfig).

Ensure invalid user input falls back to true defaults instead of placeholders.

         return ModelConfig(
-            name=self._validate_string(
-                data.get("name", "bert-base-uncased"), "model.name"
-            ),
-            device=self._validate_device(data.get("device")),
-            use_mixed_precision=self._validate_bool(
-                data.get("use_mixed_precision", True), "model.use_mixed_precision"
-            ),
-            cache_embeddings=self._validate_bool(
-                data.get("cache_embeddings", False), "model.cache_embeddings"
-            ),
-            max_sequence_length=self._validate_positive_int(
-                data.get("max_sequence_length", 512), "model.max_sequence_length"
-            ),
+            name=self._validate_string(data.get("name", "bert-base-uncased"), "model.name", "bert-base-uncased"),
+            device=self._validate_device(data.get("device"), None),
+            use_mixed_precision=self._validate_bool(data.get("use_mixed_precision", True), "model.use_mixed_precision", True),
+            cache_embeddings=self._validate_bool(data.get("cache_embeddings", False), "model.cache_embeddings", False),
+            max_sequence_length=self._validate_positive_int(data.get("max_sequence_length", 512), "model.max_sequence_length", 512),
         )

342-358: Clamp top_k to num_emotions and pass explicit defaults.

Prevents invalid combinations (e.g., top_k > num_emotions) and aligns with config semantics.

-        return EmotionDetectionConfig(
-            num_emotions=self._validate_positive_int(
-                data.get("num_emotions", 28), "emotion_detection.num_emotions"
-            ),
-            prediction_threshold=self._validate_float_range(
-                data.get("prediction_threshold", 0.6), 0.0, 1.0, 
-                "emotion_detection.prediction_threshold"
-            ),
-            temperature=self._validate_positive_float(
-                data.get("temperature", 1.0), "emotion_detection.temperature"
-            ),
-            top_k=self._validate_positive_int(
-                data.get("top_k", 5), "emotion_detection.top_k"
-            ),
-        )
+        num_emotions = self._validate_positive_int(data.get("num_emotions", 28), "emotion_detection.num_emotions", 28)
+        prediction_threshold = self._validate_float_range(
+            data.get("prediction_threshold", 0.6), 0.0, 1.0, "emotion_detection.prediction_threshold", 0.6
+        )
+        temperature = self._validate_positive_float(data.get("temperature", 1.0), "emotion_detection.temperature", 1.0)
+        top_k = self._validate_positive_int(data.get("top_k", 5), "emotion_detection.top_k", 5)
+        if top_k > num_emotions:
+            logger.warning("emotion_detection.top_k (%s) > num_emotions (%s); clamping to num_emotions", top_k, num_emotions)
+            top_k = num_emotions
+        return EmotionDetectionConfig(
+            num_emotions=num_emotions,
+            prediction_threshold=prediction_threshold,
+            temperature=temperature,
+            top_k=top_k,
+        )

426-461: Logging and model_saving: pass explicit defaults to validators.

         return LoggingConfig(
-            level=self._validate_log_level(
-                data.get("level", "INFO"), "logging.level"
-            ),
+            level=self._validate_log_level(data.get("level", "INFO"), "logging.level", "INFO"),
@@
-            log_dir=self._validate_string(
-                data.get("log_dir", "logs/emotion_detection"), "logging.log_dir"
-            ),
+            log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir", "logs/emotion_detection"),
         )
@@
         return ModelSavingConfig(
-            save_dir=self._validate_string(
-                data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir"
-            ),
-            save_best_metric=self._validate_string(
-                data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric"
-            ),
+            save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir", "models/emotion_detection"),
+            save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric", "f1_macro"),
@@
-            save_checkpoints=self._validate_bool(
-                data.get("save_checkpoints", True), "model_saving.save_checkpoints"
-            ),
-            checkpoint_interval=self._validate_positive_int(
-                data.get("checkpoint_interval", 1), "model_saving.checkpoint_interval"
-            ),
+            save_checkpoints=self._validate_bool(data.get("save_checkpoints", True), "model_saving.save_checkpoints", True),
+            checkpoint_interval=self._validate_positive_int(data.get("checkpoint_interval", 1), "model_saving.checkpoint_interval", 1),
         )
🧹 Nitpick comments (6)
src/models/emotion_detection/enhanced_config.py (6)

408-424: Use defaults for metrics/top_k_values and threshold.

Ensures invalid inputs fall back to configured defaults.

-            metrics=self._validate_list(
-                data.get("metrics", ["precision", "recall", "f1_micro", "f1_macro", "accuracy"]), 
-                "evaluation.metrics"
-            ),
-            threshold=self._validate_float_range(
-                data.get("threshold", 0.2), 0.0, 1.0, "evaluation.threshold"
-            ),
+            metrics=self._validate_list(
+                data.get("metrics", ["precision", "recall", "f1_micro", "f1_macro", "accuracy"]),
+                "evaluation.metrics",
+                ["precision", "recall", "f1_micro", "f1_macro", "accuracy"],
+            ),
+            threshold=self._validate_float_range(
+                data.get("threshold", 0.2), 0.0, 1.0, "evaluation.threshold", 0.2
+            ),
@@
-            top_k_values=self._validate_list(
-                data.get("top_k_values", [1, 3, 5]), "evaluation.top_k_values"
-            ),
+            top_k_values=self._validate_list(data.get("top_k_values", [1, 3, 5]), "evaluation.top_k_values", [1, 3, 5]),

463-563: Propagate explicit defaults across remaining parse_ sections.*

Keeps behavior consistent with validator changes; representative changes below.

-            use_amp=self._validate_bool(
-                data.get("use_amp", True), "performance.use_amp"
-            ),
+            use_amp=self._validate_bool(data.get("use_amp", True), "performance.use_amp", True),
@@
-            num_workers=self._validate_non_negative_int(
-                data.get("num_workers", 4), "performance.num_workers"
-            ),
+            num_workers=self._validate_non_negative_int(data.get("num_workers", 4), "performance.num_workers", 4),
@@
-            pin_memory=self._validate_bool(
-                data.get("pin_memory", True), "performance.pin_memory"
-            ),
+            pin_memory=self._validate_bool(data.get("pin_memory", True), "performance.pin_memory", True),
@@
-            gradient_checkpointing=self._validate_bool(
-                data.get("gradient_checkpointing", False), "performance.gradient_checkpointing"
-            ),
+            gradient_checkpointing=self._validate_bool(data.get("gradient_checkpointing", False), "performance.gradient_checkpointing", False),
@@
-            use_torchscript=self._validate_bool(
-                data.get("use_torchscript", False), "performance.use_torchscript"
-            ),
+            use_torchscript=self._validate_bool(data.get("use_torchscript", False), "performance.use_torchscript", False),

If helpful, I can push a mechanical patch that updates all remaining call sites in one go.


230-237: Avoid blind except; log stack traces with logger.exception.

Improve debuggability and satisfy linters without changing behavior.

-        except yaml.YAMLError as e:
-            logger.error("YAML parsing error: %s", e)
+        except yaml.YAMLError as e:
+            logger.exception("YAML parsing error: %s", e)
@@
-        except Exception as e:
-            logger.error("Configuration loading failed: %s", e)
+        except (OSError, TypeError, ValueError) as e:
+            logger.exception("Configuration loading failed: %s", e)
@@
-        except Exception as e:
-            logger.error("Configuration parsing failed: %s", e)
+        except (TypeError, ValueError, KeyError) as e:
+            logger.exception("Configuration parsing failed: %s", e)

Also applies to: 319-323


18-33: Enums are unused; remove or integrate.

Either wire LogLevel/DeviceType into validators and dataclasses or drop them to avoid dead code.


441-443: Minor: path-like fields should be normalized to str and directories ensured on use.

You already ensure dirs in save_config; consider normalizing log_dir/error_log_file to str at parse time for consistency.

Also applies to: 521-524


260-318: Sanity checks across sections (top_k vs num_emotions already handled).

Optionally verify evaluation.top_k_values ≤ emotion_detection.num_emotions and >0.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 04436cb and ca09ca1.

📒 Files selected for processing (1)
  • src/models/emotion_detection/enhanced_config.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/models/emotion_detection/enhanced_config.py (2)
src/models/emotion_detection/config.py (3)
  • EmotionDetectionConfig (50-97)
  • get_config (121-123)
  • update_config (126-129)
deployment/cloud-run/secure_api_server.py (1)
  • sanitize_input (158-172)
🪛 Ruff (0.12.2)
src/models/emotion_detection/enhanced_config.py

231-231: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


234-234: Do not catch blind exception: Exception

(BLE001)


235-235: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


319-319: Do not catch blind exception: Exception

(BLE001)


320-320: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


596-596: Consider moving this statement to an else block

(TRY300)


613-613: Consider moving this statement to an else block

(TRY300)


630-630: Consider moving this statement to an else block

(TRY300)


649-649: Consider moving this statement to an else block

(TRY300)


705-705: Do not catch blind exception: Exception

(BLE001)


706-706: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


719-719: Do not catch blind exception: Exception

(BLE001)


720-720: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

Comment on lines +385 to +406
def _parse_data_config(self, data: Dict[str, Any]) -> DataConfig:
"""Parse data configuration with validation."""
return DataConfig(
max_length=self._validate_positive_int(
data.get("max_length", 512), "data.max_length"
),
truncation=self._validate_bool(
data.get("truncation", True), "data.truncation"
),
padding=self._validate_string(
data.get("padding", "max_length"), "data.padding"
),
enable_augmentation=self._validate_bool(
data.get("enable_augmentation", False), "data.enable_augmentation"
),
validation_split=self._validate_float_range(
data.get("validation_split", 0.2), 0.0, 1.0, "data.validation_split"
),
test_split=self._validate_float_range(
data.get("test_split", 0.1), 0.0, 1.0, "data.test_split"
),
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Validate data splits sum and pass explicit defaults.

Avoids val+test ≥ 1.0, which would leave no training data.

-        return DataConfig(
-            max_length=self._validate_positive_int(
-                data.get("max_length", 512), "data.max_length"
-            ),
-            truncation=self._validate_bool(
-                data.get("truncation", True), "data.truncation"
-            ),
-            padding=self._validate_string(
-                data.get("padding", "max_length"), "data.padding"
-            ),
-            enable_augmentation=self._validate_bool(
-                data.get("enable_augmentation", False), "data.enable_augmentation"
-            ),
-            validation_split=self._validate_float_range(
-                data.get("validation_split", 0.2), 0.0, 1.0, "data.validation_split"
-            ),
-            test_split=self._validate_float_range(
-                data.get("test_split", 0.1), 0.0, 1.0, "data.test_split"
-            ),
-        )
+        val_split = self._validate_float_range(data.get("validation_split", 0.2), 0.0, 1.0, "data.validation_split", 0.2)
+        test_split = self._validate_float_range(data.get("test_split", 0.1), 0.0, 1.0, "data.test_split", 0.1)
+        if val_split + test_split >= 1.0:
+            new_test = max(0.0, 1.0 - val_split)
+            logger.warning("data.validation_split + data.test_split >= 1.0; adjusting test_split to %.3f", new_test)
+            test_split = new_test
+        return DataConfig(
+            max_length=self._validate_positive_int(data.get("max_length", 512), "data.max_length", 512),
+            truncation=self._validate_bool(data.get("truncation", True), "data.truncation", True),
+            padding=self._validate_string(data.get("padding", "max_length"), "data.padding", "max_length"),
+            enable_augmentation=self._validate_bool(data.get("enable_augmentation", False), "data.enable_augmentation", False),
+            validation_split=val_split,
+            test_split=test_split,
+        )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/models/emotion_detection/enhanced_config.py around lines 385 to 406, the
data split parsing must ensure validation_split and test_split are explicitly
defaulted and their sum is less than 1.0; change the method to read each split
into local variables using the intended defaults (0.2 and 0.1), validate each
via the existing _validate_float_range, then check if validation_split +
test_split >= 1.0 and raise a clear ValueError (or config validation exception)
if so; return the DataConfig with those validated values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants