Skip to content

Add Quantus to grade transparency explanation quality (faithfulness/robustness/complexity/localisation/randomisation/axiomatic) #341

Description

@stanlrt

Summary

Add the Quantus XAI-evaluation library to grade the quality of transparency outputs (faithfulness, robustness, complexity, localisation, randomisation, axiomatic metrics over attribution maps).

Lives inside the transparency module as a third sibling concern next to explainers/ (produce) and visualisers/ (render): transparency/evaluation/ (grade). Coupling kept to the codebase minimum via the existing seams — decorator registration, a requires-capability gate, the invoker seam (#266), and lazy optional-extra imports.

Placement decision: inside transparency (not a new top-level module)

Option Verdict
Sub-package transparency/evaluation/ ✅ chosen
New top-level evaluation/ module ✗ name overclaims — implies it grades robustness/fairness too; it only grades XAI
Fold into metrics/ ✗ wrong contract — metrics is model-perf, task-family driven (preds vs targets); explanation-quality needs a_batch

Rationale:

Directory structure (mirrors explainers/ and robustness anatomy)

src/raitap/transparency/evaluation/
├── __init__.py                 # lazy facade, re-exported from transparency/__init__.py
├── contracts.py                # QuantusCategory, EvalRequirement enums; QuantusMetricSpec, EvaluationScore
├── bridge.py                   # torch<->numpy; raitap explainer -> Quantus explain_func; model wrap  ← SOLE coupling point
├── semantics.py                # requirement resolution + compat gate (typed skip)
├── results.py                  # EvaluationResult (scores + skipped reasons + per-metric meta)
├── report.py                   # builders folded into TransparencyPhaseResult reporting
├── step.py                     # grade_explanations(): post-step invoked by transparency/phase.py
├── evaluators/
│   ├── registration.py         # @transparency_evaluator() decorator (mirrors transparency_adapter)
│   ├── base_evaluator.py       # BaseEvaluator -> QuantusBaseEvaluator
│   └── quantus_evaluator.py    # all 6 categories in algorithm_registry
└── visualisers/
    ├── registration.py
    ├── base_visualiser.py
    └── score_visualisers.py    # per-category score panels/bars

The minimal-coupling seam: requires capability gate

Each metric declares the inputs it needs. The gate resolves them against what is available for a given ExplanationResult; missing → typed skip with a recorded reason (reuses the compat-gate convention and error-typing rule from #257 — typed only at the catch/gate boundary).

class QuantusCategory(StrEnum):
    FAITHFULNESS = "faithfulness"; ROBUSTNESS = "robustness"; COMPLEXITY = "complexity"
    LOCALISATION = "localisation"; RANDOMISATION = "randomisation"; AXIOMATIC = "axiomatic"

class EvalRequirement(StrEnum):
    ATTRIBUTIONS = "attributions"   # a_batch — always
    MODEL        = "model"          # forward access for perturbation metrics
    RE_EXPLAIN   = "re_explain"     # bridge to originating explainer (robustness/randomisation)
    SEGMENTATION = "segmentation"   # s_batch ground-truth masks (localisation)
    BASELINE     = "baseline"       # reference tensor (some axiomatic)

@dataclass(frozen=True)
class QuantusMetricSpec:
    category: QuantusCategory
    quantus_cls: str                       # "FaithfulnessCorrelation" — resolved lazily from quantus
    requires: frozenset[EvalRequirement]
    higher_is_better: bool | None          # report orientation; None = signed/relative score
    default_kwargs: Mapping[str, Any] = field(default_factory=dict)
    invoker: EvalInvoker | None = None     # #266-style custom construct-and-call for odd metrics


def resolve_metric(spec, ctx):
    missing = spec.requires - ctx.available_requirements()
    if missing:
        return SkippedMetric(spec, EvaluationIncompatible(spec, missing))   # typed only at gate
    return ResolvedMetric(spec, ctx.gather(spec.requires))

requires is the lever — same idea as ExplainerAlgorithmSpec.requires: Capability and the #246 EstimatorProvider provider/require pattern, applied to eval inputs. Coupling stays opt-in: localisation gates off when no mask provider; re-explain gates off when the originating explainer is absent or too stochastic.

The bridge (sole coupling point, isolated in bridge.py)

Quantus is numpy-native; raitap is torch. Both conversions + the explain_func wrapper live here. Surface = public ExplanationResult fields + public explainer.explain().

def to_quantus_arrays(result: ExplanationResult) -> QuantusArrays:
    return QuantusArrays(
        x_batch=result.inputs.detach().cpu().numpy(),
        y_batch=result.target_labels.detach().cpu().numpy(),
        a_batch=result.attributions.detach().cpu().numpy(),
        channel_first=result.semantics.output_space is ExplanationOutputSpace.IMAGE_SPATIAL_MAP,
    )

def explainer_to_explain_func(explainer, model_seam):   # robustness/randomisation only
    def explain_func(model, inputs, targets, **kw):
        t = torch.as_tensor(inputs, device=model_seam.device)
        res = explainer.explain(model_seam.wrap(model), t, target=targets)
        return res.attributions.detach().cpu().numpy()
    return explain_func

All six categories in one registry (decorator site)

QuantusMetricSpec instances are constructed inline inside the decorator, exactly like ExplainerAlgorithmSpec in captum_explainer.py. Spec = QuantusMetricSpec; requirement enums abbreviated.

@transparency_evaluator(registry_name="quantus", library="quantus", extra="quantus",
    algorithm_registry={
      # Faithfulness — a_batch + model re-inference (NO re-explain)
      "faithfulness_correlation": Spec(FAITHFULNESS, "FaithfulnessCorrelation", {ATTRIBUTIONS, MODEL}, hib=True),
      "pixel_flipping":           Spec(FAITHFULNESS, "PixelFlipping",  {ATTRIBUTIONS, MODEL}, hib=False),
      "road":                     Spec(FAITHFULNESS, "ROAD",           {ATTRIBUTIONS, MODEL}, hib=False),
      # Complexity — pure a_batch, cheapest
      "sparseness":               Spec(COMPLEXITY,   "Sparseness",     {ATTRIBUTIONS}, hib=True),
      "complexity":               Spec(COMPLEXITY,   "Complexity",     {ATTRIBUTIONS}, hib=False),
      # Robustness — needs RE_EXPLAIN
      "max_sensitivity":          Spec(ROBUSTNESS,   "MaxSensitivity", {ATTRIBUTIONS, MODEL, RE_EXPLAIN}, hib=False),
      "avg_sensitivity":          Spec(ROBUSTNESS,   "AvgSensitivity", {ATTRIBUTIONS, MODEL, RE_EXPLAIN}, hib=False),
      # Localisation — needs SEGMENTATION masks
      "pointing_game":            Spec(LOCALISATION, "PointingGame",   {ATTRIBUTIONS, SEGMENTATION}, hib=True),
      "relevance_rank_accuracy":  Spec(LOCALISATION, "RelevanceRankAccuracy", {ATTRIBUTIONS, SEGMENTATION}, hib=True),
      # Randomisation — needs RE_EXPLAIN (+ model copy, handled inside Quantus)
      "model_parameter_randomisation": Spec(RANDOMISATION, "ModelParameterRandomisation", {ATTRIBUTIONS, MODEL, RE_EXPLAIN}, hib=None),
      "random_logit":             Spec(RANDOMISATION, "RandomLogit",   {ATTRIBUTIONS, MODEL, RE_EXPLAIN}, hib=None),
      # Axiomatic
      "completeness":             Spec(AXIOMATIC,    "Completeness",   {ATTRIBUTIONS, MODEL, BASELINE}, hib=None),
      "input_invariance":         Spec(AXIOMATIC,    "InputInvariance",{ATTRIBUTIONS, MODEL, RE_EXPLAIN}, hib=None),
    })
class QuantusEvaluator(QuantusBaseEvaluator): ...

Coupling per category (why the gate matters)

Category Needs re-explain? Extra data Gate behaviour
Complexity no a_batch only always runs
Faithfulness no model runs when model present
Axiomatic (completeness) no model + baseline runs when baseline recorded
Robustness yes originating explainer gated on RE_EXPLAIN; warns if explainer stochastic
Randomisation yes originating explainer gated on RE_EXPLAIN
Localisation no s_batch masks gated on SEGMENTATION — skipped until a mask provider exists

Pipeline integration + config

TransparencyPhase
  produce explanations  ──►  [ExplanationResult, …]
  grade_explanations(config.evaluation, results, model, data)   # post-step, same phase
    per ExplanationResult:
       ctx = EvaluationContext(result, model_seam, originating_explainer, masks=data.masks?)
       per configured metric:
          resolved | skipped = resolve_metric(spec, ctx)   # typed skip → recorded reason
          score = quantus_metric(**bridge.inputs)          # arrays + optional explain_func
  ──►  TransparencyPhaseResult(explanations=[…], evaluations=[…])  ──►  reporting (score panels)

Config nests under transparency: (no new top-level key):

transparency:
  explainers:
    # … existing explainer config …
  evaluation:
    quantus:
      _target_: raitap.transparency.QuantusEvaluator
      metrics: [faithfulness_correlation, sparseness, max_sensitivity]
      constructor:
        faithfulness_correlation: {nr_runs: 10, subset_size: 224}
      raitap: {explanations: all}        # or list of explainer names to grade
      visualisers:
        - _target_: raitap.transparency.ScoreBarVisualiser

Invocation may use Quantus's batch API quantus.evaluate(metrics, xai_methods, model, x/y/a/s) — it maps cleanly: xai_methods = transparency results, metrics = the configured set.

Dependency wiring

  • pyproject.toml: quantus = ["quantus>=0.5"] optional-extra. The transparency extra does not pull it.
  • deps-mapping entry so config inference adds quantus when an evaluation block is present.
  • transparency/__init__.py facade re-exports QuantusEvaluator + ScoreBarVisualiser (lazy __getattr__; see lazy-package monkeypatch caveat for tests).

Scope (MVP)

  • All six categories registered in the registry.
  • Built + tested end-to-end now: Faithfulness, Complexity, Robustness, Randomisation, Axiomatic (completeness).
  • Localisation registered but gated off until a segmentation-mask provider exists (follow-up).

Follow-ups (out of MVP)

  • Segmentation-mask provider seam feeding s_batch, unlocking Localisation metrics.
  • Stochastic re-explain handling — robustness/randomisation re-invoke the explainer; when the originating explainer is stochastic=True (SHAP Gradient, NoiseTunnel) scores are noisy → decide warn vs hard-gate vs averaging.
  • Generalise the grade seam beyond transparency only if/when another module needs it (Create the 3rd party plugins system via entry-points #173).

Testing

  • Unit: resolve_metric gate (each requirement present/absent → run/skip with typed reason).
  • Unit: bridge torch↔numpy round-trip + explain_func wrapper shape.
  • Integration: tiny CNN + one captum explainer → run faithfulness + complexity, assert scores finite + orientation.
  • Lazy-import: monkeypatch the module object, not the string path (lazy __getattr__ caveat).
  • Run only touched-module tests locally; full regression is CI's job.

Docs to update

  • docs/modules/transparency/index.md — add the "grade explanations" capability.
  • docs/modules/transparency/configuration.mdevaluation: block + metric list.
  • docs/modules/transparency/output.md — evaluation scores in the result/report.
  • docs/contributor/modules/transparency.md — evaluator hierarchy, requires gate, bridge contract.
  • docs/contributor/adding/adding-an-algorithm.md — adding a Quantus metric to the registry.
  • A contributor-configs/ example demonstrating the evaluation: block.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions