Skip to content

SAFA-distill stage 2: landmark TagBundleExtractor + v17/v22 configs#629

Merged
efahnestock merged 8 commits into
mainfrom
safa-extractors
May 22, 2026
Merged

SAFA-distill stage 2: landmark TagBundleExtractor + v17/v22 configs#629
efahnestock merged 8 commits into
mainfrom
safa-extractors

Conversation

@efahnestock
Copy link
Copy Markdown
Collaborator

Summary

  • Adds the SAFA-distillation stage-2 training pipeline: frozen SAFA token + landmark TagBundleExtractor (OSM sat-side, panov2 pano-side) fused via a transformer aggregator, trained with cosine distillation against the frozen SAFA baseline plus an InfoNCE cross-view contrastive term.
  • Adds new model code: SafaExtractor, RandomTokenExtractor, OSMTagBundleExtractor, PanoramaTagBundleExtractor (each with its own TagBundleEncoder mirroring the correspondence classifier's pattern), Planar/Spherical position embeddings on landmark tokens, and the trainer (train_safa_distill.py).
  • Patches export_similarity_matrix.py to auto-detect images/landmarks requirements from the model and thread tensor-cache info correctly; adds --disable_safa_cache, --panorama_landmark_radius_px, --landmark_correspondence_inflation_factor.
  • Two stage-2 configs are kept: safa_v17_temp01_cv05_config.yaml (no positions, Seattle MRR 0.247) and safa_v22_positional_config.yaml (positions added — Seattle MRR 0.250, used to produce early_fusion_attempt_v1.pt similarity matrices on 9 cities).

Test plan

  • bazel build //experimental/overhead_matching/swag/scripts:train_safa_distill //experimental/overhead_matching/swag/scripts:export_similarity_matrix //experimental/overhead_matching/swag/model:swag_patch_embedding //experimental/overhead_matching/swag/model:swag_patch_embedding_test //experimental/overhead_matching/swag/model:tag_bundle_extractor_test (all clean locally)
  • bazel test //experimental/overhead_matching/swag/model:swag_patch_embedding_test //experimental/overhead_matching/swag/model:tag_bundle_extractor_test
  • End-to-end smoke-train v22 on Chicago for a few epochs and confirm Seattle MRR climbs above SAFA baseline
  • Run export_similarity_matrix against a v22 checkpoint on Seattle (with cache) and on a non-cached city like Boston (with --disable_safa_cache)
  • Calibrate sigma on the v22 Seattle matrix and run histogram-filter path eval, compare against 260428_v6_safa_only_noise014

Adds the SwagPatchEmbedding extractor framework (frozen SAFA token +
landmark TagBundleExtractor + transformer aggregator) and the
train_safa_distill.py trainer used to fit it via cosine distillation
against the frozen SAFA baseline plus an InfoNCE cross-view contrastive
term.

Includes:
- Model side: SafaExtractor, RandomTokenExtractor, OSM/Pano
  TagBundleExtractor with TagBundleEncoder, Planar/Spherical position
  embeddings, MlpAggregator and MlpConcatAggregator variants, and the
  SwagPatchEmbedding glue.
- Trainer: cosine + InfoNCE / pairwise-contrastive cross-view loss
  dispatch, identity-init for SAFA projection so residual + pure
  passthrough variants work, HNM controls, attention-logging gated on
  aggregator type.
- Eval: export_similarity_matrix.py picks up landmark/image requirements
  from the model and threads tensor-cache info; added
  --disable_safa_cache, --panorama_landmark_radius_px,
  --landmark_correspondence_inflation_factor.
- All v8–v36 stage-2 configs explored (positions, per_tag_dim,
  proper-noun filtering, MLP residual variants, pairwise-contrastive
  cv-loss sweep). v22 (positions + per_tag_dim=1024) was the best at
  Seattle MRR=0.2503 and is what got path-eval'd as
  early_fusion_attempt_v1 on 9 cities.

EXPERIMENT_SUMMARY.md walks through all runs, takeaways, and links to
result paths under /data/overhead_matching/evaluation/results/.
Cleans up the v8–v36 exploration. Retains:
- safa_v17_temp01_cv05_config.yaml (no-positions baseline)
- safa_v22_positional_config.yaml (early_fusion_attempt_v1 — best result)
- All model + trainer code those two configs actually use.

Drops:
- 41 other yaml configs (v8/v9-residual, v10–v16, v18–v21, v23–v36, plus
  pre-v8 safa_distill_*/safa_hybrid_*/safa_residual_*/safa_landmarks_*).
- Auxiliary scripts not referenced by the kept configs:
  dump_tb_scalars.py, eval_checkpoint_mrr.py, safa_baseline_mrr.py.
- Stray root config.yaml (orphan from a stage-1 run).
- MlpAggregatorConfig / MlpConcatAggregatorConfig types and the
  MlpAggregator / MlpConcatAggregator classes in swag_patch_embedding.
- Pairwise-contrastive cross-view loss dispatch and helper in
  train_safa_distill.py (cross_view_loss_kind +
  pairwise_{positive,semipositive,negative}_weight/avg_similarity
  fields, _cross_view_pairwise_contrastive); cv-loss path is now
  unconditionally InfoNCE.

The kept code path is exactly what produced the v22 checkpoint that
became early_fusion_attempt_v1.
Comment thread experimental/overhead_matching/swag/scripts/safa_v22_positional_config.yaml Outdated
Comment thread experimental/overhead_matching/swag/scripts/safa_v17_temp01_cv05_config.yaml Outdated
Comment thread experimental/overhead_matching/swag/model/BUILD Outdated
Comment thread experimental/overhead_matching/swag/model/tag_bundle_extractor.py Outdated
Comment thread experimental/overhead_matching/swag/model/tag_bundle_extractor.py Outdated
Comment thread experimental/overhead_matching/swag/model/swag_patch_embedding_test.py Outdated
Comment thread experimental/overhead_matching/swag/model/safa_extractor.py Outdated
Comment thread experimental/overhead_matching/swag/model/swag_patch_embedding.py Outdated
Comment thread experimental/overhead_matching/swag/model/swag_patch_embedding.py Outdated
Comment thread experimental/overhead_matching/swag/scripts/train.py Outdated
- Delete RandomTokenExtractor (config, module, BUILD target, attention test)
  — only consumer was train_safa_distill, which is also deleted.
- Delete train_safa_distill.py; fold Stage 2 into train.py via losses.py:
  - DistillationLossConfig + compute_distillation_loss (cosine teacher
    distill against a frozen extractor's output)
  - CrossViewInfoNCELossConfig + compute_cross_view_info_nce_loss
    (multi-positive symmetric InfoNCE with semipositives masked from the
    denominator, ported verbatim from _cross_view_info_nce)
  - LossInputs gains optional sat/pano_extractor_outputs, threaded through
    compute_forward_pass_and_loss
  - SwagPatchEmbedding.identity_init_extractor_projection(name) and
    TrainConfig.identity_init_extractors for the SAFA-passthrough init
- Drop return_attention_weights plumbing (kwarg, diagnostics path, layer
  helper) — no production caller after train_safa_distill removal.
- Drop residual-alpha machinery (config fields, buffer/parameter, setter,
  forward composition) — never used in any kept run.
- Drop _warm_start_from_stage1 + init_model_from from train.py — Stage 1
  was never trained.
- Drop dead BUILD targets: eval_checkpoint_mrr, dump_tb_scalars,
  safa_baseline_mrr (source files absent).
- Revert TagBundleEncoder LayerNorms and text_proj_dim=None branch in
  landmark_correspondence_model. v17 trains successfully without them.
  Restores state-dict compatibility with the simple_v1_v5 classifier
  checkpoint as a bonus.
- OSMTagBundleExtractor now processes all four geom types
  (point/linestring/polygon/multipolygon) with learned per-geom-type
  markers via nn.Embedding(4, repr_dim), matching the all-geom-type
  treatment that correspondence_matching already does. Drops landmark_type
  config field.
- PanoramaTagBundleExtractor.load_files: surface cross-city duplicate
  pano_ids via a count + warning (1,073 dupes exist in the data;
  last-seen-wins preserves prior behavior).
- SafaExtractor: use common.torch.load_and_save_models.load_model. Drops
  the hand-rolled config+weights fallback and the pyyaml/msgspec deps.
- tag_bundle_extractor doc nits: drop "v1" hedge, drop "768d" claim,
  comment the inverted True=real-tag convention in
  _stack_landmark_tag_tensors (opposite of ExtractorOutput.mask).
- Delete safa_v17_temp01_cv05_config.yaml and safa_v22_positional_config.yaml.
  v17 retrained on the new pipeline: best Seattle MRR 0.24777 at epoch 42
  vs original 0.2473 (well inside the seed-variance band).
- Replace the two InfoNCE configs/functions with a single InfoNCELossConfig +
  compute_info_nce_loss:
  - Reads similarity_matrix (same signature as every other loss)
  - Symmetric pano↔sat with multi-positive per-row averaging
  - Textbook temperature-scaled log_softmax InfoNCE math
  - Throws ValueError on non-empty semipositive_pairs (caller must collapse
    semi→pos upstream — train.py's NEAREST merge already does this)
  - Drops max_num_negative_pairs, scale_negative_by_num_items, use_pano_as_anchor,
    negative_scale (replaced by temperature)
  - Uses Pairs pairing data instead of PositiveAnchorSets
- Remove the now-dead ANCHOR_SETS pairing branch from train.py and lr_sweep.py;
  pairing_type plumbing collapses to plain create_pairs() calls
- Remove the resume_state machinery (parameters, restore blocks, --resume*
  CLI flags). Never exercised by any run on this branch; train_safa_distill
  (deleted earlier) had no resume path of its own
- Restore the SwagPatchEmbedding.forward docstring (cleaned for the removed
  return_attention_weights kwarg)
- Update losses_test.py: rewrite the InfoNCE test against the new Pairs
  signature, add an assertion that semipositives trigger the ValueError
Mirrors the historical v17_full_pipeline.sh (preserved at
/data/overhead_matching/training_outputs/260518_090000_early_fusion_attempts/_pipeline_artifacts/scripts/).
Four phases: export sim matrices for 9 cities, calibrate sigma on Seattle,
write per-city aggregator configs, run histogram-filter path eval.

Eval hyperparams match historical v17 exactly (motion_noise_frac=0.141,
subdivision_factor=4, odometry_noise_frac=0.141, odometry_noise_seed=7919,
seed=42) so any delta vs 260513_early_fusion_no_positions_v1 is attributable
to the model change.

Run output: /data/overhead_matching/evaluation/results/260521_v17_rerun_v6_path_eval/
sigma_MLE=0.1949 (matches historical 0.1950 essentially exactly).
The InfoNCE merge collapsed all losses to Pairs, leaving PositiveAnchorSets /
create_anchors / collapse_anchors_to_torch / PairingType / PairingDataType
with no callers. The resume-from-checkpoint path was also removed but the
save side still wrote a training_state.pt that nothing read. Strip both.

Also drop a dangling `case _:` leftover in train.py from the removed
`match pairing_type` block — the file did not parse.
@efahnestock efahnestock merged commit e1ce771 into main May 22, 2026
3 checks passed
@efahnestock efahnestock deleted the safa-extractors branch May 22, 2026 19:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant