This repository implements a local-first, artifact-driven educational lab for exploring:
- Supervised Autoencoders (AE)
- Gaussian Mixture VAEs (GM-VAE)
- Latent projections (PCA / UMAP, 2D & 3D)
- kNN neighbor structure and related visualization tooling
The system consists of:
- Backend: FastAPI + PyTorch for training, projections, neighbors, artifact persistence, and SSE streaming
- Frontend: Next.js (TypeScript) for deterministic rendering of backend-generated artifacts
This document defines the current system contracts. It does not describe historical migrations or speculative roadmap features.
Everything displayed in the UI must come from:
- Streamed events (SSE)
- Persisted artifacts under
runs/<run_id>/
The frontend must never recompute:
- Latent embeddings
- PCA / UMAP
- Distance matrices
- kNN graphs
- Model sampling
Frontend may compute only:
- Pixel transforms
- Camera transforms
- Pure UI state
- Visual scaling derived from persisted projection data
All model-related computation happens in Python:
- Training loops
- Loss computation
- Latent embeddings
- PCA / UMAP projections
- kNN neighbor graphs
- Sampling
- Any derived metrics (e.g., uncertainty)
Visualization work must not modify training semantics.
If training behavior changes:
- It must be explicit.
- It must be documented.
- Tests must reflect the change.
A run must be:
- Reloadable without retraining
- Deterministically visualizable from artifacts alone
- Stable across reconnects and page refreshes
Artifacts must be sufficient for rendering.
Loss:
loss = MSE(reconstruction, input)
+ class_weight * CrossEntropy(logits, label)
Defaults (configurable):
class_weight: 1.0- Optimizer: Adam
- Dataset: MNIST (train split)
- Checkpoint cadence: configurable (default: per epoch)
All execution paths (CLI, backend jobs, UI-triggered runs) must call the same underlying implementation.
Loss components:
- Reconstruction MSE
- Categorical KL divergence:
KL(q(c|x) || uniform) - Expected Gaussian KL:
E_q(c|x)[ KL(q(z|x) || p(z|c)) ]
Total:
loss = recon + kld_weight * (kld_c + kld_z)
Defaults (configurable):
kld_weight: 1.0- Mixture components: configurable
- Optimizer: Adam
Training semantics must not change silently.
Optional GM-VAE controls (configurable per run):
categorical_kl_anneal(bool)categorical_kl_anneal_epochs(int)categorical_kl_max_weight(float)entropy_bonus_enable(bool)entropy_bonus_weight(float)entropy_bonus_epochs(int)
Per-epoch schedule semantics:
- categorical KL weight ramps from 0 to
categorical_kl_max_weightacross the anneal window when enabled - entropy bonus weight decays from
entropy_bonus_weightto 0 across the entropy window when enabled
When enabled, GM-VAE metrics include additional diagnostic fields in train.metrics:
cat_kl_weightentropy_bonus_weightqc_entropy_meanqc_neffqc_pmaxqc_argmax_counts
created → running → completed
↘ cancelled
↘ failed
- A shared
cancel_requestedflag is set. - Training loop checks flag at controlled intervals.
- On cancellation:
- Training stops gracefully.
- A
run.cancelledevent is emitted. - Summary artifact reflects final status.
No hard process termination unless explicitly redesigned.
The authoritative API schema is the OpenAPI specification dynamically generated by FastAPI.
-
OpenAPI JSON:
/openapi.json -
Interactive documentation:
/docs
This schema is the single source of truth for:
- Endpoint paths
- HTTP methods
- Request models
- Response models
This document does not duplicate endpoint tables to avoid drift.
SSE is used for live training updates.
- Served directly by FastAPI
- Events are append-only and persisted in
events.jsonl - Supports
Last-Event-IDreplay - Emits periodic heartbeats
- Reconnect behavior must be stable
Examples (non-exhaustive):
run.created
run.started
train.progress
train.metrics
artifact.created
projection.progress
run.completed
run.cancelled
run.failed
Event schema changes must remain backward-compatible unless versioned.
Each run lives under:
runs/<run_id>/
Core structure:
config.json
events.jsonl
summary.json
projections/
neighbors/
density/
distortion/
mixture/
samples/
network/
checkpoints/
Artifacts must:
- Be versioned per epoch when applicable
- Provide
*_latest.*mirrors - Be sufficient for deterministic rendering
projections/latent_2d_epoch_<N>.json
projections/latent_2d_latest.json
projections/latent_3d_epoch_<N>_pca.json
projections/latent_3d_latest_pca.json
projections/latent_2d_epoch_<N>_umap.json
projections/latent_2d_latest_umap.json
projections/latent_3d_epoch_<N>_umap.json
projections/latent_3d_latest_umap.json
neighbors/knn_pca_epoch_<N>.json
neighbors/knn_pca_latest.json
neighbors/knn_umap_epoch_<N>.json
neighbors/knn_umap_latest.json
density/posterior_2d_epoch_<N>_pca.json
density/posterior_2d_latest_pca.json
density/prior_2d_epoch_<N>_pca.json
density/prior_2d_latest_pca.json
density/posterior_2d_epoch_<N>_umap.json
density/posterior_2d_latest_umap.json
density/prior_2d_epoch_<N>_umap.json
density/prior_2d_latest_umap.json
density/prior_empirical_2d_epoch_<N>_pca.json
density/prior_empirical_2d_latest_pca.json
density/prior_empirical_2d_epoch_<N>_umap.json
density/prior_empirical_2d_latest_umap.json
distortion/metric_ratio_2d_epoch_<N>_pca.json
distortion/metric_ratio_2d_latest_pca.json
distortion/metric_ratio_2d_epoch_<N>_umap.json
distortion/metric_ratio_2d_latest_umap.json
mixture/mixture_health_epoch_<N>.json
mixture/mixture_health_latest.json
samples/posterior_local_epoch_<N>_idx_<i>_m_<M>.png
samples/posterior_local_latest_idx_<i>_m_<M>.png
samples/posterior_local_epoch_<N>_idx_<i>_m_<M>.json
samples/prior_component_epoch_<N>_c_<c>_m_<M>.png
samples/prior_component_latest_c_<c>_m_<M>.png
samples/prior_component_epoch_<N>_c_<c>_m_<M>.json
network/forward/epoch_<N>/idx_<i>/input.png
network/forward/epoch_<N>/idx_<i>/recon.png
network/forward/epoch_<N>/idx_<i>/layers.json
network/forward/epoch_<N>/idx_<i>/latent.json
network/forward/epoch_<N>/idx_<i>/forward_meta.json
network/forward/epoch_<N>/idx_<i>/layer_<layer_id>/summary.json
network/forward/epoch_<N>/idx_<i>/layer_<layer_id>/grid.png
network/forward/epoch_<N>/idx_<i>/layer_<layer_id>/vector.png
network/inspect/epoch_<N>/idx_<i>/layer_<layer_id>/ch_<c>/inspect.json
network/inspect/epoch_<N>/idx_<i>/layer_<layer_id>/ch_<c>/channel.png
network/inspect/epoch_<N>/idx_<i>/layer_<layer_id>/ch_<c>/heat_rgba.png
network/inspect/epoch_<N>/idx_<i>/layer_<layer_id>/ch_<c>/overlay.png
network/conv_step/epoch_<N>/idx_<i>/layer_<layer_id>/out_<c>/x_<x>_y_<y>/conv_step.json
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel.json
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel.png
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel_agg.json
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel_agg.png
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/top_in_channels.json
network/kernels/epoch_<N>/layer_<layer_id>/out_<c>/kernel_in_<in_channel>.png
network/perturb/epoch_<N>/idx_<i>/vector_<key>/dim_<d>/delta_<tag>/recon.png
network/perturb/epoch_<N>/idx_<i>/vector_<key>/dim_<d>/delta_<tag>/meta.json
network/topk/epoch_<N>/layer_<layer_id>/ch_<c>/metric_<metric>/subset_<n>/k_<k>/topk.json
network/topk/epoch_<N>/layer_<layer_id>/ch_<c>/metric_<metric>/subset_<n>/k_<k>/input_<dataset_idx>.png
Density artifacts are compact 2D grids in projection space with:
source:posteriororprior- for
source: prior,prior_kindisuniformorempirical griddimensions andbounds- row-major
values statsmin/maxmetaincludingsubset_n,compute_ms, and prior weights summary when applicable
Distortion artifacts are compact 2D grids derived from latent-space kNN plus projection-space distances:
metric:knn_mean_ratio- ratio per point:
mean(d_proj(i, nn_j)) / mean(d_latent(i, nn_j)) k,grid,bounds, row-majorvaluesstatsincludingmin,max,medianmetaincludingsubset_nandcompute_ms
Mixture health artifacts are projection-independent GM-VAE usage diagnostics computed from empirical responsibilities:
pempirical component weights (softq(c|x)mean)psum,entropy,neff,pmax,topargmax_counts,argmax_neff- deterministic sampling identifiers (
sample_key,sample_indices_ref)
Sampling artifacts are backend-generated decoded image grids and optional metadata:
- kind:
posterior_localorprior_component - epoch and selection key (
idxorcomponent) - sample count
m - deterministic seed (unless explicit seed override is provided)
compute_msin companion JSON
Network forward artifacts are backend-generated single-example captures:
forward_meta.jsonrecords model/epoch/index and GM-VAEsample_idsampling contextlatent.jsonrecords vector payloads for vector layers (z,mu,logvar,qc_probs,z_sampleas applicable)
Channel inspect artifacts may include optional probe fields:
xyprobe coordinate and valuerfapproximate receptive-field mapping (status=okwith clamped bounds, orstatus=unavailablewith reason)
Single convolution step artifacts are backend-computed neuron-level explanations for conv1-like layers:
- supports encoder
Conv2dwithC_in=1and kernel3x3 - stores sampled input patch, kernel, elementwise product, pre-bias sum, bias, and activation-mode info
- stores both computed fed-forward value and captured model value for consistency checking
Kernel artifacts support both conv1 and deeper encoder conv layers:
C_in=1: direct 3x3 kernel values (optionally with PNG heatmap)C_in>1: aggregated 3x3 strength stencil (L2 over C_in) plus top contributing input-channel 3x3 slices
Perturb artifacts are decoder-path reconstructions for selected latent vectors:
- deterministic path selection per vector key
- per-request metadata includes
delta,base_value,new_value, and decode vector identifiers
Top-K activation artifacts are deterministic subset scans for encoder spatial layers:
- subset indices are deterministic (
0..subset_n-1) - supports metrics
maxandmean - stores sorted top-K items with score and optional
xy(formax)
New artifact types must follow the same per-epoch + latest convention.
The following behavioral expectations must remain true:
d_kis defined in latent space, not projected space.- Extent radius scales monotonically with respect to
d_k. - Normalization must not break monotonicity.
- Two edge modes are supported:
- Gated by extent
- Always show top-k
- Binary hard cutoffs that cause visual popping are discouraged.
- Fuzzy gating (e.g., smooth exponential decay beyond radius) is preferred.
- Density overlays are 2D-only and must be rendered from persisted density artifacts (posterior/prior sources).
Distance calculations must respect dimensionality:
- 2D: sqrt(dx² + dy²)
- 3D: sqrt(dx² + dy² + dz²)
Frontend must not recompute neighbor structure.
Primary routes:
/ — configure & launch run
/runs — browse runs
/runs/[id] — live & replay view
/runs/[id]/latent — latent explorer
These routes render artifacts and stream SSE updates.
To ensure consistent comparison:
- Visualization sample indices may be fixed at run start.
- Indices are persisted in
config.json. - Recon grids and latent maps must use persisted indices.
Projection scaling must be:
- Method-aware (PCA vs UMAP)
- Dimension-aware (2D vs 3D)
- Based solely on persisted projection data
- Usable at ~100–200 focused points.
- SSE updates feel live without flooding UI.
- CPU-only execution remains viable.
- Reconnect behavior is stable.
Visualization updates must avoid unnecessary O(N²) recomputation.
Minimum expectations:
- Backend smoke tests verify:
- Run creation
- Event emission
- Artifact persistence
- Projection artifacts are written correctly per epoch.
- Neighbor artifacts are written and loadable.
- Backfill logic for projections and neighbors has at least one smoke test.
- Frontend compiles cleanly (
npx tsc --noEmit). - A run can be:
- Started
- Observed live
- Cancelled
- Reloaded from disk
AGENTS.mddefines operational rules for agents.- This SPEC defines architectural and artifact contracts.
- OpenAPI (
/openapi.json) is the authoritative API schema. - Contract changes must be explicit and documented.
This document reflects the current implemented system and is intended to remain evergreen.