Skip to content

feat: add ModelPredictor class for model predictions#6

Open
dhalmazna wants to merge 2 commits intomasterfrom
feat/model-predictor
Open

feat: add ModelPredictor class for model predictions#6
dhalmazna wants to merge 2 commits intomasterfrom
feat/model-predictor

Conversation

@dhalmazna
Copy link
Collaborator

@dhalmazna dhalmazna commented Mar 5, 2026

Context:
This PR introduces the model/ module.

What was changed:

  • model/predictor.py: Added the ModelPredictor class. It serves as a clean wrapper around the PyTorch model, providing standardized methods for inference, probability extraction, and top-k predictions.

Related Task:
XAI-29

Summary by CodeRabbit

  • New Features
    • Introduced ModelPredictor, a standardized model inference wrapper with automatic device management and deterministic predictions.

@dhalmazna dhalmazna self-assigned this Mar 5, 2026
Copilot AI review requested due to automatic review settings March 5, 2026 10:23
@coderabbitai
Copy link

coderabbitai bot commented Mar 5, 2026

📝 Walkthrough

Walkthrough

Two new files introduce the CIAO model package. The __init__.py establishes the package and exports ModelPredictor. The predictor.py module defines ModelPredictor, a PyTorch model wrapper with methods for standardized inference: predictions via softmax and raw logits extraction.

Changes

Cohort / File(s) Summary
Model Package Setup
ciao/model/__init__.py
Package initializer that imports and exports ModelPredictor class for public API access.
ModelPredictor Wrapper
ciao/model/predictor.py
New class wrapping PyTorch models for standardized inference. Includes initialization with device management, get_predictions() for softmax probabilities, and get_class_logit_batch() for raw logits extraction.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

🐰 A model wrapped in PyTorch's embrace,
With predictions and logits in their rightful place,
Device-aware and eval-mode true,
The predictor hops through—fresh and new! ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and accurately describes the main change: introducing the ModelPredictor class for model predictions, which matches the core functionality added in the changeset.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/model-predictor
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request establishes a new model/ module within the ciao package, aiming to centralize and standardize model prediction functionalities. The core of this change is the introduction of the ModelPredictor class, which encapsulates PyTorch model inference, providing a consistent interface for obtaining predictions, probabilities, and top-k results. This enhances code organization and simplifies interaction with trained models.

Highlights

  • New model/ module introduced: The pull request introduces a new model/ module within the ciao package, designed to house model-related utilities and prediction logic.
  • ModelPredictor class added: A new ModelPredictor class has been added to ciao/model/predictor.py. This class provides a standardized wrapper for PyTorch models, offering methods for inference, probability extraction, and top-k predictions.
Changelog
  • ciao/model/init.py
    • Created the __init__.py file to define the ciao.model package.
    • Exported the ModelPredictor class, allowing it to be imported directly from ciao.model.
  • ciao/model/predictor.py
    • Created the predictor.py file to house the ModelPredictor class.
    • Implemented the ModelPredictor class with an initializer that takes a PyTorch model and class names.
    • Added the get_predictions method to return class probabilities from model outputs.
    • Included the predict_image method to retrieve top-k predictions (class index, name, and probability) for a given input.
    • Provided the get_class_logit_batch method to extract logits for a specific target class from a batch of inputs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new ciao.model module that provides a ModelPredictor wrapper around a PyTorch model to standardize inference, probability extraction, and top‑k predictions.

Changes:

  • Added ModelPredictor with helper methods for predictions, top‑k class selection, and extracting per-class logits.
  • Added ciao.model package init exporting ModelPredictor.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.

File Description
ciao/model/predictor.py Implements the new ModelPredictor inference wrapper.
ciao/model/__init__.py Exposes ModelPredictor as the public API of ciao.model.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a ModelPredictor class, which is a good addition for wrapping model inference logic. My review focuses on improving its robustness and correctness. I've identified a high-severity bug in the predict_image method where it incorrectly handles batch inputs, and I've provided a fix to process batches correctly. Additionally, I've suggested a medium-severity improvement to ensure the model is always in evaluation mode during prediction, which is a crucial best practice.

@dhalmazna dhalmazna force-pushed the feat/data-pipeline branch from 36f96f0 to de00a02 Compare March 10, 2026 12:39
@dhalmazna dhalmazna force-pushed the feat/model-predictor branch 2 times, most recently from d73d73e to a572f83 Compare March 10, 2026 13:18
@dhalmazna dhalmazna requested a review from Copilot March 10, 2026 13:25
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@dhalmazna dhalmazna force-pushed the feat/model-predictor branch from a572f83 to 0acc1c4 Compare March 11, 2026 09:09
@dhalmazna dhalmazna requested a review from vejtek March 12, 2026 07:31
@vejtek vejtek requested review from Adames4 and vojtech-kur March 13, 2026 10:18
Base automatically changed from feat/data-pipeline to master March 13, 2026 10:28
@vejtek vejtek requested a review from a team March 13, 2026 10:28
@dhalmazna dhalmazna force-pushed the feat/model-predictor branch from 0acc1c4 to ab2a08b Compare March 13, 2026 13:17
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
ciao/model/predictor.py (3)

29-37: No bounds checking on target_class_idx.

If target_class_idx exceeds the number of output classes, this will raise an opaque IndexError. Since class_names is already stored, you could validate against it for a clearer error message.

🛡️ Optional: Add bounds validation
     def get_class_logit_batch(
         self, input_batch: torch.Tensor, target_class_idx: int
     ) -> torch.Tensor:
         """Get raw logits for a specific target class across a batch of images."""
+        if not 0 <= target_class_idx < len(self.class_names):
+            raise IndexError(
+                f"target_class_idx {target_class_idx} out of range for {len(self.class_names)} classes"
+            )
         input_batch = input_batch.to(self.device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ciao/model/predictor.py` around lines 29 - 37, Add bounds validation in
get_class_logit_batch: before moving input_batch or calling self.model, check
that target_class_idx is within 0 <= target_class_idx < len(self.class_names)
(use self.class_names to determine number of classes) and raise a clear
ValueError (or similar) with a message including the invalid index and valid
range; keep the rest of the function (input_batch.to(self.device), with
torch.no_grad(), outputs = self.model(input_batch), return outputs[:,
target_class_idx]) unchanged.

20-27: Consider documenting or validating expected input shape.

The softmax(dim=1) assumes model output has shape [batch, num_classes]. Per the context snippet in ciao/data/preprocessing.py, load_and_preprocess_image returns a tensor with shape [3, 224, 224] (no batch dimension). Users must manually add a batch dimension via unsqueeze(0) before passing to this method, or the softmax will operate on the wrong axis, producing incorrect probabilities.

Consider adding a brief docstring note about expected input shape, or optionally auto-expanding unbatched inputs:

📝 Option 1: Document expected shape in docstring
     def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor:
-        """Get model predictions (returns probabilities)."""
+        """Get model predictions (returns probabilities).
+
+        Args:
+            input_batch: Input tensor of shape [N, C, H, W] (batched).
+
+        Returns:
+            Probability tensor of shape [N, num_classes].
+        """
         input_batch = input_batch.to(self.device)
📝 Option 2: Auto-expand unbatched input
     def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor:
         """Get model predictions (returns probabilities)."""
+        if input_batch.dim() == 3:
+            input_batch = input_batch.unsqueeze(0)
         input_batch = input_batch.to(self.device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ciao/model/predictor.py` around lines 20 - 27, get_predictions assumes a
batched tensor and calls softmax(dim=1) which will be wrong for a single-image
tensor from load_and_preprocess_image (shape [3,224,224]); update
get_predictions to validate and handle unbatched inputs by checking
input_batch.dim() and if dim==3 call input_batch = input_batch.unsqueeze(0)
before .to(self.device), and also update the get_predictions docstring to state
the expected shape ([batch, num_classes] output from the model) or that a
[3,224,224] image will be auto-expanded to a batch.

7-9: class_names is stored but never used.

The class_names parameter is accepted and stored but no method in this class currently uses it. If it's intended for future use, consider documenting that intent. Otherwise, this may be dead code.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ciao/model/predictor.py` around lines 7 - 9, The __init__ stores class_names
on the Predictor but it is never used; either remove the dead parameter and
attribute or actually use it—so either (A) delete the class_names parameter and
the self.class_names assignment from __init__ (and update all callers that pass
class_names), or (B) keep the parameter and modify the prediction path (e.g.,
the predict / forward method in this class) to map predicted class indices to
names using self.class_names (return or include the label strings instead
of/alongside indices); choose one approach and update signatures and callers
accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@ciao/model/predictor.py`:
- Around line 29-37: Add bounds validation in get_class_logit_batch: before
moving input_batch or calling self.model, check that target_class_idx is within
0 <= target_class_idx < len(self.class_names) (use self.class_names to determine
number of classes) and raise a clear ValueError (or similar) with a message
including the invalid index and valid range; keep the rest of the function
(input_batch.to(self.device), with torch.no_grad(), outputs =
self.model(input_batch), return outputs[:, target_class_idx]) unchanged.
- Around line 20-27: get_predictions assumes a batched tensor and calls
softmax(dim=1) which will be wrong for a single-image tensor from
load_and_preprocess_image (shape [3,224,224]); update get_predictions to
validate and handle unbatched inputs by checking input_batch.dim() and if dim==3
call input_batch = input_batch.unsqueeze(0) before .to(self.device), and also
update the get_predictions docstring to state the expected shape ([batch,
num_classes] output from the model) or that a [3,224,224] image will be
auto-expanded to a batch.
- Around line 7-9: The __init__ stores class_names on the Predictor but it is
never used; either remove the dead parameter and attribute or actually use it—so
either (A) delete the class_names parameter and the self.class_names assignment
from __init__ (and update all callers that pass class_names), or (B) keep the
parameter and modify the prediction path (e.g., the predict / forward method in
this class) to map predicted class indices to names using self.class_names
(return or include the label strings instead of/alongside indices); choose one
approach and update signatures and callers accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8b2b9d09-255a-4aae-99f5-a2efacd5d6f9

📥 Commits

Reviewing files that changed from the base of the PR and between b85a5a3 and ab2a08b.

📒 Files selected for processing (2)
  • ciao/model/__init__.py
  • ciao/model/predictor.py

Copy link

@vojtech-kur vojtech-kur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine afaik

Copy link
Member

@Adames4 Adames4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave unresolved only relevant comments from agents!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants