Skip to content

Robustness assessors repeat the clean forward pass and run verdict forwards unbatched #326

Description

@stanlrt

From /simplify codebase sweep (2026-06-10). Two related forward-pass wastes in robustness/assessors/base_assessor.py.

1. Clean forward redone once per assessor

Lines 228 (EmpiricalAttackAssessor.assess), 394 (FormalVerificationAssessor.assess), 562 (StatisticalSamplingAssessor.assess): each computes clean_predictions = _argmax_predictions(model, inputs) over the full dataset, although the orchestrator already produced full clean logits in ForwardOutput (orchestrator.py:164) sitting unused in the same PhaseContext -- and robustness/phase.py:71-74 already derives argmax from it for target resolution. A robustness matrix with 5 assessors = 5 extra full clean passes.

Fix: optional clean_predictions: torch.Tensor | None = None plumbed through RobustnessAssessment/assess() from RobustnessPhase (which has ctx.forward_output); keep the current computation as fallback for direct callers.

2. Verdict forwards ignore raitap.batch_size

_argmax_predictions (649-667) forwards the whole dataset in one unbatched call (model(prepared) on all N samples) at call sites 228-229 (clean + adversarial) and 562-564 (clean + corrupted), while attack generation is carefully chunked via _compute_with_optional_batches. The user's batching knob silently does not apply to half of the assessment's forward work -- exactly the OOM scenario batch_size exists to prevent.

Fix: reuse the chunked loop (or ClassificationFamily.extract_forward) for the argmax forwards, honouring the resolved batch size.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions