Feat/dl add emotion detection enhancements#150
Conversation
- Enhanced BERT-based emotion classifier for journal entries - Multi-label emotion classification (28 emotions from GoEmotions) - Temperature scaling for calibrated predictions - Comprehensive emotion labels and descriptions - Standalone test script for validation - Configuration file with SAMO-specific optimizations - Error handling and logging improvements Files: - src/models/emotion_detection/samo_bert_emotion_classifier.py - src/models/emotion_detection/emotion_labels.py - configs/samo_emotion_detection_config.yaml - test_samo_emotion_detection_standalone.py This completes PR-3 of the surgical breakdown plan.
- Fixed NameError for classifier_dropout_prob and freeze_bert_layers - Updated constructor to use self.classifier_dropout_prob and self.freeze_bert_layers - Model now initializes correctly and passes standalone tests - Maintains 110M parameters with 66M frozen BERT layers Part of PR-3: Emotion Detection model completion
Reviewer's GuideThis PR introduces a complete emotion detection pipeline featuring a new BERT-based multi-label classifier with calibration and regularization, a support module for emotion labels and groupings, a standalone test harness, and a centralized YAML configuration for all related parameters. File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a SAMO emotion-detection subsystem: YAML configuration, two config modules/managers, a GoEmotions label utility, two BERT-based multi-label classifiers (SAMO and Enhanced) with training/eval helpers, loss/dataset utilities, save/load and performance hooks, and two end-to-end test scripts. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User / App
participant CFG as Config Manager / YAML
participant F as Factory
participant C as Classifier (SAMO/Enhanced)
participant T as Tokenizer
participant B as BERT Encoder
participant H as Classifier Head
participant P as Post-process
U->>CFG: load / update config
CFG-->>F: provide hyperparams, thresholds, device
F-->>C: init model (bert, head, temp, freeze layers)
U->>C: predict_emotions(texts, threshold/top_k, batch_size)
C->>T: tokenize(texts, max_length, truncation/padding)
T-->>C: input tensors
C->>B: encode(inputs)
B-->>C: pooled_output / cls_repr
C->>H: classifier head -> logits
H-->>C: logits
C->>P: apply temperature, sigmoid, threshold/top-k
P-->>U: labels, probabilities, metadata
Note over C: optional mixed-precision, caching, performance logging
sequenceDiagram
autonumber
participant Trainer as Training Loop
participant D as DataLoader / EmotionDataset
participant M as Model (SAMO/Enhanced)
participant L as WeightedBCELoss
participant E as Evaluator
Trainer->>D: load batches (texts, multi-label targets)
D-->>Trainer: batched tensors
Trainer->>M: forward(input tensors)
M-->>Trainer: logits
Trainer->>L: compute loss(logits, targets)
L-->>Trainer: loss scalar
Trainer->>M: backward() / optimizer step
Trainer->>E: evaluate_emotion_classifier(model, val_dataloader)
E-->>Trainer: metrics (precision, recall, f1_micro, f1_macro)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Pre-merge checks (3 passed)✅ Passed checks (3 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Summary of Changes
Hello @uelkerd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request integrates a new, advanced emotion detection capability into the SAMO-DL system. It establishes a foundational framework for analyzing emotional states in text, providing a highly customizable and robust solution for journal entry analysis and similar applications. The changes ensure that the system can accurately identify and categorize a wide range of human emotions.
Highlights
- New Emotion Detection System: Introduces a complete, enhanced emotion detection system for the SAMO-DL project, leveraging a BERT-based classifier.
- Comprehensive Emotion Taxonomy: Adds a dedicated module defining 28 emotion categories based on the GoEmotions dataset, including valence, arousal, and dominance groupings, along with descriptions and synonyms.
- Configurable Model and Training Parameters: Provides a detailed YAML configuration file for the emotion detection model, allowing fine-tuning of BERT architecture, training parameters, data processing, and evaluation metrics.
- Robust BERT-based Classifier: Implements a multi-label BERT emotion classifier with features like temperature scaling for calibrated predictions, dropout regularization, and configurable BERT layer freezing.
- Standalone Test Suite: Includes a comprehensive standalone test script to validate the model's initialization, prediction accuracy, batch processing, temperature scaling, and performance across various text lengths.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
|
Here's the code health analysis summary for commits Analysis Summary
|
There was a problem hiding this comment.
Pull Request Overview
This PR implements a comprehensive emotion detection system for the SAMO-DL project, adding BERT-based emotion classification capabilities with temperature scaling and multi-label support. The system is designed to detect 28 emotion categories (27 GoEmotions + neutral) from journal entries and other textual content.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
test_samo_emotion_detection_standalone.py |
Standalone test script for comprehensive model validation |
src/models/emotion_detection/samo_bert_emotion_classifier.py |
Core BERT-based emotion classifier implementation |
src/models/emotion_detection/emotion_labels.py |
Emotion labels and categorization system based on GoEmotions |
configs/samo_emotion_detection_config.yaml |
Configuration file for model parameters and training settings |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
There was a problem hiding this comment.
Hey there - I've reviewed your changes and they look great!
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location> `src/models/emotion_detection/samo_bert_emotion_classifier.py:177` </location>
<code_context>
+ )
+
+ # Use [CLS] token representation for classification
+ pooled_output = bert_outputs.pooler_output
+
+ # Pass through classification head
</code_context>
<issue_to_address>
Using pooler_output may not be robust for all transformer models.
Some models may not have pooler_output or it could be None. Add a check and consider using the first token's hidden state as a fallback.
</issue_to_address>
### Comment 2
<location> `src/models/emotion_detection/samo_bert_emotion_classifier.py:255` </location>
<code_context>
+ batch_probabilities = probabilities.cpu().numpy()
+
+ # Get emotion names for predictions
+ for pred in batch_predictions:
+ emotions = [
+ f"emotion_{i}" for i, p in enumerate(pred) if p > 0
</code_context>
<issue_to_address>
Emotion names are generic and not mapped to actual labels.
Consider mapping prediction indices to actual emotion labels from the emotion_labels module to improve interpretability.
Suggested implementation:
```python
# Get emotion names for predictions
from emotion_labels import EMOTION_LABELS # Make sure this is the correct import path and variable name
for pred in batch_predictions:
emotions = [
EMOTION_LABELS[i] for i, p in enumerate(pred) if p > 0
]
all_emotions.append(emotions)
```
- Ensure that the `emotion_labels` module is available and contains a list or tuple named `EMOTION_LABELS` that maps indices to emotion names.
- If the import path or variable name is different, adjust `from emotion_labels import EMOTION_LABELS` accordingly.
</issue_to_address>
### Comment 3
<location> `src/models/emotion_detection/samo_bert_emotion_classifier.py:435` </location>
<code_context>
+ model: SAMOBERTEmotionClassifier,
+ dataloader: DataLoader,
+ device: torch.device,
+ threshold: float = 0.2, # Lowered from 0.5 to capture more predictions
+) -> Dict[str, float]:
+ """
</code_context>
<issue_to_address>
Evaluation threshold is hardcoded and may not match training threshold.
Recommend sourcing the evaluation threshold from configuration or matching it to the training threshold for consistency.
Suggested implementation:
```python
from config import EMOTION_CLASSIFICATION_THRESHOLD
def evaluate_emotion_classifier(
model: SAMOBERTEmotionClassifier,
dataloader: DataLoader,
device: torch.device,
) -> Dict[str, float]:
"""
```
```python
# ... inside the function ...
preds = (outputs > EMOTION_CLASSIFICATION_THRESHOLD).float()
```
If you do not already have a configuration module or object, you will need to create one (e.g., `config.py`) and ensure the training code also uses `EMOTION_CLASSIFICATION_THRESHOLD` for consistency. You may also need to update any calls to `evaluate_emotion_classifier` to remove the `threshold` argument.
</issue_to_address>
### Comment 4
<location> `test_samo_emotion_detection_standalone.py:50` </location>
<code_context>
+def test_emotion_predictions(model, all_emotions):
</code_context>
<issue_to_address>
Consider adding tests for edge cases such as empty strings, non-English text, and texts with ambiguous or mixed emotions.
Please include tests for empty input, long texts, ambiguous emotion, non-English text, and mixed emotions to better assess model robustness.
</issue_to_address>
### Comment 5
<location> `test_samo_emotion_detection_standalone.py:87` </location>
<code_context>
+ print(f" {emotion_name}: {prob:.3f}")
+
+
+def test_batch_predictions(model, test_texts):
+ """Test batch prediction functionality."""
+ print("\n5. Testing batch prediction...")
+ batch_results = model.predict_emotions(test_texts[:3], threshold=0.3)
+
+ print(f" Batch size: {len(batch_results['emotions'])}")
+ print(f" All predictions successful: {len(batch_results['emotions']) == 3}")
+
+
</code_context>
<issue_to_address>
Add assertions to batch prediction tests to automatically verify expected output structure and values.
Please add assertions to verify output structure, prediction count, non-empty emotion lists, and valid probability ranges, rather than relying on print statements.
</issue_to_address>
<suggested_fix>
<<<<<<< SEARCH
def test_batch_predictions(model, test_texts):
"""Test batch prediction functionality."""
print("\n5. Testing batch prediction...")
batch_results = model.predict_emotions(test_texts[:3], threshold=0.3)
print(f" Batch size: {len(batch_results['emotions'])}")
print(f" All predictions successful: {len(batch_results['emotions']) == 3}")
=======
def test_batch_predictions(model, test_texts):
"""Test batch prediction functionality with assertions."""
batch_results = model.predict_emotions(test_texts[:3], threshold=0.3)
# Assert output structure
assert isinstance(batch_results, dict), "Batch results should be a dictionary."
assert "emotions" in batch_results, "'emotions' key missing in batch results."
assert isinstance(batch_results["emotions"], list), "'emotions' should be a list."
# Assert prediction count
assert len(batch_results["emotions"]) == 3, "Batch size should be 3."
# Assert non-empty emotion lists and valid probability ranges
for i, emotion_list in enumerate(batch_results["emotions"]):
assert isinstance(emotion_list, list), f"Prediction {i} should be a list of emotions."
assert len(emotion_list) > 0, f"Prediction {i} should not be empty."
for emotion in emotion_list:
assert "name" in emotion, f"Emotion dict missing 'name' key in prediction {i}."
assert "probability" in emotion, f"Emotion dict missing 'probability' key in prediction {i}."
prob = emotion["probability"]
assert 0.0 <= prob <= 1.0, f"Probability {prob} out of range in prediction {i}."
>>>>>>> REPLACE
</suggested_fix>
### Comment 6
<location> `test_samo_emotion_detection_standalone.py:113` </location>
<code_context>
+ print(f" Hot temperature (2.0): {len(results_hot['emotions'][0])} emotions")
+
+
+def test_prediction_thresholds(model):
+ """Test different prediction thresholds."""
+ print("\n7. Testing different prediction thresholds...")
+ test_text = "I feel both happy and sad about this situation."
+
+ for threshold in [0.1, 0.3, 0.5, 0.7]:
+ results = model.predict_emotions(test_text, threshold=threshold)
+ emotions = results['emotions'][0]
+ print(f" Threshold {threshold}: {len(emotions)} emotions - {emotions}")
+
+
</code_context>
<issue_to_address>
Add assertions to verify that increasing the threshold reduces the number of detected emotions.
Please add assertions to ensure that increasing the threshold does not result in more detected emotions, confirming the expected behavior.
</issue_to_address>
<suggested_fix>
<<<<<<< SEARCH
def test_prediction_thresholds(model):
"""Test different prediction thresholds."""
print("\n7. Testing different prediction thresholds...")
test_text = "I feel both happy and sad about this situation."
for threshold in [0.1, 0.3, 0.5, 0.7]:
results = model.predict_emotions(test_text, threshold=threshold)
emotions = results['emotions'][0]
print(f" Threshold {threshold}: {len(emotions)} emotions - {emotions}")
=======
def test_prediction_thresholds(model):
"""Test different prediction thresholds."""
print("\n7. Testing different prediction thresholds...")
test_text = "I feel both happy and sad about this situation."
thresholds = [0.1, 0.3, 0.5, 0.7]
num_emotions = []
for threshold in thresholds:
results = model.predict_emotions(test_text, threshold=threshold)
emotions = results['emotions'][0]
num_emotions.append(len(emotions))
print(f" Threshold {threshold}: {len(emotions)} emotions - {emotions}")
# Assert that increasing the threshold does not increase the number of detected emotions
for i in range(1, len(num_emotions)):
assert num_emotions[i] <= num_emotions[i-1], (
f"Number of emotions at threshold {thresholds[i]} ({num_emotions[i]}) "
f"should not be greater than at threshold {thresholds[i-1]} ({num_emotions[i-1]})"
)
>>>>>>> REPLACE
</suggested_fix>
### Comment 7
<location> `test_samo_emotion_detection_standalone.py:169` </location>
<code_context>
+def test_performance():
</code_context>
<issue_to_address>
Add assertions or checks for performance test to ensure model handles very long texts and edge cases gracefully.
Add assertions to verify the model does not crash, timeout, or produce invalid results with very long, short, or empty texts.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| batch_probabilities = probabilities.cpu().numpy() | ||
|
|
||
| # Get emotion names for predictions | ||
| for pred in batch_predictions: |
There was a problem hiding this comment.
suggestion: Emotion names are generic and not mapped to actual labels.
Consider mapping prediction indices to actual emotion labels from the emotion_labels module to improve interpretability.
Suggested implementation:
# Get emotion names for predictions
from emotion_labels import EMOTION_LABELS # Make sure this is the correct import path and variable name
for pred in batch_predictions:
emotions = [
EMOTION_LABELS[i] for i, p in enumerate(pred) if p > 0
]
all_emotions.append(emotions)- Ensure that the
emotion_labelsmodule is available and contains a list or tuple namedEMOTION_LABELSthat maps indices to emotion names. - If the import path or variable name is different, adjust
from emotion_labels import EMOTION_LABELSaccordingly.
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive emotion detection feature, including a new BERT-based classifier, configuration, emotion label definitions, and standalone tests. The implementation is solid, with a well-structured configuration and detailed emotion label definitions. My review focuses on improving the model's implementation for better numerical stability, more interpretable outputs, and enhanced maintainability. I've also suggested a fix for the test script to correctly interpret the model's output.
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
configs/samo_emotion_detection_config.yaml (1)
1-187: Add YAML document start marker
Insert---on the first line ofconfigs/samo_emotion_detection_config.yamlto satisfy thedocument-startrule in yamllint.
🧹 Nitpick comments (20)
src/models/emotion_detection/emotion_labels.py (7)
1-1: Remove shebang or make file executable.This is a library module, not a script. Drop the shebang to satisfy EXE001.
-#!/usr/bin/env python3
12-12: Drop unused import.
Tupleis unused.-from typing import List, Dict, Tuple +from typing import List, Dict
14-44: Make labels immutable and prepare O(1) lookups.Use a tuple for immutability and add an index map for constant-time
get_emotion_index.-# GoEmotions emotion categories (27 emotions + neutral = 28 total) -GOEMOTIONS_EMOTIONS = [ +# GoEmotions emotion categories (27 emotions + neutral = 28 total) +GOEMOTIONS_EMOTIONS = ( "admiration", # 0 "amusement", # 1 "anger", # 2 "annoyance", # 3 "approval", # 4 "caring", # 5 "confusion", # 6 "curiosity", # 7 "desire", # 8 "disappointment", # 9 "disapproval", # 10 "disgust", # 11 "embarrassment", # 12 "excitement", # 13 "fear", # 14 "gratitude", # 15 "grief", # 16 "joy", # 17 "love", # 18 "nervousness", # 19 "optimism", # 20 "pride", # 21 "realization", # 22 "relief", # 23 "remorse", # 24 "sadness", # 25 "surprise", # 26 "neutral", # 27 -] +) + +# Name → index map for fast lookup +EMOTION_TO_INDEX: Dict[str, int] = {name: i for i, name in enumerate(GOEMOTIONS_EMOTIONS)}
199-203: Tighten the IndexError message (TRY003).Shorter, standard message.
- else: - raise IndexError(f"Index {index} out of range for GoEmotions list") + else: + raise IndexError("emotion index out of range")
215-216: Make group lookups case-insensitive.Aligns with case-insensitive emotion lookups elsewhere.
- return EMOTION_VALENCE_GROUPS.get(valence, []) + return EMOTION_VALENCE_GROUPS.get(valence.lower(), [])- return EMOTION_AROUSAL_GROUPS.get(arousal, []) + return EMOTION_AROUSAL_GROUPS.get(arousal.lower(), [])- return EMOTION_DOMINANCE_GROUPS.get(dominance, []) + return EMOTION_DOMINANCE_GROUPS.get(dominance.lower(), [])Also applies to: 228-229, 241-242
337-337: Remove unnecessary f-string.Fixes F541.
- print(f"\nEmotion descriptions:") + print("\nEmotion descriptions:")
270-277: Return a tuple forget_all_emotionsto reflect immutability.Optional, but consistent with making the constant a tuple.
-def get_all_emotions() -> List[str]: +def get_all_emotions() -> List[str]: @@ - return GOEMOTIONS_EMOTIONS.copy() + return list(GOEMOTIONS_EMOTIONS)configs/samo_emotion_detection_config.yaml (3)
11-22: Thresholds: confirm intent.You set prediction_threshold to 0.6 (inference) but evaluation.threshold to 0.2. If unintentional, harmonize; if intentional, add a brief comment noting the rationale.
emotion_detection: # Number of emotion categories (27 GoEmotions + neutral) num_emotions: 28 - # Prediction threshold for binary classification - prediction_threshold: 0.6 # Updated from 0.5 for better calibration + prediction_threshold: 0.6 # Inference default. Keep distinct from evaluation.threshold if desired.
39-59: Add training seed for reproducibility.Helps stabilize experiments.
training: @@ warmup_steps: 100 + seed: 42
113-127: Optionally enable deterministic cuDNN toggle.Only if you care about exact reproducibility across runs.
performance: @@ gradient_checkpointing: false # Can be enabled for memory savings + cudnn_deterministic: true # Optional: reproducibility over speedtest_samo_emotion_detection_standalone.py (7)
1-1: Remove shebang or mark file executable.For a test script run via
python, the shebang isn’t needed (EXE001).-#!/usr/bin/env python3
13-15: Harden sys.path injection.Resolve absolute path and avoid duplicate entries.
-# Add src to path -sys.path.insert(0, str(Path(__file__).parent / "src")) +# Add src to path +src_path = str((Path(__file__).resolve().parent / "src")) +if src_path not in sys.path: + sys.path.insert(0, src_path)
72-76: Show human-readable emotion names, notemotion_{i}.
predict_emotionscurrently returns generic label ids; map to actual names for display.- emotions = results['emotions'][0] + pred_vec = results['predictions'][0] + emotions = [all_emotions[i] for i, p in enumerate(pred_vec) if p > 0]If you prefer to fix this at the source, update
predict_emotionsto map indices viaemotion_labels.get_emotion_name(i).
188-189: Movetimeimport to top-level.Minor cleanliness.
- import timeAdd near the other imports:
+import time
155-168: Avoid catching blindException(BLE001).Either let failures surface or narrow the exceptions.
- try: + try: model, all_emotions, trainable_params = run_all_tests() @@ - except Exception as e: + except (RuntimeError, OSError, ValueError) as e: print(f"❌ Error testing emotion classifier: {e}") import traceback traceback.print_exc() raise
200-202: Same here: narrow theexceptclause.Consistent with above.
- except Exception as e: + except (RuntimeError, OSError, ValueError) as e: print(f"❌ Error in performance test: {e}")
50-63: Sanity-check thresholds vs config.Tests hardcode threshold=0.3 while config default is 0.6. If you want tests to reflect defaults, consider pulling from config or a single source of truth.
src/models/emotion_detection/samo_bert_emotion_classifier.py (3)
86-86: Document the calibrated threshold valueThe prediction threshold is hardcoded to 0.6 with a comment mentioning it's "based on calibration", but there's no documentation about how this calibration was performed or why this specific value was chosen.
Consider adding more detailed documentation or making this configurable:
- self.prediction_threshold = 0.6 # Updated from 0.5 to 0.6 based on calibration + # Default prediction threshold calibrated on validation data + # Higher threshold (0.6) reduces false positives for emotion detection + self.prediction_threshold = config.get("prediction_threshold", 0.6)
244-248: Optimize top-k selection logicThe current implementation creates a zero tensor and then scatters ones, which is inefficient. Consider using a more direct approach.
# Get top-k if specified if top_k is not None: - _, top_k_indices = torch.topk(probabilities, top_k, dim=1) - predictions = torch.zeros_like(probabilities) - predictions.scatter_(1, top_k_indices, 1.0) + _, top_k_indices = torch.topk(probabilities, min(top_k, probabilities.size(1)), dim=1) + predictions = torch.zeros_like(probabilities) + predictions.scatter_(1, top_k_indices, 1.0)
1-1: Remove shebang from non-executable module fileThe shebang line is present but the file is not marked as executable. Since this is a module file that will be imported rather than executed directly, the shebang should be removed.
-#!/usr/bin/env python3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
configs/samo_emotion_detection_config.yaml(1 hunks)src/models/emotion_detection/emotion_labels.py(1 hunks)src/models/emotion_detection/samo_bert_emotion_classifier.py(1 hunks)test_samo_emotion_detection_standalone.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
test_samo_emotion_detection_standalone.py (2)
src/models/emotion_detection/samo_bert_emotion_classifier.py (5)
create_samo_bert_emotion_classifier(388-428)count_parameters(275-277)count_frozen_parameters(279-281)predict_emotions(187-268)set_temperature(270-273)src/models/emotion_detection/emotion_labels.py (2)
get_all_emotions(270-277)get_emotion_description(244-254)
🪛 Ruff (0.12.2)
src/models/emotion_detection/emotion_labels.py
1-1: Shebang is present but file is not executable
(EXE001)
183-183: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
183-183: Avoid specifying long messages outside the exception class
(TRY003)
202-202: Avoid specifying long messages outside the exception class
(TRY003)
337-337: f-string without any placeholders
Remove extraneous f prefix
(F541)
test_samo_emotion_detection_standalone.py
1-1: Shebang is present but file is not executable
(EXE001)
200-200: Do not catch blind exception: Exception
(BLE001)
src/models/emotion_detection/samo_bert_emotion_classifier.py
1-1: Shebang is present but file is not executable
(EXE001)
190-190: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🪛 YAMLlint (1.37.1)
configs/samo_emotion_detection_config.yaml
[error] 13-13: trailing spaces
(trailing-spaces)
[error] 16-16: trailing spaces
(trailing-spaces)
[error] 19-19: trailing spaces
(trailing-spaces)
[error] 28-28: trailing spaces
(trailing-spaces)
[error] 31-31: trailing spaces
(trailing-spaces)
[error] 34-34: trailing spaces
(trailing-spaces)
[error] 43-43: trailing spaces
(trailing-spaces)
[error] 47-47: trailing spaces
(trailing-spaces)
[error] 51-51: trailing spaces
(trailing-spaces)
[error] 55-55: trailing spaces
(trailing-spaces)
[error] 66-66: trailing spaces
(trailing-spaces)
[error] 69-69: trailing spaces
(trailing-spaces)
[error] 83-83: trailing spaces
(trailing-spaces)
[error] 86-86: trailing spaces
(trailing-spaces)
[error] 96-96: trailing spaces
(trailing-spaces)
[error] 105-105: trailing spaces
(trailing-spaces)
[error] 108-108: trailing spaces
(trailing-spaces)
[error] 117-117: trailing spaces
(trailing-spaces)
[error] 121-121: trailing spaces
(trailing-spaces)
[error] 124-124: trailing spaces
(trailing-spaces)
[error] 132-132: trailing spaces
(trailing-spaces)
[error] 135-135: trailing spaces
(trailing-spaces)
[error] 138-138: trailing spaces
(trailing-spaces)
[error] 141-141: trailing spaces
(trailing-spaces)
[error] 150-150: trailing spaces
(trailing-spaces)
[error] 154-154: trailing spaces
(trailing-spaces)
[error] 163-163: trailing spaces
(trailing-spaces)
[error] 166-166: trailing spaces
(trailing-spaces)
[error] 169-169: trailing spaces
(trailing-spaces)
[error] 177-177: trailing spaces
(trailing-spaces)
[error] 180-180: trailing spaces
(trailing-spaces)
[error] 183-183: trailing spaces
(trailing-spaces)
🔇 Additional comments (4)
configs/samo_emotion_detection_config.yaml (1)
160-172: Verify feature flags are implemented.Keys like
sanitize_input,filter_sensitive_emotions, andanonymize_predictionsmust be honored by the pipeline; otherwise drop them from config to avoid drift.Would you like a quick repo scan script to report where these flags are read?
test_samo_emotion_detection_standalone.py (1)
96-111: Guard temperature usage.If the model lacks a learnable
temperature, these calls will fail. Add a simple guard.- original_temp = model.temperature.item() + original_temp = getattr(model, "temperature", None) + if original_temp is None: + print(" Temperature scaling not available; skipping.") + return + original_temp = model.temperature.item()src/models/emotion_detection/samo_bert_emotion_classifier.py (2)
489-520: LGTM! Well-structured test codeThe main block provides a good testing framework with clear output and proper error handling. The test demonstrates the core functionality effectively.
383-383: Fix potential None value handling for token_type_idsThe current implementation may cause issues if
token_type_idsis not present in the encoding. Usingtorch.zeros_likeon the entire encoding dictionary will fail.- "token_type_ids": encoding.get("token_type_ids", torch.zeros_like(encoding["input_ids"])).squeeze(0), + "token_type_ids": encoding["token_type_ids"].squeeze(0) if "token_type_ids" in encoding else torch.zeros_like(encoding["input_ids"]).squeeze(0),Likely an incorrect or invalid review comment.
- Fix pooler_output robustness with fallback to first token hidden state - Map prediction indices to actual emotion labels from GOEMOTIONS_EMOTIONS - Create centralized configuration module for threshold management - Add comprehensive edge case testing (empty strings, non-English, mixed emotions) - Add assertions to batch prediction tests for output structure validation - Add threshold behavior assertions to verify decreasing emotion count - Add performance test assertions for processing time and error handling All tests passing with enhanced robustness and validation.
Resolved issues in the following files with DeepSource Autofix: 1. src/models/emotion_detection/config.py 2. src/models/emotion_detection/emotion_labels.py 3. src/models/emotion_detection/enhanced_bert_classifier.py 4. src/models/emotion_detection/enhanced_config.py 5. src/models/emotion_detection/samo_bert_emotion_classifier.py 6. test_samo_emotion_detection_enhanced.py 7. test_samo_emotion_detection_standalone.py
- Use F.binary_cross_entropy_with_logits for numerical stability in BCE loss - Fix return type hint for predict_emotions to reflect batch structure - Simplify config dictionary to only override necessary parameters - Remove redundant default values from create_samo_bert_emotion_classifier All tests passing with improved numerical stability and cleaner code.
…om:uelkerd/SAMO--DL into feat/dl-add-emotion-detection-enhancements
- Extract hardcoded test texts into module-level constant SAMPLE_TEST_TEXTS - Fix inefficient token_type_ids handling in EmotionDataset to use None instead of zeros - Remove code duplication and improve maintainability - BERT models can handle missing token type IDs efficiently All tests passing with cleaner, more maintainable code.
- Replace unused variable with underscore to indicate intentional non-use - Fix PYL-W0612 linting error for unused variable - Add explanatory comment for clarity - All tests still passing with cleaner code
There was a problem hiding this comment.
Actionable comments posted: 10
♻️ Duplicate comments (4)
test_samo_emotion_detection_standalone.py (1)
219-234: Deduplicate hardcoded test texts across files.Extract common sample texts to a module-level constant (or a helper module) and reuse in both standalone and enhanced tests to reduce drift.
src/models/emotion_detection/samo_bert_emotion_classifier.py (3)
328-331: Fix potential device/broadcast mismatch when applying class weights.Multiplying CPU weights with GPU loss tensors will raise, and unsqueeze may not broadcast as intended.
# Apply class weights if provided if self.class_weights is not None: - bce_loss = bce_loss * self.class_weights.unsqueeze(0) + if self.class_weights.device != bce_loss.device: + self.class_weights = self.class_weights.to(bce_loss.device) + bce_loss = bce_loss * self.class_weights.view(1, -1).expand_as(bce_loss)
320-336: Use BCEWithLogits for numerical stability.Avoid explicit sigmoid before BCE; it can cause instability and gradient issues.
- # Apply sigmoid to get probabilities - probabilities = torch.sigmoid(logits) - - # Compute BCE loss - bce_loss = F.binary_cross_entropy( - probabilities, targets.float(), reduction="none" - ) + # Compute BCE with logits directly for numerical stability + bce_loss = F.binary_cross_entropy_with_logits( + logits, targets.float(), reduction="none" + )
192-199: Fix type hints: threshold should be Optional and return types are nested lists.Current hints are inaccurate and may mislead consumers and tooling.
def predict_emotions( self, texts: Union[str, List[str]], - threshold: float = None, + threshold: Optional[float] = None, top_k: Optional[int] = None, batch_size: int = 32, - ) -> Dict[str, Union[List[str], List[float], List[List[int]]]]: + ) -> Dict[str, Union[List[List[str]], List[List[float]], List[List[float]]]]:
🧹 Nitpick comments (31)
src/models/emotion_detection/config.py (4)
1-1: Remove non-executable shebang or make the file executable.The module isn’t meant to be invoked directly; drop the shebang to satisfy linters.
-#!/usr/bin/env python3
71-75: from_dict expects flattened keys only.Callers will likely pass nested dicts (model., training.). Either document this clearly or support nested input. I recommend minimal nested support.
@classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> 'EmotionDetectionConfig': - """Create config from dictionary.""" - return cls(**config_dict) + """Create config from dictionary.""" + flat = {} + # allow nested { "model": {"name": ...}, ... } + for k, v in config_dict.items(): + if isinstance(v, dict): + flat.update(v) + else: + flat[k] = v + return cls(**flat)
101-111: update_config replaces state instead of merging with current.Current behavior discards prior runtime changes. Merge into existing _config instead.
def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig: """Get configuration from dictionary with defaults.""" - default_config = get_default_config() + default_config = get_config()
122-126: Make update_config merge rather than reset.Preserve current values when partial updates come in.
def update_config(config_dict: Dict[str, Any]) -> None: """Update global configuration.""" global _config - _config = get_config_from_dict(config_dict) + _config = get_config_from_dict(config_dict)src/models/emotion_detection/enhanced_bert_classifier.py (5)
392-403: Report calibrated temperature in metadata.If you switch to softplus in forward, match here too.
- "temperature": float(self.temperature.item()), + "temperature": float(F.softplus(self.temperature).item()),
429-441: Empty prediction is internally inconsistent.top_k_emotions shows 1.0 while all scores are 0 and confidence is 0. Return empty top_k or use consistent zeros.
return EmotionPrediction( emotions=emotions, primary_emotion="neutral", confidence=0.0, emotional_intensity="very_low", - top_k_emotions=[("neutral", 1.0)], + top_k_emotions=[], prediction_metadata=metadata )
135-136: Use logging.exception() in except blocks and avoid long externalized messages.Improve traceability and satisfy linters.
- logger.error(f"Failed to load BERT model: {e}") + logger.exception("Failed to load BERT model") @@ - logger.exception("Forward pass failed") + logger.exception("Forward pass failed") @@ - logger.error(f"Failed to load tokenizer: {e}") + logger.exception("Failed to load tokenizer") @@ - logger.error(f"Failed to save model: {e}") + logger.exception("Failed to save model") @@ - logger.error(f"Failed to load model: {e}") + logger.exception("Failed to load model")Also applies to: 207-212, 450-452, 517-519, 533-535
459-491: Minor: remove unused imports and members or wire them up.sklearn metrics, Dataset/DataLoader, warnings aren’t used; class_weights is accepted but unused in loss. Either remove or integrate.
229-246: Consider honoring prediction_threshold in outputs.You expose prediction_threshold but don’t use it to filter emotions. Either document it as metadata-only or apply it to zero out sub-threshold emotions (post top-k).
src/models/emotion_detection/enhanced_config.py (3)
217-224: Prefer logging.exception and narrower except for YAML load.Gives stack traces and avoids blind catches.
- except yaml.YAMLError as e: - logger.error(f"YAML parsing error: {e}") + except yaml.YAMLError: + logger.exception("YAML parsing error") logger.warning("Using default configuration due to YAML error") return self._create_default_config() - except Exception as e: - logger.error(f"Configuration loading failed: {e}") + except (OSError, ValueError): + logger.exception("Configuration loading failed") logger.warning("Using default configuration due to loading error") return self._create_default_config()
537-556: Save the full configuration, not just a subset.Serializing only two sections drops user changes. Serialize all dataclasses (asdict).
- def _config_to_dict(self) -> Dict[str, Any]: - """Convert configuration to dictionary.""" - # This would need proper serialization logic - # For now, return a basic structure - return { - "model": { - "name": self.config.model.name, - "device": self.config.model.device, - "use_mixed_precision": self.config.model.use_mixed_precision, - "cache_embeddings": self.config.model.cache_embeddings, - "max_sequence_length": self.config.model.max_sequence_length, - }, - "emotion_detection": { - "num_emotions": self.config.emotion_detection.num_emotions, - "prediction_threshold": self.config.emotion_detection.prediction_threshold, - "temperature": self.config.emotion_detection.temperature, - "top_k": self.config.emotion_detection.top_k, - }, - # Add other sections as needed - } + def _config_to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary.""" + from dataclasses import asdict + return asdict(self.config)
397-404: Security config is defined but unused.If sanitize_input is expected, integrate with your serving layer (e.g., deployment/cloud-run/secure_api_server.sanitize_input) or drop the option.
test_samo_emotion_detection_enhanced.py (6)
15-17: Make sys.path insertion robust to different test locations.Resolve the path and avoid relying on the current working directory.
-# Add src to path for standalone testing -sys.path.insert(0, str(Path(__file__).parent / "src")) +# Add src to path for standalone testing +repo_root = Path(__file__).resolve().parent +sys.path.insert(0, str((repo_root / "src").resolve()))
25-40: Avoid blind except and prefer try/except/else pattern.Catching Exception hides real failures and triggers linter BLE001/TRY300. Narrow the exceptions and move success prints to else.
def test_model_initialization(): """Test enhanced model initialization.""" print("1. Initializing Enhanced SAMO BERT Emotion Classifier...") - try: - model = EnhancedBERTEmotionClassifier( + try: + model = EnhancedBERTEmotionClassifier( model_name="bert-base-uncased", num_emotions=28, use_mixed_precision=True, cache_embeddings=True ) - print("✅ Enhanced classifier initialized successfully") - return model - except Exception as e: + except (OSError, ValueError, RuntimeError) as e: print(f"❌ Model initialization failed: {e}") return None + else: + print("✅ Enhanced classifier initialized successfully") + return modelApply the same pattern to similar try blocks below (Lines 42-57, 59-75, 118-147, 152-175, 177-203, 205-224, 226-287).
130-147: Strengthen batch prediction test with minimal assertions.Add basic structural assertions to prevent silent regressions.
predictions = model.predict_emotions( batch_texts, top_k=2, return_metadata=True, batch_size=2 ) end_time = time.time() print(f" Processed {len(batch_texts)} texts in {end_time - start_time:.3f}s") - - for i, prediction in enumerate(predictions): + assert isinstance(predictions, list) and len(predictions) == len(batch_texts), "Batch size mismatch" + for i, prediction in enumerate(predictions): + assert hasattr(prediction, "top_k_emotions") and len(prediction.top_k_emotions) >= 2, "Missing top-k" + assert 0.0 <= prediction.confidence <= 1.0, "Confidence out of range" print(f" Text {i+1}: {prediction.primary_emotion} ({prediction.confidence:.3f})")
152-171: Assert performance metrics invariants.Validate keys and basic ranges to catch metric accumulation issues early.
metrics = model.get_performance_metrics() - + required = {"total_inferences","total_inference_time","average_inference_time","error_count","error_rate","last_error"} + assert required.issubset(metrics.keys()), "Missing performance metric keys" + assert metrics["total_inferences"] >= 1, "No inferences recorded" + assert 0.0 <= metrics["error_rate"] <= 1.0, "Invalid error rate" print(f" Total inferences: {metrics['total_inferences']}") print(f" Average inference time: {metrics['average_inference_time']:.3f}s") print(f" Error count: {metrics['error_count']}") print(f" Error rate: {metrics['error_rate']:.3f}")
177-200: Catch specific exceptions and add light assertions in error-handling test.Limit exception type and ensure outputs look sane when no error is raised.
try: # Test with None input try: model.predict_emotions(None) - except Exception as e: + except (TypeError, ValueError, RuntimeError) as e: print(f" ✅ Handled None input: {type(e).__name__}") # Test with very long text very_long_text = "This is a test. " * 1000 prediction = model.predict_emotions(very_long_text) + assert prediction.primary_emotion and 0.0 <= prediction.confidence <= 1.0 print(f" ✅ Handled very long text: {len(very_long_text)} chars") # Test with special characters special_text = "I'm feeling 😊🎉 excited! @#$%^&*()" prediction = model.predict_emotions(special_text) + assert prediction.primary_emotion and 0.0 <= prediction.confidence <= 1.0 print(f" ✅ Handled special characters: {prediction.primary_emotion}")
1-1: Shebang without executable bit.Either remove the shebang or mark the file executable (git chmod +x) to satisfy EXE001.
test_samo_emotion_detection_standalone.py (5)
13-16: Harden sys.path handling for portability.Resolve the absolute path to src to avoid surprises when invoked from other CWDs.
-# Add src to path -sys.path.insert(0, str(Path(__file__).parent / "src")) +# Add src to path +repo_root = Path(__file__).resolve().parent +sys.path.insert(0, str((repo_root / "src").resolve()))
72-87: Empty-text branch: narrow exception type and assert behavior.Catching Exception masks unrelated failures. Add a minimal assertion to document expected behavior.
- if not text.strip(): + if not text.strip(): print(" (Empty text - testing edge case)") try: results = model.predict_emotions(text, threshold=0.3, top_k=3) emotions = results['emotions'][0] probabilities = results['probabilities'][0] + assert isinstance(emotions, list), "Emotions should be list for empty input" print(f" Detected emotions: {emotions}") print(" ✅ Empty text handled gracefully") - except Exception as e: + except (ValueError, RuntimeError) as e: print(f" ❌ Empty text caused error: {e}") continue
170-185: Temperature scaling: add a simple monotonicity check.Lower temperature should not increase entropy; check that cold yields ≥ confidence or ≤ count than hot under a fixed threshold.
model.set_temperature(0.5) # Lower temperature = more confident results_cold = model.predict_emotions("I am very happy!", threshold=0.3) model.set_temperature(2.0) # Higher temperature = less confident results_hot = model.predict_emotions("I am very happy!", threshold=0.3) model.set_temperature(original_temp) # Reset print(f" Cold temperature (0.5): {len(results_cold['emotions'][0])} emotions") print(f" Hot temperature (2.0): {len(results_hot['emotions'][0])} emotions") + assert len(results_cold['emotions'][0]) <= len(results_hot['emotions'][0]), "Cold temperature should not increase detections"
255-306: Performance threshold may be flaky on CPU.The fixed 10s limit for all cases can fail on CI/CPU. Consider relaxing for “Very long text” or gating by device.
- assert processing_time < 10.0, f"Processing time {processing_time:.3f}s too slow for {name}" + # Allow more time for very long input or CPU-only environments + limit = 15.0 if (name == "Very long text" or not torch.cuda.is_available()) else 10.0 + assert processing_time < limit, f"Processing time {processing_time:.3f}s too slow for {name}"
1-1: Shebang without executable bit.Either remove the shebang or mark executable to satisfy EXE001.
src/models/emotion_detection/samo_bert_emotion_classifier.py (8)
30-36: Avoid configuring global logging and warning filters in a library module.Leave logging configuration to applications; setting basicConfig and global filters has side effects for downstream users.
-# Configure logging -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - -# Suppress warnings for cleaner output -warnings.filterwarnings("ignore", category=UserWarning)
74-78: Condense config merge.Small readability cleanup; keeps behavior identical.
- if config is None: - config = default_config - else: - config = {**default_config, **config} + config = default_config if config is None else {**default_config, **config}
86-86: Make prediction threshold configurable.Hardcoding 0.6 limits flexibility and can drift from evaluation settings. Read from config with a default.
- self.prediction_threshold = 0.6 # Updated from 0.5 to 0.6 based on calibration + self.prediction_threshold = float(config.get("prediction_threshold", 0.6))
127-151: Layer freeze/unfreeze semantics can surprise._unfreeze_bert_layers currently sets embeddings to requires_grad=True as a side effect. Consider explicit freeze/unfreeze helpers to prevent inconsistent states.
- def _set_bert_layers_grad(self, num_layers: int, requires_grad: bool) -> None: + def _set_bert_layers_grad(self, num_layers: int, requires_grad: bool) -> None: """Set gradient requirements for BERT layers.""" if num_layers <= 0: return - # Set embeddings - for param in self.bert.embeddings.parameters(): - param.requires_grad = requires_grad + # Set embeddings consistently with encoder 0 + for p in self.bert.embeddings.parameters(): + p.requires_grad = requires_grad # Set encoder layers for i in range(min(num_layers, len(self.bert.encoder.layer))): for param in self.bert.encoder.layer[i].parameters(): param.requires_grad = requires_gradOptionally add an explicit
freeze_all/unfreeze_allif you need broader control.
259-266: Avoid per-call imports; cache emotion labels once.Importing labels inside the loop adds overhead. Cache labels on init and reference
self.emotion_labels.@@ - # Classification head self.classifier = nn.Sequential( @@ ) @@ + # Cache emotion labels for mapping indices to names + try: + from .emotion_labels import get_all_emotions + self.emotion_labels = get_all_emotions() + except Exception: + self.emotion_labels = [f"emotion_{i}" for i in range(self.num_emotions)] @@ - # Get emotion names for predictions - from .emotion_labels import GOEMOTIONS_EMOTIONS - for pred in batch_predictions: - emotions = [ - GOEMOTIONS_EMOTIONS[i] for i, p in enumerate(pred) if p > 0 - ] - all_emotions.append(emotions) + # Get emotion names for predictions + for pred in batch_predictions: + emotions = [self.emotion_labels[i] for i, p in enumerate(pred) if p > 0] + all_emotions.append(emotions)Also applies to: 98-106
369-391: Return token_type_ids only when present.Avoid allocating large zero tensors; downstream already uses .get(...).
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: @@ - return { - "input_ids": encoding["input_ids"].squeeze(0), - "attention_mask": encoding["attention_mask"].squeeze(0), - "token_type_ids": encoding.get("token_type_ids", torch.zeros_like(encoding["input_ids"])).squeeze(0), - "labels": label_tensor, - } + item = { + "input_ids": encoding["input_ids"].squeeze(0), + "attention_mask": encoding["attention_mask"].squeeze(0), + "labels": label_tensor, + } + if "token_type_ids" in encoding: + item["token_type_ids"] = encoding["token_type_ids"].squeeze(0) + return item
417-424: Avoid duplicating defaults in factory config.Let the model’s default_config handle defaults; only override what’s variable here.
- # Create model with default config - config = { - "hidden_dropout_prob": 0.3, - "classifier_dropout_prob": 0.5, - "freeze_bert_layers": freeze_bert_layers, - "temperature": 1.0, - } + # Only override what differs from model defaults + config = {"freeze_bert_layers": freeze_bert_layers}
1-1: Shebang without executable bit.Either remove the shebang or mark executable to satisfy EXE001.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (2)
src/models/emotion_detection/__pycache__/emotion_labels.cpython-38.pycis excluded by!**/*.pycsrc/models/emotion_detection/__pycache__/samo_bert_emotion_classifier.cpython-38.pycis excluded by!**/*.pyc
📒 Files selected for processing (6)
src/models/emotion_detection/config.py(1 hunks)src/models/emotion_detection/enhanced_bert_classifier.py(1 hunks)src/models/emotion_detection/enhanced_config.py(1 hunks)src/models/emotion_detection/samo_bert_emotion_classifier.py(1 hunks)test_samo_emotion_detection_enhanced.py(1 hunks)test_samo_emotion_detection_standalone.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
test_samo_emotion_detection_enhanced.py (3)
src/models/emotion_detection/enhanced_bert_classifier.py (4)
EmotionPrediction(32-39)get_model_info(475-491)predict_emotions(229-270)get_performance_metrics(459-473)src/models/emotion_detection/enhanced_config.py (2)
create_enhanced_config_manager(559-568)get_config(510-512)src/models/emotion_detection/emotion_labels.py (2)
get_all_emotions(270-277)get_emotion_description(244-254)
src/models/emotion_detection/enhanced_bert_classifier.py (1)
src/models/emotion_detection/samo_bert_emotion_classifier.py (7)
_init_classification_layers(120-125)_freeze_bert_layers(144-146)forward(152-190)forward(309-338)predict_emotions(192-274)count_parameters(281-283)count_frozen_parameters(285-287)
src/models/emotion_detection/enhanced_config.py (1)
deployment/cloud-run/secure_api_server.py (1)
sanitize_input(158-172)
src/models/emotion_detection/samo_bert_emotion_classifier.py (1)
src/models/emotion_detection/enhanced_bert_classifier.py (6)
_init_classification_layers(169-174)_freeze_bert_layers(176-190)forward(192-211)predict_emotions(229-270)count_parameters(493-495)count_frozen_parameters(497-499)
test_samo_emotion_detection_standalone.py (2)
src/models/emotion_detection/samo_bert_emotion_classifier.py (5)
create_samo_bert_emotion_classifier(394-434)count_parameters(281-283)count_frozen_parameters(285-287)predict_emotions(192-274)set_temperature(276-279)src/models/emotion_detection/emotion_labels.py (2)
get_all_emotions(270-277)get_emotion_description(244-254)
🪛 Ruff (0.12.2)
src/models/emotion_detection/config.py
1-1: Shebang is present but file is not executable
(EXE001)
test_samo_emotion_detection_enhanced.py
1-1: Shebang is present but file is not executable
(EXE001)
36-36: Consider moving this statement to an else block
(TRY300)
37-37: Do not catch blind exception: Exception
(BLE001)
53-53: Consider moving this statement to an else block
(TRY300)
54-54: Do not catch blind exception: Exception
(BLE001)
71-71: Consider moving this statement to an else block
(TRY300)
72-72: Do not catch blind exception: Exception
(BLE001)
112-112: Consider moving this statement to an else block
(TRY300)
113-113: Do not catch blind exception: Exception
(BLE001)
146-146: Consider moving this statement to an else block
(TRY300)
147-147: Do not catch blind exception: Exception
(BLE001)
171-171: Consider moving this statement to an else block
(TRY300)
172-172: Do not catch blind exception: Exception
(BLE001)
185-185: Do not catch blind exception: Exception
(BLE001)
199-199: Consider moving this statement to an else block
(TRY300)
200-200: Do not catch blind exception: Exception
(BLE001)
220-220: Consider moving this statement to an else block
(TRY300)
221-221: Do not catch blind exception: Exception
(BLE001)
281-281: Consider moving this statement to an else block
(TRY300)
src/models/emotion_detection/enhanced_bert_classifier.py
118-118: Consider moving this statement to an else block
(TRY300)
119-119: Do not catch blind exception: Exception
(BLE001)
135-135: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
136-136: Avoid specifying long messages outside the exception class
(TRY003)
205-205: Consider moving this statement to an else block
(TRY300)
211-211: Avoid specifying long messages outside the exception class
(TRY003)
263-263: Avoid specifying long messages outside the exception class
(TRY003)
450-450: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
451-451: Avoid specifying long messages outside the exception class
(TRY003)
517-517: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
518-518: Avoid specifying long messages outside the exception class
(TRY003)
531-531: Consider moving this statement to an else block
(TRY300)
533-533: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
534-534: Avoid specifying long messages outside the exception class
(TRY003)
src/models/emotion_detection/enhanced_config.py
218-218: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
221-221: Do not catch blind exception: Exception
(BLE001)
222-222: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
278-278: Do not catch blind exception: Exception
(BLE001)
279-279: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
438-438: Consider moving this statement to an else block
(TRY300)
450-450: Consider moving this statement to an else block
(TRY300)
462-462: Consider moving this statement to an else block
(TRY300)
474-474: Consider moving this statement to an else block
(TRY300)
520-520: Do not catch blind exception: Exception
(BLE001)
521-521: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
534-534: Do not catch blind exception: Exception
(BLE001)
535-535: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
src/models/emotion_detection/samo_bert_emotion_classifier.py
1-1: Shebang is present but file is not executable
(EXE001)
195-195: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
test_samo_emotion_detection_standalone.py
1-1: Shebang is present but file is not executable
(EXE001)
84-84: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (1)
src/models/emotion_detection/samo_bert_emotion_classifier.py (1)
176-186: Good fallback when pooler_output is missing.Using [CLS] when pooler_output is None improves model compatibility.
| @contextmanager | ||
| def inference_mode(self): | ||
| """Context manager for inference mode with optimizations.""" | ||
| was_training = self.training | ||
| self.eval() | ||
| try: | ||
| with torch.no_grad(): | ||
| if self.use_mixed_precision: | ||
| with torch.cuda.amp.autocast(): | ||
| yield | ||
| else: | ||
| yield | ||
| finally: | ||
| if was_training: | ||
| self.train() | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Use device-aware autocast, not torch.cuda.amp on non-CUDA devices.
Avoid CUDA-specific autocast on CPU/MPS.
@contextmanager
def inference_mode(self):
"""Context manager for inference mode with optimizations."""
was_training = self.training
self.eval()
try:
with torch.no_grad():
- if self.use_mixed_precision:
- with torch.cuda.amp.autocast():
- yield
- else:
- yield
+ device_type = self.device.type
+ enable_amp = self.use_mixed_precision and device_type in ("cuda", "cpu")
+ if enable_amp:
+ with torch.autocast(device_type=device_type):
+ yield
+ else:
+ yield
finally:
if was_training:
self.train()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @contextmanager | |
| def inference_mode(self): | |
| """Context manager for inference mode with optimizations.""" | |
| was_training = self.training | |
| self.eval() | |
| try: | |
| with torch.no_grad(): | |
| if self.use_mixed_precision: | |
| with torch.cuda.amp.autocast(): | |
| yield | |
| else: | |
| yield | |
| finally: | |
| if was_training: | |
| self.train() | |
| @contextmanager | |
| def inference_mode(self): | |
| """Context manager for inference mode with optimizations.""" | |
| was_training = self.training | |
| self.eval() | |
| try: | |
| with torch.no_grad(): | |
| device_type = self.device.type | |
| enable_amp = self.use_mixed_precision and device_type in ("cuda", "cpu") | |
| if enable_amp: | |
| with torch.autocast(device_type=device_type): | |
| yield | |
| else: | |
| yield | |
| finally: | |
| if was_training: | |
| self.train() |
🤖 Prompt for AI Agents
In src/models/emotion_detection/enhanced_bert_classifier.py around lines 213 to
228, the context manager uses torch.cuda.amp.autocast which is CUDA-only and
will fail or be suboptimal on CPU/MPS; change it to a device-aware autocast by
using torch.autocast(device_type=self.device.type) (or conditionally call the
appropriate autocast based on self.device.type) so mixed-precision is applied
correctly on CUDA and safely skipped or handled on CPU/MPS; keep the same yield
structure and restore training state in finally.
- Rename 'model' parameter to 'emotion_model' in evaluate_emotion_classifier - Rename 'model' variable to 'emotion_classifier' in create_samo_bert_emotion_classifier - Rename 'text' variable to 'sample_text' in EmotionDataset.__getitem__ - Rename 'i' variables to 'layer_idx' and 'batch_idx' in loops - Rename 'emotion' variable to 'emotion_name' in emotion_labels.py test loop All 9 PYL-W0621 linting errors resolved while maintaining functionality. All tests passing with cleaner, more readable code.
- Fix lambda closure issue by using default parameter to capture probabilities value - Change 'lambda i: probabilities[i]' to 'lambda i, probs=probabilities: probs[i]' - Prevents all closures from using the same reference to the loop variable - Eliminates potential bug where all lambdas would use the final value of probabilities This fixes the classic Python closure problem where variables are captured by reference rather than by value, ensuring each lambda gets its own copy of the probabilities list.
- Replace 37 f-string logging calls with lazy % formatting for performance - Fix logging calls in samo_bert_emotion_classifier.py (3 calls) - Fix logging calls in enhanced_config.py (34 calls) - Prevents unnecessary string formatting when logging is disabled - Improves performance by deferring string conversion until needed All tests passing with optimized logging performance.
…ier.py - Convert remaining 11 f-string logging calls to lazy % formatting - Fix device setup, model loading, tokenizer loading, and model saving logging - Complete logging performance optimization across all emotion detection modules - All logging calls now use lazy evaluation for better performance All tests passing with fully optimized logging system.
- Remove loop in test_emotion_descriptions() using list comprehension - Add explicit raise from previous error in get_emotion_index() - Improve code quality and maintainability - Follow Python best practices for error handling and test structure All tests passing with improved code quality.
There was a problem hiding this comment.
Actionable comments posted: 7
♻️ Duplicate comments (17)
src/models/emotion_detection/emotion_labels.py (1)
167-184: O(1) lookup for emotion index + cleaner exception chaining.
Use a precomputed dict and raise from None. Also reuse the map in validate_emotion. (Same ask was raised before.)@@ GOEMOTIONS_EMOTIONS = [ @@ ] + +# Precomputed mapping for O(1) lookups +EMOTION_TO_INDEX = {e.lower(): i for i, e in enumerate(GOEMOTIONS_EMOTIONS)} @@ def get_emotion_index(emotion: str) -> int: @@ - try: - return GOEMOTIONS_EMOTIONS.index(emotion.lower()) - except ValueError as e: - raise ValueError(f"Emotion '{emotion}' not found in GoEmotions list") from e + key = emotion.lower() + try: + return EMOTION_TO_INDEX[key] + except KeyError: + raise ValueError(f"Unknown emotion: {emotion!r}") from None @@ - return emotion.lower() in GOEMOTIONS_EMOTIONS + return emotion.lower() in EMOTION_TO_INDEXAlso applies to: 299-300, 44-46
src/models/emotion_detection/config.py (1)
38-43: Avoid import-time “constants” for thresholds; expose getters tied to runtime config.
Current constants won’t reflect update_config() changes. Provide getters that read _config.-# Emotion classification threshold (used in evaluation) -EMOTION_CLASSIFICATION_THRESHOLD = DEFAULT_CONFIG["evaluation"]["threshold"] - -# Prediction threshold (used in inference) -EMOTION_PREDICTION_THRESHOLD = DEFAULT_CONFIG["prediction"]["threshold"] +def get_evaluation_threshold() -> float: + """Current evaluation threshold.""" + return _config.evaluation_threshold + +def get_prediction_threshold() -> float: + """Current prediction threshold.""" + return _config.prediction_thresholdsrc/models/emotion_detection/enhanced_config.py (1)
417-425: _validate_string returns “default” but yields ""; accept a default parameter and pass real defaults.
Prevents silently writing empty strings for paths/labels.@@ - def _validate_string(value: Any, field_name: str) -> str: + def _validate_string(value: Any, field_name: str, default: str = "") -> str: """Validate string value.""" if not isinstance(value, str): - logger.warning("Invalid string value for %s: %s, using default", field_name, value) - return "" + logger.warning("Invalid string value for %s: %s, using default", field_name, value) + return default return value @@ - name=self._validate_string(data.get("name", "bert-base-uncased"), "model.name"), + name=self._validate_string(data.get("name", "bert-base-uncased"), "model.name", "bert-base-uncased"), @@ - padding=self._validate_string(data.get("padding", "max_length"), "data.padding"), + padding=self._validate_string(data.get("padding", "max_length"), "data.padding", "max_length"), @@ - level=self._validate_log_level(data.get("level", "INFO"), "logging.level"), + level=self._validate_log_level(data.get("level", "INFO"), "logging.level"), log_interval=self._validate_positive_int(data.get("log_interval", 100), "logging.log_interval"), save_interval=self._validate_positive_int(data.get("save_interval", 1000), "logging.save_interval"), enable_tensorboard=self._validate_bool(data.get("enable_tensorboard", True), "logging.enable_tensorboard"), - log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir"), + log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir", "logs/emotion_detection"), @@ - save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir"), - save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric"), + save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir", "models/emotion_detection"), + save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric", "f1_macro"), @@ - error_log_file=self._validate_string(data.get("error_log_file", "logs/emotion_detection_errors.log"), "error_handling.error_log_file"), + error_log_file=self._validate_string(data.get("error_log_file", "logs/emotion_detection_errors.log"), "error_handling.error_log_file", "logs/emotion_detection_errors.log"),Also applies to: 287-292, 331-337, 351-356, 361-365, 395-396
test_samo_emotion_detection_standalone.py (3)
56-75: De-duplicate test texts by reusing SAMPLE_TEST_TEXTS.Avoid divergence of fixtures; extend with edge cases locally.
- test_texts = [ - "I am so happy and excited about this amazing opportunity!", - "I feel really sad and disappointed about what happened today.", - "I'm feeling anxious and worried about the upcoming presentation.", - "I love spending time with my family and friends.", - "I'm angry and frustrated with this situation.", - "I feel grateful and thankful for all the support I've received.", - "I'm confused and don't understand what's going on.", - "I feel proud of my accomplishments and achievements.", + test_texts = [ + *SAMPLE_TEST_TEXTS, + "I love spending time with my family and friends.", + "I'm angry and frustrated with this situation.", + "I feel grateful and thankful for all the support I've received.", + "I'm confused and don't understand what's going on.", + "I feel proud of my accomplishments and achievements.",
124-141: Good: batch structure and value checks are asserted.These assertions address prior feedback on validating outputs.
193-214: Good: threshold monotonicity assertions.This test precisely captures expected behavior.
src/models/emotion_detection/enhanced_bert_classifier.py (6)
19-20: Broken import: wrong module name.Importing from .labels will fail; use emotion_labels.
-from .labels import GOEMOTIONS_EMOTIONS +from .emotion_labels import GOEMOTIONS_EMOTIONS
84-91: Temperature overwritten later; initialize once as learnable with provided value.Preserve user-supplied value and correct type.
- self.temperature = temperature + # Learnable temperature initialized from argument + self.temperature = nn.Parameter(torch.tensor([float(temperature)], dtype=torch.float32))
146-148: Remove redundant temperature re-assignment.Prevents clobbering provided value.
- self.temperature = nn.Parameter(torch.ones(1)) self._init_classification_layers()
190-203: Forward: handle missing pooler_output and enforce positive temperature.Robust across models and avoids divide-by-zero/negative temps.
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: @@ - bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) - pooled_output = bert_outputs.pooler_output + bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + pooled_output = getattr(bert_outputs, "pooler_output", None) + if pooled_output is None: + pooled_output = bert_outputs.last_hidden_state[:, 0, :] @@ - logits = self.classifier(pooled_output) + logits = self.classifier(pooled_output) @@ - logits = logits / self.temperature + temp = F.softplus(self.temperature) + 1e-6 + logits = logits / temp
211-223: Use device-aware autocast; current CUDA-only autocast breaks on CPU/MPS.Switch to torch.autocast with device_type.
try: with torch.no_grad(): - if self.use_mixed_precision: - with torch.cuda.amp.autocast(): - yield - else: - yield + device_type = self.device.type + enable_amp = self.use_mixed_precision and device_type in ("cuda", "cpu", "mps") + if enable_amp: + with torch.autocast(device_type=device_type): + yield + else: + yield
327-358: Batch: preserve input order and length when filtering empty/whitespace texts.Current code drops entries; reconstruct full result list.
- # Filter out empty texts - valid_texts = [text for text in texts if text and text.strip()] - if not valid_texts: - return [self._create_empty_prediction(return_metadata) for _ in texts] + # Build index map preserving positions; filter invalid + index_map, valid_texts = [], [] + for idx, t in enumerate(texts): + if t and t.strip(): + index_map.append(idx) + valid_texts.append(t) + if not valid_texts: + return [self._create_empty_prediction(return_metadata) for _ in texts] @@ - # Process results - results = [] - for i, prob in enumerate(probabilities): - result = self._process_prediction_results( - prob, top_k, return_metadata, valid_texts[i] - ) - results.append(result) - - return results + # Process results and reassemble in original order + results = [self._create_empty_prediction(return_metadata) for _ in texts] + for i, prob in enumerate(probabilities): + results[index_map[i]] = self._process_prediction_results( + prob, top_k, return_metadata, valid_texts[i] + ) + return resultssrc/models/emotion_detection/samo_bert_emotion_classifier.py (5)
126-151: Layer freeze/unfreeze symmetry
unfreeze_bert_layers()only toggles the first N layers/embeddings; previously frozen later layers remain frozen. Consider a symmetrical approach (unfreeze all, then selectively (re)freeze) or track frozen ranges.
175-183: Good fallback for models without pooler_outputCLS fallback is correctly implemented.
225-243: Token type IDs handling is finePassing
Nonewhen absent avoids unnecessary zero tensors and is supported by HF models.
258-263: Map to labels once; guard for length mismatchesImporting every batch is inefficient and assumes 28 labels. Store labels on init and guard index overflow to avoid IndexError if
num_emotionsdiffers.@@ - # Get emotion names for predictions - from .emotion_labels import GOEMOTIONS_EMOTIONS - for pred in batch_predictions: - emotions = [ - GOEMOTIONS_EMOTIONS[i] for i, p in enumerate(pred) if p > 0 - ] + # Get emotion names for predictions + for pred in batch_predictions: + labels = getattr(self, "emotion_labels", [f"emotion_{i}" for i in range(self.num_emotions)]) + emotions = [labels[i] for i, p in enumerate(pred) if p > 0 and i < len(labels)] all_emotions.append(emotions)And initialize once:
@@ self.prediction_threshold = 0.6 # Updated from 0.5 to 0.6 based on calibration + # Emotion labels (fallback to placeholders if module unavailable) + try: + from .emotion_labels import GOEMOTIONS_EMOTIONS + self.emotion_labels = list(GOEMOTIONS_EMOTIONS) + except Exception: + self.emotion_labels = [f"emotion_{i}" for i in range(num_emotions)]
319-333: Device/broadcasting bug when applying class weightsOn CUDA, multiplying CPU
class_weightswith GPU loss will error;unsqueeze(0)also may not broadcast for batch>1. Fix device + shape handling (or register as buffer).# Apply class weights if provided if self.class_weights is not None: - bce_loss = bce_loss * self.class_weights.unsqueeze(0) + cw = self.class_weights + if cw.device != bce_loss.device: + cw = cw.to(bce_loss.device) + if cw.ndim == 1: + cw = cw.view(1, -1) + bce_loss = bce_loss * cw.expand_as(bce_loss)Optional: make weights track device automatically.
@@ class WeightedBCELoss(nn.Module): - self.class_weights = class_weights + if class_weights is not None: + self.register_buffer("class_weights", class_weights.float()) + else: + self.class_weights = None
🧹 Nitpick comments (24)
src/models/emotion_detection/emotion_labels.py (2)
1-1: Remove non-executable shebang or make file executable.
The shebang triggers Ruff EXE001. Either drop it orchmod +xthe file.-#!/usr/bin/env python3
214-214: Make group lookups case-insensitive.
Lower the key before dict access.- return EMOTION_VALENCE_GROUPS.get(valence, []) + return EMOTION_VALENCE_GROUPS.get(valence.lower(), []) @@ - return EMOTION_AROUSAL_GROUPS.get(arousal, []) + return EMOTION_AROUSAL_GROUPS.get(arousal.lower(), []) @@ - return EMOTION_DOMINANCE_GROUPS.get(dominance, []) + return EMOTION_DOMINANCE_GROUPS.get(dominance.lower(), [])Also applies to: 227-227, 240-240
src/models/emotion_detection/config.py (2)
1-1: Drop non-executable shebang or make executable.
Matches Ruff EXE001.-#!/usr/bin/env python3
12-36: Single source of truth for defaults.
DEFAULT_CONFIG isn’t used to build EmotionDetectionConfig; either wire get_default_config() from DEFAULT_CONFIG or remove DEFAULT_CONFIG to prevent drift.Also applies to: 96-99
src/models/emotion_detection/enhanced_config.py (2)
216-223: Use logging.exception and avoid blind excepts where feasible.
Switch to exception() to capture tracebacks; narrow exceptions where practical.- except yaml.YAMLError as e: - logger.error("YAML parsing error: %s", e) + except yaml.YAMLError: + logger.exception("YAML parsing error") logger.warning("Using default configuration due to YAML error") return self._create_default_config() - except Exception as e: - logger.error("Configuration loading failed: %s", e) + except Exception: + logger.exception("Configuration loading failed") logger.warning("Using default configuration due to loading error") return self._create_default_config() @@ - except Exception as e: - logger.error("Configuration parsing failed: %s", e) + except Exception: + logger.exception("Configuration parsing failed") logger.warning("Using default configuration due to parsing error") return self._create_default_config() @@ - except Exception as e: - logger.error("Configuration update failed: %s", e) + except Exception: + logger.exception("Configuration update failed") @@ - except Exception as e: - logger.error("Failed to save configuration: %s", e) + except Exception: + logger.exception("Failed to save configuration")Also applies to: 279-283, 531-533, 545-547
524-533: Convertupdate_configto an instance method and implement a safe shallow merge.
Allow runtime updates section-by-section.- @staticmethod - def update_config(updates: Dict[str, Any]) -> None: + def update_config(self, updates: Dict[str, Any]) -> None: """Update configuration with new values.""" try: - # This would need more sophisticated merging logic - # For now, just log the attempt - logger.info("Configuration update requested: %s", updates) + logger.info("Configuration update requested: %s", updates) + for section, vals in updates.items(): + if hasattr(self.config, section) and isinstance(vals, dict): + section_obj = getattr(self.config, section) + for k, v in vals.items(): + if hasattr(section_obj, k): + setattr(section_obj, k, v) except Exception: logger.exception("Configuration update failed")test_samo_emotion_detection_enhanced.py (5)
1-1: Remove non-executable shebang or make the script executable.Either drop the shebang or set the exec bit in git to satisfy linters.
Apply this diff to remove it:
-#!/usr/bin/env python3
15-15: Harden sys.path injection.Avoid duplicate inserts and resolve path for robustness.
-sys.path.insert(0, str(Path(__file__).parent / "src")) +src_path = str((Path(__file__).resolve().parent / "src")) +if src_path not in sys.path: + sys.path.insert(0, src_path)
24-39: Narrow exception handling in initialization.Catching Exception hides actionable failures; return in else for clarity.
def test_model_initialization(): @@ - try: + try: model = EnhancedBERTEmotionClassifier( @@ - print("✅ Enhanced classifier initialized successfully") - return model - except Exception as e: + print("✅ Enhanced classifier initialized successfully") + except (RuntimeError, OSError, ValueError) as e: print(f"❌ Model initialization failed: {e}") - return None + return None + else: + return model
58-74: Config access looks correct. Add one sanity assertion.Guard against missing/invalid config fields to fail fast.
- print(f" Model name: {config.model.name}") + print(f" Model name: {config.model.name}") + assert config.emotion_detection.num_emotions > 0, "num_emotions must be > 0"
76-115: Add lightweight assertions to validate EmotionPrediction structure.Current test only prints. Validate types and ranges.
@@ - prediction = model.predict_emotions( + prediction = model.predict_emotions( text, top_k=3, return_metadata=True ) @@ print(f" Top emotions: {prediction.top_k_emotions[:3]}") @@ if prediction.prediction_metadata: print(f" Text length: {prediction.prediction_metadata.get('text_length', 'N/A')}") + # Assertions + assert isinstance(prediction.primary_emotion, str) + assert 0.0 <= prediction.confidence <= 1.0 + assert isinstance(prediction.top_k_emotions, list) and len(prediction.top_k_emotions) <= 3test_samo_emotion_detection_standalone.py (3)
1-1: Remove non-executable shebang or mark executable.Same as the enhanced test script.
-#!/usr/bin/env python3
176-191: Temperature test: consider asserting relative confidence/entropy.Increase signal by checking that colder temperature increases top-1 confidence.
@@ - print(f" Cold temperature (0.5): {len(results_cold['emotions'][0])} emotions") - print(f" Hot temperature (2.0): {len(results_hot['emotions'][0])} emotions") + print(f" Cold temperature (0.5): {len(results_cold['emotions'][0])} emotions") + print(f" Hot temperature (2.0): {len(results_hot['emotions'][0])} emotions") + # Optional: assert colder => not fewer detected emotions than hotter for same threshold + assert len(results_cold['emotions'][0]) >= len(results_hot['emotions'][0])
263-317: Reduce flakiness in performance test.First-run HF downloads and CPU can exceed 10s. Warm up and relax/tier SLA.
@@ - model, _ = create_samo_bert_emotion_classifier() + model, _ = create_samo_bert_emotion_classifier() + # Warm-up (tokenizer + model graph) + _ = model.predict_emotions("warmup", threshold=0.3) @@ - try: - results = model.predict_emotions(text, threshold=0.3) + try: + results = model.predict_emotions(text, threshold=0.3) end_time = time.time() @@ - assert processing_time < 10.0, f"Processing time {processing_time:.3f}s too slow for {name}" + assert processing_time < 15.0, f"Processing time {processing_time:.3f}s too slow for {name}" - assert isinstance(emotions, list), f"Emotions should be a list for {name}" - assert len(emotions) >= 0, f"Emotions count should be non-negative for {name}" + assert isinstance(emotions, list), f"Emotions should be a list for {name}" + # Optional: at least one or zero depending on threshold/text; keep non-negative check meaningful + assert len(emotions) >= 0src/models/emotion_detection/enhanced_bert_classifier.py (4)
131-135: Prefer logger.exception in except paths.Improves tracebacks per Ruff TRY400.
- except Exception as e: - logger.error("Failed to load BERT model: %s", e) - raise RuntimeError(f"BERT model initialization failed: {e}") from e + except Exception as e: + logger.exception("Failed to load BERT model") + raise RuntimeError(f"BERT model initialization failed: {e}") from e
446-450: Tokenizer load: use logger.exception for traceback.Minor logging improvement.
- except Exception as e: - logger.error("Failed to load tokenizer: %s", e) + except Exception as e: + logger.exception("Failed to load tokenizer") raise RuntimeError(f"Tokenizer loading failed: {e}") from e
499-517: Use logger.exception in save_model failure path.Consistency with other error logs.
- except Exception as e: - logger.error("Failed to save model: %s", e) + except Exception as e: + logger.exception("Failed to save model") raise RuntimeError(f"Model saving failed: {e}") from e
518-533: Use logger.exception in load_model failure path and else for success log.Minor but cleaner flow.
- try: + try: checkpoint = torch.load(path, map_location=device or 'cpu') @@ - logger.info("Model loaded from: %s", path) - return model - except Exception as e: - logger.error("Failed to load model: %s", e) + logger.info("Model loaded from: %s", path) + return model + except Exception as e: + logger.exception("Failed to load model") raise RuntimeError(f"Model loading failed: {e}") from esrc/models/emotion_detection/samo_bert_emotion_classifier.py (6)
29-35: Avoid global logging/warning configuration in a library moduleRemove
logging.basicConfig(...)and the globalwarnings.filterwarnings(...)to prevent side effects on host apps/tests. Keep a module logger only.-# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Suppress warnings for cleaner output -warnings.filterwarnings("ignore", category=UserWarning) +logger = logging.getLogger(__name__)
84-84: Remove unusedself.class_weightsfieldIt’s not used by this class (loss handles weights). Drop it to avoid confusion.
- self.class_weights = None
191-198: Fix implicit Optional type forthreshold(RUF013)Use
Optional[float]to satisfy PEP 484 and ruff.def predict_emotions( self, texts: Union[str, List[str]], - threshold: float = None, + threshold: Optional[float] = None, top_k: Optional[int] = None, batch_size: int = 32, ) -> Dict[str, Union[List[List[str]], List[List[float]], List[List[float]]]]:
370-375: Use tokenizer’s model_max_length instead of hardcoded 512Aligns with the chosen backbone and avoids silent truncation mismatches.
- max_length=512, + max_length=self.tokenizer.model_max_length,- max_length=self.max_length, + max_length=self.max_length,(Consider defaulting
max_lengthtotokenizer.model_max_lengthin the dataset’s constructor.)Also applies to: 230-231
445-447: Make evaluation threshold robust to missing configImport may fail if constant is absent; fall back to model’s threshold.
- from .config import EMOTION_CLASSIFICATION_THRESHOLD - threshold = EMOTION_CLASSIFICATION_THRESHOLD + try: + from .config import EMOTION_CLASSIFICATION_THRESHOLD as threshold # type: ignore + except Exception: + threshold = getattr(emotion_model, "prediction_threshold", 0.5)If you want, I can add a small verification script to confirm the constant exists in the repo.
1-1: Shebang on a library fileEither remove the shebang or mark the file executable. It isn’t needed for an importable module.
-#!/usr/bin/env python3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (2)
src/models/emotion_detection/__pycache__/emotion_labels.cpython-38.pycis excluded by!**/*.pycsrc/models/emotion_detection/__pycache__/samo_bert_emotion_classifier.cpython-38.pycis excluded by!**/*.pyc
📒 Files selected for processing (7)
src/models/emotion_detection/config.py(1 hunks)src/models/emotion_detection/emotion_labels.py(1 hunks)src/models/emotion_detection/enhanced_bert_classifier.py(1 hunks)src/models/emotion_detection/enhanced_config.py(1 hunks)src/models/emotion_detection/samo_bert_emotion_classifier.py(1 hunks)test_samo_emotion_detection_enhanced.py(1 hunks)test_samo_emotion_detection_standalone.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
src/models/emotion_detection/samo_bert_emotion_classifier.py (1)
src/models/emotion_detection/enhanced_bert_classifier.py (6)
_init_classification_layers(167-172)_freeze_bert_layers(174-188)forward(190-209)predict_emotions(227-268)count_parameters(491-493)count_frozen_parameters(495-497)
src/models/emotion_detection/config.py (1)
src/models/emotion_detection/enhanced_config.py (3)
EmotionDetectionConfig(46-51)get_config(520-522)update_config(525-532)
test_samo_emotion_detection_standalone.py (2)
src/models/emotion_detection/samo_bert_emotion_classifier.py (5)
create_samo_bert_emotion_classifier(389-426)count_parameters(280-282)count_frozen_parameters(284-286)predict_emotions(191-273)set_temperature(275-278)src/models/emotion_detection/emotion_labels.py (2)
get_all_emotions(269-276)get_emotion_description(243-253)
src/models/emotion_detection/enhanced_bert_classifier.py (1)
src/models/emotion_detection/samo_bert_emotion_classifier.py (7)
_init_classification_layers(119-124)_freeze_bert_layers(143-145)forward(151-189)forward(308-333)predict_emotions(191-273)count_parameters(280-282)count_frozen_parameters(284-286)
src/models/emotion_detection/enhanced_config.py (2)
src/models/emotion_detection/config.py (3)
EmotionDetectionConfig(46-93)get_config(117-119)update_config(122-125)deployment/cloud-run/secure_api_server.py (1)
sanitize_input(158-172)
test_samo_emotion_detection_enhanced.py (3)
src/models/emotion_detection/enhanced_bert_classifier.py (4)
EnhancedBERTEmotionClassifier(39-532)get_model_info(473-489)predict_emotions(227-268)get_performance_metrics(457-471)src/models/emotion_detection/enhanced_config.py (2)
create_enhanced_config_manager(570-579)get_config(520-522)src/models/emotion_detection/emotion_labels.py (2)
get_all_emotions(269-276)get_emotion_description(243-253)
🪛 Ruff (0.12.2)
src/models/emotion_detection/samo_bert_emotion_classifier.py
1-1: Shebang is present but file is not executable
(EXE001)
194-194: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
src/models/emotion_detection/config.py
1-1: Shebang is present but file is not executable
(EXE001)
test_samo_emotion_detection_standalone.py
1-1: Shebang is present but file is not executable
(EXE001)
90-90: Do not catch blind exception: Exception
(BLE001)
src/models/emotion_detection/emotion_labels.py
1-1: Shebang is present but file is not executable
(EXE001)
183-183: Avoid specifying long messages outside the exception class
(TRY003)
201-201: Avoid specifying long messages outside the exception class
(TRY003)
src/models/emotion_detection/enhanced_bert_classifier.py
116-116: Consider moving this statement to an else block
(TRY300)
117-117: Do not catch blind exception: Exception
(BLE001)
133-133: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
134-134: Avoid specifying long messages outside the exception class
(TRY003)
203-203: Consider moving this statement to an else block
(TRY300)
209-209: Avoid specifying long messages outside the exception class
(TRY003)
261-261: Avoid specifying long messages outside the exception class
(TRY003)
448-448: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
449-449: Avoid specifying long messages outside the exception class
(TRY003)
515-515: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
516-516: Avoid specifying long messages outside the exception class
(TRY003)
529-529: Consider moving this statement to an else block
(TRY300)
531-531: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
532-532: Avoid specifying long messages outside the exception class
(TRY003)
src/models/emotion_detection/enhanced_config.py
217-217: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
220-220: Do not catch blind exception: Exception
(BLE001)
221-221: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
279-279: Do not catch blind exception: Exception
(BLE001)
280-280: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
442-442: Consider moving this statement to an else block
(TRY300)
455-455: Consider moving this statement to an else block
(TRY300)
468-468: Consider moving this statement to an else block
(TRY300)
481-481: Consider moving this statement to an else block
(TRY300)
531-531: Do not catch blind exception: Exception
(BLE001)
532-532: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
545-545: Do not catch blind exception: Exception
(BLE001)
546-546: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
test_samo_emotion_detection_enhanced.py
1-1: Shebang is present but file is not executable
(EXE001)
35-35: Consider moving this statement to an else block
(TRY300)
36-36: Do not catch blind exception: Exception
(BLE001)
52-52: Consider moving this statement to an else block
(TRY300)
53-53: Do not catch blind exception: Exception
(BLE001)
70-70: Consider moving this statement to an else block
(TRY300)
71-71: Do not catch blind exception: Exception
(BLE001)
111-111: Consider moving this statement to an else block
(TRY300)
112-112: Do not catch blind exception: Exception
(BLE001)
145-145: Consider moving this statement to an else block
(TRY300)
146-146: Do not catch blind exception: Exception
(BLE001)
170-170: Consider moving this statement to an else block
(TRY300)
171-171: Do not catch blind exception: Exception
(BLE001)
184-184: Do not catch blind exception: Exception
(BLE001)
198-198: Consider moving this statement to an else block
(TRY300)
199-199: Do not catch blind exception: Exception
(BLE001)
219-219: Consider moving this statement to an else block
(TRY300)
220-220: Do not catch blind exception: Exception
(BLE001)
280-280: Consider moving this statement to an else block
(TRY300)
🔇 Additional comments (2)
test_samo_emotion_detection_enhanced.py (2)
151-171: Performance metrics won’t update on CPU with current classifier timing.predict_emotions only tracks CUDA-event timings; on CPU you’ll see zeros. After fixing the classifier (see model review), consider asserting >0 to catch regressions.
Add this assertion post-fix:
@@ - metrics = model.get_performance_metrics() + metrics = model.get_performance_metrics() print(f" Total inferences: {metrics['total_inferences']}") @@ - print("✅ Performance metrics working") + assert metrics['total_inferences'] >= len(test_texts), "Inference count not tracked" + print("✅ Performance metrics working")
176-199: Nice coverage of edge cases.Positive checks for None, very long text, and special chars. Keep.
| def update_config(config_dict: Dict[str, Any]) -> None: | ||
| """Update global configuration.""" | ||
| global _config | ||
| _config = get_config_from_dict(config_dict) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
update_config discards existing values; perform in-place shallow merge.
Rebuilding from defaults nukes prior updates.
-def update_config(config_dict: Dict[str, Any]) -> None:
- """Update global configuration."""
- global _config
- _config = get_config_from_dict(config_dict)
+def update_config(config_dict: Dict[str, Any]) -> None:
+ """Update global configuration in place (shallow merge)."""
+ global _config
+ for key, value in config_dict.items():
+ if hasattr(_config, key):
+ setattr(_config, key, value)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def update_config(config_dict: Dict[str, Any]) -> None: | |
| """Update global configuration.""" | |
| global _config | |
| _config = get_config_from_dict(config_dict) | |
| def update_config(config_dict: Dict[str, Any]) -> None: | |
| """Update global configuration in place (shallow merge).""" | |
| global _config | |
| for key, value in config_dict.items(): | |
| if hasattr(_config, key): | |
| setattr(_config, key, value) |
🤖 Prompt for AI Agents
In src/models/emotion_detection/config.py around lines 122 to 126, update_config
currently replaces the global _config with get_config_from_dict(config_dict),
discarding prior values; change it to perform an in-place shallow merge: compute
updates = get_config_from_dict(config_dict), ensure _config is initialized
(e.g., to get_default_config() or an empty dict if None), then iterate
updates.items() and assign each key into _config ( _config[key] = value ) so
existing keys not present in updates are preserved; keep the function signature
and avoid reassigning the _config reference.
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
| from transformers import AutoConfig, AutoModel, AutoTokenizer | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Add functional API import used in temperature handling.
Required by softplus in forward/info.
import torch.nn as nn
+import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoConfig, AutoModel, AutoTokenizer | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoConfig, AutoModel, AutoTokenizer |
🤖 Prompt for AI Agents
In src/models/emotion_detection/enhanced_bert_classifier.py around lines 14 to
18, the file uses the functional API (e.g., softplus) in forward/info but does
not import it; add the torch.nn.functional import (for example import
torch.nn.functional as F) near the other imports so softplus (F.softplus) is
available.
| start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | ||
| if start_time: | ||
| start_time.record() | ||
|
|
||
| try: | ||
| # Handle single text input | ||
| if isinstance(texts, str): | ||
| return self._predict_single_text(texts, top_k, return_metadata) | ||
|
|
||
| # Handle multiple texts | ||
| return self._predict_batch_texts(texts, top_k, return_metadata, batch_size) | ||
|
|
||
| except Exception as e: | ||
| logger.exception("Emotion prediction failed") | ||
| self._error_count += 1 | ||
| self._last_error = str(e) | ||
| raise RuntimeError(f"Emotion prediction failed: {e}") from e | ||
| finally: | ||
| if start_time: | ||
| end_time = torch.cuda.Event(enable_timing=True) | ||
| end_time.record() | ||
| torch.cuda.synchronize() | ||
| inference_time = start_time.elapsed_time(end_time) / 1000.0 | ||
| self._update_performance_metrics(inference_time) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Performance metrics currently update only on CUDA; add CPU/MPS timing fallback.
Ensure inference count/time tracked on all devices.
- start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
- if start_time:
- start_time.record()
+ use_cuda_timing = torch.cuda.is_available() and self.device.type == "cuda"
+ if use_cuda_timing:
+ start_event = torch.cuda.Event(enable_timing=True)
+ end_event = torch.cuda.Event(enable_timing=True)
+ start_event.record()
+ else:
+ import time as _time
+ _t0 = _time.perf_counter()
@@
- finally:
- if start_time:
- end_time = torch.cuda.Event(enable_timing=True)
- end_time.record()
- torch.cuda.synchronize()
- inference_time = start_time.elapsed_time(end_time) / 1000.0
- self._update_performance_metrics(inference_time)
+ finally:
+ if use_cuda_timing:
+ end_event.record()
+ torch.cuda.synchronize()
+ inference_time = start_event.elapsed_time(end_event) / 1000.0
+ else:
+ inference_time = (_time.perf_counter() - _t0)
+ self._update_performance_metrics(inference_time)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | |
| if start_time: | |
| start_time.record() | |
| try: | |
| # Handle single text input | |
| if isinstance(texts, str): | |
| return self._predict_single_text(texts, top_k, return_metadata) | |
| # Handle multiple texts | |
| return self._predict_batch_texts(texts, top_k, return_metadata, batch_size) | |
| except Exception as e: | |
| logger.exception("Emotion prediction failed") | |
| self._error_count += 1 | |
| self._last_error = str(e) | |
| raise RuntimeError(f"Emotion prediction failed: {e}") from e | |
| finally: | |
| if start_time: | |
| end_time = torch.cuda.Event(enable_timing=True) | |
| end_time.record() | |
| torch.cuda.synchronize() | |
| inference_time = start_time.elapsed_time(end_time) / 1000.0 | |
| self._update_performance_metrics(inference_time) | |
| # Start timing, with CUDA or CPU/MPS fallback | |
| use_cuda_timing = torch.cuda.is_available() and self.device.type == "cuda" | |
| if use_cuda_timing: | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| start_event.record() | |
| else: | |
| import time as _time | |
| _t0 = _time.perf_counter() | |
| try: | |
| # Handle single text input | |
| if isinstance(texts, str): | |
| return self._predict_single_text(texts, top_k, return_metadata) | |
| # Handle multiple texts | |
| return self._predict_batch_texts(texts, top_k, return_metadata, batch_size) | |
| except Exception as e: | |
| logger.exception("Emotion prediction failed") | |
| self._error_count += 1 | |
| self._last_error = str(e) | |
| raise RuntimeError(f"Emotion prediction failed: {e}") from e | |
| finally: | |
| if use_cuda_timing: | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| inference_time = start_event.elapsed_time(end_event) / 1000.0 | |
| else: | |
| inference_time = _time.perf_counter() - _t0 | |
| self._update_performance_metrics(inference_time) |
🧰 Tools
🪛 Ruff (0.12.2)
261-261: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In src/models/emotion_detection/enhanced_bert_classifier.py around lines 245 to
269, the current timing logic only initializes CUDA events and leaves start_time
None on CPU/MPS, so performance metrics aren't updated for non-CUDA runs; change
to use torch.cuda.Event timing when torch.cuda.is_available(), otherwise record
a CPU/MPS start timestamp with time.perf_counter() (ensure time is imported),
and in the finally block compute elapsed via CUDA event elapsed_time when using
CUDA (with torch.cuda.synchronize()), otherwise compute elapsed as
time.perf_counter() - cpu_start; always call
self._update_performance_metrics(inference_time) with seconds and ensure
inference count/time bookkeeping runs for all devices.
| self.temperature = nn.Parameter(torch.ones(1) * config["temperature"]) | ||
| self.class_weights = None | ||
| self.prediction_threshold = 0.6 # Updated from 0.5 to 0.6 based on calibration | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Enforce positive temperature; prevent divide-by-zero/NaN
Guard set_temperature to reject non-positive or non-finite values; current code could yield inf/NaN logits.
def set_temperature(self, temperature: float) -> None:
"""Set temperature scaling parameter."""
- self.temperature.data.fill_(temperature)
+ if temperature <= 0.0 or not np.isfinite(temperature):
+ raise ValueError("temperature must be a positive, finite float")
+ self.temperature.data.fill_(float(temperature))
logger.info("Set temperature to %s", temperature)Also applies to: 275-279
🤖 Prompt for AI Agents
In src/models/emotion_detection/samo_bert_emotion_classifier.py around lines
83-86 (and similarly update lines 275-279), the temperature parameter is
accepted without validation which can produce zero, negative or non-finite
values leading to inf/NaN logits; update the set_temperature logic to validate
the incoming value is a finite positive scalar (e.g., > 0 and torch.isfinite or
math.isfinite), reject or raise a ValueError for non-positive or non-finite
inputs, and only assign torch.nn.Parameter(torch.tensor(value)) when valid; also
ensure any public setters or config loaders that set temperature use the same
validation to prevent divide-by-zero issues.
| # Get top-k if specified | ||
| if top_k is not None: | ||
| _, top_k_indices = torch.topk(probabilities, top_k, dim=1) | ||
| predictions = torch.zeros_like(probabilities) | ||
| predictions.scatter_(1, top_k_indices, 1.0) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Clamp top_k and ensure k≥1
Prevent runtime errors when top_k > num classes or 0; clarify precedence vs threshold.
- if top_k is not None:
- _, top_k_indices = torch.topk(probabilities, top_k, dim=1)
+ if top_k is not None:
+ k = int(max(1, min(top_k, probabilities.size(1))))
+ _, top_k_indices = torch.topk(probabilities, k, dim=1)
predictions = torch.zeros_like(probabilities)
predictions.scatter_(1, top_k_indices, 1.0)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Get top-k if specified | |
| if top_k is not None: | |
| _, top_k_indices = torch.topk(probabilities, top_k, dim=1) | |
| predictions = torch.zeros_like(probabilities) | |
| predictions.scatter_(1, top_k_indices, 1.0) | |
| # Get top-k if specified | |
| if top_k is not None: | |
| # Clamp top_k to [1, num_classes] to avoid invalid k values | |
| k = int(max(1, min(top_k, probabilities.size(1)))) | |
| _, top_k_indices = torch.topk(probabilities, k, dim=1) | |
| predictions = torch.zeros_like(probabilities) | |
| predictions.scatter_(1, top_k_indices, 1.0) |
🤖 Prompt for AI Agents
In src/models/emotion_detection/samo_bert_emotion_classifier.py around lines 248
to 253, the top-k logic can raise errors when top_k <= 0 or top_k > number of
classes; clamp and validate top_k before using it: coerce top_k to an int, set
top_k = max(1, min(top_k, num_classes)) where num_classes =
probabilities.size(1), and then compute top_k_indices with the clamped value;
also explicitly document that when top_k is provided it takes precedence over
any probability-threshold branch so the top-k selection is applied regardless of
threshold.
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | ||
| """Get item at index.""" | ||
| sample_text = self.texts[idx] | ||
| labels = self.labels[idx] | ||
|
|
||
| # Tokenize text | ||
| encoding = self.tokenizer( | ||
| sample_text, | ||
| truncation=True, | ||
| padding="max_length", | ||
| max_length=self.max_length, | ||
| return_tensors="pt", | ||
| ) | ||
|
|
||
| # Convert labels to tensor | ||
| label_tensor = torch.tensor(labels, dtype=torch.float) | ||
|
|
||
| return { | ||
| "input_ids": encoding["input_ids"].squeeze(0), | ||
| "attention_mask": encoding["attention_mask"].squeeze(0), | ||
| "token_type_ids": encoding.get("token_type_ids", None), | ||
| "labels": label_tensor, | ||
| } |
There was a problem hiding this comment.
Fix None/shape issues for token_type_ids in Dataset items
Returning token_type_ids=None breaks default collate; not squeezing yields shape (B,1,L). Only include the key when present and squeeze.
- return {
- "input_ids": encoding["input_ids"].squeeze(0),
- "attention_mask": encoding["attention_mask"].squeeze(0),
- "token_type_ids": encoding.get("token_type_ids", None),
- "labels": label_tensor,
- }
+ item = {
+ "input_ids": encoding["input_ids"].squeeze(0),
+ "attention_mask": encoding["attention_mask"].squeeze(0),
+ "labels": label_tensor,
+ }
+ if "token_type_ids" in encoding and encoding["token_type_ids"] is not None:
+ item["token_type_ids"] = encoding["token_type_ids"].squeeze(0)
+ return item📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| """Get item at index.""" | |
| sample_text = self.texts[idx] | |
| labels = self.labels[idx] | |
| # Tokenize text | |
| encoding = self.tokenizer( | |
| sample_text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ) | |
| # Convert labels to tensor | |
| label_tensor = torch.tensor(labels, dtype=torch.float) | |
| return { | |
| "input_ids": encoding["input_ids"].squeeze(0), | |
| "attention_mask": encoding["attention_mask"].squeeze(0), | |
| "token_type_ids": encoding.get("token_type_ids", None), | |
| "labels": label_tensor, | |
| } | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| """Get item at index.""" | |
| sample_text = self.texts[idx] | |
| labels = self.labels[idx] | |
| # Tokenize text | |
| encoding = self.tokenizer( | |
| sample_text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ) | |
| # Convert labels to tensor | |
| label_tensor = torch.tensor(labels, dtype=torch.float) | |
| item = { | |
| "input_ids": encoding["input_ids"].squeeze(0), | |
| "attention_mask": encoding["attention_mask"].squeeze(0), | |
| "labels": label_tensor, | |
| } | |
| if "token_type_ids" in encoding and encoding["token_type_ids"] is not None: | |
| item["token_type_ids"] = encoding["token_type_ids"].squeeze(0) | |
| return item |
🤖 Prompt for AI Agents
In src/models/emotion_detection/samo_bert_emotion_classifier.py around lines 364
to 386, the Dataset __getitem__ currently returns token_type_ids=None (breaking
default collate) and leaves input tensors with an extra leading dimension;
change it so you only include the "token_type_ids" key when encoding provides
it, and when present squeeze the 0th dim like the other inputs (i.e., if
"token_type_ids" in encoding: include "token_type_ids":
encoding["token_type_ids"].squeeze(0)); ensure input_ids and attention_mask
continue to be squeezed and labels tensor has correct shape/dtype.
- Fix FLK-E501 line length violations in enhanced_config.py - Break long lines in configuration parsing methods - Fix temperature parameter overwrite in enhanced_bert_classifier.py - Preserve user-provided temperature values during initialization - Improve code readability and maintainability
- Replace EMOTION_CLASSIFICATION_THRESHOLD and EMOTION_PREDICTION_THRESHOLD constants - Add get_evaluation_threshold() and get_prediction_threshold() functions - Functions now read from global config to reflect runtime updates - Update usage in samo_bert_emotion_classifier.py - Ensures threshold values stay synchronized with config changes
- Change import from .labels to .emotion_labels in enhanced_bert_classifier.py - Standardize on emotion_labels.py as the canonical source for GOEMOTIONS_EMOTIONS - Fixes broken import path issue identified by code review - Maintains consistency with other modules using emotion_labels.py
- Move class weights tensor creation after model initialization - Use model.device to ensure tensor is on same device as model - Prevents device mismatch errors during training - Maintains compatibility with both CPU and CUDA devices - Improves robustness of loss function creation
- Move GOEMOTIONS_EMOTIONS import from function to module level - Improves performance by avoiding repeated imports - Makes code cleaner and more maintainable - Emotion mapping already working correctly with proper labels - Maintains existing functionality while improving code structure
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (15)
src/models/emotion_detection/config.py (1)
126-130: update_config should shallow-merge, not replace the singletonReassigning _config discards prior fields not present in the update. Merge in place.
Apply this diff:
def update_config(config_dict: Dict[str, Any]) -> None: - """Update global configuration.""" + """Update global configuration in place (shallow merge).""" global _config - _config = get_config_from_dict(config_dict) + if _config is None: + _config = get_default_config() + updates = get_config_from_dict(config_dict) + for k, v in updates.__dict__.items(): + setattr(_config, k, v)src/models/emotion_detection/enhanced_bert_classifier.py (5)
84-87: Avoid double-assigning temperature; initialize once as nn.ParameterSet the learnable temperature in init and remove the reassignment in _initialize_classifier.
Apply this diff:
@@ - self.temperature = temperature + self.temperature = nn.Parameter(torch.tensor(float(temperature), dtype=torch.float32)) @@ - self.temperature = nn.Parameter(torch.ones(1) * self.temperature) + # temperature already initialized in __init__Also applies to: 146-147
211-226: Use device-aware autocast; avoid CUDA-only context on CPU/MPSSwitch to torch.autocast with device_type guard.
Apply this diff:
@contextmanager def inference_mode(self): """Context manager for inference mode with optimizations.""" was_training = self.training self.eval() try: with torch.no_grad(): - if self.use_mixed_precision: - with torch.cuda.amp.autocast(): - yield - else: - yield + device_type = self.device.type + enable_amp = self.use_mixed_precision and device_type in ("cuda", "cpu") + if enable_amp: + with torch.autocast(device_type=device_type): + yield + else: + yield finally: if was_training: self.train()
245-269: Performance metrics not tracked on CPU/MPSOnly CUDA path updates timing. Add CPU/MPS fallback.
Apply this diff:
- start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None - if start_time: - start_time.record() + use_cuda_timing = torch.cuda.is_available() and self.device.type == "cuda" + if use_cuda_timing: + _start = torch.cuda.Event(enable_timing=True) + _end = torch.cuda.Event(enable_timing=True) + _start.record() + else: + import time as _time + _t0 = _time.perf_counter() @@ - finally: - if start_time: - end_time = torch.cuda.Event(enable_timing=True) - end_time.record() - torch.cuda.synchronize() - inference_time = start_time.elapsed_time(end_time) / 1000.0 - self._update_performance_metrics(inference_time) + finally: + if use_cuda_timing: + _end.record() + torch.cuda.synchronize() + inference_time = _start.elapsed_time(_end) / 1000.0 + else: + inference_time = (_time.perf_counter() - _t0) + self._update_performance_metrics(inference_time)
190-203: Make forward robust: handle missing pooler_output and enforce positive temperatureFallback to CLS when pooler_output is None and ensure temperature > 0 via softplus; import F.
Apply this diff:
@@ -import torch.nn as nn -from transformers import AutoConfig, AutoModel, AutoTokenizer +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoConfig, AutoModel, AutoTokenizer @@ - def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Forward pass with error handling and optimizations.""" try: # Get BERT outputs bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) - pooled_output = bert_outputs.pooler_output + pooled_output = getattr(bert_outputs, "pooler_output", None) + if pooled_output is None: + pooled_output = bert_outputs.last_hidden_state[:, 0, :] # Classification head logits = self.classifier(pooled_output) # Apply temperature scaling - logits = logits / self.temperature + temp = F.softplus(self.temperature) + 1e-6 + logits = logits / tempAlso applies to: 14-18
320-359: Batch results lose alignment when inputs contain empty stringsPreserve original order by inserting empty predictions for invalid entries.
Apply this diff:
- # Filter out empty texts - valid_texts = [text for text in texts if text and text.strip()] - if not valid_texts: - return [self._create_empty_prediction(return_metadata) for _ in texts] + # Build index map preserving positions + index_map: List[int] = [] + valid_texts: List[str] = [] + for idx, t in enumerate(texts): + if t and t.strip(): + index_map.append(idx) + valid_texts.append(t) + if not valid_texts: + return [self._create_empty_prediction(return_metadata) for _ in texts] @@ - # Process results - results = [] - for i, prob in enumerate(probabilities): - result = self._process_prediction_results( - prob, top_k, return_metadata, valid_texts[i] - ) - results.append(result) - - return results + # Reconstruct results in original order + results: List[EmotionPrediction] = [ + self._create_empty_prediction(return_metadata) for _ in texts + ] + for i, prob in enumerate(probabilities): + results[index_map[i]] = self._process_prediction_results( + prob, top_k, return_metadata, valid_texts[i] + ) + return resultssrc/models/emotion_detection/samo_bert_emotion_classifier.py (6)
193-201: Type hint uses implicit OptionalChange to Optional[float] for threshold.
Apply this diff:
- threshold: float = None, + threshold: Optional[float] = None,
250-254: Clamp top_k to valid range and ensure k ≥ 1Prevents runtime errors when top_k ≤ 0 or > num_classes.
Apply this diff:
- if top_k is not None: - _, top_k_indices = torch.topk(probabilities, top_k, dim=1) + if top_k is not None: + k = int(max(1, min(top_k, probabilities.size(1)))) + _, top_k_indices = torch.topk(probabilities, k, dim=1) predictions = torch.zeros_like(probabilities) predictions.scatter_(1, top_k_indices, 1.0)
276-279: Validate temperature setterReject non-positive or non-finite values.
Apply this diff:
def set_temperature(self, temperature: float) -> None: """Set temperature scaling parameter.""" - self.temperature.data.fill_(temperature) + if temperature <= 0.0 or not np.isfinite(temperature): + raise ValueError("temperature must be a positive, finite float") + self.temperature.data.fill_(float(temperature)) logger.info("Set temperature to %s", temperature)
325-331: Class weights may be on wrong device and not broadcast correctlyEnsure device alignment and expand to batch shape.
Apply this diff:
# Apply class weights if provided if self.class_weights is not None: - bce_loss = bce_loss * self.class_weights.unsqueeze(0) + if self.class_weights.device != bce_loss.device: + self.class_weights = self.class_weights.to(bce_loss.device) + bce_loss = bce_loss * self.class_weights.unsqueeze(0).expand_as(bce_loss)
382-387: Returning token_type_ids=None breaks default collateOnly include token_type_ids when present; squeeze to match shapes.
Apply this diff:
- return { - "input_ids": encoding["input_ids"].squeeze(0), - "attention_mask": encoding["attention_mask"].squeeze(0), - "token_type_ids": encoding.get("token_type_ids", None), - "labels": label_tensor, - } + item = { + "input_ids": encoding["input_ids"].squeeze(0), + "attention_mask": encoding["attention_mask"].squeeze(0), + "labels": label_tensor, + } + ttids = encoding.get("token_type_ids", None) + if ttids is not None: + item["token_type_ids"] = ttids.squeeze(0) + return item
188-191: Enforce positive temperature to avoid divide-by-zero/NaNUse softplus on the learnable temperature before scaling logits.
Apply this diff:
- # Apply temperature scaling - logits = logits / self.temperature + # Apply temperature scaling (ensure strictly positive temperature) + temp = F.softplus(self.temperature) + 1e-6 + logits = logits / tempsrc/models/emotion_detection/enhanced_config.py (3)
12-12: Persist all sections when saving; use asdict + safe_dump and ensure the directory exists.Current save drops most sections and uses yaml.dump. Persist the full dataclass tree, create parent dirs, and use safe_dump.
-from dataclasses import dataclass, field +from dataclasses import dataclass, field, asdict @@ def save_config(self, path: Optional[Union[str, Path]] = None) -> None: """Save current configuration to file.""" if path is None: path = self.config_path or "configs/samo_emotion_detection_config.yaml" try: - # Convert config to dictionary and save as YAML - config_dict = self._config_to_dict() - with open(path, 'w', encoding='utf-8') as f: - yaml.dump(config_dict, f, default_flow_style=False, indent=2) + # Convert config to dictionary and save as YAML + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + config_dict = self._config_to_dict() + with open(path, 'w', encoding='utf-8') as f: + yaml.safe_dump(config_dict, f, default_flow_style=False, indent=2, sort_keys=False) logger.info("Configuration saved to: %s", path) - except Exception as e: - logger.error("Failed to save configuration: %s", e) + except Exception as e: + logger.exception("Failed to save configuration: %s", e) @@ - def _config_to_dict(self) -> Dict[str, Any]: - """Convert configuration to dictionary.""" - # This would need proper serialization logic - # For now, return a basic structure - return { - "model": { - "name": self.config.model.name, - "device": self.config.model.device, - "use_mixed_precision": self.config.model.use_mixed_precision, - "cache_embeddings": self.config.model.cache_embeddings, - "max_sequence_length": self.config.model.max_sequence_length, - }, - "emotion_detection": { - "num_emotions": self.config.emotion_detection.num_emotions, - "prediction_threshold": ( - self.config.emotion_detection.prediction_threshold - ), - "temperature": self.config.emotion_detection.temperature, - "top_k": self.config.emotion_detection.top_k, - }, - # Add other sections as needed - } + def _config_to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary (all sections).""" + return asdict(self.config)Also applies to: 649-662, 663-685
506-516: Validators should return true defaults, not placeholders; thread defaults through call sites.Returning "", False, 1, etc. silently loses intended defaults when types are wrong. Accept a default param and return it; pass explicit defaults at call sites.
- def _validate_string(value: Any, field_name: str) -> str: + def _validate_string(value: Any, field_name: str, default: str = "") -> str: @@ - return "" + return default return value @@ - def _validate_bool(value: Any, field_name: str) -> bool: + def _validate_bool(value: Any, field_name: str, default: bool = False) -> bool: @@ - return False + return default return value @@ - def _validate_positive_int(value: Any, field_name: str) -> int: + def _validate_positive_int(value: Any, field_name: str, default: int = 1) -> int: @@ - return 1 + return default @@ - return 1 + return default @@ - def _validate_non_negative_int(value: Any, field_name: str) -> int: + def _validate_non_negative_int(value: Any, field_name: str, default: int = 0) -> int: @@ - return 0 + return default @@ - return 0 + return default @@ - def _validate_positive_float(value: Any, field_name: str) -> float: + def _validate_positive_float(value: Any, field_name: str, default: float = 1.0) -> float: @@ - return 1.0 + return default @@ - return 1.0 + return default @@ - def _validate_float_range( - value: Any, min_val: float, max_val: float, field_name: str - ) -> float: + def _validate_float_range( + value: Any, min_val: float, max_val: float, field_name: str, default: Optional[float] = None + ) -> float: @@ - return (min_val + max_val) / 2 + return default if default is not None else (min_val + max_val) / 2 @@ - return (min_val + max_val) / 2 + return default if default is not None else (min_val + max_val) / 2 @@ - def _validate_list(value: Any, field_name: str) -> List: + def _validate_list(value: Any, field_name: str, default: Optional[List] = None) -> List: @@ - return [] + return default if default is not None else [] return valuePass defaults in representative call sites:
- name=self._validate_string( - data.get("name", "bert-base-uncased"), "model.name" - ), + name=self._validate_string( + data.get("name", "bert-base-uncased"), "model.name", "bert-base-uncased" + ), @@ - prediction_threshold=self._validate_float_range( - data.get("prediction_threshold", 0.6), 0.0, 1.0, - "emotion_detection.prediction_threshold" - ), + prediction_threshold=self._validate_float_range( + data.get("prediction_threshold", 0.6), 0.0, 1.0, + "emotion_detection.prediction_threshold", 0.6 + ), @@ - padding=self._validate_string(data.get("padding", "max_length"), "data.padding"), + padding=self._validate_string(data.get("padding", "max_length"), "data.padding", "max_length"), @@ - metrics=self._validate_list(data.get("metrics", ["precision", "recall", "f1_micro", "f1_macro", "accuracy"]), "evaluation.metrics"), + metrics=self._validate_list( + data.get("metrics", ["precision", "recall", "f1_micro", "f1_macro", "accuracy"]), + "evaluation.metrics", + ["precision", "recall", "f1_micro", "f1_macro", "accuracy"] + ), @@ - top_k_values=self._validate_list(data.get("top_k_values", [1, 3, 5]), "evaluation.top_k_values"), + top_k_values=self._validate_list(data.get("top_k_values", [1, 3, 5]), "evaluation.top_k_values", [1, 3, 5]), @@ - log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir"), + log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir", "logs/emotion_detection"), @@ - save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir"), + save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir", "models/emotion_detection"), @@ - save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric"), + save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric", "f1_macro"), @@ - error_log_file=self._validate_string( - data.get("error_log_file", "logs/emotion_detection_errors.log"), - "error_handling.error_log_file" - ), + error_log_file=self._validate_string( + data.get("error_log_file", "logs/emotion_detection_errors.log"), + "error_handling.error_log_file", + "logs/emotion_detection_errors.log" + ),Also applies to: 517-526, 527-543, 544-560, 561-577, 579-596, 627-634, 324-341, 385-395, 396-404, 405-414, 415-423
639-648: update_config is a no-op and static; make it an instance method with a safe shallow merge.Implement a minimal, documented, section-wise merge and drop the blind except.
- @staticmethod - def update_config(updates: Dict[str, Any]) -> None: + def update_config(self, updates: Dict[str, Any]) -> None: """Update configuration with new values.""" - try: - # This would need more sophisticated merging logic - # For now, just log the attempt - logger.info("Configuration update requested: %s", updates) - except Exception as e: - logger.error("Configuration update failed: %s", e) + logger.info("Configuration update requested: %s", updates) + if not isinstance(updates, dict): + raise TypeError("updates must be a dict of sections -> dict") + for section, vals in updates.items(): + if not hasattr(self.config, section): + logger.warning("Unknown config section: %s (skipped)", section) + continue + section_obj = getattr(self.config, section) + if not isinstance(vals, dict): + logger.warning("Section %s update must be a dict (skipped)", section) + continue + for k, v in vals.items(): + if hasattr(section_obj, k): + setattr(section_obj, k, v) + else: + logger.warning("Unknown field %s.%s (skipped)", section, k)
🧹 Nitpick comments (8)
src/models/emotion_detection/config.py (1)
1-1: Remove shebang or mark file executableLibrary modules don’t need a shebang; Ruff EXE001. Remove it to silence the lint.
Apply this diff:
-#!/usr/bin/env python3src/models/emotion_detection/samo_bert_emotion_classifier.py (3)
1-1: Remove shebang or make file executableLibrary module doesn’t need a shebang; fixes Ruff EXE001.
Apply this diff:
-#!/usr/bin/env python3
31-37: Avoid configuring global logging and blanket warning suppression in a libraryRemove basicConfig and scope warning filter to transformers only.
Apply this diff:
-# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Suppress warnings for cleaner output -warnings.filterwarnings("ignore", category=UserWarning) +# Configure module logger (do not set global handlers in libraries) +logger = logging.getLogger(__name__) +# Suppress noisy HF warnings only +warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
212-214: Source prediction threshold from config when availableUse the runtime getter to keep inference aligned with configuration.
Apply this diff:
- if threshold is None: - threshold = self.prediction_threshold + if threshold is None: + try: + from .config import get_prediction_threshold + threshold = get_prediction_threshold() + except Exception: + threshold = self.prediction_thresholdsrc/models/emotion_detection/enhanced_config.py (4)
230-237: Prefer logging.exception and narrow exception types in except blocks.Improves diagnostics and satisfies linter (TRY400/BLE001).
- except yaml.YAMLError as e: - logger.error("YAML parsing error: %s", e) + except yaml.YAMLError as e: + logger.exception("YAML parsing error: %s", e) logger.warning("Using default configuration due to YAML error") return self._create_default_config() - except Exception as e: - logger.error("Configuration loading failed: %s", e) + except (OSError, UnicodeDecodeError) as e: + logger.exception("Configuration loading failed: %s", e) logger.warning("Using default configuration due to loading error") return self._create_default_config() @@ - except Exception as e: - logger.error("Configuration parsing failed: %s", e) + except (ValueError, TypeError, KeyError) as e: + logger.exception("Configuration parsing failed: %s", e) logger.warning("Using default configuration due to parsing error") return self._create_default_config()Also applies to: 319-323
36-43: Device handling: return a concrete default (“auto”) instead of None; accept DeviceType.None vs “auto” is ambiguous and can leak into downstream logic. Standardize on the string.
-@dataclass -class ModelConfig: +@dataclass +class ModelConfig: @@ - device: Optional[str] = None + device: str = "auto"- def _validate_device(value: Any) -> Optional[str]: + def _validate_device(value: Any, default: str = "auto") -> str: """Validate device value.""" - if value is None: - return None - if not isinstance(value, str): - logger.warning("Invalid device value: %s, using auto", value) - return None + if value is None: + return default + if isinstance(value, DeviceType): + value = value.value + if not isinstance(value, str): + logger.warning("Invalid device value: %s, using %s", value, default) + return default valid_devices = ["auto", "cpu", "cuda", "mps"] if value.lower() not in valid_devices: - logger.warning("Invalid device: %s, using auto", value) - return None + logger.warning("Invalid device: %s, using %s", value, default) + return default return value.lower()Would any caller currently rely on device=None to imply “auto”? If so, I’ll adapt the downstream code accordingly.
Also applies to: 597-610
260-303: Add cross-field consistency checks (top_k bounds, split sums).Prevents invalid-but-parseable configs from slipping through.
def _parse_config(self, config_data: Dict[str, Any]) -> EnhancedEmotionDetectionConfig: """Parse configuration data into structured format.""" try: @@ - return EnhancedEmotionDetectionConfig( + # Cross-field normalization + if emotion_config.top_k > emotion_config.num_emotions: + logger.warning( + "emotion_detection.top_k (%d) > num_emotions (%d); clamping to %d", + emotion_config.top_k, emotion_config.num_emotions, emotion_config.num_emotions + ) + emotion_config.top_k = emotion_config.num_emotions + if any(k > emotion_config.num_emotions for k in evaluation_config.top_k_values): + evaluation_config.top_k_values = [ + min(k, emotion_config.num_emotions) for k in evaluation_config.top_k_values + ] + if (data_config.validation_split + data_config.test_split) >= 1.0: + logger.warning( + "data.validation_split + data.test_split >= 1.0; resetting to 0.2/0.1" + ) + data_config.validation_split = 0.2 + data_config.test_split = 0.1 + + return EnhancedEmotionDetectionConfig( model=model_config, emotion_detection=emotion_config, architecture=architecture_config, training=training_config, data=data_config, evaluation=evaluation_config, logging=logging_config, model_saving=model_saving_config, performance=performance_config, samo_optimizations=samo_optimizations, error_handling=error_handling, security=security_config, development=development_config, )
612-626: Optionally accept LogLevel enums in _validate_log_level (or remove unused LogLevel).Minor DX polish; keeps API consistent with the defined enum.
def _validate_log_level(value: Any, field_name: str) -> str: """Validate log level value.""" + if isinstance(value, LogLevel): + return value.value if not isinstance(value, str): logger.warning( "Invalid log level for %s: %s, using default", field_name, value ) return "INFO"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/models/emotion_detection/config.py(1 hunks)src/models/emotion_detection/enhanced_bert_classifier.py(1 hunks)src/models/emotion_detection/enhanced_config.py(1 hunks)src/models/emotion_detection/samo_bert_emotion_classifier.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
src/models/emotion_detection/enhanced_bert_classifier.py (1)
src/models/emotion_detection/samo_bert_emotion_classifier.py (4)
_freeze_bert_layers(145-147)forward(153-191)forward(309-334)predict_emotions(193-274)
src/models/emotion_detection/samo_bert_emotion_classifier.py (1)
src/models/emotion_detection/config.py (1)
get_evaluation_threshold(39-41)
src/models/emotion_detection/config.py (1)
src/models/emotion_detection/enhanced_config.py (3)
EmotionDetectionConfig(46-51)get_config(635-637)update_config(640-647)
src/models/emotion_detection/enhanced_config.py (2)
src/models/emotion_detection/config.py (3)
EmotionDetectionConfig(50-97)get_config(121-123)update_config(126-129)deployment/cloud-run/secure_api_server.py (1)
sanitize_input(158-172)
🪛 Ruff (0.12.2)
src/models/emotion_detection/enhanced_bert_classifier.py
116-116: Consider moving this statement to an else block
(TRY300)
117-117: Do not catch blind exception: Exception
(BLE001)
133-133: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
134-134: Avoid specifying long messages outside the exception class
(TRY003)
203-203: Consider moving this statement to an else block
(TRY300)
209-209: Avoid specifying long messages outside the exception class
(TRY003)
261-261: Avoid specifying long messages outside the exception class
(TRY003)
448-448: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
449-449: Avoid specifying long messages outside the exception class
(TRY003)
515-515: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
516-516: Avoid specifying long messages outside the exception class
(TRY003)
529-529: Consider moving this statement to an else block
(TRY300)
531-531: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
532-532: Avoid specifying long messages outside the exception class
(TRY003)
src/models/emotion_detection/samo_bert_emotion_classifier.py
1-1: Shebang is present but file is not executable
(EXE001)
196-196: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
src/models/emotion_detection/config.py
1-1: Shebang is present but file is not executable
(EXE001)
src/models/emotion_detection/enhanced_config.py
231-231: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
234-234: Do not catch blind exception: Exception
(BLE001)
235-235: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
319-319: Do not catch blind exception: Exception
(BLE001)
320-320: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
537-537: Consider moving this statement to an else block
(TRY300)
554-554: Consider moving this statement to an else block
(TRY300)
571-571: Consider moving this statement to an else block
(TRY300)
590-590: Consider moving this statement to an else block
(TRY300)
646-646: Do not catch blind exception: Exception
(BLE001)
647-647: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
660-660: Do not catch blind exception: Exception
(BLE001)
661-661: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
🔇 Additional comments (1)
src/models/emotion_detection/config.py (1)
39-46: Good fix: thresholds now reflect runtime configSwitching to getters removes the import-time binding issue and aligns with runtime updates.
| def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig: | ||
| """Get configuration from dictionary with defaults.""" | ||
| default_config = get_default_config() | ||
|
|
||
| # Update with provided values | ||
| for key, value in config_dict.items(): | ||
| if hasattr(default_config, key): | ||
| setattr(default_config, key, value) | ||
|
|
||
| return default_config |
There was a problem hiding this comment.
Config update ignores nested YAML structure; implement nested → flat mapping
Passing a YAML-like dict with sections (model/training/evaluation/prediction) won’t update the dataclass fields. Add flattening so nested keys map to attributes.
Apply this diff:
def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig:
"""Get configuration from dictionary with defaults."""
- default_config = get_default_config()
-
- # Update with provided values
- for key, value in config_dict.items():
- if hasattr(default_config, key):
- setattr(default_config, key, value)
-
- return default_config
+ cfg = get_default_config()
+
+ # Map nested sections to dataclass fields
+ section_map = {
+ "model": {
+ "name": "model_name",
+ "num_emotions": "num_emotions",
+ "hidden_dropout_prob": "hidden_dropout_prob",
+ "classifier_dropout_prob": "classifier_dropout_prob",
+ "freeze_bert_layers": "freeze_bert_layers",
+ "temperature": "temperature",
+ },
+ "training": {
+ "batch_size": "batch_size",
+ "learning_rate": "learning_rate",
+ "num_epochs": "num_epochs",
+ "weight_decay": "weight_decay",
+ },
+ "evaluation": {
+ "threshold": "evaluation_threshold",
+ "top_k": "top_k",
+ },
+ "prediction": {
+ "threshold": "prediction_threshold",
+ "max_length": "max_length",
+ },
+ }
+
+ # Apply nested updates
+ for section, mapping in section_map.items():
+ data = config_dict.get(section)
+ if isinstance(data, dict):
+ for k, v in data.items():
+ attr = mapping.get(k)
+ if attr and hasattr(cfg, attr):
+ setattr(cfg, attr, v)
+
+ # Apply any flat overrides
+ for key, value in config_dict.items():
+ if not isinstance(value, dict) and hasattr(cfg, key):
+ setattr(cfg, key, value)
+
+ return cfg📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig: | |
| """Get configuration from dictionary with defaults.""" | |
| default_config = get_default_config() | |
| # Update with provided values | |
| for key, value in config_dict.items(): | |
| if hasattr(default_config, key): | |
| setattr(default_config, key, value) | |
| return default_config | |
| def get_config_from_dict(config_dict: Dict[str, Any]) -> EmotionDetectionConfig: | |
| """Get configuration from dictionary with defaults.""" | |
| cfg = get_default_config() | |
| # Map nested sections to dataclass fields | |
| section_map = { | |
| "model": { | |
| "name": "model_name", | |
| "num_emotions": "num_emotions", | |
| "hidden_dropout_prob": "hidden_dropout_prob", | |
| "classifier_dropout_prob": "classifier_dropout_prob", | |
| "freeze_bert_layers": "freeze_bert_layers", | |
| "temperature": "temperature", | |
| }, | |
| "training": { | |
| "batch_size": "batch_size", | |
| "learning_rate": "learning_rate", | |
| "num_epochs": "num_epochs", | |
| "weight_decay": "weight_decay", | |
| }, | |
| "evaluation": { | |
| "threshold": "evaluation_threshold", | |
| "top_k": "top_k", | |
| }, | |
| "prediction": { | |
| "threshold": "prediction_threshold", | |
| "max_length": "max_length", | |
| }, | |
| } | |
| # Apply nested updates | |
| for section, mapping in section_map.items(): | |
| data = config_dict.get(section) | |
| if isinstance(data, dict): | |
| for k, v in data.items(): | |
| attr = mapping.get(k) | |
| if attr and hasattr(cfg, attr): | |
| setattr(cfg, attr, v) | |
| # Apply any flat overrides | |
| for key, value in config_dict.items(): | |
| if not isinstance(value, dict) and hasattr(cfg, key): | |
| setattr(cfg, key, value) | |
| return cfg |
🤖 Prompt for AI Agents
In src/models/emotion_detection/config.py around lines 105 to 114, the current
get_config_from_dict only updates top-level keys and ignores nested YAML
sections (model/training/evaluation/prediction). Modify the function to flatten
one-level (or recursive) nested dicts into a single mapping where nested keys
map to dataclass attribute names (e.g., config_dict["model"]["lr"] -> "lr"),
then iterate that flattened mapping and setattr on default_config for keys that
exist; ignore unknown keys. Ensure nested dict values override defaults and
preserve non-dict top-level keys.
- Break long lines in configuration parsing methods - Fix SAMO optimizations, performance, logging, model saving configs - Fix data and evaluation config parsing methods - Improve code readability and maintainability - All lines now comply with 88-character limit
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (7)
src/models/emotion_detection/enhanced_config.py (7)
565-575: Regression note: matches prior feedback.This validator behavior was previously flagged. The above validator+call-site fixes align with that guidance.
Run to spot any remaining 2-arg validator calls lacking explicit defaults:
#!/bin/bash rg -nP 'self\._validate_(string|bool|positive_int|non_negative_int|positive_float|log_level|list)\([^,]+,\s*["'\"'"][^"'\"'"]+["'\"'"]\s*\)' src/models/emotion_detection/enhanced_config.py
708-744: Persist the full config; use asdict, safe_dump, and ensure target dir exists.Current save drops most sections and uses yaml.dump. This risks config loss and unsafe YAML emission.
Apply:
@@ - def save_config(self, path: Optional[Union[str, Path]] = None) -> None: + def save_config(self, path: Optional[Union[str, Path]] = None) -> None: """Save current configuration to file.""" if path is None: path = self.config_path or "configs/samo_emotion_detection_config.yaml" try: - # Convert config to dictionary and save as YAML - config_dict = self._config_to_dict() - with open(path, 'w', encoding='utf-8') as f: - yaml.dump(config_dict, f, default_flow_style=False, indent=2) + dest = Path(path) + dest.parent.mkdir(parents=True, exist_ok=True) + config_dict = self._config_to_dict() + with dest.open('w', encoding='utf-8') as f: + yaml.safe_dump(config_dict, f, default_flow_style=False, indent=2, sort_keys=False) logger.info("Configuration saved to: %s", path) - except Exception as e: - logger.error("Failed to save configuration: %s", e) + except (OSError, TypeError) as e: + logger.exception("Failed to save configuration: %s", e) @@ - def _config_to_dict(self) -> Dict[str, Any]: - """Convert configuration to dictionary.""" - # This would need proper serialization logic - # For now, return a basic structure - return { - "model": { - "name": self.config.model.name, - "device": self.config.model.device, - "use_mixed_precision": self.config.model.use_mixed_precision, - "cache_embeddings": self.config.model.cache_embeddings, - "max_sequence_length": self.config.model.max_sequence_length, - }, - "emotion_detection": { - "num_emotions": self.config.emotion_detection.num_emotions, - "prediction_threshold": ( - self.config.emotion_detection.prediction_threshold - ), - "temperature": self.config.emotion_detection.temperature, - "top_k": self.config.emotion_detection.top_k, - }, - # Add other sections as needed - } + def _config_to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary (all sections).""" + return asdict(self.config)And add the import:
-from dataclasses import dataclass, field +from dataclasses import dataclass, field, asdict
566-693: Validators return placeholders; return true defaults via a default= parameter.Log says “using default” but returns "", 0, 1, etc. Make validators accept a default and return it; propagate at call sites.
- def _validate_string(value: Any, field_name: str) -> str: + def _validate_string(value: Any, field_name: str, default: str = "") -> str: @@ - return "" + return default return value @@ - def _validate_bool(value: Any, field_name: str) -> bool: + def _validate_bool(value: Any, field_name: str, default: bool = False) -> bool: @@ - return False + return default return value @@ - def _validate_positive_int(value: Any, field_name: str) -> int: + def _validate_positive_int(value: Any, field_name: str, default: int = 1) -> int: @@ - return 1 + return default return int_val except (ValueError, TypeError): @@ - return 1 + return default @@ - def _validate_non_negative_int(value: Any, field_name: str) -> int: + def _validate_non_negative_int(value: Any, field_name: str, default: int = 0) -> int: @@ - return 0 + return default return int_val except (ValueError, TypeError): @@ - return 0 + return default @@ - def _validate_positive_float(value: Any, field_name: str) -> float: + def _validate_positive_float(value: Any, field_name: str, default: float = 1.0) -> float: @@ - return 1.0 + return default return float_val except (ValueError, TypeError): @@ - return 1.0 + return default @@ - def _validate_float_range( - value: Any, min_val: float, max_val: float, field_name: str - ) -> float: + def _validate_float_range( + value: Any, min_val: float, max_val: float, field_name: str, default: Optional[float] = None + ) -> float: @@ - return (min_val + max_val) / 2 + return default if default is not None else (min_val + max_val) / 2 return float_val except (ValueError, TypeError): @@ - return (min_val + max_val) / 2 + return default if default is not None else (min_val + max_val) / 2 @@ - def _validate_device(value: Any) -> Optional[str]: + def _validate_device(value: Any, default: Optional[str] = None) -> Optional[str]: @@ - return None + return default if not isinstance(value, str): - logger.warning("Invalid device value: %s, using auto", value) - return None + logger.warning("Invalid device value: %s, using default", value) + return default valid_devices = ["auto", "cpu", "cuda", "mps"] if value.lower() not in valid_devices: - logger.warning("Invalid device: %s, using auto", value) - return None + logger.warning("Invalid device: %s, using default", value) + return default return value.lower() @@ - def _validate_log_level(value: Any, field_name: str) -> str: + def _validate_log_level(value: Any, field_name: str, default: str = "INFO") -> str: @@ - return "INFO" + return default valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] if value.upper() not in valid_levels: @@ - return "INFO" + return default return value.upper() @@ - def _validate_list(value: Any, field_name: str) -> List: + def _validate_list(value: Any, field_name: str, default: Optional[List] = None) -> List: @@ - return [] + return default if default is not None else [] return value
698-707: update_config is a no-op and static; implement a shallow, validated merge.Make it an instance method that merges top-level sections by re-parsing with validators.
- @staticmethod - def update_config(updates: Dict[str, Any]) -> None: + def update_config(self, updates: Dict[str, Any]) -> None: """Update configuration with new values.""" try: - # This would need more sophisticated merging logic - # For now, just log the attempt - logger.info("Configuration update requested: %s", updates) - except Exception as e: - logger.error("Configuration update failed: %s", e) + logger.info("Configuration update requested: %s", updates) + parsers = { + "model": self._parse_model_config, + "emotion_detection": self._parse_emotion_detection_config, + "architecture": self._parse_architecture_config, + "training": self._parse_training_config, + "data": self._parse_data_config, + "evaluation": self._parse_evaluation_config, + "logging": self._parse_logging_config, + "model_saving": self._parse_model_saving_config, + "performance": self._parse_performance_config, + "samo_optimizations": self._parse_samo_optimizations, + "error_handling": self._parse_error_handling, + "security": self._parse_security_config, + "development": self._parse_development_config, + } + for section, vals in (updates or {}).items(): + if hasattr(self.config, section) and isinstance(vals, dict) and section in parsers: + current = getattr(self.config, section) + merged = {**asdict(current), **vals} + setattr(self.config, section, parsers[section](merged)) + except (TypeError, ValueError, AttributeError) as e: + logger.exception("Configuration update failed: %s", e)
324-340: Pass explicit defaults to validators (ModelConfig).Ensure invalid user input falls back to true defaults instead of placeholders.
return ModelConfig( - name=self._validate_string( - data.get("name", "bert-base-uncased"), "model.name" - ), - device=self._validate_device(data.get("device")), - use_mixed_precision=self._validate_bool( - data.get("use_mixed_precision", True), "model.use_mixed_precision" - ), - cache_embeddings=self._validate_bool( - data.get("cache_embeddings", False), "model.cache_embeddings" - ), - max_sequence_length=self._validate_positive_int( - data.get("max_sequence_length", 512), "model.max_sequence_length" - ), + name=self._validate_string(data.get("name", "bert-base-uncased"), "model.name", "bert-base-uncased"), + device=self._validate_device(data.get("device"), None), + use_mixed_precision=self._validate_bool(data.get("use_mixed_precision", True), "model.use_mixed_precision", True), + cache_embeddings=self._validate_bool(data.get("cache_embeddings", False), "model.cache_embeddings", False), + max_sequence_length=self._validate_positive_int(data.get("max_sequence_length", 512), "model.max_sequence_length", 512), )
342-358: Clamp top_k to num_emotions and pass explicit defaults.Prevents invalid combinations (e.g., top_k > num_emotions) and aligns with config semantics.
- return EmotionDetectionConfig( - num_emotions=self._validate_positive_int( - data.get("num_emotions", 28), "emotion_detection.num_emotions" - ), - prediction_threshold=self._validate_float_range( - data.get("prediction_threshold", 0.6), 0.0, 1.0, - "emotion_detection.prediction_threshold" - ), - temperature=self._validate_positive_float( - data.get("temperature", 1.0), "emotion_detection.temperature" - ), - top_k=self._validate_positive_int( - data.get("top_k", 5), "emotion_detection.top_k" - ), - ) + num_emotions = self._validate_positive_int(data.get("num_emotions", 28), "emotion_detection.num_emotions", 28) + prediction_threshold = self._validate_float_range( + data.get("prediction_threshold", 0.6), 0.0, 1.0, "emotion_detection.prediction_threshold", 0.6 + ) + temperature = self._validate_positive_float(data.get("temperature", 1.0), "emotion_detection.temperature", 1.0) + top_k = self._validate_positive_int(data.get("top_k", 5), "emotion_detection.top_k", 5) + if top_k > num_emotions: + logger.warning("emotion_detection.top_k (%s) > num_emotions (%s); clamping to num_emotions", top_k, num_emotions) + top_k = num_emotions + return EmotionDetectionConfig( + num_emotions=num_emotions, + prediction_threshold=prediction_threshold, + temperature=temperature, + top_k=top_k, + )
426-461: Logging and model_saving: pass explicit defaults to validators.return LoggingConfig( - level=self._validate_log_level( - data.get("level", "INFO"), "logging.level" - ), + level=self._validate_log_level(data.get("level", "INFO"), "logging.level", "INFO"), @@ - log_dir=self._validate_string( - data.get("log_dir", "logs/emotion_detection"), "logging.log_dir" - ), + log_dir=self._validate_string(data.get("log_dir", "logs/emotion_detection"), "logging.log_dir", "logs/emotion_detection"), ) @@ return ModelSavingConfig( - save_dir=self._validate_string( - data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir" - ), - save_best_metric=self._validate_string( - data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric" - ), + save_dir=self._validate_string(data.get("save_dir", "models/emotion_detection"), "model_saving.save_dir", "models/emotion_detection"), + save_best_metric=self._validate_string(data.get("save_best_metric", "f1_macro"), "model_saving.save_best_metric", "f1_macro"), @@ - save_checkpoints=self._validate_bool( - data.get("save_checkpoints", True), "model_saving.save_checkpoints" - ), - checkpoint_interval=self._validate_positive_int( - data.get("checkpoint_interval", 1), "model_saving.checkpoint_interval" - ), + save_checkpoints=self._validate_bool(data.get("save_checkpoints", True), "model_saving.save_checkpoints", True), + checkpoint_interval=self._validate_positive_int(data.get("checkpoint_interval", 1), "model_saving.checkpoint_interval", 1), )
🧹 Nitpick comments (6)
src/models/emotion_detection/enhanced_config.py (6)
408-424: Use defaults for metrics/top_k_values and threshold.Ensures invalid inputs fall back to configured defaults.
- metrics=self._validate_list( - data.get("metrics", ["precision", "recall", "f1_micro", "f1_macro", "accuracy"]), - "evaluation.metrics" - ), - threshold=self._validate_float_range( - data.get("threshold", 0.2), 0.0, 1.0, "evaluation.threshold" - ), + metrics=self._validate_list( + data.get("metrics", ["precision", "recall", "f1_micro", "f1_macro", "accuracy"]), + "evaluation.metrics", + ["precision", "recall", "f1_micro", "f1_macro", "accuracy"], + ), + threshold=self._validate_float_range( + data.get("threshold", 0.2), 0.0, 1.0, "evaluation.threshold", 0.2 + ), @@ - top_k_values=self._validate_list( - data.get("top_k_values", [1, 3, 5]), "evaluation.top_k_values" - ), + top_k_values=self._validate_list(data.get("top_k_values", [1, 3, 5]), "evaluation.top_k_values", [1, 3, 5]),
463-563: Propagate explicit defaults across remaining parse_ sections.*Keeps behavior consistent with validator changes; representative changes below.
- use_amp=self._validate_bool( - data.get("use_amp", True), "performance.use_amp" - ), + use_amp=self._validate_bool(data.get("use_amp", True), "performance.use_amp", True), @@ - num_workers=self._validate_non_negative_int( - data.get("num_workers", 4), "performance.num_workers" - ), + num_workers=self._validate_non_negative_int(data.get("num_workers", 4), "performance.num_workers", 4), @@ - pin_memory=self._validate_bool( - data.get("pin_memory", True), "performance.pin_memory" - ), + pin_memory=self._validate_bool(data.get("pin_memory", True), "performance.pin_memory", True), @@ - gradient_checkpointing=self._validate_bool( - data.get("gradient_checkpointing", False), "performance.gradient_checkpointing" - ), + gradient_checkpointing=self._validate_bool(data.get("gradient_checkpointing", False), "performance.gradient_checkpointing", False), @@ - use_torchscript=self._validate_bool( - data.get("use_torchscript", False), "performance.use_torchscript" - ), + use_torchscript=self._validate_bool(data.get("use_torchscript", False), "performance.use_torchscript", False),If helpful, I can push a mechanical patch that updates all remaining call sites in one go.
230-237: Avoid blind except; log stack traces with logger.exception.Improve debuggability and satisfy linters without changing behavior.
- except yaml.YAMLError as e: - logger.error("YAML parsing error: %s", e) + except yaml.YAMLError as e: + logger.exception("YAML parsing error: %s", e) @@ - except Exception as e: - logger.error("Configuration loading failed: %s", e) + except (OSError, TypeError, ValueError) as e: + logger.exception("Configuration loading failed: %s", e) @@ - except Exception as e: - logger.error("Configuration parsing failed: %s", e) + except (TypeError, ValueError, KeyError) as e: + logger.exception("Configuration parsing failed: %s", e)Also applies to: 319-323
18-33: Enums are unused; remove or integrate.Either wire LogLevel/DeviceType into validators and dataclasses or drop them to avoid dead code.
441-443: Minor: path-like fields should be normalized to str and directories ensured on use.You already ensure dirs in save_config; consider normalizing log_dir/error_log_file to str at parse time for consistency.
Also applies to: 521-524
260-318: Sanity checks across sections (top_k vs num_emotions already handled).Optionally verify evaluation.top_k_values ≤ emotion_detection.num_emotions and >0.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/models/emotion_detection/enhanced_config.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/models/emotion_detection/enhanced_config.py (2)
src/models/emotion_detection/config.py (3)
EmotionDetectionConfig(50-97)get_config(121-123)update_config(126-129)deployment/cloud-run/secure_api_server.py (1)
sanitize_input(158-172)
🪛 Ruff (0.12.2)
src/models/emotion_detection/enhanced_config.py
231-231: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
234-234: Do not catch blind exception: Exception
(BLE001)
235-235: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
319-319: Do not catch blind exception: Exception
(BLE001)
320-320: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
596-596: Consider moving this statement to an else block
(TRY300)
613-613: Consider moving this statement to an else block
(TRY300)
630-630: Consider moving this statement to an else block
(TRY300)
649-649: Consider moving this statement to an else block
(TRY300)
705-705: Do not catch blind exception: Exception
(BLE001)
706-706: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
719-719: Do not catch blind exception: Exception
(BLE001)
720-720: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
| def _parse_data_config(self, data: Dict[str, Any]) -> DataConfig: | ||
| """Parse data configuration with validation.""" | ||
| return DataConfig( | ||
| max_length=self._validate_positive_int( | ||
| data.get("max_length", 512), "data.max_length" | ||
| ), | ||
| truncation=self._validate_bool( | ||
| data.get("truncation", True), "data.truncation" | ||
| ), | ||
| padding=self._validate_string( | ||
| data.get("padding", "max_length"), "data.padding" | ||
| ), | ||
| enable_augmentation=self._validate_bool( | ||
| data.get("enable_augmentation", False), "data.enable_augmentation" | ||
| ), | ||
| validation_split=self._validate_float_range( | ||
| data.get("validation_split", 0.2), 0.0, 1.0, "data.validation_split" | ||
| ), | ||
| test_split=self._validate_float_range( | ||
| data.get("test_split", 0.1), 0.0, 1.0, "data.test_split" | ||
| ), | ||
| ) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Validate data splits sum and pass explicit defaults.
Avoids val+test ≥ 1.0, which would leave no training data.
- return DataConfig(
- max_length=self._validate_positive_int(
- data.get("max_length", 512), "data.max_length"
- ),
- truncation=self._validate_bool(
- data.get("truncation", True), "data.truncation"
- ),
- padding=self._validate_string(
- data.get("padding", "max_length"), "data.padding"
- ),
- enable_augmentation=self._validate_bool(
- data.get("enable_augmentation", False), "data.enable_augmentation"
- ),
- validation_split=self._validate_float_range(
- data.get("validation_split", 0.2), 0.0, 1.0, "data.validation_split"
- ),
- test_split=self._validate_float_range(
- data.get("test_split", 0.1), 0.0, 1.0, "data.test_split"
- ),
- )
+ val_split = self._validate_float_range(data.get("validation_split", 0.2), 0.0, 1.0, "data.validation_split", 0.2)
+ test_split = self._validate_float_range(data.get("test_split", 0.1), 0.0, 1.0, "data.test_split", 0.1)
+ if val_split + test_split >= 1.0:
+ new_test = max(0.0, 1.0 - val_split)
+ logger.warning("data.validation_split + data.test_split >= 1.0; adjusting test_split to %.3f", new_test)
+ test_split = new_test
+ return DataConfig(
+ max_length=self._validate_positive_int(data.get("max_length", 512), "data.max_length", 512),
+ truncation=self._validate_bool(data.get("truncation", True), "data.truncation", True),
+ padding=self._validate_string(data.get("padding", "max_length"), "data.padding", "max_length"),
+ enable_augmentation=self._validate_bool(data.get("enable_augmentation", False), "data.enable_augmentation", False),
+ validation_split=val_split,
+ test_split=test_split,
+ )Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/models/emotion_detection/enhanced_config.py around lines 385 to 406, the
data split parsing must ensure validation_split and test_split are explicitly
defaulted and their sum is less than 1.0; change the method to read each split
into local variables using the intended defaults (0.2 and 0.1), validate each
via the existing _validate_float_range, then check if validation_split +
test_split >= 1.0 and raise a clear ValueError (or config validation exception)
if so; return the DataConfig with those validated values.
Summary by Sourcery
Add a new SAMO-enhanced BERT-based multi-label emotion detection pipeline including model implementation, label definitions, configuration, and standalone tests
New Features:
Tests:
Summary by CodeRabbit
New Features
Tests