From /simplify codebase sweep (2026-06-10). Two related forward-pass wastes in robustness/assessors/base_assessor.py.
1. Clean forward redone once per assessor
Lines 228 (EmpiricalAttackAssessor.assess), 394 (FormalVerificationAssessor.assess), 562 (StatisticalSamplingAssessor.assess): each computes clean_predictions = _argmax_predictions(model, inputs) over the full dataset, although the orchestrator already produced full clean logits in ForwardOutput (orchestrator.py:164) sitting unused in the same PhaseContext -- and robustness/phase.py:71-74 already derives argmax from it for target resolution. A robustness matrix with 5 assessors = 5 extra full clean passes.
Fix: optional clean_predictions: torch.Tensor | None = None plumbed through RobustnessAssessment/assess() from RobustnessPhase (which has ctx.forward_output); keep the current computation as fallback for direct callers.
2. Verdict forwards ignore raitap.batch_size
_argmax_predictions (649-667) forwards the whole dataset in one unbatched call (model(prepared) on all N samples) at call sites 228-229 (clean + adversarial) and 562-564 (clean + corrupted), while attack generation is carefully chunked via _compute_with_optional_batches. The user's batching knob silently does not apply to half of the assessment's forward work -- exactly the OOM scenario batch_size exists to prevent.
Fix: reuse the chunked loop (or ClassificationFamily.extract_forward) for the argmax forwards, honouring the resolved batch size.
From /simplify codebase sweep (2026-06-10). Two related forward-pass wastes in
robustness/assessors/base_assessor.py.1. Clean forward redone once per assessor
Lines 228 (
EmpiricalAttackAssessor.assess), 394 (FormalVerificationAssessor.assess), 562 (StatisticalSamplingAssessor.assess): each computesclean_predictions = _argmax_predictions(model, inputs)over the full dataset, although the orchestrator already produced full clean logits inForwardOutput(orchestrator.py:164) sitting unused in the samePhaseContext-- androbustness/phase.py:71-74already derives argmax from it for target resolution. A robustness matrix with 5 assessors = 5 extra full clean passes.Fix: optional
clean_predictions: torch.Tensor | None = Noneplumbed throughRobustnessAssessment/assess()fromRobustnessPhase(which hasctx.forward_output); keep the current computation as fallback for direct callers.2. Verdict forwards ignore raitap.batch_size
_argmax_predictions(649-667) forwards the whole dataset in one unbatched call (model(prepared)on all N samples) at call sites 228-229 (clean + adversarial) and 562-564 (clean + corrupted), while attack generation is carefully chunked via_compute_with_optional_batches. The user's batching knob silently does not apply to half of the assessment's forward work -- exactly the OOM scenariobatch_sizeexists to prevent.Fix: reuse the chunked loop (or
ClassificationFamily.extract_forward) for the argmax forwards, honouring the resolved batch size.