feat: add ModelPredictor class for model predictions#6
feat: add ModelPredictor class for model predictions#6
Conversation
📝 WalkthroughWalkthroughTwo new files introduce the CIAO model package. The Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes a new Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Pull request overview
This PR introduces a new ciao.model module that provides a ModelPredictor wrapper around a PyTorch model to standardize inference, probability extraction, and top‑k predictions.
Changes:
- Added
ModelPredictorwith helper methods for predictions, top‑k class selection, and extracting per-class logits. - Added
ciao.modelpackage init exportingModelPredictor.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
ciao/model/predictor.py |
Implements the new ModelPredictor inference wrapper. |
ciao/model/__init__.py |
Exposes ModelPredictor as the public API of ciao.model. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
There was a problem hiding this comment.
Code Review
This pull request introduces a ModelPredictor class, which is a good addition for wrapping model inference logic. My review focuses on improving its robustness and correctness. I've identified a high-severity bug in the predict_image method where it incorrectly handles batch inputs, and I've provided a fix to process batches correctly. Additionally, I've suggested a medium-severity improvement to ensure the model is always in evaluation mode during prediction, which is a crucial best practice.
36f96f0 to
de00a02
Compare
d73d73e to
a572f83
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
a572f83 to
0acc1c4
Compare
0acc1c4 to
ab2a08b
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (3)
ciao/model/predictor.py (3)
29-37: No bounds checking ontarget_class_idx.If
target_class_idxexceeds the number of output classes, this will raise an opaqueIndexError. Sinceclass_namesis already stored, you could validate against it for a clearer error message.🛡️ Optional: Add bounds validation
def get_class_logit_batch( self, input_batch: torch.Tensor, target_class_idx: int ) -> torch.Tensor: """Get raw logits for a specific target class across a batch of images.""" + if not 0 <= target_class_idx < len(self.class_names): + raise IndexError( + f"target_class_idx {target_class_idx} out of range for {len(self.class_names)} classes" + ) input_batch = input_batch.to(self.device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/model/predictor.py` around lines 29 - 37, Add bounds validation in get_class_logit_batch: before moving input_batch or calling self.model, check that target_class_idx is within 0 <= target_class_idx < len(self.class_names) (use self.class_names to determine number of classes) and raise a clear ValueError (or similar) with a message including the invalid index and valid range; keep the rest of the function (input_batch.to(self.device), with torch.no_grad(), outputs = self.model(input_batch), return outputs[:, target_class_idx]) unchanged.
20-27: Consider documenting or validating expected input shape.The
softmax(dim=1)assumes model output has shape[batch, num_classes]. Per the context snippet inciao/data/preprocessing.py,load_and_preprocess_imagereturns a tensor with shape[3, 224, 224](no batch dimension). Users must manually add a batch dimension viaunsqueeze(0)before passing to this method, or the softmax will operate on the wrong axis, producing incorrect probabilities.Consider adding a brief docstring note about expected input shape, or optionally auto-expanding unbatched inputs:
📝 Option 1: Document expected shape in docstring
def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: - """Get model predictions (returns probabilities).""" + """Get model predictions (returns probabilities). + + Args: + input_batch: Input tensor of shape [N, C, H, W] (batched). + + Returns: + Probability tensor of shape [N, num_classes]. + """ input_batch = input_batch.to(self.device)📝 Option 2: Auto-expand unbatched input
def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: """Get model predictions (returns probabilities).""" + if input_batch.dim() == 3: + input_batch = input_batch.unsqueeze(0) input_batch = input_batch.to(self.device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/model/predictor.py` around lines 20 - 27, get_predictions assumes a batched tensor and calls softmax(dim=1) which will be wrong for a single-image tensor from load_and_preprocess_image (shape [3,224,224]); update get_predictions to validate and handle unbatched inputs by checking input_batch.dim() and if dim==3 call input_batch = input_batch.unsqueeze(0) before .to(self.device), and also update the get_predictions docstring to state the expected shape ([batch, num_classes] output from the model) or that a [3,224,224] image will be auto-expanded to a batch.
7-9:class_namesis stored but never used.The
class_namesparameter is accepted and stored but no method in this class currently uses it. If it's intended for future use, consider documenting that intent. Otherwise, this may be dead code.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/model/predictor.py` around lines 7 - 9, The __init__ stores class_names on the Predictor but it is never used; either remove the dead parameter and attribute or actually use it—so either (A) delete the class_names parameter and the self.class_names assignment from __init__ (and update all callers that pass class_names), or (B) keep the parameter and modify the prediction path (e.g., the predict / forward method in this class) to map predicted class indices to names using self.class_names (return or include the label strings instead of/alongside indices); choose one approach and update signatures and callers accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@ciao/model/predictor.py`:
- Around line 29-37: Add bounds validation in get_class_logit_batch: before
moving input_batch or calling self.model, check that target_class_idx is within
0 <= target_class_idx < len(self.class_names) (use self.class_names to determine
number of classes) and raise a clear ValueError (or similar) with a message
including the invalid index and valid range; keep the rest of the function
(input_batch.to(self.device), with torch.no_grad(), outputs =
self.model(input_batch), return outputs[:, target_class_idx]) unchanged.
- Around line 20-27: get_predictions assumes a batched tensor and calls
softmax(dim=1) which will be wrong for a single-image tensor from
load_and_preprocess_image (shape [3,224,224]); update get_predictions to
validate and handle unbatched inputs by checking input_batch.dim() and if dim==3
call input_batch = input_batch.unsqueeze(0) before .to(self.device), and also
update the get_predictions docstring to state the expected shape ([batch,
num_classes] output from the model) or that a [3,224,224] image will be
auto-expanded to a batch.
- Around line 7-9: The __init__ stores class_names on the Predictor but it is
never used; either remove the dead parameter and attribute or actually use it—so
either (A) delete the class_names parameter and the self.class_names assignment
from __init__ (and update all callers that pass class_names), or (B) keep the
parameter and modify the prediction path (e.g., the predict / forward method in
this class) to map predicted class indices to names using self.class_names
(return or include the label strings instead of/alongside indices); choose one
approach and update signatures and callers accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8b2b9d09-255a-4aae-99f5-a2efacd5d6f9
📒 Files selected for processing (2)
ciao/model/__init__.pyciao/model/predictor.py
Adames4
left a comment
There was a problem hiding this comment.
Please leave unresolved only relevant comments from agents!
Context:
This PR introduces the
model/module.What was changed:
model/predictor.py: Added theModelPredictorclass. It serves as a clean wrapper around the PyTorch model, providing standardized methods for inference, probability extraction, and top-k predictions.Related Task:
XAI-29
Summary by CodeRabbit