From bc6d907c1c7078041c102b60439635ded15bef44 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Thu, 5 Mar 2026 11:20:22 +0100 Subject: [PATCH 1/2] feat: add ModelPredictor class for model predictions --- ciao/model/__init__.py | 6 ++++++ ciao/model/predictor.py | 44 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 ciao/model/__init__.py create mode 100644 ciao/model/predictor.py diff --git a/ciao/model/__init__.py b/ciao/model/__init__.py new file mode 100644 index 0000000..dd425b4 --- /dev/null +++ b/ciao/model/__init__.py @@ -0,0 +1,6 @@ +"""Model prediction utilities for CIAO.""" + +from ciao.model.predictor import ModelPredictor + + +__all__ = ["ModelPredictor"] diff --git a/ciao/model/predictor.py b/ciao/model/predictor.py new file mode 100644 index 0000000..bb75c3c --- /dev/null +++ b/ciao/model/predictor.py @@ -0,0 +1,44 @@ +import torch + + +class ModelPredictor: + """Handles model predictions and class information.""" + + def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: + self.model = model + self.class_names = class_names + self.device = next(model.parameters()).device + + def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: + """Get model predictions (returns probabilities).""" + with torch.no_grad(): + outputs = self.model(input_batch) + probabilities = torch.nn.functional.softmax(outputs, dim=1) + return probabilities + + def predict_image( + self, input_batch: torch.Tensor, top_k: int = 5 + ) -> list[tuple[int, str, float]]: + """Get top-k predictions for an image.""" + probabilities = self.get_predictions(input_batch) + top_probs, top_indices = torch.topk(probabilities[0], top_k) + + results = [] + for i in range(top_k): + class_idx = int(top_indices[i].item()) + prob = float(top_probs[i].item()) + class_name = ( + self.class_names[class_idx] + if class_idx < len(self.class_names) + else f"class_{class_idx}" + ) + results.append((class_idx, class_name, prob)) + return results + + def get_class_logit_batch( + self, input_batch: torch.Tensor, target_class_idx: int + ) -> torch.Tensor: + """Get logits for a batch of images - optimized for batched inference (directly from model outputs).""" + with torch.no_grad(): + outputs = self.model(input_batch) + return outputs[:, target_class_idx] From ab2a08bbf99e3b17e3ffff77d7ac058daa205752 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Tue, 10 Mar 2026 14:07:17 +0100 Subject: [PATCH 2/2] chore: apply agents' suggestions --- ciao/model/predictor.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/ciao/model/predictor.py b/ciao/model/predictor.py index bb75c3c..5efd700 100644 --- a/ciao/model/predictor.py +++ b/ciao/model/predictor.py @@ -2,43 +2,36 @@ class ModelPredictor: - """Handles model predictions and class information.""" + """Handles model predictions and class information for the CIAO explainer.""" def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: self.model = model self.class_names = class_names - self.device = next(model.parameters()).device + + # Ensure deterministic inference by disabling Dropout and freezing BatchNorm + self.model.eval() + + # Robustly determine the device (fall back to CPU if model has no parameters) + try: + self.device = next(model.parameters()).device + except StopIteration: + self.device = torch.device("cpu") def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: """Get model predictions (returns probabilities).""" + input_batch = input_batch.to(self.device) + with torch.no_grad(): outputs = self.model(input_batch) probabilities = torch.nn.functional.softmax(outputs, dim=1) return probabilities - def predict_image( - self, input_batch: torch.Tensor, top_k: int = 5 - ) -> list[tuple[int, str, float]]: - """Get top-k predictions for an image.""" - probabilities = self.get_predictions(input_batch) - top_probs, top_indices = torch.topk(probabilities[0], top_k) - - results = [] - for i in range(top_k): - class_idx = int(top_indices[i].item()) - prob = float(top_probs[i].item()) - class_name = ( - self.class_names[class_idx] - if class_idx < len(self.class_names) - else f"class_{class_idx}" - ) - results.append((class_idx, class_name, prob)) - return results - def get_class_logit_batch( self, input_batch: torch.Tensor, target_class_idx: int ) -> torch.Tensor: - """Get logits for a batch of images - optimized for batched inference (directly from model outputs).""" + """Get raw logits for a specific target class across a batch of images.""" + input_batch = input_batch.to(self.device) + with torch.no_grad(): outputs = self.model(input_batch) return outputs[:, target_class_idx]