Skip to content

Latest commit

 

History

History
605 lines (422 loc) · 14.2 KB

File metadata and controls

605 lines (422 loc) · 14.2 KB

SPEC.md — Artifact-Driven AE / GM-VAE Educational Lab

1. Purpose

This repository implements a local-first, artifact-driven educational lab for exploring:

  • Supervised Autoencoders (AE)
  • Gaussian Mixture VAEs (GM-VAE)
  • Latent projections (PCA / UMAP, 2D & 3D)
  • kNN neighbor structure and related visualization tooling

The system consists of:

  • Backend: FastAPI + PyTorch for training, projections, neighbors, artifact persistence, and SSE streaming
  • Frontend: Next.js (TypeScript) for deterministic rendering of backend-generated artifacts

This document defines the current system contracts. It does not describe historical migrations or speculative roadmap features.


2. Design Principles

2.1 Artifact-First

Everything displayed in the UI must come from:

  • Streamed events (SSE)
  • Persisted artifacts under runs/<run_id>/

The frontend must never recompute:

  • Latent embeddings
  • PCA / UMAP
  • Distance matrices
  • kNN graphs
  • Model sampling

Frontend may compute only:

  • Pixel transforms
  • Camera transforms
  • Pure UI state
  • Visual scaling derived from persisted projection data

2.2 Backend Owns Computation

All model-related computation happens in Python:

  • Training loops
  • Loss computation
  • Latent embeddings
  • PCA / UMAP projections
  • kNN neighbor graphs
  • Sampling
  • Any derived metrics (e.g., uncertainty)

Visualization work must not modify training semantics.

If training behavior changes:

  • It must be explicit.
  • It must be documented.
  • Tests must reflect the change.

2.3 Reproducibility

A run must be:

  • Reloadable without retraining
  • Deterministically visualizable from artifacts alone
  • Stable across reconnects and page refreshes

Artifacts must be sufficient for rendering.


3. Canonical Training Semantics

3.1 Supervised Autoencoder (AE)

Loss:

loss = MSE(reconstruction, input)
     + class_weight * CrossEntropy(logits, label)

Defaults (configurable):

  • class_weight: 1.0
  • Optimizer: Adam
  • Dataset: MNIST (train split)
  • Checkpoint cadence: configurable (default: per epoch)

All execution paths (CLI, backend jobs, UI-triggered runs) must call the same underlying implementation.


3.2 Gaussian Mixture VAE (GM-VAE)

Loss components:

  • Reconstruction MSE
  • Categorical KL divergence:
    KL(q(c|x) || uniform)
    
  • Expected Gaussian KL:
    E_q(c|x)[ KL(q(z|x) || p(z|c)) ]
    

Total:

loss = recon + kld_weight * (kld_c + kld_z)

Defaults (configurable):

  • kld_weight: 1.0
  • Mixture components: configurable
  • Optimizer: Adam

Training semantics must not change silently.

3.3 GM-VAE Anti-Collapse Options

Optional GM-VAE controls (configurable per run):

  • categorical_kl_anneal (bool)
  • categorical_kl_anneal_epochs (int)
  • categorical_kl_max_weight (float)
  • entropy_bonus_enable (bool)
  • entropy_bonus_weight (float)
  • entropy_bonus_epochs (int)

Per-epoch schedule semantics:

  • categorical KL weight ramps from 0 to categorical_kl_max_weight across the anneal window when enabled
  • entropy bonus weight decays from entropy_bonus_weight to 0 across the entropy window when enabled

When enabled, GM-VAE metrics include additional diagnostic fields in train.metrics:

  • cat_kl_weight
  • entropy_bonus_weight
  • qc_entropy_mean
  • qc_neff
  • qc_pmax
  • qc_argmax_counts

4. Run Lifecycle

4.1 States

created → running → completed
                 ↘ cancelled
                 ↘ failed

4.2 Cancellation

  • A shared cancel_requested flag is set.
  • Training loop checks flag at controlled intervals.
  • On cancellation:
    • Training stops gracefully.
    • A run.cancelled event is emitted.
    • Summary artifact reflects final status.

No hard process termination unless explicitly redesigned.


5. API Contract

The authoritative API schema is the OpenAPI specification dynamically generated by FastAPI.

Canonical Sources

  • OpenAPI JSON:

    /openapi.json
    
  • Interactive documentation:

    /docs
    

This schema is the single source of truth for:

  • Endpoint paths
  • HTTP methods
  • Request models
  • Response models

This document does not duplicate endpoint tables to avoid drift.


6. Server-Sent Events (SSE)

SSE is used for live training updates.

Requirements

  • Served directly by FastAPI
  • Events are append-only and persisted in events.jsonl
  • Supports Last-Event-ID replay
  • Emits periodic heartbeats
  • Reconnect behavior must be stable

Typical Event Types

Examples (non-exhaustive):

run.created
run.started
train.progress
train.metrics
artifact.created
projection.progress
run.completed
run.cancelled
run.failed

Event schema changes must remain backward-compatible unless versioned.


7. Artifact Model

Each run lives under:

runs/<run_id>/

Core structure:

config.json
events.jsonl
summary.json
projections/
neighbors/
density/
distortion/
mixture/
samples/
network/
checkpoints/

Artifacts must:

  • Be versioned per epoch when applicable
  • Provide *_latest.* mirrors
  • Be sufficient for deterministic rendering

7.1 Artifact Filename Conventions (Authoritative)

PCA 2D

projections/latent_2d_epoch_<N>.json
projections/latent_2d_latest.json

PCA 3D

projections/latent_3d_epoch_<N>_pca.json
projections/latent_3d_latest_pca.json

UMAP 2D

projections/latent_2d_epoch_<N>_umap.json
projections/latent_2d_latest_umap.json

UMAP 3D

projections/latent_3d_epoch_<N>_umap.json
projections/latent_3d_latest_umap.json

kNN Neighbors (PCA)

neighbors/knn_pca_epoch_<N>.json
neighbors/knn_pca_latest.json

kNN Neighbors (UMAP)

neighbors/knn_umap_epoch_<N>.json
neighbors/knn_umap_latest.json

Density 2D (PCA, Posterior)

density/posterior_2d_epoch_<N>_pca.json
density/posterior_2d_latest_pca.json

Density 2D (PCA, Prior)

density/prior_2d_epoch_<N>_pca.json
density/prior_2d_latest_pca.json

Density 2D (UMAP, Posterior)

density/posterior_2d_epoch_<N>_umap.json
density/posterior_2d_latest_umap.json

Density 2D (UMAP, Prior)

density/prior_2d_epoch_<N>_umap.json
density/prior_2d_latest_umap.json

Density 2D (PCA, Prior Empirical)

density/prior_empirical_2d_epoch_<N>_pca.json
density/prior_empirical_2d_latest_pca.json

Density 2D (UMAP, Prior Empirical)

density/prior_empirical_2d_epoch_<N>_umap.json
density/prior_empirical_2d_latest_umap.json

Metric Distortion 2D (PCA)

distortion/metric_ratio_2d_epoch_<N>_pca.json
distortion/metric_ratio_2d_latest_pca.json

Metric Distortion 2D (UMAP)

distortion/metric_ratio_2d_epoch_<N>_umap.json
distortion/metric_ratio_2d_latest_umap.json

Mixture Health (Projection-Invariant, GM-VAE)

mixture/mixture_health_epoch_<N>.json
mixture/mixture_health_latest.json

GM-VAE Posterior-Local Sampling (On Demand)

samples/posterior_local_epoch_<N>_idx_<i>_m_<M>.png
samples/posterior_local_latest_idx_<i>_m_<M>.png
samples/posterior_local_epoch_<N>_idx_<i>_m_<M>.json

GM-VAE Prior-Component Sampling (On Demand)

samples/prior_component_epoch_<N>_c_<c>_m_<M>.png
samples/prior_component_latest_c_<c>_m_<M>.png
samples/prior_component_epoch_<N>_c_<c>_m_<M>.json

Network Forward / Inspect (On Demand)

network/forward/epoch_<N>/idx_<i>/input.png
network/forward/epoch_<N>/idx_<i>/recon.png
network/forward/epoch_<N>/idx_<i>/layers.json
network/forward/epoch_<N>/idx_<i>/latent.json
network/forward/epoch_<N>/idx_<i>/forward_meta.json
network/forward/epoch_<N>/idx_<i>/layer_<layer_id>/summary.json
network/forward/epoch_<N>/idx_<i>/layer_<layer_id>/grid.png
network/forward/epoch_<N>/idx_<i>/layer_<layer_id>/vector.png

network/inspect/epoch_<N>/idx_<i>/layer_<layer_id>/ch_<c>/inspect.json
network/inspect/epoch_<N>/idx_<i>/layer_<layer_id>/ch_<c>/channel.png
network/inspect/epoch_<N>/idx_<i>/layer_<layer_id>/ch_<c>/heat_rgba.png
network/inspect/epoch_<N>/idx_<i>/layer_<layer_id>/ch_<c>/overlay.png

Network Single Convolution Step (On Demand, Conv1-Like Only)

network/conv_step/epoch_<N>/idx_<i>/layer_<layer_id>/out_<c>/x_<x>_y_<y>/conv_step.json

Network Kernel Artifacts (On Demand)

network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel.json
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel.png

network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel_agg.json
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel_agg.png
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/top_in_channels.json
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel_in_<in_channel>.png

Network Perturb + Decode (On Demand)

network/perturb/epoch_<N>/idx_<i>/vector_<key>/dim_<d>/delta_<tag>/recon.png
network/perturb/epoch_<N>/idx_<i>/vector_<key>/dim_<d>/delta_<tag>/meta.json

Network Top-K Activations (On Demand)

network/topk/epoch_<N>/layer_<layer_id>/ch_<c>/metric_<metric>/subset_<n>/k_<k>/topk.json
network/topk/epoch_<N>/layer_<layer_id>/ch_<c>/metric_<metric>/subset_<n>/k_<k>/input_<dataset_idx>.png

Density artifacts are compact 2D grids in projection space with:

  • source: posterior or prior
  • for source: prior, prior_kind is uniform or empirical
  • grid dimensions and bounds
  • row-major values
  • stats min/max
  • meta including subset_n, compute_ms, and prior weights summary when applicable

Distortion artifacts are compact 2D grids derived from latent-space kNN plus projection-space distances:

  • metric: knn_mean_ratio
  • ratio per point: mean(d_proj(i, nn_j)) / mean(d_latent(i, nn_j))
  • k, grid, bounds, row-major values
  • stats including min, max, median
  • meta including subset_n and compute_ms

Mixture health artifacts are projection-independent GM-VAE usage diagnostics computed from empirical responsibilities:

  • p empirical component weights (soft q(c|x) mean)
  • psum, entropy, neff, pmax, top
  • argmax_counts, argmax_neff
  • deterministic sampling identifiers (sample_key, sample_indices_ref)

Sampling artifacts are backend-generated decoded image grids and optional metadata:

  • kind: posterior_local or prior_component
  • epoch and selection key (idx or component)
  • sample count m
  • deterministic seed (unless explicit seed override is provided)
  • compute_ms in companion JSON

Network forward artifacts are backend-generated single-example captures:

  • forward_meta.json records model/epoch/index and GM-VAE sample_id sampling context
  • latent.json records vector payloads for vector layers (z, mu, logvar, qc_probs, z_sample as applicable)

Channel inspect artifacts may include optional probe fields:

  • xy probe coordinate and value
  • rf approximate receptive-field mapping (status=ok with clamped bounds, or status=unavailable with reason)

Single convolution step artifacts are backend-computed neuron-level explanations for conv1-like layers:

  • supports encoder Conv2d with C_in=1 and kernel 3x3
  • stores sampled input patch, kernel, elementwise product, pre-bias sum, bias, and activation-mode info
  • stores both computed fed-forward value and captured model value for consistency checking

Kernel artifacts support both conv1 and deeper encoder conv layers:

  • C_in=1: direct 3x3 kernel values (optionally with PNG heatmap)
  • C_in>1: aggregated 3x3 strength stencil (L2 over C_in) plus top contributing input-channel 3x3 slices

Perturb artifacts are decoder-path reconstructions for selected latent vectors:

  • deterministic path selection per vector key
  • per-request metadata includes delta, base_value, new_value, and decode vector identifiers

Top-K activation artifacts are deterministic subset scans for encoder spatial layers:

  • subset indices are deterministic (0..subset_n-1)
  • supports metrics max and mean
  • stores sorted top-K items with score and optional xy (for max)

New artifact types must follow the same per-epoch + latest convention.


8. Neighbors / Extent Semantics (Contract)

The following behavioral expectations must remain true:

  • d_k is defined in latent space, not projected space.
  • Extent radius scales monotonically with respect to d_k.
  • Normalization must not break monotonicity.
  • Two edge modes are supported:
    • Gated by extent
    • Always show top-k
  • Binary hard cutoffs that cause visual popping are discouraged.
  • Fuzzy gating (e.g., smooth exponential decay beyond radius) is preferred.
  • Density overlays are 2D-only and must be rendered from persisted density artifacts (posterior/prior sources).

Distance calculations must respect dimensionality:

  • 2D: sqrt(dx² + dy²)
  • 3D: sqrt(dx² + dy² + dz²)

Frontend must not recompute neighbor structure.


9. Frontend Routes (Orientation)

Primary routes:

/                — configure & launch run
/runs            — browse runs
/runs/[id]       — live & replay view
/runs/[id]/latent — latent explorer

These routes render artifacts and stream SSE updates.


10. Visualization Stability

To ensure consistent comparison:

  • Visualization sample indices may be fixed at run start.
  • Indices are persisted in config.json.
  • Recon grids and latent maps must use persisted indices.

Projection scaling must be:

  • Method-aware (PCA vs UMAP)
  • Dimension-aware (2D vs 3D)
  • Based solely on persisted projection data

11. Performance Expectations

  • Usable at ~100–200 focused points.
  • SSE updates feel live without flooding UI.
  • CPU-only execution remains viable.
  • Reconnect behavior is stable.

Visualization updates must avoid unnecessary O(N²) recomputation.


12. Testing & Verification

Minimum expectations:

  • Backend smoke tests verify:
    • Run creation
    • Event emission
    • Artifact persistence
  • Projection artifacts are written correctly per epoch.
  • Neighbor artifacts are written and loadable.
  • Backfill logic for projections and neighbors has at least one smoke test.
  • Frontend compiles cleanly (npx tsc --noEmit).
  • A run can be:
    1. Started
    2. Observed live
    3. Cancelled
    4. Reloaded from disk

13. Governance

  • AGENTS.md defines operational rules for agents.
  • This SPEC defines architectural and artifact contracts.
  • OpenAPI (/openapi.json) is the authoritative API schema.
  • Contract changes must be explicit and documented.

This document reflects the current implemented system and is intended to remain evergreen.