diff --git a/experiments/subtask_probe/FINDINGS.md b/experiments/subtask_probe/FINDINGS.md new file mode 100644 index 0000000..e7e4cc1 --- /dev/null +++ b/experiments/subtask_probe/FINDINGS.md @@ -0,0 +1,634 @@ +# pi0.5 Subtask Generation Probe: Findings + +## Core Result + +**The public pi0.5 base checkpoint contains working subtask text generation capability.** Using JAX with the correct prompt format (`"Task: {task}. Subtask: "`), the model generates coherent English subtask text with high confidence (60-87% top-1 probability). + +| Prompt | Generated Subtask | Confidence | +|---|---|---| +| "pick up the red cup and place it on the shelf" | "pick up cup" | 78% | +| "fold the towel neatly" | "fold the towel" | 87% | +| "open the drawer and put the block inside" | "pull out the drawer" | 34% | +| "stack the blue block on top of the red block" | "pick up the paper" | 57% | +| "wipe the table with the sponge" | "wipe the spill" | 61% | + +No retraining needed. No new head needed. The existing `embed_tokens.weight` matrix serves as the lm_head via weight tying (`dot(hidden_state, embed_tokens.T) -> vocab logits`). + +## What Makes It Work + +Two things were required, both non-obvious: + +1. **Prompt format**: `"Task: {task}. Subtask: "` -- NOT the action format `"Task: {task}, State: {state};\nAction: "`. The model has distinct modes triggered by the prompt suffix. The paper describes this conceptually but never specifies the exact format string. + +2. **Autoregressive decoding loop**: Embed prefix -> forward through PaliGemma backbone -> KV cache -> project last hidden state to vocab via `dot(h, embed_table.T)` -> take argmax -> embed that token -> forward with KV cache -> repeat. + +## What Doesn't Work + +### PyTorch path is broken for autoregressive text generation + +The PyTorch model (via HuggingFace transformers) gets the first token approximately right but degrades to Unicode garbage on the second token onwards. We tried: + +- **v1-v2**: Wrong prompt format (`"Action: "` suffix). Got Unicode attractors (ⓙ, ⤙, Ẁ, etc.) +- **v3**: Correct prompt format + proper AR loop through `GemmaModel.forward()`. First token correct ("put" at 43%), second token garbage. Cause: HuggingFace's `create_causal_mask` intermediary in `GemmaModel.forward()` creates a causal mask that conflicts with the bidirectional prefix attention. +- **v3 + 2D mask fix**: Passing 2D padding masks instead of 4D float masks. Improved first token confidence but didn't fix continuation. +- **v4**: Bypassed `GemmaModel.forward()` entirely, manually iterating through decoder layers with custom 4D masks. Same result: first token OK, continuation broken. + +**Root cause**: The PyTorch and JAX implementations produce different hidden states from the same input. The first token prediction differs between them ("put" vs "pick" for the same prompt), and continuation diverges completely. + +**Confirmed NOT the weights**: We tested with lerobot/pi05_base (independently converted by HuggingFace team, fp32) and our openpi conversion (bf16). Both produce identical results. Weight values are bitwise identical between the two checkpoints (max diff = 0.0 across all tested keys). Both are straight JAX→PyTorch conversions with uniform precision (no selective mixed precision). The `to_bfloat16_for_selected_params` step in our Modal conversion script is a deployment optimization for `torch.compile` stability, not part of the upstream conversion. + +**Root cause identified**: Side-by-side JAX vs PyTorch comparison revealed: +- A minimal 3-token prefix (no images): AR step cos_sim = **0.999** -- nearly identical. KV cache works correctly. +- Full 968-token prefix (768 image + 200 language): AR step cos_sim = **0.30** -- complete divergence. + +The KV cache mechanism is correct. The problem is that the **prefix forward** already puts slightly different values into the cache (cos_sim 0.998 per position). With 968 positions of slightly-wrong cached key/value vectors, the attention scores during the AR step accumulate these errors and produce a completely different weighted sum. This is a numerical amplification issue: a 0.2% per-position error in the prefix becomes a 70% error in the AR step's attention output over 968 cached positions. + +The prefix error likely comes from the image embedding pipeline (SigLIP implementation differences between JAX Flax and PyTorch HuggingFace) or from attention precision differences (the JAX code uses float32 for attention logits while PyTorch may use bfloat16 in some paths). + +### JAX works, PyTorch doesn't + +All successful subtask generation implementations in the community use JAX: +- @BrunoFANG1 (openpi#701) -- JAX, partial subtask text from base checkpoint +- @LisavilaLee (openpi_with_subtask fork) -- JAX, full implementation with position fix +- Our probe -- JAX works, PyTorch doesn't + +## Architecture + +pi0.5 is a two-expert transformer: +- **PaliGemma backbone** (Gemma-2B + SigLIP): processes images + text, dim 2048 +- **Action expert** (Gemma-300M): processes noisy actions + timestep, dim 1024 +- Both share attention through 18 layers (fused Q/K/V) +- **Action output**: `action_out_proj: Linear(1024, action_dim)` on the action expert hidden states +- **Text output**: `dot(paligemma_hidden, embed_tokens.T)` on the backbone hidden states (weight-tied lm_head) + +The model has two modes selected by prompt format: +- `"...; Action: "` → action generation (flow matching, iterative denoising) +- `"... Subtask: "` → text generation (autoregressive, standard LM decoding) + +## Paper Context + +The pi0.5 paper describes two training stages: +1. **Pre-training**: Discrete tokens, web data, subtask prediction (HL data), cross-embodiment data. This is where the subtask text capability comes from. +2. **Post-training**: Adds the action expert, flow matching for continuous actions, specializes for mobile manipulation. + +The public checkpoint includes the action expert, so it has been through post-training. The subtask capability persists but generates short sequences (3-4 tokens) before hitting EOS with dummy images, likely because: +- Post-training may have partially degraded the LM capability +- Dummy/zero images provide no visual context for the model to describe +- The exact prompt format may not match what PI used internally + +Community reports indicate ~100 gradient steps of LM fine-tuning on subtask data produces full-quality subtask text. + +## Files + +| File | Purpose | +|---|---| +| `decode_jax.py` | **Working** JAX subtask generation probe | +| `hybrid_prompt_experiment.py` | JAX subtask → PyTorch action, three prompt variants | +| `dual_runtime_benchmark.py` | JAX + PyTorch coexistence on single GPU, latency benchmarks | +| `compare_jax_pytorch.py` | Side-by-side JAX vs PyTorch hidden state comparison (root cause diagnosis) | +| `latency_profile.py` | Latency breakdown and JIT optimization tests (WIP) | + +## How to Run (JAX, working) + +```bash +# On L40S (Seoul instance): +ssh ubuntu@43.200.36.250 + +# Ensure JAX checkpoint is downloaded +cd ~/openpi +source $HOME/.local/bin/env +uv run python -c "from openpi.shared import download; download.maybe_download('gs://openpi-assets/checkpoints/pi05_base')" + +# Run the probe +uv run python experiments/subtask_probe/decode_jax.py +``` + +## Hybrid Prompt Experiment (2026-04-14) + +### Question + +The pi0.5 action prompt format (`"Task: X, State: S;\nAction: "`) is completely different from the subtask prompt format (`"Task: X. Subtask: "`). Can we inject JAX-generated subtask text into the action prompt without retraining? + +### Method + +Generated subtasks via JAX (Phase 1), then compared PyTorch action outputs across three prompt variants (Phase 2): + +1. **Baseline**: `"Task: X, State: S;\nAction: "` (standard, no subtask) +2. **Hybrid A**: `"Task: X. Subtask: Y, State: S;\nAction: "` (subtask injected before state) +3. **Hybrid B**: `"Task: X (Y), State: S;\nAction: "` (subtask in parentheses) + +Used zero images, zero state, fixed random seed. Both models loaded sequentially on a single L40S (JAX freed before PyTorch loaded via `XLA_PYTHON_CLIENT_PREALLOCATE=false`). + +### Results + +| Task | Subtask (JAX) | Baseline vs Hybrid A | Baseline vs Hybrid B | +|---|---|---|---| +| pick up red cup → shelf | "put the blue cup in the bin" | cos=0.36, L2=94% | cos=-0.19, L2=135% | +| fold the towel neatly | "fold the towel" | cos=0.34, L2=265% | cos=0.26, L2=168% | +| open drawer, put block | (garbage — zero images) | cos=-0.06, L2=205% | cos=-0.11, L2=130% | +| wipe table with sponge | "1No" (degraded) | cos=-0.24, L2=679% | cos=0.43, L2=130% | + +### Interpretation + +1. **The model IS conditioning on subtask text.** Cosine similarities of 0.36 and -0.06 mean completely different action trajectories — not noise, but a different policy output. +2. **Prompt format matters.** Hybrid A ≠ Hybrid B, confirming the model parses the text structure, not just bag-of-words. +3. **Cannot assess quality with zero images.** All actions are meaningless without visual context. The experiment proves the *mechanism* works (text changes actions) but not whether it produces *better* actions. Need real robot images. +4. **Subtask quality degrades with zero images.** 2/4 prompts produced garbage subtasks. Real images should fix this (the model needs visual context to describe what it sees). + +### Two-Phase Inference Architecture (from paper) + +The pi0.5 paper (Section V.E, Figure 3, Figure 7) describes: + +``` +Phase 1 — Subtask generation (autoregressive text, PaliGemma backbone): + Prompt: "Task: clean the bedroom. Subtask: " + images + state + → AR decode → "pick up pillow" + +Phase 2 — Action generation (flow matching, action expert): + Prompt: "Task: clean the bedroom. Subtask: pick up pillow" + images + state + → 10 denoising steps → action chunk [a_t:t+H] +``` + +Key detail: the prompt formats are **different modes** of the same model: +- `"Task: X. Subtask: "` triggers autoregressive text generation (PaliGemma LM head) +- `"Task: X, State: S;\nAction: "` triggers flow-matching action generation (action expert) + +@LisavilaLee's implementation (`build_full_observation`) splices generated subtask tokens into the padding of the subtask-format prompt, then runs `sample_actions` on that. This means the action step uses `"Task: X. Subtask: Y"` as its prompt — NOT the standard `"Task: X, State: S;\nAction: "` format. This requires ~100 gradient steps of fine-tuning to teach the model to produce actions from the subtask prompt format. + +Our hybrid approach (injecting subtask text into the standard action format) is an alternative that avoids retraining, but the quality is unvalidated — needs real images to assess. + +### DROID Evaluation with Real Images (2026-04-16) + +Ran the full eval pipeline against 10 DROID episodes (276 frames) using the deployed two-phase server (Seoul, g6e.2xlarge). Subtask generation uses the base JAX checkpoint (pi05_base); action generation uses the DROID PyTorch checkpoint (swatery/pi05_droid_base). + +**Prompt format comparison (Hybrid A vs Hybrid B):** + +Tested two ways to inject subtask text into the action prompt: +- Hybrid A: `"{instruction}. Subtask: {subtask}"` (closer to pre-training format) +- Hybrid B: `"{instruction} ({subtask})"` (parenthetical) + +Results were statistically indistinguishable (Wilcoxon p=0.89). Both improved over baseline similarly. **Conclusion: the format doesn't matter, only the presence of subtask text.** Going forward, we use only Hybrid A (now called "subtask" condition) since it's closer to the pre-training prompt format. + +**Noise control was critical:** + +Initial results showed no significant difference (p=0.96) between baseline and subtask conditions. Investigation revealed that each server request gets independent random noise in the flow matching denoising loop. The noise-induced L2 variance (~0.76) was 76x larger than the prompt effect (~0.01), completely drowning the signal. After adding a `seed` field to the obs dict that sets `torch.manual_seed()` before denoising, all 3 conditions for the same frame get identical noise. With controlled noise: **p=0.025** (significant), subtask is closer to ground truth 55% of the time. + +**DROID checkpoint cannot generate subtask text:** + +Tested `gs://openpi-assets/checkpoints/pi05_droid` for subtask generation — produces empty strings (immediate EOS) on all 276 frames. The DROID fine-tuning completely destroyed the LM head's text generation capability. The base checkpoint (pi05_base) is the correct choice for the subtask planner. This confirms the paper's note that post-training degrades the subtask capability. + +**Image format issues discovered and fixed:** + +Three bugs prevented the server from processing DROID images correctly: + +1. **Camera name mapping**: The subtask generator expects `base_0_rgb`/`left_wrist_0_rgb` but clients send embodiment-specific names (`cam_high` for ALOHA, `observation/exterior_image_1_left` for DROID). Fixed: auto-detect embodiment from key names. + +2. **Image normalization**: The subtask generator's `_build_subtask_observation` expected float32 [-1,1] but clients send uint8 [0,255]. No normalization was applied. Fixed: added `_normalize_image()` that handles uint8→float32, CHW→HWC, and [0,1]→[-1,1] automatically. + +3. **Aspect ratio distortion**: DROID images are 180x320 (16:9). The extraction script resized to 224x224 with plain `tf.image.resize`, squishing the images. The model's own `preprocess_observation` uses `resize_with_pad` which preserves aspect ratio by adding black padding. Fixed: store original dimensions, let the server's preprocessing handle resize. + +These fixes improved diversity (228 unique subtasks across 276 frames, up from identical outputs), but **Unicode garbage persists in ~51% of frames** (141/276). Quality is highly variable by episode — ep_0005 produces 100% valid English, while ep_0004 and ep_0007 produce 0%. Examples of garbage: `셍踯≎𝟻셍ᔑ毟⢱Ꮸ𨨏ѱ`, `শՔ䭈⠤ǎƞᇃḡᵐ䁱჻ັ` (Korean, CJK, math symbols, emoji, Cyrillic mixed together). + +**Root cause**: The base checkpoint's LM head was degraded by post-training. The logit distribution is flattened — non-English tokens (CJK, Korean, etc.) sometimes receive higher probability than correct English tokens. Greedy argmax picks whatever is highest, regardless of language. + +**Fix: ASCII vocabulary masking.** Before argmax at each decode step, set logits for all non-ASCII tokens to -inf. This forces generation from English-only tokens. This is a standard industry technique — the same approach as vLLM's `allowed_token_ids`, HuggingFace's `LogitsProcessor`, and OpenAI's `logit_bias` API parameter. Implemented in `SubtaskGenerator._build_ascii_vocab_mask()`. The mask is built once at init (scanning all 257K PaliGemma vocabulary tokens) and applied as a `jnp.where` before every argmax — zero runtime cost, deterministic, JIT-compatible. + +### Dual Runtime Coexistence Test (2026-04-14) + +Confirmed both JAX and PyTorch models loaded simultaneously on a single L40S using `XLA_PYTHON_CLIENT_MEM_FRACTION=0.5`: + +| Resource | Usage | +|---|---| +| JAX (subtask model) | 6.4 GB VRAM / 22.7 GB limit | +| PyTorch (action model) | 7.1 GB VRAM | +| **Total GPU** | **~13.6 GB / 46 GB** | +| JAX subtask latency (first call, JIT) | ~64s | +| JAX subtask latency (warm, eager) | **~14s** | +| PyTorch action latency | **~280ms** | +| Total two-phase (warm) | **~14.2s** | + +Memory is not the bottleneck — 13.6GB out of 46GB leaves plenty of headroom. A bigger instance is unnecessary for memory. + +**The latency bottleneck is JAX eager-mode AR decoding (~14s warm).** The breakdown is: +- Prefix forward (SigLIP image encoding + 18 transformer layers over 968 tokens → KV cache): majority of time +- AR loop (3-5 token generations with growing KV cache): each step retraces because KV cache shape changes + +Attempted to profile the exact breakdown with JIT optimization tests but the profiling script was OOM-killed on system RAM (32GB). The 30GB system RAM may be tight when both JAX and PyTorch runtimes plus XLA compilation buffers are active. This is a system RAM constraint, not GPU VRAM. + +**Latency reduction options (untested, for next session):** +1. **JIT-compile the prefix forward** — fixed shape, should compile well. The AR loop is harder because KV cache grows per step. +2. **Pre-allocate KV cache** to max size (prefix + max_gen_tokens) and use `jax.lax.while_loop` for fully JIT AR generation. +3. **Cache subtasks aggressively** — subtask only needs regeneration when the visual scene changes significantly, not every action cycle. At 14s per subtask, caching is essential. +4. **Larger system RAM** — profiling was killed at 32GB. A g6e.2xlarge (64GB RAM, same L40S GPU) would allow JIT compilation without OOM. + +### Hosting Architecture + +Both runtimes coexist on a single L40S (48GB VRAM) with `XLA_PYTHON_CLIENT_MEM_FRACTION=0.5`. Two separate instances are NOT needed for memory reasons. However: + +- **Single instance, two processes** is viable if latency can be reduced to <1s via JIT compilation +- **Two instances** only makes sense if system RAM (32GB on g6e.xlarge) is the bottleneck for JIT compilation, since the profiling script OOMed — a single g6e.2xlarge (64GB RAM) would be cheaper than two g6e.xlarge instances +- The two-phase inference is opaque behind QUIC — client sends `{task, images, state}`, gets back `{actions}` + +The subtask refresh rate is a design choice. The paper's Figure 7 shows subtask predictions changing frame-by-frame as the scene evolves, suggesting periodic re-generation (not just once per task). Given the 14s latency, aggressive caching is necessary until JIT optimization is done. + +## Checkpoint Conversion Notes + +### JAX → PyTorch conversion pipeline + +The stock openpi conversion script (`examples/convert_jax_model_to_pytorch.py`) does a **straight conversion** — uniform precision (float32 or bfloat16), no selective mixed precision. This is what the HuggingFace/LeRobot team used to produce `lerobot/pi05_base`. + +Our Modal conversion script (`convert_checkpoint_modal.py`) adds an extra post-conversion step: `to_bfloat16_for_selected_params()`, which keeps layernorms, vision patch embeddings, and position embeddings in float32 while converting everything else to bfloat16. This is a **deployment optimization for `torch.compile` stability** (prevents fp32/bf16 matmul crashes), not a correctness requirement. Standard inference without `torch.compile` works fine with uniform precision. + +### DROID checkpoint conversion (completed 2026-04-16) + +Converted `gs://openpi-assets/checkpoints/pi05_droid` → PyTorch on the Seoul g6e.2xlarge instance using the stock openpi conversion script. Uploaded to HuggingFace: **`swatery/pi05_droid_base`** + +```bash +uv run python examples/convert_jax_model_to_pytorch.py \ + --checkpoint_dir ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid \ + --output_path ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch \ + --config_name pi05_droid \ + --precision bfloat16 +``` + +Output: `model.safetensors` (6.8GB, bfloat16), `config.json` (action_horizon=15, pi05=True). No assets directory was copied (DROID norm stats are in a separate location). + +Previous Modal attempt (2026-04-14) hit import bugs and client disconnect issues — running directly on AWS was simpler. + +## Next Steps + +### Validate with real robot images (DROID) + +3. **Download DROID dataset samples** — need actual robot camera frames to test subtask generation and action quality. DROID provides multi-camera observations with task labels, matching pi0.5's expected input format. + +4. **Test subtask generation with real images using DROID checkpoint** — feed actual robot workspace images into the JAX subtask generator with the DROID-finetuned checkpoint. Expect longer, more specific subtask text (vs. the 3-4 token outputs with zero images on the base checkpoint). + +5. **Test hybrid prompt action quality with real images** — compare the baseline (no subtask) vs. hybrid (with subtask) action outputs on real images. If the subtask-conditioned actions are measurably different AND more semantically aligned with the task, the hybrid approach works without retraining. + +### Fix PyTorch AR generation + +6. **Force float32 for the entire prefix forward** instead of bfloat16. The JAX code computes attention logits and RMSNorm variance in float32 but the PyTorch path may use bfloat16 in some places. Eliminating precision loss in the prefix would reduce per-position error and may prevent the amplification. Quick test -- just set `model.to(torch.float32)` before the prefix forward. + +7. **Audit SigLIP image embedding differences** between JAX (`openpi/src/openpi/models/siglip.py`) and HuggingFace's SigLIP implementation. The 768 image tokens contribute the most cached positions and are likely the largest source of error. + +8. **Audit attention precision paths**. The JAX gemma.py explicitly does `jnp.einsum(..., preferred_element_type=jnp.float32)` for attention logits. The PyTorch `eager_attention_forward` may not enforce float32 for the QK matmul. + +### Production integration + +9. **Single-process combined-mode serving** — one openpi-flash process loads both slots, exposes the action policy on ports 8000 (WebSocket) / 5555 (QUIC) and the JAX planner on ports 8002 / 5556, and shares one `SubtaskGenerator` instance between them. The action endpoint calls the planner first and splices the subtask into the prompt before action inference. See `src/hosting/serve.py`, `src/hosting/subtask_policy.py`, and the README's *Subtask generation (planner)* section. + +10. **Subtask caching and refresh policy** — decide how often to re-generate subtasks. Options: every action chunk, every N steps, or when visual change exceeds a threshold. @LisavilaLee's code caches by prompt only (never refreshes), but the paper's Figure 7 shows it should update with the scene. Today combined mode regenerates on every action `infer()` call unless the client opts out with `obs["mode"] = "action_only"`. + +### Future optimization: Flash attention for prefix forward + +The prefix forward processes ~968 tokens through 18 transformer layers. Each layer computes self-attention via explicit einsums in `gemma.py`: + +```python +logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k) # materializes full 968×968 matrix in HBM +probs = jax.nn.softmax(masked_logits, axis=-1) +encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) +``` + +`jax.nn.dot_product_attention(q, k, v, is_causal=True, implementation='cudnn')` would replace all three ops with a single fused kernel that tiles the computation in GPU SRAM instead of materializing the full attention matrix in HBM. It handles GQA natively (Q `[B, T, 8, H]` against K/V `[B, T, 1, H]`). + +**Estimated impact**: Attention is ~40-60% of each layer's compute. Flash attention typically gives 2-3x speedup on the attention kernel. For the prefix forward (~1s JIT'd), this could save 200-400ms. For AR decode steps (~5ms each, query is single token), the benefit is negligible. + +**Blockers**: Requires modifying `gemma.py` (upstream openpi code shared by all models). Would need a compelling reason — the 200-400ms saving on prefix is nice but not critical given the larger wins from JIT-compiling the decode loop. Also requires cuDNN availability (L40S has it, but needs correct CUDA/cuDNN versions). Changing the `implementation` parameter requires recompilation (different XLA graph), so it's a deploy-time choice, not runtime-switchable. + +**Alternative approaches considered**: +- **Flax NNX `MultiHeadAttention` with `decode=True`**: Has built-in pre-allocated KV cache with `dynamic_update_slice` and `init_cache()`. Would enable `jax.lax.while_loop` for AR decode. But requires replacing Gemma's custom attention module — same upstream modification concern. +- **`chex.dataclass`**: Registers dataclasses as JAX pytrees. Would clean up a `DecodeState` carry struct if we used `while_loop`, but the JIT-unrolled approach doesn't need one. +- **`chex.assert_shape()`**: JIT-compatible shape validation. Nice for development but doesn't affect performance. + +## Comparison with openpi_with_subtask Fork (2026-04-17) + +Deep comparison of our `SubtaskGenerator` against @LisavilaLee's `openpi_with_subtask` fork to understand why our DROID eval produces nonsense subtasks. + +**Finding: the subtask generation code is functionally identical.** Same prompt format, same cleaning logic (lowercase, strip punctuation), same greedy argmax decoding, same SigLIP image encoding, same KV cache approach. There is no hidden inference-time trick in the fork. + +The fork's quality advantage comes entirely from ~100 gradient steps of fine-tuning with: +1. `token_ar_mask` — causal attention on subtask tokens, bidirectional on prefix +2. `token_loss_mask` — CE loss applied only to subtask portion, not prefix +3. Identity subtask training (`high_prompt = low_prompt = prompt`) + +One additional difference affects action quality (not subtask text quality): the fork's `build_full_observation()` splices generated subtask TOKEN IDs directly back into the padded prompt region, maintaining exact token fidelity. Our approach converts tokens→text→re-tokenizes as a new string, which can produce different token boundaries. + +### ASCII Vocabulary Masking (2026-04-17) + +**Problem:** ~51% of DROID eval frames produce Unicode garbage (Korean, CJK, math symbols, emoji, Cyrillic). The base checkpoint's degraded LM head assigns higher probability to non-English tokens than correct English tokens on many inputs. + +**Solution:** ASCII vocabulary masking — before argmax at each decode step, set logits for all non-ASCII tokens to `-inf`. Implemented in `SubtaskGenerator._build_ascii_vocab_mask()`: + +1. At init: scan all 257K PaliGemma vocabulary tokens via SentencePiece `id_to_piece()` +2. Mark token as valid if its piece (with `▁` treated as space) is fully ASCII +3. Store as `jnp.array` boolean mask — embedded as XLA constant in JIT-compiled graph +4. Before every argmax: `logits = jnp.where(valid_vocab_mask, logits, -1e9)` + +**Industry context:** This is the same technique as: +- **vLLM** `allowed_token_ids` — restrict sampling to a set of token IDs +- **HuggingFace** `LogitsProcessor` — arbitrary logit manipulation before sampling +- **OpenAI API** `logit_bias` — per-token logit adjustments +- **Constrained decoding** (`outlines`, `guidance`, `lm-format-enforcer`) — enforce regex/grammar on output (more powerful but heavier; these libraries are PyTorch-only, not compatible with our JAX backend) + +**Properties:** Zero runtime cost, deterministic, JIT-compatible. Does not change the model — only filters the output vocabulary at decode time. + +**Status:** Implemented, awaiting re-run of DROID eval to measure impact. Expect Unicode garbage rate to drop from ~51% to ~0%. English subtask quality should be unchanged since the mask only removes non-ASCII tokens — the correct English token was always in the distribution, just sometimes ranked below a non-English token. + +## Subtask Prompt-Format Sweep (2026-04-17) + +Ran all 276 DROID frames against 4 subtask-generation prompt formats on the Seoul g6e.2xlarge, swapped at runtime via the admin HTTP endpoint (`PATCH /config` with `subtask_prompt_format`). Results below use a stricter "usable" metric: printable ASCII only (`str.isprintable() and str.isascii()`), since the earlier `isascii()`-only vocab mask let control characters `\x16`, `\x19`, `\x1d` through. + +| Format | Prompt | Printable usable | Unique | Mean chars | Most common output | +|---|---|---|---|---|---| +| **default** | `Task: {task}. Subtask: ` | 166/276 (60%) | 89 | 23.1 | `'move the pan'` (9×) | +| raw | `{task}` | 13/276 (5%) | 6 | 20.1 | mostly empty | +| numbered | `Task: {task}.\n` | 169/276 (61%) | 87 | 31.4 | `'No, move the arms home'` (19×) | +| listprime | `Task: {task}. Subtask: 1` | 233/276 (84%) | 69 | 19.3 | `'No progress'` (32×) | + +### Decision: keep the default format + +Although `listprime` has higher printable-output coverage (84% vs 60%), its outputs skew toward self-critique phrases: 72/233 are variants of `'No progress'`, `'No movement'`, `'No significant movement'`, `'No skill'`. These aren't subtasks — they're the model describing the action history. Coverage went up, usefulness went down. + +`default` still produces the cleanest *imperative subtasks* when it works — `'pick up lid'`, `'wipe the spill'`, `'open the drawer'`, `'move the pan'`. This matches the pre-training format the paper describes. The 38% control-char garbage it produces is a *mask* problem, not a *prompt* problem, and is solved separately below. + +`raw` (bare instruction, no suffix cue) is effectively broken — the model EOSes immediately. Confirms the hypothesis that the `"Subtask: "` / newline suffix is a required mode-selector for AR text generation. + +`numbered` produces longer outputs than default but drifts into state descriptions ("No, the blue ring is in the gripper and then place the blue ring in the basket") rather than next-action subtasks. + +### Vocab mask tightening: printable + ASCII + +The existing `_build_ascii_vocab_mask` used `piece.isascii()`, which admits control characters (0x00–0x1F, 0x7F). These are legitimately ASCII but not text. After tightening to `piece.isascii() and piece.isprintable()`, the mask excludes those tokens at decode time, eliminating the control-char garbage class that accounted for 38% of "default" outputs. + +`str.isprintable()` is the right stdlib primitive here — it excludes control chars while keeping letters, digits, spaces, and punctuation. No extra dependency needed. For broader language filtering (e.g., "is this really English?") the industry-standard options are `langdetect`, `lingua-py`, or `fasttext` with `lid.176.bin`, but those are post-hoc heuristics; vocab masking at decode time is strictly cheaper and deterministic. + +### Bugs fixed during the sweep + +Two bugs in the admin-endpoint deploy were blocking the test and had to be fixed before any prompt format could run: + +1. **`serve.py:210`** was initializing `RuntimeConfig.subtask_prompt_format` from `SubtaskConfig.prompt_template` — but those are different strings. `prompt_template` is the *action* prompt format (`"{task}. Subtask: {subtask}"`) and contains `{subtask}`, which the subtask tokenizer's `.format(task=...)` call then raised `KeyError: 'subtask'` on. Seoul had been in a container crash-loop since the admin-endpoint deploy went out. Fix: use `RuntimeConfig()` with its default. + +2. **`admin_server.py:to_dict()`** used `dataclasses.asdict()`, which deep-copies every field recursively — including `_lock: threading.Lock`, which can't be pickled. Every `GET /config` and `PATCH /config` call returned an empty reply and the server logged a pickle error. Fix: build the dict manually via `fields(self)`, skipping underscore-prefixed fields. + +### Deployment config fix + +The admin endpoint was only reachable during the sweep because of a one-off `docker run -p 127.0.0.1:8001:8001` invocation. Terraform's `user_data.yaml.tftpl` only published `8000:8000` and `5555:5555/udp` — on a fresh cloud-init (stop→start or full redeploy) the admin port would have been lost. + +Added `-p 127.0.0.1:8001:8001` to `infra/modules/regional_inference_instance/templates/user_data.yaml.tftpl`. Bound to `127.0.0.1` (localhost on the host), not `0.0.0.0` — the admin endpoint has no auth and should never be internet-reachable. Operators reach it via SSH port forwarding: `ssh -L 8001:127.0.0.1:8001 ubuntu@` then `curl localhost:8001/config` from their laptop. No security-group change needed since the port isn't exposed publicly. + +## DROID duration distribution (2026-04-18) + +Measured on 3000 successful DROID v1.0.1 episodes streamed sequentially from `gs://gresearch/robotics/droid/1.0.1`: + +| stat | steps | seconds @15 Hz | +|---|---|---| +| mean | 305 | 20.3 | +| p50 | 234 | 15.6 | +| p95 | 792 | 52.8 | +| p99 | 1222 | 81.5 | +| max | 2324 | 154.9 | + +Only ~0.1% of episodes hit 2 min and most of the longest ones have empty language instructions. Naive "first N episodes with duration ≥ 120s" returns zero matches in the first 2000 scanned. + +**Selection strategy used for the long-horizon eval**: top-K longest with filters, implemented in `extract_droid_samples.py` via `--scan_episodes`, `--min_duration_s`, `--require_multi_step`. Scan 5000 episodes, reject empty instructions, floor at 60s, require the multi-step keyword heuristic (`pick…place`, `and then`, `put…in`, etc.), keep the 5 longest in an in-memory min-heap. Deterministic ~10-min GCS scan on a cloud box, guarantees we exercise the long-horizon regime where subtask conditioning is hypothesized to help. + +## Comet-Style Hierarchical Subtask Generation (2026-04-18) + +Experiment: test the `plan → critique → subtask` scaffold from **openpi-comet** (`src/openpi/shared/client.py`) on our long-horizon DROID cache, using two off-the-shelf VLM backends. Code: `experiments/subtask_probe/droid_eval/comet_style/`. + +### Paper vs. code: what we're actually testing + +Comet's **paper** (arxiv 2512.10071v3) reports their 0.3453 Q-score from π₀.₅ + RFT + expanded pre-training (§4.1-4.3). It **does not describe any reasoner/planner VLM**. The `client.py` plan/critique/subtask loop is activated only when `fine_grained_level > 0` (`eval_b1k_wrapper.py:59`), a training-data knob the paper doesn't ablate. Their released checkpoints (`pi05-b1kpt12-cs32`, `pi05-b1kpt50-cs32`) are `fine_grained_level=0`. §5 lists "more structured long-horizon reasoning" as **future work**. + +Conclusion: **the scaffold is exploratory code that wasn't part of their reported result.** We're not replicating a paper claim; we're testing the scaffold's idea on DROID with our own VLM backends. This is also why the scaffold's default reasoner endpoint (`b5k2m9x7-pre-exp011-043-32000.xenon.lepton.run`) is dead and the default model name (`Qwen3-VL-30B-A3B-Instruct`) is just a kwarg default never actually called. + +### Scaffold architecture + +Backend-agnostic `BaseReasoner` in `comet_style/reasoner_base.py` with two backends: +- **Gemini** (`gemini_reasoner.py`): `gemini-robotics-er-1.6-preview` via `google-genai`, with 120s request timeout and retry-on-transient-network-error (not just 429). +- **OpenAI-compatible** (`openai_compat_reasoner.py`): any vLLM-hosted VLM — we ran `Qwen3-VL-30B-A3B-Instruct` FP8-quantized on a g6e.2xlarge (L40S 48GB, 64GB RAM) in us-west-2. + +### Structured output is load-bearing + +Off-the-shelf VLMs do not reliably emit structured output for Comet's prompts out of the box. Initial runs showed: +- `generate_plan`: Gemini emitted a prose paragraph, Qwen-8B emitted a markdown numbered list. Neither parsed as JSON → fell back to a single-step plan = the global task verbatim. +- `generate_subtask`: Gemini Robotics-ER returned ~150-word reasoning paragraphs; Qwen-8B echoed the global task. +- `plan_critique`: free-form prose, wording varied every call, triggered the `if updated != plan_status` reset on every frame — effectively resetting `subtask_history` constantly, losing the hierarchical structure. + +Fix: enforce a schema on **all three** VLM calls via the backend's native structured-output API (Gemini `response_schema` / vLLM `response_format=json_schema`). Three schemas in `reasoner_base.py`: +- `PLAN_SCHEMA`: `{"type": "array", "items": {"type": "string"}, "minItems": 2, "maxItems": 10}` +- `SUBTASK_SCHEMA`: `{"type": "object", "properties": {"subtask": {"type": "string", "maxLength": 120}}}` +- `CRITIQUE_SCHEMA`: `{"type": "object", "properties": {"statuses": {"type": "array", "items": {"type": "string", "enum": ["done", "in_progress", "not_started"]}}}}` + +Refactored state: `plans: list[str]` + `plan_statuses: list[PlanStepStatus]` (parallel lists, canonical), `plan_status: str` becomes a derived property. The reset-on-change gate now compares status lists structurally, not prose strings. + +### Structured critique: 7× speedup, clean progression + +| Variant | Mean latency/replan | Unique subtasks (111-frame ep) | Plan progression | +|---|---|---|---| +| Qwen-30B + free-form critique | 5.74 s | 21 | Stuck 65 frames on "locate/search/scan" | +| Qwen-30B + **structured critique** | **0.81 s** | **5** | Clean monotonic: move → grasp → move to dish → place | + +Latency drop is output-token count: free-form critique emitted ~300-500-token paragraphs, structured critique emits ~5-15 tokens (enum list). No change in input, no change in model. + +### History stride matters (and Comet's hardcoded 5 is correct) + +`sample_images` walks the image history with stride=5 by default (Comet's original value, tuned for their 30 Hz sim buffer). With our 1 Hz cache (`--frame_subsample=15`), we assumed stride=1 would be better ("just give the model consecutive frames"). Tested both on the full 5-episode run: + +| Episode | Task | stride=5 transitions | stride=1 transitions | +|---|---|---|---| +| ep_0000 | cube + dish (simple) | **8** | 80 | +| ep_0001 | multi-step kitchen | 24 | **8** | +| ep_0002 | turnip plushie + duster | 90 | 82 (semantic oscillation, see below) | +| ep_0003 | cloths + markers | **4** | 8 | +| ep_0004 | bottle + blind | **4** | 5 | + +**stride=5 wins on 3/5 episodes and draws on 1.** The original intuition ("stride=5 is over-sparse for 1 Hz cache") was wrong: when history has ≥40 entries, stride=5 gives 8 images spanning ~40 seconds, providing the **temporal-contrast signal** the reasoner needs to detect progression ("40 s ago arm was here, now it's there"). stride=1 gives 8 consecutive seconds — too narrow a window for slow manipulation tasks, where the model over-interprets frame-to-frame motion. + +Default kept at `--history_stride=5`. + +### Two-call design: intentional, not wasteful + +Every replan fires two VLM calls (`plan_critique` + subtask-selection). We considered merging into one call with a `{statuses, subtask}` schema (halves cost/latency). The **load-bearing reason** to keep them separate (documented in `BaseReasoner` docstring): + +1. The subtask prompt is built from the **post-critique** `plan_status`. Merged calls force the model to emit both fields in one generation with pre-critique context — fine for reasoning models, risks inconsistent outputs on non-reasoning models (e.g. Gemini Robotics-ER at `thinking_budget=0`). +2. The `subtask_history` reset gate runs between the two calls. If the critique changed `plan_statuses`, `last_subtask` in the subtask prompt becomes `"None"` — merged calls can't apply this reset mid-generation. +3. Graceful degradation: a malformed critique keeps the old statuses and still runs subtask; a malformed merged call loses both. + +Merged-call mode is a reasonable follow-up if API cost ever becomes the bottleneck. + +### DROID frames continue past task completion + +~55% of frames in a cache-selected long-horizon episode are **post-task** — the operator retracted the arm, adjusted, or just held position until the fixed recording window closed. This dilutes Phase 2 L2 metrics (we're measuring whether pi0.5 predicts operator-idle behavior, not task execution). Added `all_steps_done(plan_statuses)` short-circuit so the reasoner skips VLM calls once every plan step is marked `done`, reusing the last real subtask. Pending a Phase 2 split-metrics run that reports pre-completion vs. post-completion separately. + +### Multi-object tasks produce semantic oscillation, not a bug + +ep_0002 (`"Place the turnip plushie on the table, then the duster on the box..."`) shows the reasoner flipping between `"place the turnip plushie on the table"` and `"place the duster on the box"` every 15 cached frames, regardless of stride. Not a scaffold bug: the task genuinely has two parallel placement sub-goals and each frame's "active step" depends on which object is more visually salient. Real finding worth reporting in the write-up — Comet's scaffold assumes strictly sequential plan steps, which doesn't fit every task structure. + +### Action-horizon alignment: the cache was wrong + +The previous 10-step-subsampled cache (`.experiments_cache/droid_eval_2min`, 1.5 Hz) was not aligned with pi0.5's `action_horizon=15`. Per-frame behavior-cloning eval with misaligned cache means every cached frame is a "fresh state reset," not a closed-loop decision point. Re-extracted with `--frame_subsample=15` → `.experiments_cache/droid_eval_ah15` (5 episodes, 475 frames total, 1 Hz, 1 cached frame = 1 action horizon of real time). + +### Serving infrastructure + +- **Local runs** (`run.py` from laptop): Gemini backend. ~0.8 s/replan, but WAN latency amplifies — the initial Gemini Comet run was killed at 50-min hang due to a `Server disconnected` not caught by the 429-only retry. Retry logic broadened to cover transient 5xx/disconnect errors; request timeout hard-capped at 120 s per call. +- **Remote runs** (on the vLLM host itself): Qwen backend. Runs the CLI directly on US West 2 (code + cache rsynced to `~/comet_eval/`), hits `http://localhost:8000/v1` — no SSH tunnel fragility. +- **US West 2 instance** resized from `g6e.xlarge` (32 GB RAM) to `g6e.2xlarge` (64 GB) + EBS grown from 100 GB to 200 GB (online `modify-volume` + `growpart` + `resize2fs`) to host Qwen-30B FP8 weights (~30 GB VRAM, ~58 GB disk). scratch.md updated. + +### Visualization + +`visualize_subtasks.py` renders both camera views (exterior + wrist) side-by-side with per-frame subtask text: +- **HTML** (default): self-contained page per episode, scrollable table with exterior/wrist thumbnails + subtask column. +- **Video** (`--video`): per-episode mp4 at 2 fps (matches cache rate), composite 640×180 frame with EXTERIOR/WRIST labels and a subtask banner. + +Used for inspecting plan progression without running the full Phase 2 action eval — much faster feedback loop during scaffold iteration. Bug caught through this: `visualize_subtasks` previously only rendered the exterior camera, hiding the gripper state that's visible only in the wrist view. + +### Status: Phase 2 not yet run + +The full Qwen-30B run on `.experiments_cache/droid_eval_ah15` (475 frames, structured critique, stride=5) is complete and saved to `subtasks_comet_qwen30b.json`. Phase 2 (action eval) and Phase 3 (metrics) against Seoul's pi0.5 action server are the next step — pending. + +## ForeAct release audit (2026-04-18) + +Before starting the ForeAct reconstruction on DROID, we read the paper (arxiv 2602.12322) and walked through the released code at `/Users/kkuan/openpi/foreact/` to pin down what we can actually reuse. Two load-bearing findings, recorded here because they scope everything downstream. + +### Finding 1 — "No architectural modification" ≠ "no training" + +The paper claims (§3.4) that ForeAct "requires no architectural modifications" to the VLA. Verified in code: pi0.5's image list is built dynamically at runtime by filtering the observation dict against `self.config.image_features` (`foreact/third-party/lerobot/src/lerobot/policies/pi05/modeling_pi05.py:1143-1150`). Adding a foresight image slot is just adding a new key to the dataset and the config — no new layers, no param changes, no architecture shim. That claim is literally true. + +But the public pi0.5 base checkpoint was pretrained with 2 images (`exterior`, `wrist`). Registering a 3rd key makes it an input but produces no useful behavior: the action head has no learned pathway from that feature slot to actions. ForeAct's recipe `foreact/third-party/lerobot/scripts/run_sub_task_100k.sh` fine-tunes from `lerobot/pi05_base` for 100k steps on `ForeAct_VLA_Dataset` — that fine-tune is what builds the pathway. "No architectural modification" is accurate; "plug-and-play at inference time" it is not. + +Implication for us: feeding foresight to our pi0.5 action server zero-shot is blocked at two independent layers. + +1. *Server-side*: `src/hosting/warmup.py:121-132` and the Rust QUIC sidecar hardcode a 2-image spec (`observation/exterior_image_1_left`, `observation/wrist_image_left`). A 3rd key is silently dropped before reaching the policy. +2. *Weights-side*: even if we bypassed serving and set `image_features` to include a foresight key, `_preprocess_images` would encode it but the action head has no learned attention/gating for it. Either silently ignored or actively harmful. + +So end-to-end action eval with foresight is out of scope for an inference-only reconstruction. Only fine-tuning pi0.5 on the augmented input (the paper's depth-C recipe) gets there. + +### Finding 2 — Only the foresight generator is released, not a fine-tuned VLA + +`mit-han-lab/foreact-pretrained` on HuggingFace contains exactly: +- 3 safetensors shards (10.2 GB total, ~5B params) +- `config.json` (941 B) +- `model.safetensors.index.json` (100 kB) +- No `vla/` or `pi0/` subfolder + +The 5B param count cleanly adds up to Sana-1600M (1.6B) + Gemma-2-2B-IT (2B) + DC-AE VAE — i.e. π_g only. The HF tag `visualforesight` and the `-pretrained` suffix (= cross-embodiment pretraining stage, *before* target-robot fine-tune) both confirm this. The `mit-han-lab` HF account has only one `foreact-*` repo; other pi0.5 checkpoints listed there (e.g. `vlash-pi05-libero-async5`) belong to different projects. + +The VLA training script `run_sub_task_100k.sh` starts from the *public* `lerobot/pi05_base` and produces a Galaxea-R1-Lite-specific fine-tune locally. That output is never uploaded. The `mit-han-lab/ForeActDataset` release contains the raw Galaxea episodes (subtask-segmented) but not a DROID-flavored preprocessing pipeline. + +Why it's not actually strange: the generator is robot-agnostic (pretrained on 10M cross-embodiment pairs across AgiBot / RoboMind / Galaxea / Bridge), useful to anyone with any robot. The fine-tuned VLA is Galaxea-R1-Lite-specific — different camera mount, different action space, different embodiment from DROID / LIBERO / anything else — so publishing it would have minimal downstream value. Standard pattern for robotics papers. + +Implication: even *with* training infra, reproducing ForeAct end-to-end on DROID would require running their fine-tune recipe ourselves on a DROID-flavored variant of `ForeAct_VLA_Dataset`, which they didn't release a preprocessing pipeline for (only the raw Galaxea episodes). + +### Scope of the reconstruction we're doing + +Given the above, the faithful-without-training reconstruction is: + +- **π_v planner** (their VLM subtask planner): fully faithful. Table 5 prompts verbatim, paper's `Qwen/Qwen3-VL-8B-Instruct` model. Only the harness is ours. +- **π_g generator** (their foresight image generator): faithful to the released `mit-han-lab/foreact-pretrained` checkpoint + `VisualForesightPipeline`. **Skip** the 5-epoch target-data fine-tune the paper always runs. DROID was deliberately excluded from π_g pretraining (§3.2) so this is genuinely zero-shot OOD — a setting the paper never evaluates. +- **VLA integration**: unreachable. Foresight images are outputs for human inspection only. No action eval. + +## ForeAct zero-shot reconstruction on DROID (2026-04-18) + +Faithful inference-only reconstruction of ForeAct's π_v (planner) + π_g (foresight generator) on our DROID cache. No fine-tuning — see the "ForeAct release audit" section above for why that scopes what's reachable. Code lives at `experiments/subtask_probe/droid_eval/foreact_eval/`. + +### Setup + +- Cache: `.experiments_cache/droid_eval_ah15/` (5 episodes, 475 frames @ ~1 Hz stride-15). +- Planner: paper's exact Table 5 two-turn prompts, `Qwen/Qwen3-VL-8B-Instruct` served by vLLM on US West 2 L40S (bf16, no quantization). This is the model the paper uses for both its VLM+π_0 baseline and ForeAct "Ours" in §4.3. +- Generator: `mit-han-lab/foreact-pretrained` checkpoint (5B params, 10.2GB) loaded in bf16 via the paper's `VisualForesightPipeline`. Paper's inference hparams: `guidance_scale=4.5`, `image_guidance_scale=1.5`, `num_inference_steps=8`. Runs in the foreact conda env on the same L40S. + +### Planner behavior — decomposition is the bottleneck + +Per-episode subtask counts from `subtasks_foreact_qwen8b.json`: + +| Episode | Instruction complexity | Unique subtasks | Transitions | Mean latency | +|---|---|---|---|---| +| ep_0000 | 1-step ("put cube in dish") | 6 | 28 | 0.73s | +| ep_0001 | 6-step (sink / cup / pan) | **1** | **0** | 0.97s | +| ep_0002 | 10-step (sort plushies) | **1** | **0** | 0.81s | +| ep_0003 | 2-step (cloths + markers) | 3 | 40 | 0.85s | +| ep_0004 | 2-step (bottle + blind) | 3 | 25 | 1.01s | + +For ep_0001 and ep_0002, the 8B planner emitted a single subtask for every single frame — the first sub-step of the instruction (ep_0001: "Take the straw out of the sink...", ep_0002: "Place the turnip plushie on the table"). It never advanced. For ep_0003 and ep_0004 it oscillated between echoing the full instruction and the first half. On the simplest task (ep_0000) it produced six near-synonyms of "pick up the cube" / "put the cube in the dish". + +The 8B model is clearly on the weaker end of the paper's VLM scaling analysis. Figure 13 reports Qwen3-VL-8B at 84.4% planning accuracy (σ=19.8) vs. Qwen2.5-VL-32B at 73.9% and Qwen2.5-VL-7B at 40.8%. Our DROID episodes pattern-match the long-horizon instructions where 7B failed and 32B was just acceptable. To get meaningful decomposition we'd likely need 32B+ — the paper's success numbers are on their own shorter-horizon Galaxea benchmark, not DROID-style free-form instructions. + +Secondary observation: the planner reported `previous_finished=True` on **0 of 475 frames**. We added this as an extra schema field for observability; the 8B model always returns `False` regardless of actual progress. Paper's prompt doesn't ask for this field, so the model has no training signal for it — fine, but the observability is lost. + +### Generator behavior — quality varies per episode, pretraining distribution-match matters + +The pretrained foresight generator produced 475 DROID-size PNGs at 0.48s mean latency on the L40S (claimed 0.33s on H100 — close enough for the architecture). + +No wholesale viewpoint hallucination — outputs don't snap to AgiBot/Galaxea scenes. But contrary to my first-pass "it's just an autoencoder" claim, the generator does produce *varying* predictions across frames. Quality and faithfulness to the subtask depend heavily on how close the DROID scene sits to the pretraining distribution: + +| Episode | Scene | Foresight quality | Notes | +|---|---|---|---| +| **ep_0001** | kitchen sink, overhead, dishes | plausible motion, **but hallucinates a phantom second arm** | Arm visibly reaches into the sink in a pose the current frame hasn't reached yet. But every foresight frame also shows a large dark blob on the right that's clearly a second end-effector — DROID is single-arm Franka. See explanation below. | +| ep_0003 | top-down lab, cloths + markers | moderate | Arm position shifts in roughly the right direction, some detail blur. | +| ep_0002 | dim side view, plushie sort | mixed | Arm motion visible but scene is dark and details smear; box sometimes drops out of the foresight. | +| ep_0000 | inverted lab shelf, cube + dish | **worst — heavy distortion** | Black blobs obscure the right half of many frames; occasional hallucinated extra cubes on the shelf. The ceiling-mounted / inverted camera orientation appears to be out-of-distribution. | +| ep_0004 | side view kitchen, bottle | near-identity | Foresight is nearly identical to the current frame; little motion predicted. | + +The pattern: the generator works best on scenes that visually resemble its pretraining data (kitchen/sink = ep_0001, similar to AgiBot-Beta). It degrades into artifacts on unusual camera orientations (ep_0000's inverted view) and collapses toward the identity map when the scene is too far from anything it was trained on (ep_0004). Subtask text has *some* effect on the motion trajectory but is easily dominated by the image-conditioning pathway in the weaker-distributional-match cases. + +**Embodiment bias — phantom second arm on ep_0001.** Every foresight frame in ep_0001 shows a large dark blob on the right side of the frame that's unmistakably a second bimanual end-effector, despite DROID being a single-arm Franka setup. This is a direct artifact of the pretraining composition (§3.2, Figure 4): AgiBot-World-Colosseo (947k subtasks, ~80% of pretraining volume) is bimanual (AgiBot-Beta), Galaxea Open-World is bimanual (Galaxea R1), and RoboMind includes bimanual ALOHA. Only Bridge (WidowX, single-arm) counters that prior, and it's a small slice. When the DROID scene pattern-matches an AgiBot-like sink/manipulation view, the generator's learned prior inserts the second arm that "should" be there in its training distribution. This is a clean illustration of why the paper always fine-tunes on target-embodiment data before reporting any number — pretraining alone bakes in the dominant embodiment prior. + +This mirrors the paper's §4.2 pretraining ablation but from the opposite side — they reported Fidelity=0.00 / Quality=0.00 OOD *without* pretraining; we see pretraining gets us to "mixed, episode-dependent" on DROID, but without the 5-epoch target-data fine-tune (`configs/finetune.yaml:24`) we're still far from consistent subtask-following. + +### Verdict: depth-C is a prerequisite for any real signal + +Both pieces ran without incident, but neither produces DROID signal that would improve pi0.5 even if we could route it into the action server: + +- **Planner**: the 8B choice from the paper is too small for DROID's long-horizon instructions. To reproduce the paper's planner quality on DROID we'd need at minimum Qwen3-VL-32B+. +- **Generator**: zero-shot quality is episode-dependent — decent on kitchen-sink-like scenes (ep_0001 looks genuinely promising), degenerate on inverted / unusual viewpoints (ep_0000, ep_0004). The paper's 5-epoch target fine-tune (`configs/finetune.yaml:24`) is load-bearing, not optional, for consistent subtask-following. + +So the depth-C path to a real replication has two training prerequisites before touching pi0.5 at all: +1. Fine-tune the foresight generator on DROID (or a DROID-flavored subtask-segmented dataset we'd need to construct). +2. Fine-tune pi0.5 via the paper's `run_sub_task_100k.sh` recipe to consume the third image slot. + +Each of those is a multi-day commitment. Punt unless the finding above is insufficient to kill this direction. + +### In-distribution sanity check on ForeActDataset / Galaxea R1 Lite (2026-04-19) + +To separate "the generator is broken" from "DROID is OOD for the generator", I ran the same pretrained checkpoint on 2 episodes from `mit-han-lab/ForeActDataset` (the paper's own Galaxea R1 Lite recordings, same robot family as one of the pretraining datasets). Same checkpoint, same inference hparams, same 1 Hz stride — only the input distribution changed. Adapter driver: `foreact_eval/generate_foresight_lerobot.py`. + +**Result: the generator clearly works here.** Across 12 foresight frames from episodes 0 ("Pick up the eggplant and place it into the plate") and 3 ("Pick up the corn and place it into the plate"): + +- **Subtask-conditioned motion**: Episode 0's foresight shows the eggplant being moved — no corn manipulation. Episode 3's foresight shows the yellow corn being picked up — no eggplant manipulation. The text conditioning actually steers the output, which was the missing signal on DROID. +- **Clean scene reconstruction**: no black-blob artifacts, no hallucinated extra vegetables, no viewpoint shift. All 5 veg items (leek, carrot, cucumber, corn, eggplant) + plate stay in their correct positions in frames where they shouldn't move. +- **Correct single-arm behavior**: the arm enters from the right at the correct angle for the Galaxea R1 Lite mounting. No phantom second arm like we saw on DROID ep_0001 — the embodiment prior matches the target scene, so it doesn't need to "pattern-complete" a missing arm. +- **Task progression**: across a single episode's strided frames, the arm visibly moves closer to the target object and the target object visibly shifts toward the plate. This is the subtask-end-state prediction the paper describes. + +Latency was 0.47-0.68s/frame on L40S (bf16), consistent with the DROID run. + +**Takeaway**: the pretrained checkpoint is doing its job. Everything we saw on DROID — near-identity autoencoding on some scenes, distortion/artifacts on inverted views, phantom bimanual arms on kitchen-sink scenes — is a distribution-mismatch problem, not a generator problem. With target-robot fine-tune (the paper's 5-epoch recipe), or even just running on robots in the pretraining pool, the generator produces genuinely useful foresight. Depth-C on DROID (fine-tune on DROID-flavored subtask data) would very likely recover this quality level. + +### Artifacts + +- `experiments/subtask_probe/droid_eval/foreact_eval/{planner,generate_subtasks,generate_foresight,generate_foresight_lerobot,visualize_foreact}.py` +- DROID zero-shot (OOD): + - `.experiments_cache/droid_eval_ah15/subtasks_foreact_qwen8b.json` (475 records) + - `.experiments_cache/droid_eval_ah15/foresight_foreact/{ep_0000..ep_0004}/frame_*.png` (475 PNGs @ 640×480) + - `.experiments_cache/droid_eval_ah15/foreact_html/*.html` (5 per-episode reports, exterior | wrist | foresight | subtask) + - `.experiments_cache/droid_eval_ah15/foreact_videos/*.mp4` (5 per-episode mp4s at 2 fps) +- ForeActDataset / Galaxea R1 Lite in-distribution (moved to its own cache dir since it's not a DROID run): + - `.experiments_cache/foreact_eval/foresight_picksveg/` (first 2-episode stride=15 test, with `picksveg_report.html`) + - `.experiments_cache/foreact_eval/foresight_picksveg_dense/` (5 episodes at stride=5, with per-episode mp4s) + - `.experiments_cache/foreact_eval/foresight_chain_eggplant/` (full 3-sub-episode chain ep[0,1,2], unified subtask text, seed=42 — apple-morphing artifact in placement phase) + - `.experiments_cache/foreact_eval/foresight_chain_eggplant_v2/` (**golden** — ep0+ep1 use `Pick up the eggplant`, ep2 uses `Place the eggplant into the plate`, seed=123; see `chain_eggplant_v2.mp4` / `chain_eggplant_v2.html`) + +Ablations tried for the golden chain (all drifted worse than v2, deleted): +- ep1+ep2 both `Place the eggplant into the plate` — drifted the eggplant identity starting mid-ep1 +- mid-ep1 swap at frame 55 — same drift starting at the swap point + +Takeaway on labeling: the generator was trained to predict *half-a-subtask ahead*, so the right text for a given frame describes the state ~2.5 s ahead, not the *current* motion phase. For ep1 mid-frames that future state is still "arm holding eggplant over plate" → a pick-phase description. Switching to `Place ...` before the arm is actually descending onto the plate creates a text-vs-image conflict and drifts the eggplant identity. + +## Open Questions + +1. What is the exact prompt format PI used internally for subtask training? +2. Would real robot images produce longer, more specific subtask text? +3. Is the short output (3-4 tokens) an inherent limitation of the base checkpoint, or a prompt format / missing-images issue? +4. Does the ~100 gradient step fine-tuning need real robot data, or would LLM-generated subtask decompositions work? +5. How much does subtask conditioning improve action quality for long-horizon tasks (the paper reports it's significant)? +6. For the hybrid prompt approach (no retraining), does injecting subtask text into the action format produce *better* actions with real images, or just *different* ones? +7. Does ASCII vocabulary masking eliminate Unicode garbage and preserve English subtask quality? (pending re-run) diff --git a/experiments/subtask_probe/__init__.py b/experiments/subtask_probe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/subtask_probe/decode_jax.py b/experiments/subtask_probe/decode_jax.py new file mode 100644 index 0000000..07e6990 --- /dev/null +++ b/experiments/subtask_probe/decode_jax.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +"""JAX-based subtask generation probe for pi0.5. + +Uses the proven JAX code path (adapted from LisavilaLee/openpi_with_subtask +and BrunoFANG1's implementation referenced in Physical-Intelligence/openpi#701). + +The approach: + 1. Load the pi0.5 JAX model from the Orbax checkpoint + 2. Build a proper observation with images + prompt + 3. embed_prefix() -> forward through PaliGemma -> KV cache + 4. decode_to_logits (Embedder.decode = dot(h, embed_table.T)) + 5. Autoregressive token generation from last prefix position +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +OPENPI_SRC = Path(__file__).resolve().parents[2] / "src" +sys.path.insert(0, str(OPENPI_SRC)) + +from openpi.models import model as _model # noqa: E402 +from openpi.models.pi0 import Pi0, make_attn_mask # noqa: E402 +from openpi.models.pi0_config import Pi0Config # noqa: E402 +from openpi.models.tokenizer import PaligemmaTokenizer # noqa: E402 + + +def _add_subtask_tokenizer_methods(tok: PaligemmaTokenizer) -> None: + """Monkey-patch the subtask tokenization methods onto the stock tokenizer.""" + import string + + def tokenize_high_level_prefix( + self: PaligemmaTokenizer, high_prompt: str + ) -> tuple[np.ndarray, np.ndarray]: + cleaned = high_prompt.lower().strip().replace("_", " ").replace("\n", " ") + if cleaned and cleaned[-1] in string.punctuation: + cleaned = cleaned[:-1] + prefix_str = f"Task: {cleaned}. Subtask: " + tokens = self._tokenizer.encode(prefix_str, add_bos=True) # ty: ignore[unresolved-attribute] + tokens_len = len(tokens) + if tokens_len < self._max_len: + pad_len = self._max_len - tokens_len + mask = [True] * tokens_len + [False] * pad_len + tokens = tokens + [0] * pad_len + else: + tokens = tokens[: self._max_len] + mask = [True] * self._max_len + return np.asarray(tokens, dtype=np.int32), np.asarray(mask, dtype=np.bool_) + + def detokenize(self: PaligemmaTokenizer, tokens: np.ndarray) -> str: + valid = [int(t) for t in tokens if t != 0 and t != 1] + return self._tokenizer.decode(valid) # ty: ignore[unresolved-attribute] + + # Bind methods + import types + + tok.tokenize_high_level_prefix = types.MethodType(tokenize_high_level_prefix, tok) # ty: ignore[unresolved-attribute] + tok.detokenize = types.MethodType(detokenize, tok) # ty: ignore[unresolved-attribute] + + +# --------------------------------------------------------------------------- +# Model loading +# --------------------------------------------------------------------------- + + +def load_pi05_jax(checkpoint_dir: str) -> Pi0: + """Load the pi0.5 JAX model from an Orbax checkpoint.""" + config = Pi0Config(pi05=True) + rng = jax.random.key(0) + model = config.create(rng) + + print(f"[load] Restoring params from {checkpoint_dir}/params ...") + params = _model.restore_params(f"{checkpoint_dir}/params", dtype=jnp.bfloat16) + import flax.nnx as nnx + + nnx.update(model, nnx.State(params)) + model.eval() + print("[load] Model loaded and in eval mode.") + return model + + +# --------------------------------------------------------------------------- +# Observation construction +# --------------------------------------------------------------------------- + + +def make_observation( + prompt: str, + tokenizer: PaligemmaTokenizer, + action_dim: int = 32, + use_random_images: bool = True, +) -> tuple[_model.Observation, int]: + """Build an observation for the probe. + + Uses the SUBTASK prompt format: "Task: {task}. Subtask: " (no state, no Action:). + This matches what the model expects for subtask generation mode. + + Returns (observation, num_real_tokens). + """ + # Tokenize with subtask prefix format (NOT the action format) + # "Task: {task}. Subtask: " -- no state, no Action: + tokens, mask = tokenizer.tokenize_high_level_prefix(prompt) # ty: ignore[unresolved-attribute] + num_real_tokens = int(mask.sum()) + state = np.zeros(action_dim, dtype=np.float32) + + # Images: random noise or zeros + def make_img() -> np.ndarray: + if use_random_images: + return np.random.default_rng(42).random((224, 224, 3)).astype(np.float32) * 2 - 1 + return np.zeros((224, 224, 3), dtype=np.float32) + + # Build observation (batch dim added, convert to JAX arrays) + obs = _model.Observation( + images={ + "base_0_rgb": jnp.array(make_img()[None]), + "left_wrist_0_rgb": jnp.array(make_img()[None]), + "right_wrist_0_rgb": jnp.array(make_img()[None]), + }, + image_masks={ + "base_0_rgb": jnp.array([True]), + "left_wrist_0_rgb": jnp.array([True]), + "right_wrist_0_rgb": jnp.array([True]), + }, + state=jnp.array(state[None]), + tokenized_prompt=jnp.array(tokens[None]), + tokenized_prompt_mask=jnp.array(mask[None]), + ) + + return obs, num_real_tokens + + +# --------------------------------------------------------------------------- +# Subtask generation (JAX, adapted from LisavilaLee's generate_subtask) +# --------------------------------------------------------------------------- + + +def generate_subtask_text( + model: Pi0, + observation: _model.Observation, + max_decoding_steps: int = 50, + temperature: float = 0.0, +) -> dict: + """Generate subtask text autoregressively from the pi0.5 model. + + This follows the proven JAX approach: + 1. embed_prefix -> PaliGemma forward -> KV cache + 2. Embedder.decode (dot with embedding table transpose) -> logits + 3. Greedy/sampled token generation in a loop + """ + observation = _model.preprocess_observation(None, observation, train=False) + + # Step 1: Embed prefix (images + language tokens) + prefix_tokens, prefix_mask, prefix_ar_mask = model.embed_prefix(observation) + B, prefix_S, _ = prefix_tokens.shape + + # Step 2: Forward through PaliGemma to get KV cache + prefix output + # No mask padding -- KV cache size matches prefix length. + prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) + positions = jnp.cumsum(prefix_mask, axis=1) - 1 # ty: ignore[invalid-argument-type] + + (prefix_out, _), kv_cache = model.PaliGemma.llm( + [prefix_tokens, None], + mask=prefix_attn_mask, + positions=positions, + adarms_cond=[None, None], + ) + + # Step 3: Find the last VALID token position (LisavilaLee's fix) + seq_indices = jnp.arange(prefix_S)[None, :] # [1, S] + last_pos = jnp.max(jnp.where(prefix_mask, seq_indices, -1), axis=1).astype(jnp.int32) # ty: ignore[no-matching-overload] + + last_hidden = prefix_out[jnp.arange(B), last_pos, :] # [B, D] + + # Step 4: Project to vocab logits via Embedder.decode + # Embedder.decode = dot(x, embedding_table.T) + embed_table = model.PaliGemma.llm.embedder["input_embedding"].value # ty: ignore[unresolved-attribute] + logits = jnp.dot(last_hidden, embed_table.T) # [B, vocab_size] + + # Collect initial top predictions + probs = jax.nn.softmax(logits, axis=-1) + top_k_probs, top_k_indices = jax.lax.top_k(probs[0], 10) + initial_predictions = list(zip(top_k_indices.tolist(), top_k_probs.tolist(), strict=True)) + + # Step 5: Autoregressive generation loop (eager, matches LisavilaLee's generate_subtask) + EOS_TOKEN = 1 + num_real = int(jnp.sum(prefix_mask, axis=-1)[0]) # ty: ignore[invalid-argument-type] + next_pos = jnp.array([num_real], dtype=jnp.int32) # [B] + generated_token_ids = [] + + current_logits = logits[ + None, :, : + ] # reshape to [B, 1, V] to match their pattern... actually [B, V] + current_cache = kv_cache + + for step_idx in range(max_decoding_steps): + # Greedy decode + if temperature > 0: + token_id = int( + jax.random.categorical(jax.random.key(step_idx), current_logits[0] / temperature) + ) + else: + token_id = int(jnp.argmax(current_logits[0])) + + generated_token_ids.append(token_id) + if token_id == EOS_TOKEN: + break + + # Embed the token + token_jax = jnp.array([[token_id]], dtype=jnp.int32) + token_embedding = model.PaliGemma.llm(token_jax, method="embed") # [B, 1, D] + + # Attention mask: [B, 1, prefix_S + gen_count] + # New token attends to all prefix tokens + all previously generated tokens + itself. + gen_count = step_idx + 1 + gen_mask = jnp.ones((B, gen_count), dtype=jnp.bool_) + full_mask = jnp.concatenate([prefix_mask, gen_mask], axis=1) # ty: ignore[invalid-argument-type] + attn_mask = full_mask[:, None, :] # [B, 1, prefix_S + gen_count] + + new_positions = next_pos[:, None] # [B, 1] + + # Forward with KV cache + (new_out, _), current_cache = model.PaliGemma.llm( + [token_embedding, None], + mask=attn_mask, + positions=new_positions, + kv_cache=current_cache, + ) + + # Project to logits via embedding table transpose + new_hidden = new_out[:, -1, :] # [B, D] + current_logits = jnp.dot(new_hidden, embed_table.T) # [B, vocab_size] + next_pos = next_pos + 1 + + return { + "output_tokens": generated_token_ids, + "num_steps": len(generated_token_ids), + "initial_predictions": initial_predictions, + "last_valid_position": int(last_pos[0]), + "prefix_length": prefix_S, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser(description="JAX subtask generation probe for pi0.5") + parser.add_argument( + "--checkpoint_dir", + type=str, + default=str(Path.home() / ".cache/openpi/openpi-assets/checkpoints/pi05_base"), + ) + parser.add_argument("--random_images", action="store_true", default=True) + parser.add_argument("--zero_images", action="store_true") + args = parser.parse_args() + + use_random = not args.zero_images + + tokenizer = PaligemmaTokenizer(max_len=200) + _add_subtask_tokenizer_methods(tokenizer) + print("Tokenizer loaded (with subtask methods)") + + model = load_pi05_jax(args.checkpoint_dir) + + test_prompts = [ + "pick up the red cup and place it on the shelf", + "fold the towel neatly", + "open the drawer and put the block inside", + "stack the blue block on top of the red block", + "wipe the table with the sponge", + ] + + for prompt in test_prompts: + print(f"\n{'=' * 80}") + print(f' Prompt: "{prompt}"') + print(f"{'=' * 80}") + + obs, num_real = make_observation(prompt, tokenizer, use_random_images=use_random) + print(f" Real tokens: {num_real}/200, images: {'random' if use_random else 'zeros'}") + + result = generate_subtask_text(model, obs, max_decoding_steps=50, temperature=0.0) + + # Decode tokens + token_ids = result["output_tokens"] + # Remove EOS if present + if 1 in token_ids: + token_ids = token_ids[: token_ids.index(1)] + + generated_text = tokenizer._tokenizer.decode(token_ids) # ty: ignore[unresolved-attribute] + + print( + f" Prefix length: {result['prefix_length']}, last valid pos: {result['last_valid_position']}" + ) + print(f" Steps generated: {result['num_steps']}") + + # Initial predictions + print(" Top-10 next-token predictions from last prefix position:") + for token_id, prob in result["initial_predictions"]: + token_str = tokenizer._tokenizer.decode([token_id]) # ty: ignore[unresolved-attribute] + print(f' [{token_id:6d}] "{token_str}" (prob={prob:.4f})') + + print(f' Generated text: "{generated_text}"') + print(f" Raw tokens: {token_ids[:30]}") + + print(f"\n{'=' * 80}") + print(" DONE") + print(f"{'=' * 80}") + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/README.md b/experiments/subtask_probe/droid_eval/README.md new file mode 100644 index 0000000..4e24752 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/README.md @@ -0,0 +1,154 @@ +# DROID Evaluation: Planner+Action vs Action-Only + +## Goal + +Validate whether injecting JAX-generated subtask text into the pi0.5 action prompt produces **better** actions (closer to ground truth) compared to action-only inference. Previous experiments with zero images proved the mechanism works (subtask text changes actions) but couldn't assess quality. + +## Setup + +**Models:** +- **Subtask generator (JAX):** `pi05_base` checkpoint — generates subtask text via AR decoding +- **Action generator (PyTorch):** `swatery/pi05_droid_base` on HuggingFace — DROID-finetuned pi0.5 + +**Data:** DROID v1.0.1 episodes streamed from GCS (`gs://gresearch/robotics/droid/1.0.1`). Each episode provides real robot images, proprioceptive state, language instructions, and ground truth actions. + +## Pipeline + +``` +Phase 0: extract_droid_samples.py + Stream DROID episodes from GCS → cache frames + ground truth actions to .npz + Two modes: first-K (default) or top-K longest (--scan_episodes N) + +Phase 1: generate_subtasks.py (pi0.5 server) or generate_subtasks_gemini.py (Gemini) + Emit subtask text for each cached frame → JSON + +Phase 2: run_action_eval.py + Load PyTorch DROID checkpoint → run 3 prompt conditions per frame: + A. Baseline: "Task: X, State: S;\nAction: " + B. Hybrid A: "Task: X. Subtask: Y, State: S;\nAction: " + C. Hybrid B: "Task: X (Y), State: S;\nAction: " + +Phase 3: compute_metrics.py + Compare predicted actions to ground truth → L2 distance, cosine sim, per-dim MAE +``` + +## Prerequisites + +```bash +# On the Seoul L40S instance: + +# 1. Install RLDS dependencies (in the openpi project root) +cd ~/openpi +uv sync --group rlds + +# 2. Download DROID norm stats +gsutil cp -r gs://openpi-assets/checkpoints/pi05_droid/assets/droid/ \ + ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid/assets/droid/ + +# 3. Download the PyTorch DROID checkpoint +hf download swatery/pi05_droid_base + +# 4. Ensure JAX base checkpoint is available +uv run python -c "from openpi.shared import download; download.maybe_download('gs://openpi-assets/checkpoints/pi05_base')" +``` + +## Running + +```bash +cd ~/openpi/hosting + +# Phase 0 (quick): Extract the first 10 successful episodes (short demos, ~30–60 min). +uv run python -m experiments.subtask_probe.droid_eval.extract_droid_samples \ + --num_episodes 10 \ + --output_dir ./.experiments_cache/droid_eval + +# Phase 0 (long-horizon eval): Scan 5k DROID episodes and keep the 5 longest +# multi-step ones (>=60s, keyword-matched). This is the setup used to measure +# whether subtask conditioning helps on genuinely long tasks. Streams from GCS, +# ~10 min on a cloud box. +uv run python -m experiments.subtask_probe.droid_eval.extract_droid_samples \ + --num_episodes 5 --scan_episodes 5000 \ + --min_duration_s 60 --require_multi_step \ + --output_dir ./.experiments_cache/droid_eval_2min + +# Phase 1: Generate subtasks (JAX subtask generator, ~1s per frame warm) +uv run python experiments/subtask_probe/droid_eval/generate_subtasks.py \ + --samples_dir ./.experiments_cache/droid_eval \ + --output ./.experiments_cache/droid_eval/subtasks.json + +# Phase 1 (alt): Generate subtasks via Gemini Robotics-ER instead of pi0.5. +# Requires GEMINI_API_KEY in the environment (or .env). Emits the same JSON +# schema so Phase 2 / 3 consume it unchanged, and pairs cleanly with +# compare_subtask_outputs.py for pi0.5-vs-Gemini diffs. +uv run python experiments/subtask_probe/droid_eval/generate_subtasks_gemini.py \ + --samples_dir ./.experiments_cache/droid_eval \ + --output ./.experiments_cache/droid_eval/subtasks_gemini.json + +# Phase 1 (alt 2): Comet-style hierarchical subtask generation. +# +# Runs a stateful plan -> critique -> subtask loop per episode, ported from +# openpi-comet/src/openpi/shared/client.py. Each cached frame issues 2 VLM +# calls (critique + subtask), so expect ~2x the wall clock and API spend of +# the stateless Gemini run. Output JSON schema is identical — drop-in for +# Phase 2 / 3. + +# Backend A: Gemini Robotics-ER 1.6 Preview (requires GEMINI_API_KEY). +uv run python -m experiments.subtask_probe.droid_eval.comet_style.run \ + --samples_dir ./.experiments_cache/droid_eval \ + --output ./.experiments_cache/droid_eval/subtasks_comet_gemini.json \ + --backend gemini + +# Backend B: OpenAI-compatible VLM (e.g. vLLM hosting Qwen3-VL-30B). +# First, on a GPU host with >=48 GB VRAM, serve the model: +# uv pip install vllm +# vllm serve Qwen/Qwen3-VL-30B-A3B-Instruct \ +# --port 8000 --max-model-len 32768 --limit-mm-per-prompt image=64 +# Then from the local machine (tunnel the port if the server is remote): +uv run python -m experiments.subtask_probe.droid_eval.comet_style.run \ + --samples_dir ./.experiments_cache/droid_eval \ + --output ./.experiments_cache/droid_eval/subtasks_comet_qwen.json \ + --backend openai_compat \ + --base_url http://localhost:8000/v1 \ + --model Qwen/Qwen3-VL-30B-A3B-Instruct + +# Phase 2: Run action evaluation (~280ms per inference × 3 conditions) +uv run python experiments/subtask_probe/droid_eval/run_action_eval.py \ + --samples_dir ./.experiments_cache/droid_eval \ + --subtasks ./.experiments_cache/droid_eval/subtasks.json \ + --output_dir ./.experiments_cache/droid_eval/predictions + +# Phase 3: Compute metrics +uv run python experiments/subtask_probe/droid_eval/compute_metrics.py \ + --samples_dir ./.experiments_cache/droid_eval \ + --predictions_dir ./.experiments_cache/droid_eval/predictions \ + --output ./.experiments_cache/droid_eval/results.json +``` + +## Metrics + +| Metric | What it measures | +|--------|-----------------| +| L2 distance to ground truth | Primary quality signal — are predicted actions closer to what the robot actually did? | +| Per-dimension MAE | Which joints benefit most from subtask conditioning | +| Cosine similarity | Directional alignment independent of magnitude | +| Gripper accuracy | Binary open/closed correctness (threshold 0.5) | + +Results are aggregated by: +- Multi-step vs single-step tasks +- Episode progress (early / middle / late) +- Overall + +## Interpretation + +- **Hybrid < Baseline L2** on multi-step tasks → subtask conditioning helps +- Improvement **only on multi-step**, not single-step → genuine hierarchical decomposition +- **Hybrid A ≠ Hybrid B** → prompt format matters for how the model parses injected text +- No improvement → hybrid prompt injection doesn't work without fine-tuning + +## Key Implementation Details + +- **Normalization:** Raw DROID joint positions must be z-score normalized before tokenizing into the action prompt. Norm stats from `gs://openpi-assets/checkpoints/pi05_droid/assets/droid/norm_stats.json`. +- **Action horizon:** DROID checkpoint uses `action_horizon=15` (not the base checkpoint's 50). +- **Same noise seed:** All 3 conditions use identical initial noise per frame for fair comparison. +- **Image format:** JAX expects HWC float32 [-1,1]; PyTorch expects CHW float32 [-1,1]. +- **DroidOutputs:** Only first 8 dims of 32D model output are meaningful (7 joints + 1 gripper). diff --git a/experiments/subtask_probe/droid_eval/__init__.py b/experiments/subtask_probe/droid_eval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/subtask_probe/droid_eval/comet_style/__init__.py b/experiments/subtask_probe/droid_eval/comet_style/__init__.py new file mode 100644 index 0000000..aa4ab8f --- /dev/null +++ b/experiments/subtask_probe/droid_eval/comet_style/__init__.py @@ -0,0 +1,9 @@ +"""Comet-style hierarchical subtask generation for DROID frames. + +Ports the plan -> critique -> subtask loop from +openpi-comet/src/openpi/shared/client.py into a backend-agnostic scaffold with +two concrete reasoner backends: Gemini Robotics-ER and any OpenAI-compatible +VLM server (e.g. a local vLLM hosting Qwen3-VL). + +This is evaluation-only code; the live hosting stack is untouched. +""" diff --git a/experiments/subtask_probe/droid_eval/comet_style/_gemini_utils.py b/experiments/subtask_probe/droid_eval/comet_style/_gemini_utils.py new file mode 100644 index 0000000..765cf40 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/comet_style/_gemini_utils.py @@ -0,0 +1,118 @@ +"""Shared helpers for Gemini-backed subtask generators. + +Split out of generate_subtasks_gemini.py so both the stateless generator and +the Comet-style hierarchical reasoner use the same PNG encoding and 429 +retry-with-backoff logic. +""" + +from __future__ import annotations + +import io +import logging +import re +import time +from collections.abc import Callable +from typing import TypeVar + +import numpy as np +from PIL import Image + +logger = logging.getLogger(__name__) + +_RETRY_DELAY_PATTERN = re.compile(r"'retryDelay':\s*'(\d+(?:\.\d+)?)s'") + +T = TypeVar("T") + + +def encode_png(image: np.ndarray) -> bytes: + """PNG-encode an HxWx3 uint8 RGB image.""" + buffer = io.BytesIO() + Image.fromarray(np.asarray(image, dtype=np.uint8)).save(buffer, format="PNG") + return buffer.getvalue() + + +def parse_retry_delay_seconds(exc: Exception) -> float | None: + """Extract the server's suggested retry delay (seconds) from a 429 error string.""" + match = _RETRY_DELAY_PATTERN.search(str(exc)) + if match is None: + return None + return float(match.group(1)) + + +def is_rate_limit_error(exc: Exception) -> bool: + message = str(exc) + return "429" in message or "RESOURCE_EXHAUSTED" in message + + +# Message fragments that indicate a *transient* network problem worth retrying +# (vs. a deterministic API error like 400 or schema validation). Matched against +# str(exc) so we're robust to the specific exception classes google-genai uses +# across versions. +_TRANSIENT_NETWORK_FRAGMENTS = ( + "server disconnected", + "remote end closed connection", + "connection reset", + "connection aborted", + "read timed out", + "readtimeout", + "connecttimeout", + "503", # service unavailable + "502", # bad gateway + "504", # gateway timeout + "500", # internal server error +) + + +def is_transient_network_error(exc: Exception) -> bool: + message = str(exc).lower() + return any(fragment in message for fragment in _TRANSIENT_NETWORK_FRAGMENTS) + + +def call_with_retry( + call: Callable[[], T], + *, + max_retries: int, +) -> T: + """Call a Gemini API function with retry on 429s and transient network errors. + + * 429 / RESOURCE_EXHAUSTED -> sleep for the server-suggested ``retryDelay`` + (or a capped exponential fallback) and try again. + * Transient network failures (server disconnect, 5xx, read timeout) -> + exponential backoff and try again, because these previously caused a + 50+ minute hang when they surfaced at the wrong point in the Gemini + client's internal retry logic. + * Other exceptions propagate immediately. + """ + last_exc: Exception | None = None + for attempt in range(max_retries + 1): + try: + return call() + except Exception as exc: + rate_limited = is_rate_limit_error(exc) + transient = is_transient_network_error(exc) + if (not rate_limited and not transient) or attempt == max_retries: + raise + last_exc = exc + if rate_limited: + # Server sends retryDelay like '3s'; add jitter so a stampede + # of workers doesn't all wake up at the same instant and trip + # the quota again. + base_delay = parse_retry_delay_seconds(exc) or min(2**attempt, 30.0) + sleep_s = base_delay + (0.2 * attempt) + logger.info( + "Rate-limited (attempt %d/%d); sleeping %.1fs before retry", + attempt + 1, + max_retries, + sleep_s, + ) + else: + sleep_s = min(2**attempt, 10.0) + (0.2 * attempt) + logger.warning( + "Transient network error (attempt %d/%d): %s; sleeping %.1fs", + attempt + 1, + max_retries, + exc, + sleep_s, + ) + time.sleep(sleep_s) + raise RuntimeError("retry loop exited without returning") from last_exc diff --git a/experiments/subtask_probe/droid_eval/comet_style/gemini_reasoner.py b/experiments/subtask_probe/droid_eval/comet_style/gemini_reasoner.py new file mode 100644 index 0000000..1f60c3e --- /dev/null +++ b/experiments/subtask_probe/droid_eval/comet_style/gemini_reasoner.py @@ -0,0 +1,85 @@ +"""Gemini Robotics-ER backed reasoner for Comet-style hierarchical subtask generation.""" + +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +from google import genai +from google.genai import types + +from ._gemini_utils import call_with_retry, encode_png +from .reasoner_base import BaseReasoner + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = "gemini-robotics-er-1.6-preview" + + +class GeminiReasoner(BaseReasoner): + """Plan/critique/subtask loop backed by Gemini Robotics-ER 1.6 Preview. + + Accepts any Gemini model via ``model=`` — e.g. ``gemini-3.1-pro-preview`` + if you want a general-purpose reasoning VLM instead of the robotics-tuned + default. + """ + + def __init__( + self, + model: str = DEFAULT_MODEL, + thinking_budget: int = 0, + max_retries: int = 10, + request_timeout_s: float = 120.0, + history_maxlen: int = 640, + sampled_images_max: int = 64, + history_stride: int = 5, + ) -> None: + super().__init__( + history_maxlen=history_maxlen, + sampled_images_max=sampled_images_max, + history_stride=history_stride, + ) + # timeout is in milliseconds per google-genai's HttpOptions. A 2-minute + # cap is well above normal ~5-10s latency but short enough that a hung + # socket surfaces fast instead of stalling the whole run. + self._client = genai.Client( + http_options=types.HttpOptions(timeout=int(request_timeout_s * 1000)) + ) + self._model = model + self._thinking_budget = thinking_budget + self._max_retries = max_retries + + def _chat( + self, + user_prompt: str, + images: list[np.ndarray], + response_schema: dict[str, Any] | None = None, + ) -> str: + contents: list = [] + for image in images: + contents.append(types.Part.from_bytes(data=encode_png(image), mime_type="image/png")) + contents.append(user_prompt) + + config_kwargs: dict[str, Any] = { + "temperature": 1.0, + "thinking_config": types.ThinkingConfig(thinking_budget=self._thinking_budget), + } + if response_schema is not None: + # The google-genai SDK accepts a dict JSON schema directly and + # converts it to a types.Schema internally. Pairing with + # response_mime_type="application/json" enforces the structure. + config_kwargs["response_mime_type"] = "application/json" + config_kwargs["response_schema"] = response_schema + + config = types.GenerateContentConfig(**config_kwargs) + + def _call() -> str: + response = self._client.models.generate_content( + model=self._model, + contents=contents, + config=config, + ) + return (response.text or "").strip() + + return call_with_retry(_call, max_retries=self._max_retries) diff --git a/experiments/subtask_probe/droid_eval/comet_style/openai_compat_reasoner.py b/experiments/subtask_probe/droid_eval/comet_style/openai_compat_reasoner.py new file mode 100644 index 0000000..8192a27 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/comet_style/openai_compat_reasoner.py @@ -0,0 +1,95 @@ +"""OpenAI-compatible VLM backend for Comet-style hierarchical subtask generation. + +Targets a self-hosted vLLM server (or any server that speaks the OpenAI +chat-completions protocol with image inputs). Default model name matches the +only VLM named in the openpi-comet repo, ``Qwen3-VL-30B-A3B-Instruct`` +(``openpi-comet/src/openpi/shared/client.py:169``), but any multimodal chat +model the server hosts will work. +""" + +from __future__ import annotations + +import base64 +import logging +from typing import Any, cast + +import numpy as np +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from ._gemini_utils import encode_png +from .reasoner_base import BaseReasoner + +logger = logging.getLogger(__name__) + +DEFAULT_BASE_URL = "http://localhost:8000/v1" +DEFAULT_MODEL = "Qwen/Qwen3-VL-30B-A3B-Instruct" + + +def _encode_data_url(image: np.ndarray) -> str: + """PNG-encode an image as a ``data:`` URL suitable for OpenAI image inputs.""" + png_bytes = encode_png(image) + return f"data:image/png;base64,{base64.b64encode(png_bytes).decode('utf-8')}" + + +class OpenAICompatReasoner(BaseReasoner): + def __init__( + self, + base_url: str = DEFAULT_BASE_URL, + model: str = DEFAULT_MODEL, + api_key: str = "none", + temperature: float = 1.0, + timeout_s: float = 600.0, + history_maxlen: int = 640, + sampled_images_max: int = 64, + history_stride: int = 5, + ) -> None: + super().__init__( + history_maxlen=history_maxlen, + sampled_images_max=sampled_images_max, + history_stride=history_stride, + ) + self._client = OpenAI(base_url=base_url, api_key=api_key, timeout=timeout_s) + self._model = model + self._temperature = temperature + + def _chat( + self, + user_prompt: str, + images: list[np.ndarray], + response_schema: dict[str, Any] | None = None, + ) -> str: + content: list[dict[str, Any]] = [{"type": "text", "text": user_prompt}] + for image in images: + content.append( + { + "type": "image_url", + "image_url": {"url": _encode_data_url(image)}, + } + ) + # The chat-completions SDK types are strict TypedDicts that don't + # accept our general-purpose content list; cast is safer than + # maintaining parallel TypedDict literals for every image. + messages = cast(list[ChatCompletionMessageParam], [{"role": "user", "content": content}]) + + create_kwargs: dict[str, Any] = { + "model": self._model, + "messages": messages, + "temperature": self._temperature, + } + if response_schema is not None: + # vLLM exposes OpenAI-spec structured outputs backed by xgrammar. + # The server fills in decoding constraints from the JSON schema so + # the response is guaranteed to parse. + create_kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "response", + "schema": response_schema, + "strict": True, + }, + } + + response = self._client.chat.completions.create(**create_kwargs) + message = response.choices[0].message.content or "" + return message.strip() diff --git a/experiments/subtask_probe/droid_eval/comet_style/reasoner_base.py b/experiments/subtask_probe/droid_eval/comet_style/reasoner_base.py new file mode 100644 index 0000000..90b4fe2 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/comet_style/reasoner_base.py @@ -0,0 +1,470 @@ +"""Backend-agnostic scaffold for Comet-style hierarchical subtask generation. + +Ports ``strip_think_tags``, ``generate_stylized_plan``, ``sample_images`` and +the three-step plan -> critique -> subtask control flow verbatim from +``openpi-comet/src/openpi/shared/client.py``. The prompt strings are copied +byte-for-byte from that file so we're testing Comet's prompting, not our +rewording. + +Concrete backends subclass ``BaseReasoner`` and implement a single ``_chat`` +hook that sends a VLM request with a user prompt plus a list of images and +returns the response text. +""" + +from __future__ import annotations + +import abc +import collections +import json +import logging +import re +from typing import Any, Literal + +import numpy as np + +logger = logging.getLogger(__name__) + + +# Comet defaults: deque(maxlen=64*10) for image history, sample at most 64 images per call +# (client.py:178, 72-79). +DEFAULT_HISTORY_MAXLEN = 64 * 10 +DEFAULT_SAMPLED_IMAGES_MAX = 64 + +# Status markers for plan steps, matching Comet's stylized plan output: +# [o] done, [-] in_progress, [x] not_started +PlanStepStatus = Literal["done", "in_progress", "not_started"] +_STATUS_VALUES: tuple[PlanStepStatus, ...] = ("done", "in_progress", "not_started") +_STATUS_MARKERS: dict[PlanStepStatus, str] = { + "done": "[o]", + "in_progress": "[-]", + "not_started": "[x]", +} + +# JSON schemas enforced via the backend's native structured-output mechanism +# (Gemini response_schema / vLLM xgrammar json_schema). Both backends translate +# these dicts to their native representation. +# +# The schemas are designed to produce short outputs compatible with the pi0.5 +# action prompt budget: plans are 2-10 short step strings, subtasks are a +# single short imperative phrase (~120 chars / ~20 words). +PLAN_SCHEMA: dict[str, Any] = { + "type": "array", + "items": {"type": "string"}, + "minItems": 2, + "maxItems": 10, +} + +SUBTASK_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "subtask": { + "type": "string", + "maxLength": 120, + }, + }, + "required": ["subtask"], +} + +# plan_critique returns a parallel list of statuses — one per step of the +# original plan, in the same order. Using an enum makes structural equality +# meaningful (no prose drift) so we only reset ``subtask_history`` when the +# plan state actually changes, not when wording varies. +CRITIQUE_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "statuses": { + "type": "array", + "items": {"type": "string", "enum": list(_STATUS_VALUES)}, + }, + }, + "required": ["statuses"], +} + + +def strip_think_tags(text: str) -> str: + """Remove ``...`` content or any leading ```` preamble. + + Mirrors ``openpi-comet/src/openpi/shared/client.py:82``. + """ + lower = text.lower() + if "" in lower and "" in lower: + cleaned_text = re.sub(r".*?", "", text, flags=re.DOTALL | re.IGNORECASE) + elif "" in lower: + cleaned_text = re.sub(r"^.*?", "", text, flags=re.DOTALL | re.IGNORECASE) + else: + cleaned_text = text + return re.sub(r"\n\s*\n", "\n", cleaned_text.strip()) + + +def parse_plan_list(text: str) -> list[str] | None: + """Parse a VLM response that should be a JSON list of plan-step strings. + + Returns the parsed list on success, or ``None`` when the response cannot + be coerced into a non-empty list of strings (bad JSON, wrong shape, + non-string items, empty list). Callers are expected to fall back to a + single-step plan in that case. + + Handles the two common messy outputs from reasoning VLMs: ```` + tags around the response and markdown code fences (```json ... ```). + Comet uses ``json_repair`` for further tolerance; we avoid the dep and + surface parse failures as ``None`` instead. + """ + cleaned = strip_think_tags(text).strip() + if cleaned.startswith("```"): + cleaned = re.sub(r"^```[a-zA-Z]*\n?", "", cleaned) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + try: + parsed = json.loads(cleaned) + except json.JSONDecodeError: + return None + if not isinstance(parsed, list) or not parsed: + return None + result: list[str] = [] + for item in parsed: + if not isinstance(item, str): + return None + result.append(item) + return result + + +def render_plan_status(plans: list[str], statuses: list[PlanStepStatus]) -> str: + """Render a plan + per-step status list into Comet's stylized string format. + + Example output:: + + [o] pick up the cube + [-] move to the dish + [x] release the cube + + Mirrors the shape of ``openpi-comet/src/openpi/shared/client.py:118`` but + is driven by discrete enum statuses rather than an index + completion + flag, which maps cleanly onto our structured-output critique responses. + """ + if len(plans) != len(statuses): + raise ValueError( + f"plans/statuses length mismatch: {len(plans)} steps vs {len(statuses)} statuses" + ) + return "\n".join( + f"{_STATUS_MARKERS[status]} {step}" for step, status in zip(plans, statuses, strict=True) + ) + + +def initial_statuses(num_steps: int) -> list[PlanStepStatus]: + """Starting state: first step in_progress, the rest not_started.""" + if num_steps <= 0: + raise ValueError("initial_statuses requires at least one step") + statuses: list[PlanStepStatus] = ["in_progress"] + statuses.extend(["not_started"] * (num_steps - 1)) + return statuses + + +def sample_images( + image_list: list[np.ndarray], + max_len: int = 64, + stride: int = 5, +) -> list[np.ndarray]: + """Sample every ``stride``-th image in reverse from the most recent, up to ``max_len``. + + Mirrors ``openpi-comet/src/openpi/shared/client.py:72``, which hardcoded + stride=5 because their history buffer was populated at the 30 Hz sim rate + and they wanted ~6 Hz of temporal density per VLM call. On our DROID + cache the buffer is populated at the *cached* rate (1 Hz with + ``--frame_subsample=15``), so stride=5 skips over almost everything. + Rule of thumb: ``stride ≈ cache_hz / desired_sample_hz`` (with floor 1). + For our 1 Hz cache, stride=1 means the VLM sees the last ``max_len`` + consecutive seconds of history. + """ + if stride <= 0: + raise ValueError(f"stride must be positive, got {stride}") + sampled: list[np.ndarray] = [] + for i in range(len(image_list) - 1, -1, -stride): + sampled.append(image_list[i]) + if len(sampled) >= max_len: + break + sampled.reverse() + return sampled + + +def all_steps_done(statuses: list[PlanStepStatus] | None) -> bool: + """Return True when every step is marked ``done``. + + DROID cached frames can continue well past the point where the robot has + finished the instructed task. Once the plan is fully complete, the VLM + has nothing useful to say and its degenerate outputs ("finished", "none") + pollute the action prompt. Callers short-circuit on this condition. + """ + return bool(statuses) and all(s == "done" for s in statuses) + + +class BaseReasoner(abc.ABC): + """Stateful hierarchical reasoner shared by all backends. + + State that carries across frames within an episode: + * ``history_multi_modals`` — ring buffer of past images + * ``subtask_history`` — ordered list of subtask strings produced so far + * ``plan_status`` — the current stylized plan string with status markers, + or ``None`` when no plan has been generated yet (pre-first-call or + post-reset) + + Call ``reset()`` at the start of each new episode. + + Design note — two-call vs merged ``plan_critique`` + ``generate_subtask``: + Every replan fires two VLM calls back-to-back (critique + subtask select). + These could be merged into one call with a schema like + ``{"statuses": [...], "subtask": "..."}``, which would halve cost + latency. + We keep them separate because Comet's original flow depends on the + sequential dependency — the subtask prompt is built from the *updated* + ``plan_status`` *and* a possibly-reset ``subtask_history`` (triggered when + the statuses change). Merging loses the reset gate and forces the model to + maintain cross-field consistency in one generation, which is fine for + strong reasoning models but risks inconsistent outputs on non-reasoning + models (e.g. Gemini Robotics-ER at ``thinking_budget=0``). Revisit this + if API cost becomes the bottleneck and the model can handle it. + """ + + def __init__( + self, + history_maxlen: int = DEFAULT_HISTORY_MAXLEN, + sampled_images_max: int = DEFAULT_SAMPLED_IMAGES_MAX, + history_stride: int = 5, + ) -> None: + self.history_multi_modals: collections.deque[np.ndarray] = collections.deque( + maxlen=history_maxlen + ) + self.sampled_images_max = sampled_images_max + # Stride used by _sampled_history when picking frames from the + # history deque. Default 5 matches Comet's original hardcoded value + # and empirically gives the best plan stability on our 1 Hz cache + # because the wider temporal span (~40 s of history with max=8) + # provides the temporal-contrast signal the reasoner needs to + # detect plan progression. stride=1 looks tempting ("consecutive + # frames!") but flips plan interpretation on every tiny frame + # change on slow tasks. + self.history_stride = history_stride + self.subtask_history: list[str] = [] + # Canonical plan state: the immutable step texts from the one-shot + # generate_plan call and a parallel list of per-step statuses that + # plan_critique updates. plan_status (the rendered string) is derived. + self.plans: list[str] | None = None + self.plan_statuses: list[PlanStepStatus] | None = None + + def reset(self) -> None: + self.history_multi_modals.clear() + self.subtask_history = [] + self.plans = None + self.plan_statuses = None + + @property + def plan_status(self) -> str | None: + """Rendered Comet-style plan string, or None before generate_plan runs.""" + if self.plans is None or self.plan_statuses is None: + return None + return render_plan_status(self.plans, self.plan_statuses) + + @abc.abstractmethod + def _chat( + self, + user_prompt: str, + images: list[np.ndarray], + response_schema: dict[str, Any] | None = None, + ) -> str: + """Send a VLM request and return the raw response string. + + When ``response_schema`` is provided the backend must enforce it via + its native structured-output mechanism (Gemini ``response_schema`` / + vLLM ``json_schema``); otherwise the backend should free-form generate. + + Implementations should NOT strip ```` tags — the base class does + that where appropriate. + """ + + def _sampled_history(self) -> list[np.ndarray]: + return sample_images( + list(self.history_multi_modals), + max_len=self.sampled_images_max, + stride=self.history_stride, + ) + + def generate_plan(self, task: str, initial_image: np.ndarray) -> str: + """One-shot plan generation from the initial observation of the episode. + + Prompt adapted from ``openpi-comet/src/openpi/shared/client.py:274``, + with an explicit JSON-array constraint enforced via ``PLAN_SCHEMA`` so + off-the-shelf VLMs emit the structured list Comet's scaffold expects. + + Stores the parsed plan as ``self.plans`` and initializes + ``self.plan_statuses`` with the first step in_progress, others + not_started. Returns the rendered plan_status string. + """ + user_prompt = ( + f"Given the task '{task}', break it down into several concrete high-level steps. " + "Respond with a JSON array of 2-10 short imperative step strings, nothing else." + ) + self.history_multi_modals.append(initial_image) + response = self._chat(user_prompt, [initial_image], response_schema=PLAN_SCHEMA) + plans = parse_plan_list(response) + if plans is None: + logger.warning( + "Plan response could not be parsed as a non-empty JSON list of strings; " + "falling back to single-step plan. Raw response: %r", + response, + ) + plans = [task] + self.plans = plans + self.plan_statuses = initial_statuses(len(plans)) + rendered = self.plan_status + assert rendered is not None + return rendered + + def plan_critique(self, task: str) -> list[PlanStepStatus]: + """Ask the VLM to update per-step statuses given the image history. + + Returns a parallel list of ``PlanStepStatus`` values, one per step of + ``self.plans`` in the same order. Enforced via ``CRITIQUE_SCHEMA`` so + the backend returns discrete enum values rather than prose — this is + load-bearing for the reset-on-change logic in ``generate_subtask``, + which previously thrashed on tiny wording variations in free-form + critique output. + """ + if self.plans is None or self.plan_statuses is None: + raise RuntimeError("plan_critique called before generate_plan — no plan to critique") + + last_subtask = self.subtask_history[-1] if self.subtask_history else "None" + numbered_plan = "\n".join(f" {i + 1}. {step}" for i, step in enumerate(self.plans)) + current_state = "\n".join( + f" {i + 1}. {status}" for i, status in enumerate(self.plan_statuses) + ) + user_prompt = ( + f"You are given the task of '{task}'. The plan has these steps in order:\n" + f"{numbered_plan}\n" + f"\nTheir current statuses are:\n" + f"{current_state}\n" + f"\nThe last high-level objective given to the robot was '{last_subtask}'. " + "Looking at the images, update each step's status to one of " + f"{list(_STATUS_VALUES)}. Respond with a JSON object " + '{"statuses": [...]} containing exactly ' + f"{len(self.plans)} status values in the same order as the steps." + ) + response = self._chat(user_prompt, self._sampled_history(), response_schema=CRITIQUE_SCHEMA) + new_statuses = _extract_statuses_field(response, expected_len=len(self.plans)) + if new_statuses is None: + logger.warning( + "plan_critique response was not a valid statuses list of length %d; " + "keeping previous statuses. Raw response: %r", + len(self.plans), + response, + ) + return list(self.plan_statuses) + return new_statuses + + def generate_subtask(self, task: str, images: list[np.ndarray]) -> str: + """Produce the next high-level subtask given the current observation. + + On the first call of an episode this also bootstraps the plan. + Mirrors the control flow of + ``openpi-comet/src/openpi/shared/client.py:294`` but avoids double- + pushing the current image into the history deque and uses structural + status comparison (not prose equality) to decide when to reset + ``subtask_history``. + + Prompt adapted from ``openpi-comet/src/openpi/shared/client.py:305``; + the subtask string is enforced via ``SUBTASK_SCHEMA`` to keep the + output short enough for the pi0.5 action prompt budget. + """ + if not images: + raise ValueError("generate_subtask requires at least one image") + + if self.plans is None: + # First frame of the episode — bootstrap the plan. generate_plan + # already pushes initial_image into the history deque. + self.generate_plan(task, images[0]) + self.subtask_history = [] + for extra in images[1:]: + self.history_multi_modals.append(extra) + else: + for img in images: + self.history_multi_modals.append(img) + # Once the plan is fully complete, subsequent critique/subtask + # calls produce degenerate output ("finished", "none") because + # there is nothing left to plan. Skip the VLM calls entirely and + # reuse the last real subtask so the action policy still gets a + # meaningful prompt. + if all_steps_done(self.plan_statuses): + last_subtask = self.subtask_history[-1] if self.subtask_history else task + self.subtask_history.append(last_subtask) + return last_subtask + if self.subtask_history: + updated_statuses = self.plan_critique(task) + if updated_statuses != self.plan_statuses: + self.subtask_history = [] + self.plan_statuses = updated_statuses + + # Unreachable in practice: generate_plan always sets plans/statuses, + # and we took the "first frame" branch above when plans was None. + assert self.plans is not None and self.plan_statuses is not None + + rendered_plan_status = render_plan_status(self.plans, self.plan_statuses) + last_subtask = self.subtask_history[-1] if self.subtask_history else "None" + user_prompt = ( + f"You are given the task of '{task}'. The status of the plans are:\n" + f" {rendered_plan_status}\n" + f" Note that [-] indicates in progress. [o] indicates completed. [x] indicates not started.\n" + f" The last high-level objective given to the robot was '{last_subtask}'." + f"Based on your analysis, what should be the next high-level objective the robot should achieve? " + 'Respond with a JSON object {"subtask": "..."} where the value is ' + "a 3-6 word lowercase imperative phrase." + ) + response = self._chat(user_prompt, self._sampled_history(), response_schema=SUBTASK_SCHEMA) + subtask = _extract_subtask_field(response) + self.subtask_history.append(subtask) + return subtask + + +def _extract_subtask_field(response: str) -> str: + """Pull the ``subtask`` field from a JSON object response. + + Falls back to the cleaned raw text if parsing fails or the field is + missing so we never lose the data the model actually returned. + """ + cleaned = strip_think_tags(response).strip() + if cleaned.startswith("```"): + cleaned = re.sub(r"^```[a-zA-Z]*\n?", "", cleaned) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + try: + obj = json.loads(cleaned) + except json.JSONDecodeError: + logger.warning("Subtask response was not valid JSON; using raw text. Got: %r", cleaned) + return cleaned + if isinstance(obj, dict) and isinstance(obj.get("subtask"), str): + return obj["subtask"].strip() + logger.warning("Subtask JSON missing 'subtask' string field; using raw text. Got: %r", obj) + return cleaned + + +def _extract_statuses_field(response: str, expected_len: int) -> list[PlanStepStatus] | None: + """Pull the ``statuses`` array from a JSON critique response. + + Returns a list of exactly ``expected_len`` validated ``PlanStepStatus`` + values, or ``None`` when parsing fails, the array is missing, lengths + don't match, or any element isn't a known status. Callers should keep + the prior plan_statuses on ``None`` rather than reset them. + """ + cleaned = strip_think_tags(response).strip() + if cleaned.startswith("```"): + cleaned = re.sub(r"^```[a-zA-Z]*\n?", "", cleaned) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + try: + obj = json.loads(cleaned) + except json.JSONDecodeError: + return None + if not isinstance(obj, dict): + return None + statuses = obj.get("statuses") + if not isinstance(statuses, list) or len(statuses) != expected_len: + return None + validated: list[PlanStepStatus] = [] + for item in statuses: + if item not in _STATUS_VALUES: + return None + validated.append(item) + return validated diff --git a/experiments/subtask_probe/droid_eval/comet_style/run.py b/experiments/subtask_probe/droid_eval/comet_style/run.py new file mode 100644 index 0000000..6ac9ed4 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/comet_style/run.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +"""Phase 1 (alt 2): Comet-style hierarchical subtask generation for DROID frames. + +Runs a stateful plan -> critique -> subtask loop per episode, ported from +``openpi-comet/src/openpi/shared/client.py``. Supports two backends: + + * ``--backend gemini`` — Gemini Robotics-ER 1.6 Preview (default). + * ``--backend openai_compat`` — any OpenAI-compatible chat-completions + server (e.g. a local vLLM hosting + Qwen3-VL-30B-A3B-Instruct). + +Output JSON schema matches ``generate_subtasks_gemini.py`` so ``run_action_eval``, +``compute_metrics`` and ``visualize_results`` consume it unchanged. + +Usage: + # Gemini backend (requires GEMINI_API_KEY) + uv run python -m experiments.subtask_probe.droid_eval.comet_style.run \\ + --samples_dir ./.experiments_cache/droid_eval_2min \\ + --output ./.experiments_cache/droid_eval_2min/subtasks_comet_gemini.json \\ + --backend gemini + + # OpenAI-compatible backend (vLLM hosting Qwen3-VL-30B) + uv run python -m experiments.subtask_probe.droid_eval.comet_style.run \\ + --samples_dir ./.experiments_cache/droid_eval_2min \\ + --output ./.experiments_cache/droid_eval_2min/subtasks_comet_qwen.json \\ + --backend openai_compat \\ + --base_url http://localhost:8000/v1 \\ + --model Qwen/Qwen3-VL-30B-A3B-Instruct +""" + +from __future__ import annotations + +import argparse +import json +import logging +import time +from pathlib import Path +from typing import Any, Literal, assert_never, cast, get_args + +import numpy as np +from dotenv import load_dotenv + +from experiments.subtask_probe.droid_eval.utils import load_manifest + +from .reasoner_base import BaseReasoner + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + +Backend = Literal["gemini", "openai_compat"] +BACKEND_CHOICES: tuple[Backend, ...] = get_args(Backend) + + +def _parse_backend(raw: str) -> Backend: + """Narrow the argparse string into the ``Backend`` literal union. + + argparse's ``choices=`` already rejects bad values at runtime; this + function exists to carry that proof into the type system. + """ + if raw in BACKEND_CHOICES: + return cast(Backend, raw) + raise ValueError(f"Unknown backend: {raw!r} (expected one of {BACKEND_CHOICES})") + + +def _build_reasoner( + backend: Backend, + args: argparse.Namespace, +) -> tuple[BaseReasoner, str]: + """Instantiate the requested backend and return (reasoner, backend_label). + + ``backend_label`` gets written into the output JSON's ``backend`` field so + downstream tooling can tell runs apart. + """ + match backend: + case "gemini": + from .gemini_reasoner import DEFAULT_MODEL as GEMINI_DEFAULT + from .gemini_reasoner import GeminiReasoner + + model = args.model or GEMINI_DEFAULT + reasoner: BaseReasoner = GeminiReasoner( + model=model, + thinking_budget=args.thinking_budget, + max_retries=args.max_retries, + history_maxlen=args.history_maxlen, + sampled_images_max=args.sampled_images_max, + history_stride=args.history_stride, + ) + return reasoner, model + case "openai_compat": + from .openai_compat_reasoner import DEFAULT_BASE_URL, OpenAICompatReasoner + from .openai_compat_reasoner import DEFAULT_MODEL as OAI_DEFAULT + + base_url = args.base_url or DEFAULT_BASE_URL + model = args.model or OAI_DEFAULT + reasoner = OpenAICompatReasoner( + base_url=base_url, + model=model, + api_key=args.api_key, + history_maxlen=args.history_maxlen, + sampled_images_max=args.sampled_images_max, + history_stride=args.history_stride, + ) + return reasoner, f"{model}@{base_url}" + case _ as unreachable: + assert_never(unreachable) + + +def _process_episode( + reasoner: BaseReasoner, + samples_dir: Path, + episode: dict[str, Any], + replan_every: int, +) -> list[dict[str, Any]]: + """Run the reasoner across all frames of one episode. + + Calls ``reasoner.reset()`` at the start. On replan frames the reasoner + issues 2 VLM calls (plan/critique + subtask); non-replan frames reuse + the last subtask text and make no VLM calls. + """ + reasoner.reset() + records: list[dict[str, Any]] = [] + episode_id = episode["episode_id"] + instruction = episode["instruction"] + last_subtask = "" + + for step_idx, frame_info in enumerate(episode["frames"]): + frame_idx = frame_info["frame_idx"] + frame_path = samples_dir / frame_info["file"] + frame_data = np.load(frame_path) + exterior_image = np.asarray(frame_data["exterior_image"], dtype=np.uint8) + + is_replan = step_idx % replan_every == 0 + elapsed = 0.0 + if is_replan: + start = time.time() + try: + last_subtask = reasoner.generate_subtask(instruction, [exterior_image]) + except Exception as exc: + logger.warning( + "Reasoner call failed for %s frame %d: %s", + episode_id, + frame_idx, + exc, + ) + last_subtask = "" + elapsed = time.time() - start + + records.append( + { + "episode_id": episode_id, + "frame_idx": frame_idx, + "instruction": instruction, + "subtask_text": last_subtask, + "generation_time_s": round(elapsed, 2), + "server_subtask_ms": round(elapsed * 1000, 1), + } + ) + + if step_idx % 5 == 0 or step_idx == len(episode["frames"]) - 1: + logger.info( + "[%s] frame %d/%d: %r (%.1fs, replan=%s)", + episode_id, + step_idx + 1, + len(episode["frames"]), + last_subtask, + elapsed, + is_replan, + ) + + logger.info( + "[%s] done — final plan_status:\n%s", + episode_id, + reasoner.plan_status or "", + ) + return records + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Comet-style hierarchical subtask generation for DROID frames" + ) + parser.add_argument("--samples_dir", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + parser.add_argument( + "--backend", + choices=list(BACKEND_CHOICES), + default="gemini", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model ID. Defaults: gemini-robotics-er-1.6-preview (gemini) / " + "Qwen/Qwen3-VL-30B-A3B-Instruct (openai_compat).", + ) + parser.add_argument( + "--base_url", + type=str, + default=None, + help="OpenAI-compatible server URL (openai_compat backend only). " + "Default: http://localhost:8000/v1", + ) + parser.add_argument( + "--api_key", + type=str, + default="none", + help="API key for openai_compat backend (local vLLM ignores this)", + ) + parser.add_argument( + "--thinking_budget", + type=int, + default=0, + help="Gemini thinking budget (gemini backend only; 0 disables thinking)", + ) + parser.add_argument( + "--max_retries", + type=int, + default=10, + help="Per-call retry budget for 429 (gemini backend only)", + ) + parser.add_argument( + "--replan_every", + type=int, + default=1, + help=( + "Issue a new plan+critique+subtask every N cached frames; " + "intermediate frames reuse the previous subtask text. Default 1." + ), + ) + parser.add_argument( + "--history_maxlen", + type=int, + default=640, + help="Per-episode image history deque size. Default matches Comet's 64*10.", + ) + parser.add_argument( + "--sampled_images_max", + type=int, + default=64, + help="Max images sampled from history per VLM call.", + ) + parser.add_argument( + "--history_stride", + type=int, + default=5, + help=( + "Stride used when sampling history for each VLM call. Default " + "5 matches Comet's original hardcoded value and empirically " + "gives the best plan-stability on our 1 Hz cache because the " + "wider temporal span provides the contrast signal the reasoner " + "needs to detect progression. stride=1 over-interprets " + "frame-to-frame motion on slow tasks — see FINDINGS.md." + ), + ) + parser.add_argument( + "--max_episodes", + type=int, + default=None, + help="Only process the first N episodes from the manifest (for triage runs).", + ) + return parser.parse_args() + + +def main() -> None: + load_dotenv() + args = _parse_args() + + samples_dir = Path(args.samples_dir) + manifest = load_manifest(samples_dir) + if args.max_episodes is not None: + manifest = manifest[: args.max_episodes] + logger.info("Loaded manifest: %d episodes", len(manifest)) + + backend = _parse_backend(args.backend) + reasoner, backend_label = _build_reasoner(backend, args) + logger.info( + "Backend=%s model=%s replan_every=%d history_maxlen=%d sampled_images_max=%d", + backend, + backend_label, + args.replan_every, + args.history_maxlen, + args.sampled_images_max, + ) + + all_records: list[dict[str, Any]] = [] + for episode in manifest: + logger.info( + "Starting episode %s (%d frames): %r", + episode["episode_id"], + len(episode["frames"]), + episode["instruction"], + ) + all_records.extend( + _process_episode(reasoner, samples_dir, episode, replan_every=args.replan_every) + ) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump( + { + "prompt_format": "comet_style_hierarchical", + "backend": backend_label, + "results": all_records, + }, + f, + indent=2, + ) + + logger.info("Saved %d records to %s (backend=%s)", len(all_records), output_path, backend_label) + + gen_times = [r["generation_time_s"] for r in all_records if r["generation_time_s"] > 0] + if gen_times: + logger.info( + "Latency: mean=%.2fs, min=%.2fs, max=%.2fs (replan calls only)", + float(np.mean(gen_times)), + float(np.min(gen_times)), + float(np.max(gen_times)), + ) + unique_subtasks = {r["subtask_text"] for r in all_records if r["subtask_text"]} + logger.info("Unique subtask texts: %d out of %d frames", len(unique_subtasks), len(all_records)) + failed = sum(1 for r in all_records if not r["subtask_text"]) + if failed: + logger.warning("Empty subtasks: %d / %d frames", failed, len(all_records)) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/compute_metrics.py b/experiments/subtask_probe/droid_eval/compute_metrics.py new file mode 100644 index 0000000..5536561 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/compute_metrics.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +"""Phase 3: Compute metrics comparing action conditions against ground truth. + +Loads ground truth action chunks and predicted actions from all 3 conditions, +computes L2 distance, cosine similarity, per-dimension MAE, and gripper accuracy. + +Usage: + uv run python experiments/subtask_probe/droid_eval/compute_metrics.py \ + --samples_dir ./.experiments_cache/droid_eval \ + --predictions_dir ./.experiments_cache/droid_eval/predictions \ + --output ./.experiments_cache/droid_eval/results.json +""" + +from __future__ import annotations + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, TypedDict + +import numpy as np + +from .constants import CONDITION_NAMES, GRIPPER_THRESHOLD, JOINT_NAMES +from .utils import load_manifest + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +class ConditionMetrics(TypedDict): + l2_distance: float + cosine_similarity: float + per_dim_mae: list[float] + gripper_accuracy: float + per_step_l2: list[float] + + +def compute_frame_metrics( + ground_truth: np.ndarray, + predictions: dict[str, np.ndarray], +) -> dict[str, ConditionMetrics]: + """Compute per-condition metrics for a single frame. + + Args: + ground_truth: [action_horizon, 8] ground truth action chunk + predictions: dict mapping condition name → [action_horizon, 8] predicted actions + """ + results = {} + + for condition_name, pred in predictions.items(): + # Ensure shapes match + min_horizon = min(ground_truth.shape[0], pred.shape[0]) + gt = ground_truth[:min_horizon] + pr = pred[:min_horizon] + + # L2 distance (over entire action chunk) + l2_distance = float(np.linalg.norm(gt - pr)) + + # Cosine similarity (flatten both) + gt_flat = gt.flatten() + pr_flat = pr.flatten() + cos_sim = float( + np.dot(gt_flat, pr_flat) / (np.linalg.norm(gt_flat) * np.linalg.norm(pr_flat) + 1e-10) + ) + + # Per-dimension MAE + per_dim_mae = np.mean(np.abs(gt - pr), axis=0).tolist() # [8] + + # Gripper accuracy (binary: open/closed) + gt_gripper = (gt[:, -1] > GRIPPER_THRESHOLD).astype(int) + pred_gripper = (pr[:, -1] > GRIPPER_THRESHOLD).astype(int) + gripper_accuracy = float(np.mean(gt_gripper == pred_gripper)) + + # Per-timestep L2 (for trajectory analysis) + per_step_l2 = np.linalg.norm(gt - pr, axis=-1).tolist() # [min_horizon] + + results[condition_name] = { + "l2_distance": l2_distance, + "cosine_similarity": cos_sim, + "per_dim_mae": per_dim_mae, + "gripper_accuracy": gripper_accuracy, + "per_step_l2": per_step_l2, + } + + return results + + +def aggregate_metrics( + all_frame_metrics: list[dict[str, ConditionMetrics]], + task_types: list[str], + episode_progress: list[float], +) -> dict[str, Any]: + """Aggregate per-frame metrics into summary statistics.""" + results = {} + + for condition in CONDITION_NAMES: + l2_distances = [fm[condition]["l2_distance"] for fm in all_frame_metrics] + cos_sims = [fm[condition]["cosine_similarity"] for fm in all_frame_metrics] + gripper_accs = [fm[condition]["gripper_accuracy"] for fm in all_frame_metrics] + per_dim_maes = np.array([fm[condition]["per_dim_mae"] for fm in all_frame_metrics]) + + results[condition] = { + "overall": { + "l2_distance": { + "mean": float(np.mean(l2_distances)), + "std": float(np.std(l2_distances)), + }, + "cosine_similarity": { + "mean": float(np.mean(cos_sims)), + "std": float(np.std(cos_sims)), + }, + "gripper_accuracy": { + "mean": float(np.mean(gripper_accs)), + "std": float(np.std(gripper_accs)), + }, + "per_dim_mae": { + name: float(per_dim_maes[:, i].mean()) + for i, name in enumerate(JOINT_NAMES[: per_dim_maes.shape[1]]) + }, + "n_frames": len(l2_distances), + }, + } + + # By task type + for task_type in ["multi_step", "single_step"]: + mask = [t == task_type for t in task_types] + if not any(mask): + continue + + type_l2 = [d for d, m in zip(l2_distances, mask, strict=True) if m] + type_cos = [d for d, m in zip(cos_sims, mask, strict=True) if m] + type_gripper = [d for d, m in zip(gripper_accs, mask, strict=True) if m] + + results[condition][f"task_type_{task_type}"] = { + "l2_distance": {"mean": float(np.mean(type_l2)), "std": float(np.std(type_l2))}, + "cosine_similarity": { + "mean": float(np.mean(type_cos)), + "std": float(np.std(type_cos)), + }, + "gripper_accuracy": { + "mean": float(np.mean(type_gripper)), + "std": float(np.std(type_gripper)), + }, + "n_frames": len(type_l2), + } + + # By episode progress (early/middle/late) + for label, low, high in [ + ("early", 0.0, 0.33), + ("middle", 0.33, 0.67), + ("late", 0.67, 1.01), + ]: + mask = [low <= p < high for p in episode_progress] + if not any(mask): + continue + + prog_l2 = [d for d, m in zip(l2_distances, mask, strict=True) if m] + prog_cos = [d for d, m in zip(cos_sims, mask, strict=True) if m] + + results[condition][f"progress_{label}"] = { + "l2_distance": {"mean": float(np.mean(prog_l2)), "std": float(np.std(prog_l2))}, + "cosine_similarity": { + "mean": float(np.mean(prog_cos)), + "std": float(np.std(prog_cos)), + }, + "n_frames": len(prog_l2), + } + + # Pairwise comparisons (paired differences) + pairwise = {} + for condition_a, condition_b in [ + ("baseline", "subtask"), + ]: + l2_a = np.array([fm[condition_a]["l2_distance"] for fm in all_frame_metrics]) + l2_b = np.array([fm[condition_b]["l2_distance"] for fm in all_frame_metrics]) + diff = l2_a - l2_b # positive = condition_b is closer to ground truth + + pairwise[f"{condition_a}_vs_{condition_b}"] = { + "l2_diff_mean": float(np.mean(diff)), + "l2_diff_std": float(np.std(diff)), + "pct_b_better": float(np.mean(diff > 0) * 100), + "pct_a_better": float(np.mean(diff < 0) * 100), + } + + # Wilcoxon signed-rank test + try: + from scipy.stats import wilcoxon + + stat, p_value = wilcoxon(l2_a, l2_b, alternative="two-sided") + pairwise[f"{condition_a}_vs_{condition_b}"]["wilcoxon_p"] = float(p_value) + pairwise[f"{condition_a}_vs_{condition_b}"]["wilcoxon_stat"] = float(stat) + except ImportError: + logger.warning("scipy not available, skipping Wilcoxon test") + except ValueError as e: + logger.warning("Wilcoxon test failed: %s", e) + + results["pairwise"] = pairwise + return results + + +def print_summary(results: dict[str, Any]) -> None: + """Print a human-readable summary table.""" + print("\n" + "=" * 80) + print(" DROID EVALUATION RESULTS") + print("=" * 80) + + # Overall metrics table + print(f"\n {'Condition':<15} {'L2 Dist':>12} {'Cos Sim':>12} {'Grip Acc':>12} {'N':>6}") + print(f" {'-' * 57}") + + for condition in CONDITION_NAMES: + overall = results[condition]["overall"] + l2 = overall["l2_distance"] + cos = overall["cosine_similarity"] + grip = overall["gripper_accuracy"] + n = overall["n_frames"] + print( + f" {condition:<15} {l2['mean']:>8.4f}+-{l2['std']:<4.4f}" + f" {cos['mean']:>8.4f}+-{cos['std']:<4.4f}" + f" {grip['mean']:>8.4f}+-{grip['std']:<4.4f}" + f" {n:>6}" + ) + + # Per-dimension MAE + print("\n Per-dimension MAE:") + print(f" {'Condition':<15}", end="") + for name in JOINT_NAMES: + print(f" {name:>8}", end="") + print() + print(f" {'-' * (15 + 8 * len(JOINT_NAMES))}") + + for condition in CONDITION_NAMES: + per_dim = results[condition]["overall"]["per_dim_mae"] + print(f" {condition:<15}", end="") + for name in JOINT_NAMES: + if name in per_dim: + print(f" {per_dim[name]:>8.4f}", end="") + print() + + # Task type breakdown + for task_type in ["multi_step", "single_step"]: + key = f"task_type_{task_type}" + if key not in results["baseline"]: + continue + + print(f"\n {task_type.replace('_', ' ').title()} Tasks:") + print(f" {'Condition':<15} {'L2 Dist':>12} {'Cos Sim':>12} {'N':>6}") + print(f" {'-' * 45}") + + for condition in CONDITION_NAMES: + if key not in results[condition]: + continue + data = results[condition][key] + l2 = data["l2_distance"] + cos = data["cosine_similarity"] + n = data["n_frames"] + print( + f" {condition:<15} {l2['mean']:>8.4f}+-{l2['std']:<4.4f}" + f" {cos['mean']:>8.4f}+-{cos['std']:<4.4f}" + f" {n:>6}" + ) + + # Pairwise comparisons + print("\n Pairwise Comparisons:") + print(f" {'Comparison':<30} {'L2 Diff':>10} {'% B Better':>12} {'Wilcoxon p':>12}") + print(f" {'-' * 64}") + + for comparison, data in results["pairwise"].items(): + p_str = f"{data['wilcoxon_p']:.4f}" if "wilcoxon_p" in data else "N/A" + print( + f" {comparison:<30} {data['l2_diff_mean']:>10.4f}" + f" {data['pct_b_better']:>11.1f}%" + f" {p_str:>12}" + ) + + print("\n" + "=" * 80) + + # Interpretation + pairwise = results["pairwise"] + comparison = pairwise.get("baseline_vs_subtask", {}) + + print("\n Interpretation:") + if comparison.get("l2_diff_mean", 0) > 0: + print(" - Subtask produces actions CLOSER to ground truth than baseline (L2 diff > 0)") + else: + print(" - Subtask produces actions FARTHER from ground truth than baseline") + + p = comparison.get("wilcoxon_p", 1.0) + if p < 0.05: + print(f" - Difference is statistically significant (p={p:.4f})") + else: + print(f" - Difference is NOT statistically significant (p={p:.4f})") + + print() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Compute eval metrics") + parser.add_argument("--samples_dir", type=str, required=True) + parser.add_argument("--predictions_dir", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + args = parser.parse_args() + + samples_dir = Path(args.samples_dir) + predictions_dir = Path(args.predictions_dir) + + # Load manifests + manifest = load_manifest(samples_dir) + + with (predictions_dir / "prediction_manifest.json").open() as f: + pred_manifest = json.load(f) + + # Build episode metadata index + episode_metadata = {} + for episode in manifest: + episode_metadata[episode["episode_id"]] = episode + + # Process all frames + all_frame_metrics = [] + task_types = [] + episode_progress_values = [] + + for pred_entry in pred_manifest: + episode_id = pred_entry["episode_id"] + frame_idx = pred_entry["frame_idx"] + episode = episode_metadata[episode_id] + + # Load ground truth + frame_file = samples_dir / next( + f["file"] for f in episode["frames"] if f["frame_idx"] == frame_idx + ) + frame_data = np.load(frame_file) + ground_truth = frame_data["ground_truth_actions"] # [15, 8] + + # Load predictions + pred_file = predictions_dir / pred_entry["prediction_file"] + pred_data = np.load(pred_file) + + predictions = { + "baseline": pred_data["baseline"], + "subtask": pred_data["subtask"], + } + + # Compute metrics + frame_metrics = compute_frame_metrics(ground_truth, predictions) + all_frame_metrics.append(frame_metrics) + task_types.append(episode["task_type"]) + + # Episode progress: where in the trajectory is this frame? + progress = frame_idx / max(episode["traj_len"] - 1, 1) + episode_progress_values.append(progress) + + logger.info("Computed metrics for %d frames", len(all_frame_metrics)) + + # Aggregate + results = aggregate_metrics(all_frame_metrics, task_types, episode_progress_values) + + # Save results + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump(results, f, indent=2) + + logger.info("Results saved to %s", output_path) + + # Print summary + print_summary(results) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/constants.py b/experiments/subtask_probe/droid_eval/constants.py new file mode 100644 index 0000000..5eec722 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/constants.py @@ -0,0 +1,32 @@ +"""Shared constants for the DROID subtask evaluation pipeline.""" + +from typing import Literal + +# Inference mode — finite set of valid values for the "mode" field in observation dicts. +# "subtask_only": generate subtask text only (no action generation) +# "action_only": generate actions only (skip subtask generation) +InferenceMode = Literal["subtask_only", "action_only"] + +# DROID action dimensions +ACTION_HORIZON = 15 # Number of future action steps predicted per frame +DROID_ACTION_DIM = 8 # 7 joints + 1 gripper +MODEL_ACTION_DIM = 32 # pi0.5 internal latent action dimension + +# Evaluation conditions +CONDITION_NAMES = ["baseline", "subtask"] + +# Joint names for per-dimension metrics and visualization +JOINT_NAMES = ["j1", "j2", "j3", "j4", "j5", "j6", "j7", "gripper"] + +# Gripper state threshold (values above = closed, below = open) +GRIPPER_THRESHOLD = 0.5 + +# Visualization colors per condition +CONDITION_COLORS = { + "baseline": "#4a90d9", + "subtask": "#d94a4a", + "ground_truth": "#888888", +} + +# Default server ports +DEFAULT_QUIC_PORT = 5555 diff --git a/experiments/subtask_probe/droid_eval/extract_droid_samples.py b/experiments/subtask_probe/droid_eval/extract_droid_samples.py new file mode 100644 index 0000000..5b14d8f --- /dev/null +++ b/experiments/subtask_probe/droid_eval/extract_droid_samples.py @@ -0,0 +1,403 @@ +#!/usr/bin/env python3 +"""Phase 0: Extract DROID samples from RLDS for evaluation. + +Streams episodes from gs://gresearch/robotics/droid/1.0.1, extracts frames +with images, proprioceptive state, language instructions, and ground truth +action chunks. Caches to local .npz files. + +Two selection modes: + +* **First-K** (default): take the first ``--num_episodes`` matches in stream + order. Fast and useful when the dataset is already well-shaped. + +* **Top-K longest** (``--scan_episodes N``): scan the first N episodes in the + stream, buffer qualifiers in an in-memory min-heap, save the + ``--num_episodes`` longest. Use for long-horizon evals where the raw stream + is dominated by short demos. Works well paired with ``--require_multi_step``. + +Usage: + # First-K (legacy behavior): + uv run python -m experiments.subtask_probe.droid_eval.extract_droid_samples \\ + --num_episodes 10 \\ + --output_dir ./.experiments_cache/droid_eval + + # Top-K longest with multi-step filter (long-horizon eval): + uv run python -m experiments.subtask_probe.droid_eval.extract_droid_samples \\ + --num_episodes 5 --scan_episodes 5000 \\ + --min_duration_s 60 --require_multi_step \\ + --output_dir ./.experiments_cache/droid_eval_2min +""" + +from __future__ import annotations + +import argparse +import heapq +import json +import logging +import re +from pathlib import Path +from typing import Any + +import numpy as np +import tqdm + +from .constants import ACTION_HORIZON +from .utils import decode_droid_image + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + +# DROID is recorded at 15 Hz — used to convert --min_duration_s to a step count. +DROID_FPS = 15 + +# Multi-step task keywords — tasks containing these are likely long-horizon +MULTI_STEP_KEYWORDS = re.compile( + r"\b(and then|then|after|pick.{1,20}place|put.{1,20}in|open.{1,20}put|grab.{1,20}move)\b", + re.IGNORECASE, +) + + +def is_multi_step_task(instruction: str) -> bool: + """Heuristic: does the instruction describe a multi-step task?""" + return bool(MULTI_STEP_KEYWORDS.search(instruction)) + + +def _decode_str(value: bytes | str) -> str: + return value.decode("utf-8") if isinstance(value, bytes) else value + + +def _extract_traj_fields(traj: dict[str, Any]) -> dict[str, Any] | None: + """Pull metadata + arrays from a raw numpy traj. + + Returns None if the episode is unusable (empty instruction, too short). + """ + file_path = _decode_str(traj["traj_metadata"]["episode_metadata"]["file_path"][0]) + instruction = _decode_str(traj["language_instruction"][0]).strip() + if not instruction or instruction.lower() in ("nan", "none"): + logger.debug("Skipping episode with empty instruction: %s", file_path) + return None + + actions_joint = traj["action_dict"]["joint_position"] # [T, 7] + actions_gripper = traj["action_dict"]["gripper_position"] # [T, 1] + actions = np.concatenate([actions_joint, actions_gripper], axis=-1) # [T, 8] + + traj_len = len(actions) + if traj_len < ACTION_HORIZON: + logger.debug("Skipping short episode (%d frames): %s", traj_len, file_path) + return None + + return { + "file_path": file_path, + "instruction": instruction, + "traj_len": int(traj_len), + "actions": actions, + "exterior_images": traj["observation"]["exterior_image_1_left"], + "wrist_images": traj["observation"]["wrist_image_left"], + "joint_positions": traj["observation"]["joint_position"], + "gripper_positions": traj["observation"]["gripper_position"], + } + + +def _passes_filters( + fields: dict[str, Any], min_duration_s: float, require_multi_step: bool +) -> bool: + min_traj_len = int(min_duration_s * DROID_FPS) + if fields["traj_len"] < min_traj_len: + return False + return not (require_multi_step and not is_multi_step_task(fields["instruction"])) + + +def _save_episode( + fields: dict[str, Any], + episode_id: str, + output_dir: Path, + frame_subsample: int, +) -> dict[str, Any]: + """Decode images + save npz frames for one episode. Returns the manifest entry.""" + episode_dir = output_dir / episode_id + episode_dir.mkdir(exist_ok=True) + + traj_len = fields["traj_len"] + actions = fields["actions"] + frame_indices = list(range(0, traj_len, frame_subsample)) + + frame_records = [] + for frame_idx in frame_indices: + # Decode images from JPEG bytes (kept at native DROID res; server pads to 224x224). + exterior_img = decode_droid_image(fields["exterior_images"][frame_idx]) + wrist_img = decode_droid_image(fields["wrist_images"][frame_idx]) + + # Ground truth action chunk: next ACTION_HORIZON actions, pad tail with last. + action_chunk_indices = np.minimum( + np.arange(frame_idx, frame_idx + ACTION_HORIZON), traj_len - 1 + ) + ground_truth_action_chunk = actions[action_chunk_indices] # [15, 8] + + state = np.concatenate( + [fields["joint_positions"][frame_idx], fields["gripper_positions"][frame_idx]] + ) # [8] + + frame_file = episode_dir / f"frame_{frame_idx:05d}.npz" + np.savez_compressed( + frame_file, + exterior_image=exterior_img, + wrist_image=wrist_img, + state=state, + ground_truth_actions=ground_truth_action_chunk, + frame_idx=frame_idx, + ) + + frame_records.append( + {"frame_idx": int(frame_idx), "file": str(frame_file.relative_to(output_dir))} + ) + + return { + "episode_id": episode_id, + "instruction": fields["instruction"], + "task_type": "multi_step" if is_multi_step_task(fields["instruction"]) else "single_step", + "file_path": fields["file_path"], + "traj_len": traj_len, + "num_frames": len(frame_records), + "frames": frame_records, + } + + +def _write_manifest(manifest: list[dict[str, Any]], output_dir: Path) -> None: + manifest_path = output_dir / "manifest.json" + with manifest_path.open("w") as f: + json.dump(manifest, f, indent=2) + + multi_step_count = sum(1 for ep in manifest if ep["task_type"] == "multi_step") + single_step_count = sum(1 for ep in manifest if ep["task_type"] == "single_step") + total_frames = sum(ep["num_frames"] for ep in manifest) + logger.info("Extraction complete:") + logger.info( + " Episodes: %d (%d multi-step, %d single-step)", + len(manifest), + multi_step_count, + single_step_count, + ) + logger.info(" Total frames: %d", total_frames) + logger.info(" Manifest: %s", manifest_path) + + +def _extract_first_k( + dataset: Any, + num_episodes: int, + output_dir: Path, + frame_subsample: int, + min_duration_s: float, + require_multi_step: bool, +) -> None: + """Stream-scan mode: save the first ``num_episodes`` qualifying episodes.""" + output_dir.mkdir(parents=True, exist_ok=True) + manifest: list[dict[str, Any]] = [] + + for traj in tqdm.tqdm(dataset.as_numpy_iterator(), desc="Extracting", total=num_episodes): + if len(manifest) >= num_episodes: + break + fields = _extract_traj_fields(traj) + if fields is None or not _passes_filters(fields, min_duration_s, require_multi_step): + continue + entry = _save_episode(fields, f"ep_{len(manifest):04d}", output_dir, frame_subsample) + manifest.append(entry) + logger.info( + "Episode %s (%d steps / %.1fs, %s): %r", + entry["episode_id"], + entry["traj_len"], + entry["traj_len"] / DROID_FPS, + entry["task_type"], + entry["instruction"][:80], + ) + + _write_manifest(manifest, output_dir) + + +def _extract_top_k( + dataset: Any, + num_episodes: int, + scan_episodes: int, + output_dir: Path, + frame_subsample: int, + min_duration_s: float, + require_multi_step: bool, +) -> None: + """Top-K longest mode: scan ``scan_episodes`` and keep the longest ``num_episodes`` matches. + + Uses a size-``num_episodes`` min-heap keyed on traj_len so memory stays bounded + (roughly ``num_episodes`` x episode_size in RAM). Episode size is ~20-40 MB for + a 1-2 minute DROID episode, so keeping 5 in memory is ~150 MB — fine. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + # heap: (traj_len, insertion_index, fields_dict). insertion_index is a tiebreaker + # so heapq never tries to compare raw dicts when traj_len matches. + heap: list[tuple[int, int, dict[str, Any]]] = [] + scanned = 0 + qualifying = 0 + + for i, traj in enumerate( + tqdm.tqdm(dataset.as_numpy_iterator(), desc="Scanning", total=scan_episodes) + ): + if i >= scan_episodes: + break + scanned = i + 1 + fields = _extract_traj_fields(traj) + if fields is None or not _passes_filters(fields, min_duration_s, require_multi_step): + continue + qualifying += 1 + key = (fields["traj_len"], i, fields) + if len(heap) < num_episodes: + heapq.heappush(heap, key) + elif fields["traj_len"] > heap[0][0]: + heapq.heapreplace(heap, key) + + logger.info( + "Scan complete: %d episodes read, %d passed filters, keeping top %d by duration", + scanned, + qualifying, + len(heap), + ) + + # Sort longest-first so ep_0000 is the longest — easier to eyeball the manifest. + selected = sorted(heap, key=lambda entry: -entry[0]) + + manifest: list[dict[str, Any]] = [] + for episode_count, (_, _, fields) in enumerate(selected): + entry = _save_episode(fields, f"ep_{episode_count:04d}", output_dir, frame_subsample) + manifest.append(entry) + logger.info( + "Saved %s (%d steps / %.1fs, %s): %r", + entry["episode_id"], + entry["traj_len"], + entry["traj_len"] / DROID_FPS, + entry["task_type"], + entry["instruction"][:80], + ) + + _write_manifest(manifest, output_dir) + + +def extract_episodes( + data_dir: str, + num_episodes: int, + output_dir: Path, + frame_subsample: int = 10, + min_duration_s: float = 0.0, + require_multi_step: bool = False, + scan_episodes: int | None = None, +) -> None: + """Stream DROID episodes from GCS and cache selected frames.""" + # Lazy imports — tensorflow is heavy and optional + import dlimp as dl # ty: ignore[unresolved-import] + import tensorflow as tf # ty: ignore[unresolved-import] + import tensorflow_datasets as tfds # ty: ignore[unresolved-import] + + # Prevent TF from grabbing GPU + tf.config.set_visible_devices([], "GPU") + + logger.info("Building DROID RLDS dataset from %s ...", data_dir) + builder = tfds.builder("droid", data_dir=data_dir, version="1.0.1") + dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=False, num_parallel_reads=4) + + # Filter for successful episodes only + dataset = dataset.filter( + lambda traj: tf.strings.regex_full_match( + traj["traj_metadata"]["episode_metadata"]["file_path"][0], + ".*success.*", + ) + ) + + if scan_episodes is None: + _extract_first_k( + dataset=dataset, + num_episodes=num_episodes, + output_dir=output_dir, + frame_subsample=frame_subsample, + min_duration_s=min_duration_s, + require_multi_step=require_multi_step, + ) + else: + _extract_top_k( + dataset=dataset, + num_episodes=num_episodes, + scan_episodes=scan_episodes, + output_dir=output_dir, + frame_subsample=frame_subsample, + min_duration_s=min_duration_s, + require_multi_step=require_multi_step, + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Extract DROID samples for evaluation") + parser.add_argument( + "--data_dir", + type=str, + default="gs://gresearch/robotics", + help=( + "GCS path that contains the `droid/` TFDS tree. The public DROID v1.0.1 " + "lives at gs://gresearch/robotics/droid/1.0.1, so the right data_dir is " + "`gs://gresearch/robotics`." + ), + ) + parser.add_argument( + "--num_episodes", + type=int, + default=10, + help="Number of episodes to save (start small to validate pipeline)", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./.experiments_cache/droid_eval", + help="Local directory to cache extracted frames", + ) + parser.add_argument( + "--frame_subsample", + type=int, + default=10, + help="Take every Nth frame from each episode", + ) + parser.add_argument( + "--min_duration_s", + type=float, + default=0.0, + help=( + "Minimum episode duration in seconds. DROID runs at 15 Hz, so " + "e.g. 60 keeps only episodes >= 1:00. Default 0 keeps all." + ), + ) + parser.add_argument( + "--require_multi_step", + action="store_true", + help=( + "Keep only episodes whose language instruction matches the multi-step " + "keyword heuristic (and-then, pick…place, put…in, etc.)." + ), + ) + parser.add_argument( + "--scan_episodes", + type=int, + default=None, + help=( + "If set, enable top-K longest mode: scan this many stream episodes, buffer " + "qualifying ones in a min-heap, save the --num_episodes longest. If omitted, " + "use legacy first-K streaming." + ), + ) + args = parser.parse_args() + + extract_episodes( + data_dir=args.data_dir, + num_episodes=args.num_episodes, + output_dir=Path(args.output_dir), + frame_subsample=args.frame_subsample, + min_duration_s=args.min_duration_s, + require_multi_step=args.require_multi_step, + scan_episodes=args.scan_episodes, + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/README.md b/experiments/subtask_probe/droid_eval/foreact_eval/README.md new file mode 100644 index 0000000..24e7730 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/README.md @@ -0,0 +1,80 @@ +# ForeAct reconstruction (inference-only) + +Faithful-as-possible reconstruction of [ForeAct (arxiv 2602.12322)](https://arxiv.org/abs/2602.12322) +on our DROID subtask-probe pipeline. **No training.** See the "ForeAct release +audit" section in `../../FINDINGS.md` for the scope-defining findings — short +version: the paper's released checkpoint is the foresight generator only; the +fine-tuned VLA is not released; and our pi0.5 action server has a fixed +2-image interface so we can't feed foresight to actions zero-shot. + +This module covers two of the paper's three components: + +| Component | File | Status | +|---|---|---| +| π_v — VLM planner (Table 5 prompts) | `planner.py` + `generate_subtasks.py` | faithful | +| π_g — foresight image generator | `generate_foresight.py` | faithful to released checkpoint, runs on remote GPU | +| VLA with augmented visual input | — | **unreachable without fine-tuning** | + +## Phase 1: Planner subtasks + +The planner uses the exact Table 5 prompts (initial + follow-up). Per-episode +state is literally just `previous_subtask: str | None`. We enforce a JSON +schema on the VLM output (`{"subtask": str, "previous_finished": bool}`) — +the paper only says "concise and deterministic"; the schema is our addition +for reliability and observability. + +```bash +# Local vLLM on US West 2 serving the paper's model (Qwen3-VL-8B-Instruct): +uv run python -m experiments.subtask_probe.droid_eval.foreact_eval.generate_subtasks \ + --samples_dir ./.experiments_cache/droid_eval_ah15 \ + --output ./.experiments_cache/droid_eval_ah15/subtasks_foreact_qwen8b.json \ + --backend openai_compat \ + --base_url http://localhost:8000/v1 \ + --model Qwen/Qwen3-VL-8B-Instruct +``` + +Backends: `openai_compat` (paper's setup), `gemini` (optional, for comparison +with our prior Comet-Gemini run). + +## Phase 2: Foresight image generator + +Runs the released `mit-han-lab/foreact-pretrained` checkpoint on our DROID +cache frames + the Phase 1 subtasks. **Must run on a GPU box** inside the +foreact conda env — the `diffusers`/`transformers`/`deepspeed` pins conflict +with our hosting project's deps. + +On the remote box (US West 2 L40S): +```bash +# After stopping the Qwen 8B vLLM to free VRAM: +conda activate foreact +huggingface-cli download mit-han-lab/foreact-pretrained --local-dir ~/foreact_ckpt + +python generate_foresight.py \ + --samples_dir ~/.experiments_cache/droid_eval_ah15 \ + --subtasks ~/.experiments_cache/droid_eval_ah15/subtasks_foreact_qwen8b.json \ + --output_dir ~/.experiments_cache/droid_eval_ah15/foresight_foreact \ + --checkpoint ~/foreact_ckpt +``` + +Then rsync the PNGs back to local: +```bash +rsync -av us-west-2-l40s:~/.experiments_cache/droid_eval_ah15/foresight_foreact/ \ + ./.experiments_cache/droid_eval_ah15/foresight_foreact/ +``` + +Paper's recommended inference hparams (from `foreact/app_cli.py`): +`guidance_scale=4.5`, `image_guidance_scale=1.5`, `num_inference_steps=8`. + +## Phase 3: Visualization + +HTML + mp4 with the third "predicted foresight" image column alongside +exterior / wrist / subtask: + +```bash +uv run python -m experiments.subtask_probe.droid_eval.foreact_eval.visualize_foreact \ + --samples_dir ./.experiments_cache/droid_eval_ah15 \ + --subtasks ./.experiments_cache/droid_eval_ah15/subtasks_foreact_qwen8b.json \ + --foresight_dir ./.experiments_cache/droid_eval_ah15/foresight_foreact \ + --output_dir ./.experiments_cache/droid_eval_ah15/foreact_report \ + [--video --fps 2] +``` diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/__init__.py b/experiments/subtask_probe/droid_eval/foreact_eval/__init__.py new file mode 100644 index 0000000..c14396a --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/__init__.py @@ -0,0 +1,6 @@ +"""Faithful ForeAct reconstruction on DROID (inference-only, no training). + +See ``README.md`` for scope and the ``FINDINGS.md`` ForeAct release-audit +section for why this module only covers π_v (planner) + π_g (foresight +generator) and not the end-to-end VLA integration the paper reports. +""" diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/_io.py b/experiments/subtask_probe/droid_eval/foreact_eval/_io.py new file mode 100644 index 0000000..6ba6e26 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/_io.py @@ -0,0 +1,59 @@ +"""Shared helpers for the foreact_eval package. + +Currently imported by the nano-banana generator + visualizer. The ForeAct +generators (generate_foresight.py, generate_foresight_lerobot.py) run in a +separate conda env on the remote GPU box and don't import these helpers, +but nothing here prevents them from doing so later. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Literal, NamedTuple + + +def foresight_path(output_dir: Path, episode_index: int, frame_idx: int) -> Path: + """Where a foresight generator writes its PNG for one frame. + + Single owner of the on-disk ``episode_{:06d}/frame_{:05d}.png`` layout. + If we ever change it (e.g. to include model name in the path), this is + the only place to touch. + """ + return output_dir / f"episode_{episode_index:06d}" / f"frame_{frame_idx:05d}.png" + + +class SourceFrame(NamedTuple): + """One input frame for foresight generation or visualization. + + Produced by ``iter_source_frames`` at the filesystem boundary so + downstream code treats (episode_index, frame_idx) as one logical ID + and doesn't re-parse filenames. + """ + + episode_index: int + frame_idx: int + actual_path: Path + + +def iter_source_frames(source_root: Path) -> list[SourceFrame]: + """Every frame under ``source_root/episode_*/actual/frame_*.png``. + + Sorted by (episode_index, frame_idx) so callers render or generate in + temporal order. + """ + frames: list[SourceFrame] = [] + for episode_dir in sorted(source_root.glob("episode_*")): + actual_dir = episode_dir / "actual" + if not actual_dir.exists(): + continue + episode_index = int(episode_dir.name.removeprefix("episode_")) + for path in sorted(actual_dir.glob("frame_*.png")): + frame_idx = int(path.stem.removeprefix("frame_")) + frames.append(SourceFrame(episode_index, frame_idx, path)) + return frames + + +# Chain-mode manifest records one of these per frame. Typing as a Literal +# (not free-form str) means a typo in any write site fails at type-check +# instead of silent log-grep later. +ForesightStatus = Literal["cached", "generated", "failed", "refused"] diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/generate_foresight.py b/experiments/subtask_probe/droid_eval/foreact_eval/generate_foresight.py new file mode 100644 index 0000000..490080e --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/generate_foresight.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +"""Run the ForeAct foresight image generator across DROID cache frames. + +IMPORTANT: this driver is designed to run **inside the foreact conda env on a +remote GPU box**, not inside our hosting project's Python env. It imports +``pipeline.VisualForesightPipeline`` and ``utils.trainer_utils`` from the +``/Users/kkuan/openpi/foreact/`` repo — those modules live in the foreact +env, not ours, which is why our linter ignores the imports. + +Typical workflow: + +1. On the remote box, clone the foreact repo and run its ``environment_setup.sh``. +2. ``huggingface-cli download mit-han-lab/foreact-pretrained --local-dir ~/foreact_ckpt``. +3. Copy our DROID cache + subtasks JSON to the box. +4. ``scp`` this script into the foreact repo directory. +5. ``conda activate foreact && python generate_foresight.py ...``. + +Uses the paper's recommended inference hparams (``foreact/app_cli.py``): +``guidance_scale=4.5``, ``image_guidance_scale=1.5``, ``num_inference_steps=8``. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import random +import time +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from PIL import Image + +# These imports resolve inside the foreact conda env, not ours. This script +# is meant to be run on the remote GPU box after ``conda activate foreact``. +from pipeline import VisualForesightPipeline # ty: ignore[unresolved-import] +from utils.trainer_utils import find_newest_checkpoint # ty: ignore[unresolved-import] + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def _load_manifest(samples_dir: Path) -> list[dict[str, Any]]: + with (samples_dir / "manifest.json").open() as f: + return json.load(f) + + +def _load_subtask_records(path: Path) -> list[dict[str, Any]]: + with path.open() as f: + payload = json.load(f) + if isinstance(payload, dict) and "results" in payload: + return payload["results"] + if isinstance(payload, list): + return payload + raise ValueError(f"Unrecognized subtask JSON shape in {path}") + + +def _index_subtasks(records: list[dict[str, Any]]) -> dict[tuple[str, int], str]: + return {(r["episode_id"], r["frame_idx"]): r["subtask_text"] for r in records} + + +def _process_frame( + pipeline: VisualForesightPipeline, + exterior_image: np.ndarray, + subtask_text: str, + *, + guidance_scale: float, + image_guidance_scale: float, + num_inference_steps: int, + seed: int, +) -> tuple[Image.Image, float]: + """Generate one foresight image and return (image, elapsed_seconds).""" + pil_in = Image.fromarray(exterior_image).convert("RGB") + generator = torch.Generator().manual_seed(seed) + start = time.time() + out = pipeline( + caption=subtask_text, + image=pil_in, + guidance_scale=guidance_scale, + image_guidance_scale=image_guidance_scale, + num_inference_steps=num_inference_steps, + num_images_per_prompt=1, + generator=generator, + ).images + elapsed = time.time() - start + if not out: + raise RuntimeError("VisualForesightPipeline returned no images") + return out[0], elapsed + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="ForeAct foresight generator over DROID cache frames" + ) + parser.add_argument("--samples_dir", type=str, required=True) + parser.add_argument( + "--subtasks", + type=str, + required=True, + help="Path to a subtasks_*.json (e.g. from foreact_eval.generate_subtasks).", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Where to write foresight PNGs and foresight_manifest.json.", + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Local path to the foreact-pretrained checkpoint directory.", + ) + parser.add_argument("--guidance_scale", type=float, default=4.5) + parser.add_argument("--image_guidance_scale", type=float, default=1.5) + parser.add_argument("--num_inference_steps", type=int, default=8) + parser.add_argument( + "--max_episodes", + type=int, + default=None, + help="Only process the first N episodes (for smoke tests).", + ) + parser.add_argument("--seed", type=int, default=42) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + + samples_dir = Path(args.samples_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + manifest = _load_manifest(samples_dir) + if args.max_episodes is not None: + manifest = manifest[: args.max_episodes] + logger.info("Loaded manifest: %d episodes", len(manifest)) + + records = _load_subtask_records(Path(args.subtasks)) + subtask_index = _index_subtasks(records) + logger.info("Loaded %d subtask records", len(records)) + + logger.info("Loading VisualForesightPipeline from %s ...", args.checkpoint) + pipeline = VisualForesightPipeline.from_pretrained( + find_newest_checkpoint(args.checkpoint), + ignore_mismatched_sizes=True, + _gradient_checkpointing=False, + torch_dtype=torch.bfloat16, + ) + pipeline = pipeline.to(device="cuda", dtype=torch.bfloat16) + logger.info("Pipeline loaded.") + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + generation_records: list[dict[str, Any]] = [] + for episode in manifest: + episode_id = episode["episode_id"] + episode_dir = output_dir / episode_id + episode_dir.mkdir(parents=True, exist_ok=True) + logger.info("Starting episode %s (%d frames)", episode_id, len(episode["frames"])) + + for step_idx, frame_info in enumerate(episode["frames"]): + frame_idx = frame_info["frame_idx"] + subtask_text = subtask_index.get((episode_id, frame_idx), "") + if not subtask_text: + logger.warning("No subtask for %s frame %d; skipping", episode_id, frame_idx) + continue + + frame_path = samples_dir / frame_info["file"] + frame_data = np.load(frame_path) + exterior_image = np.asarray(frame_data["exterior_image"], dtype=np.uint8) + + # Per-frame seed so the generation is reproducible but each frame + # has independent noise. + frame_seed = (args.seed, episode_id, frame_idx).__hash__() & 0x7FFFFFFF + try: + image, elapsed = _process_frame( + pipeline, + exterior_image, + subtask_text, + guidance_scale=args.guidance_scale, + image_guidance_scale=args.image_guidance_scale, + num_inference_steps=args.num_inference_steps, + seed=frame_seed, + ) + except Exception as exc: + logger.warning( + "Foresight generation failed for %s frame %d: %s", + episode_id, + frame_idx, + exc, + ) + continue + + out_path = episode_dir / f"frame_{frame_idx:05d}.png" + image.save(out_path) + generation_records.append( + { + "episode_id": episode_id, + "frame_idx": frame_idx, + "subtask_text": subtask_text, + "output": str(out_path.relative_to(output_dir)), + "generation_time_s": round(elapsed, 3), + "seed": frame_seed, + } + ) + + if step_idx % 5 == 0 or step_idx == len(episode["frames"]) - 1: + logger.info( + "[%s] frame %d/%d: %r (%.2fs)", + episode_id, + step_idx + 1, + len(episode["frames"]), + subtask_text, + elapsed, + ) + + manifest_path = output_dir / "foresight_manifest.json" + with manifest_path.open("w") as f: + json.dump( + { + "hparams": { + "guidance_scale": args.guidance_scale, + "image_guidance_scale": args.image_guidance_scale, + "num_inference_steps": args.num_inference_steps, + "seed": args.seed, + "checkpoint": args.checkpoint, + }, + "records": generation_records, + }, + f, + indent=2, + ) + logger.info( + "Wrote %d foresight images + manifest to %s", + len(generation_records), + output_dir, + ) + + latencies = [r["generation_time_s"] for r in generation_records] + if latencies: + logger.info( + "Foresight latency: mean=%.2fs min=%.2fs max=%.2fs (n=%d)", + float(np.mean(latencies)), + float(np.min(latencies)), + float(np.max(latencies)), + len(latencies), + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/generate_foresight_lerobot.py b/experiments/subtask_probe/droid_eval/foreact_eval/generate_foresight_lerobot.py new file mode 100644 index 0000000..4b2abc1 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/generate_foresight_lerobot.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +"""Run the ForeAct foresight generator across a LeRobot-format dataset episode. + +Same generator, same hparams as ``generate_foresight.py``, but reads input +frames from LeRobot-v2.1 shards (mp4 videos + parquet) instead of our DROID +.npz cache. Used to test the pretrained generator on a dataset it was +*pretrained on* (Galaxea R1 Lite, `mit-han-lab/ForeActDataset`). + +This script runs inside the foreact conda env on the remote GPU box: + + ssh us-west-2 "cd ~/foreact && source .venv/bin/activate && \\ + python generate_foresight_lerobot.py \\ + --dataset_root ~/foreact_dataset/20251102_Pick_Veg \\ + --camera_key observation.images.head_left_rgb \\ + --output_dir ~/foresight_foreact_picksveg \\ + --checkpoint ~/foreact_ckpt \\ + --episode_indices 0,1,2,3,4 \\ + --stride 15" +""" + +from __future__ import annotations + +import argparse +import json +import logging +import random +import time +from pathlib import Path +from typing import Any + +import av +import numpy as np +import torch +from PIL import Image + +# Resolved inside the foreact conda env on the remote box. +from pipeline import VisualForesightPipeline # ty: ignore[unresolved-import] +from utils.trainer_utils import find_newest_checkpoint # ty: ignore[unresolved-import] + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def _load_episode_index(dataset_root: Path) -> dict[int, dict[str, Any]]: + """Parse meta/episodes.jsonl into {episode_index: {tasks: [...], length: int}}.""" + by_idx: dict[int, dict[str, Any]] = {} + with (dataset_root / "meta" / "episodes.jsonl").open() as f: + for line in f: + row = json.loads(line) + by_idx[row["episode_index"]] = row + return by_idx + + +def _decode_mp4_frames(video_path: Path, stride: int) -> list[np.ndarray]: + """Return every `stride`-th frame from an mp4 as HxWx3 uint8 RGB arrays. + + LeRobot stores each camera as a single mp4 per episode; we decode all + frames sequentially and pick out the strided ones. For Galaxea R1 Lite at + 15 fps, stride=15 gives ~1 Hz — matches the paper's pretraining sampling + cadence (§3.2: "sample condition frames at 1-second intervals"). + """ + frames: list[np.ndarray] = [] + with av.open(str(video_path)) as container: + stream = container.streams.video[0] + stream.thread_type = "AUTO" + for idx, frame in enumerate(container.decode(stream)): + if idx % stride == 0: + frames.append(frame.to_ndarray(format="rgb24")) + return frames + + +def _process_frame( + pipeline: VisualForesightPipeline, + exterior_image: np.ndarray, + subtask_text: str, + *, + guidance_scale: float, + image_guidance_scale: float, + num_inference_steps: int, + seed: int, +) -> tuple[Image.Image, float]: + pil_in = Image.fromarray(exterior_image).convert("RGB") + generator = torch.Generator().manual_seed(seed) + start = time.time() + out = pipeline( + caption=subtask_text, + image=pil_in, + guidance_scale=guidance_scale, + image_guidance_scale=image_guidance_scale, + num_inference_steps=num_inference_steps, + num_images_per_prompt=1, + generator=generator, + ).images + elapsed = time.time() - start + if not out: + raise RuntimeError("VisualForesightPipeline returned no images") + return out[0], elapsed + + +def _parse_episode_indices(raw: str) -> list[int]: + return [int(x.strip()) for x in raw.split(",") if x.strip()] + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="ForeAct foresight over a LeRobot dataset") + parser.add_argument( + "--dataset_root", + type=str, + required=True, + help="Path to a LeRobot-v2.1 dataset dir containing meta/, data/, videos/.", + ) + parser.add_argument( + "--camera_key", + type=str, + default="observation.images.head_left_rgb", + help="Camera feature name. Default matches foreact/configs/finetune.yaml.", + ) + parser.add_argument( + "--episode_indices", + type=str, + default="0", + help="Comma-separated episode indices to process (e.g. '0,1,2,3,4').", + ) + parser.add_argument( + "--stride", + type=int, + default=15, + help="Decode every Nth frame (15 Hz dataset / stride=15 = ~1 Hz).", + ) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--checkpoint", type=str, required=True) + parser.add_argument("--guidance_scale", type=float, default=4.5) + parser.add_argument("--image_guidance_scale", type=float, default=1.5) + parser.add_argument("--num_inference_steps", type=int, default=8) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--force_task", + type=str, + default=None, + help=( + "Override the per-episode task label with this text for every processed " + "episode. Use when rendering a linked chain where later sub-episodes are " + "labeled 'Finish.' but are really continuations of the earlier task." + ), + ) + parser.add_argument( + "--task_schedule", + type=str, + default=None, + help=( + "Path to a JSON file mapping (episode, frame-range) -> subtask text. " + "Expected shape: list of {episode, start_frame, end_frame, task}. " + "Frames are source-frame indices (pre-stride). A frame is conditioned " + "on the first matching entry. Overrides --force_task when both are set." + ), + ) + return parser.parse_args() + + +def _load_task_schedule(path: str | None) -> list[dict[str, Any]] | None: + if path is None: + return None + with Path(path).open() as f: + schedule = json.load(f) + for entry in schedule: + required = {"episode", "start_frame", "end_frame", "task"} + if not required.issubset(entry): + raise ValueError(f"task_schedule entry missing keys {required - entry.keys()}: {entry}") + return schedule + + +def _lookup_scheduled_task( + schedule: list[dict[str, Any]] | None, ep_idx: int, frame_idx: int +) -> str | None: + if schedule is None: + return None + for entry in schedule: + if entry["episode"] == ep_idx and entry["start_frame"] <= frame_idx <= entry["end_frame"]: + return entry["task"] + return None + + +def main() -> None: + args = _parse_args() + dataset_root = Path(args.dataset_root) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + episode_meta = _load_episode_index(dataset_root) + episode_indices = _parse_episode_indices(args.episode_indices) + for idx in episode_indices: + if idx not in episode_meta: + raise SystemExit(f"episode {idx} not in meta/episodes.jsonl") + logger.info( + "Dataset root=%s, processing %d episodes: %s", + dataset_root, + len(episode_indices), + episode_indices, + ) + + logger.info("Loading VisualForesightPipeline from %s ...", args.checkpoint) + pipeline = VisualForesightPipeline.from_pretrained( + find_newest_checkpoint(args.checkpoint), + ignore_mismatched_sizes=True, + _gradient_checkpointing=False, + torch_dtype=torch.bfloat16, + ) + pipeline = pipeline.to(device="cuda", dtype=torch.bfloat16) + logger.info("Pipeline loaded.") + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + schedule = _load_task_schedule(args.task_schedule) + + generation_records: list[dict[str, Any]] = [] + for ep_idx in episode_indices: + meta = episode_meta[ep_idx] + tasks: list[str] = meta.get("tasks") or [] + episode_fallback_task: str | None + if args.force_task is not None: + episode_fallback_task = args.force_task + elif tasks and tasks[0].strip().lower().rstrip(".") != "finish": + episode_fallback_task = tasks[0] + elif schedule is not None and any(e["episode"] == ep_idx for e in schedule): + episode_fallback_task = None # schedule will supply the text + else: + logger.info("Skipping episode %d (task=%r)", ep_idx, tasks) + continue + video_path = ( + dataset_root / "videos" / "chunk-000" / args.camera_key / f"episode_{ep_idx:06d}.mp4" + ) + if not video_path.exists(): + logger.warning("missing video: %s", video_path) + continue + + logger.info("Decoding %s ...", video_path) + frames = _decode_mp4_frames(video_path, stride=args.stride) + logger.info( + "Episode %d: %d strided frames (length=%d) — fallback subtask=%r (schedule=%s)", + ep_idx, + len(frames), + meta.get("length"), + episode_fallback_task, + "yes" if schedule is not None else "no", + ) + + episode_dir = output_dir / f"episode_{ep_idx:06d}" + episode_dir.mkdir(parents=True, exist_ok=True) + + # Also save the actual source frames so the downstream HTML / eyeball + # comparison doesn't need to re-decode the mp4 on the local box. + src_dir = episode_dir / "actual" + src_dir.mkdir(parents=True, exist_ok=True) + + for step_idx, frame_rgb in enumerate(frames): + frame_idx = step_idx * args.stride + Image.fromarray(frame_rgb).save(src_dir / f"frame_{frame_idx:05d}.png") + seed = (args.seed, ep_idx, frame_idx).__hash__() & 0x7FFFFFFF + scheduled = _lookup_scheduled_task(schedule, ep_idx, frame_idx) + per_frame_task = scheduled if scheduled is not None else episode_fallback_task + if per_frame_task is None: + logger.warning( + "No task text for ep%d frame %d (no schedule match + no fallback); skipping", + ep_idx, + frame_idx, + ) + continue + try: + image, elapsed = _process_frame( + pipeline, + frame_rgb, + per_frame_task, + guidance_scale=args.guidance_scale, + image_guidance_scale=args.image_guidance_scale, + num_inference_steps=args.num_inference_steps, + seed=seed, + ) + except Exception as exc: + logger.warning("episode %d frame %d failed: %s", ep_idx, frame_idx, exc) + continue + + out_path = episode_dir / f"frame_{frame_idx:05d}.png" + image.save(out_path) + generation_records.append( + { + "episode_index": ep_idx, + "frame_idx": frame_idx, + "subtask_text": per_frame_task, + "output": str(out_path.relative_to(output_dir)), + "generation_time_s": round(elapsed, 3), + "seed": seed, + } + ) + + if step_idx % 3 == 0 or step_idx == len(frames) - 1: + logger.info( + "[ep%d] step %d/%d (frame %d): %.2fs", + ep_idx, + step_idx + 1, + len(frames), + frame_idx, + elapsed, + ) + + manifest_path = output_dir / "foresight_manifest.json" + with manifest_path.open("w") as f: + json.dump( + { + "dataset_root": str(dataset_root), + "camera_key": args.camera_key, + "stride": args.stride, + "hparams": { + "guidance_scale": args.guidance_scale, + "image_guidance_scale": args.image_guidance_scale, + "num_inference_steps": args.num_inference_steps, + "seed": args.seed, + "checkpoint": args.checkpoint, + }, + "records": generation_records, + }, + f, + indent=2, + ) + logger.info("Wrote %d foresight images to %s", len(generation_records), output_dir) + + latencies = [r["generation_time_s"] for r in generation_records] + if latencies: + logger.info( + "Foresight latency: mean=%.2fs min=%.2fs max=%.2fs (n=%d)", + float(np.mean(latencies)), + float(np.min(latencies)), + float(np.max(latencies)), + len(latencies), + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/generate_foresight_nano_banana.py b/experiments/subtask_probe/droid_eval/foreact_eval/generate_foresight_nano_banana.py new file mode 100644 index 0000000..c289f5c --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/generate_foresight_nano_banana.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 +"""Foresight image generation via Gemini 3.1 Flash Image ("Nano Banana"). + +Drop-in alternative to the ForeAct SANA+Gemma generator. Doesn't require a +GPU or any pretrained robot-specific weights — just the google-genai SDK and +a careful generic scene-rules prompt. + +Two structural choices worth knowing: + +1. **Generic scene-rules prompt + subtask slot.** A single ``SCENE_RULES_TEMPLATE`` + encodes the scene inventory, robot anatomy, grasp physics, and preservation + constraints. Per-frame subtask labels (from ``CHAIN_PHASES`` or, eventually, + a Qwen3-VL planner) slot into ``{subtask}``. This replaces the earlier + per-phase hand-crafted prompts and makes short planner output pluggable. + +2. **Two-image conditioning.** Each API call passes (reference_frame, + current_observation, prompt). The reference frame (ep0 f00, all objects on + the table, nothing occluded) gives the stateless image generator a visual + anchor for object identity. This is specifically to combat the "eggplant + morphs to apple during carry frames" failure mode we saw with single-image + conditioning when the eggplant is occluded inside the closed gripper. +""" + +from __future__ import annotations + +import argparse +import io +import json +import logging +import os +from pathlib import Path +from typing import Any, NamedTuple + +from google import genai +from google.genai import types +from PIL import Image + +from experiments.subtask_probe.droid_eval.foreact_eval._io import ( + ForesightStatus, + foresight_path, + iter_source_frames, +) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +DEFAULT_MODEL = "gemini-3.1-flash-image-preview" + + +# --------------------------------------------------------------------------- +# Generic scene-rules prompt + per-phase subtask slot. +# +# Rationale: the earlier version had a bespoke ~200-word descriptive prompt +# per phase (PICK_UP_PROMPT / PLACE_PROMPT / RETURN_HOME_PROMPT). That made +# each phase's prompt a one-shot artifact — you couldn't plug in a short +# subtask label from a VLM planner (Qwen3-VL) without rewriting the prompt. +# The new design splits out (a) a stable SCENE_RULES prompt that encodes the +# physics and scene inventory once, and (b) a short ``subtask`` string that +# slots in per frame. Now any short subtask label (hardcoded from +# CHAIN_PHASES or dynamic from a planner) drives the generator. +# +# Multi-image input: we also pass a REFERENCE FRAME (ep0 f00, all objects on +# the table with nothing occluded) alongside the current observation. With a +# stateless single-image generator, the eggplant's identity collapses during +# carry frames where it's occluded behind the closed gripper. The reference +# frame gives the model an anchor for what the eggplant looks like. +# --------------------------------------------------------------------------- + +SCENE_RULES_TEMPLATE = """\ +You will be given TWO images followed by a subtask instruction. + +IMAGE 1 — IDENTITY REFERENCE. An early frame of this scene, used ONLY to \ +show you what each object in the scene looks like (color, shape, size, \ +texture). It is NOT a target state to restore to, and object positions \ +shown here may be out of date. Never use IMAGE 1 to determine where an \ +object currently is. + +IMAGE 2 — CURRENT OBSERVATION. The exact current state of the scene. Its \ +pixels are the ground truth for every object's current position and \ +for the camera, lighting, surfaces, and background. Your prediction must \ +preserve IMAGE 2 exactly except where the current subtask requires a \ +change. + +SCENE OBJECTS: the set of objects in the scene is exactly the set of \ +objects visible in IMAGE 1. Do not add, remove, or substitute any \ +object. Use IMAGE 2 for every object's current position. + +ROBOT: the scene contains one robot arm that enters from one side of \ +the frame and is always continuous to that edge — no floating grippers \ +or arms that appear disconnected from the frame boundary. If the arm's \ +anatomy is visible in IMAGE 2, match its visual style exactly; if the \ +arm is not in IMAGE 2, follow the same visual style used in other \ +frames of this episode. The gripper grasps oblong objects by their \ +stem, so a held oblong object hangs BELOW the closed fingers as a \ +visible silhouette. + +CURRENT SUBTASK: {subtask} + +SUBTASK SEMANTICS: a subtask describes one action the ROBOT ARM takes \ +next. It never undoes previous work, and it never moves any object \ +unless that object is explicitly named as being moved to a specific \ +location. A subtask of the form "move X to Y" names both an object (X) \ +and a destination (Y); any other subtask — for example "return to home", \ +"go home", "retract", "finish", "idle" — names neither, so nothing \ +moves except the arm. An object's location at the moment the current \ +subtask starts is the location it keeps, unless the current subtask \ +explicitly relocates it. + +OUTPUT RULES (strict pixel preservation): treat IMAGE 2 as the baseline. \ +Only two kinds of change are allowed between IMAGE 2 and the output: \ +(1) the robot arm's pose may update to reflect progress through the \ +current subtask; (2) any object the current subtask explicitly names as \ +being moved may change position accordingly. Everything else — every \ +other object, the background, the surface, the lighting, the camera — \ +must match IMAGE 2 exactly. An object contained inside another object \ +in IMAGE 2 stays contained. An object resting on a surface in IMAGE 2 \ +stays on that surface. If the current subtask does not explicitly name \ +an object as being moved, that object does not move. + +IDENTITY ANCHORING: if an object is occluded or hard to see in IMAGE 2 \ +(e.g. hidden inside the closed gripper), use IMAGE 1 to recall its \ +identity so you do not hallucinate a different-looking object in its \ +place. Do not morph an object into a different type. + +PREDICTION HORIZON: predict the scene at roughly the half-subtask-ahead \ +point — partway through or at the end of the current subtask's action. +""" + + +class Phase(NamedTuple): + """One subtask's extent within the chain. + + ``start_frame`` / ``end_frame`` are inclusive and refer to raw frame + indices inside ``episode_index`` (same scale as filenames on disk — + stride=5 between consecutive frames). Frames outside any phase are + skipped by both the generator and the visualizer. + """ + + episode_index: int + start_frame: int + end_frame: int + subtask_label: str + + +# Boundaries tuned against the v2 golden chain's actual physics: +# * ep0 f00-25: arm not in frame yet → trim (no API call, no video frame). +# * ep0 f30 → ep1 f25: approach + grasp → "Pick up" subtask. +# * ep1 f30 → ep2 f15: eggplant in air, carried + placed → "Place" subtask. +# * ep2 f20 → f55: arm retracting to home → "Return home" subtask. +# * ep2 f60+: arm already home, scene static → trim. +CHAIN_PHASES: list[Phase] = [ + Phase(episode_index=0, start_frame=30, end_frame=100, subtask_label="Pick up the eggplant."), + Phase(episode_index=1, start_frame=0, end_frame=25, subtask_label="Pick up the eggplant."), + Phase( + episode_index=1, + start_frame=30, + end_frame=75, + subtask_label="Place the eggplant into the plate.", + ), + Phase( + episode_index=2, + start_frame=0, + end_frame=15, + subtask_label="Place the eggplant into the plate.", + ), + Phase(episode_index=2, start_frame=20, end_frame=55, subtask_label="Return to home position."), +] + + +# Default reference frame — ep0 f00 of the v2 chain has all objects on the +# table with nothing occluded, which is what we want as an identity anchor. +DEFAULT_REFERENCE_FRAME = Path( + ".experiments_cache/foreact_eval/foresight_chain_eggplant_v2/episode_000000/actual/frame_00000.png" +) + + +def lookup_phase(episode_index: int, frame_idx: int) -> Phase | None: + """Return the phase that owns (episode_index, frame_idx), or None if trimmed.""" + for phase in CHAIN_PHASES: + if phase.episode_index != episode_index: + continue + if phase.start_frame <= frame_idx <= phase.end_frame: + return phase + return None + + +def _extract_image(response: Any) -> bytes | None: + """Pull the first inline image from a Gemini generate_content response.""" + candidates = getattr(response, "candidates", None) or [] + for candidate in candidates: + content = getattr(candidate, "content", None) + parts = getattr(content, "parts", None) or [] + for part in parts: + inline = getattr(part, "inline_data", None) + if inline is not None and getattr(inline, "data", None): + return inline.data + return None + + +def _extract_text(response: Any) -> str: + """Concatenate any text parts — Gemini sometimes narrates alongside the image.""" + texts: list[str] = [] + candidates = getattr(response, "candidates", None) or [] + for candidate in candidates: + content = getattr(candidate, "content", None) + parts = getattr(content, "parts", None) or [] + for part in parts: + text = getattr(part, "text", None) + if text: + texts.append(text) + return "\n".join(texts).strip() + + +def _generate_one( + client: genai.Client, + model: str, + reference_bytes: bytes, + current_bytes: bytes, + prompt: str, +) -> tuple[bytes | None, str]: + """Call Gemini with (reference_frame, current_frame, prompt) and return (image, text). + + Passing two images in order lets the model anchor object identity to the + reference frame — critical for the "Place" phase where the eggplant is + occluded inside the closed gripper in the current observation. + """ + response = client.models.generate_content( + model=model, + contents=[ + types.Part.from_bytes(data=reference_bytes, mime_type="image/png"), + types.Part.from_bytes(data=current_bytes, mime_type="image/png"), + prompt, + ], + config=types.GenerateContentConfig(response_modalities=["IMAGE", "TEXT"]), + ) + return _extract_image(response), _extract_text(response) + + +def _chain_record( + *, + episode_index: int, + frame_idx: int, + subtask_text: str, + status: ForesightStatus, + output: str | None, + extra: dict[str, Any] | None = None, +) -> dict[str, Any]: + record: dict[str, Any] = { + "episode_index": episode_index, + "frame_idx": frame_idx, + "subtask_text": subtask_text, + "output": output, + "status": status, + } + if extra: + record.update(extra) + return record + + +def _run_chain( + client: genai.Client, + model: str, + v2_root: Path, + output_dir: Path, + reference_frame_path: Path, + force: bool, +) -> None: + """Generate foresight for every in-phase frame of the v2 chain. + + Each call passes (reference_frame, current_frame, generic_prompt) to + anchor object identity. Frames outside all CHAIN_PHASES (e.g. pre-arm + intro frames or post-arm-home tail frames) are trimmed — no API call, + no manifest entry. Output layout matches the ForeAct generator so + downstream visualization tools can point at this directory: + ``/episode_{id:06d}/frame_{idx:05d}.png``. + """ + if not reference_frame_path.exists(): + raise SystemExit(f"reference frame missing: {reference_frame_path}") + reference_bytes = reference_frame_path.read_bytes() + logger.info("Using reference frame: %s (%d bytes)", reference_frame_path, len(reference_bytes)) + + all_source_frames = iter_source_frames(v2_root) + frames = [ + f for f in all_source_frames if lookup_phase(f.episode_index, f.frame_idx) is not None + ] + trimmed = len(all_source_frames) - len(frames) + logger.info("Chain has %d in-phase frames (trimmed %d)", len(frames), trimmed) + + records: list[dict[str, Any]] = [] + for i, frame in enumerate(frames): + phase = lookup_phase(frame.episode_index, frame.frame_idx) + assert phase is not None # filtered above + out_path = foresight_path(output_dir, frame.episode_index, frame.frame_idx) + out_path.parent.mkdir(parents=True, exist_ok=True) + relative_out = str(out_path.relative_to(output_dir)) + progress = f"[{i + 1}/{len(frames)}] ep{frame.episode_index} f{frame.frame_idx:05d}" + + if out_path.exists() and not force: + logger.info("%s: already exists, skipping", progress) + records.append( + _chain_record( + episode_index=frame.episode_index, + frame_idx=frame.frame_idx, + subtask_text=phase.subtask_label, + status="cached", + output=relative_out, + ) + ) + continue + + current_bytes = frame.actual_path.read_bytes() + prompt = SCENE_RULES_TEMPLATE.format(subtask=phase.subtask_label) + logger.info("%s: generating (subtask=%r)", progress, phase.subtask_label) + try: + out_bytes, narration = _generate_one( + client, model, reference_bytes, current_bytes, prompt + ) + except Exception as exc: + logger.warning("%s: failed: %s", progress, exc) + records.append( + _chain_record( + episode_index=frame.episode_index, + frame_idx=frame.frame_idx, + subtask_text=phase.subtask_label, + status="failed", + output=None, + extra={"error": str(exc)}, + ) + ) + continue + if out_bytes is None: + logger.warning("%s: no image (narration=%r)", progress, narration[:200]) + records.append( + _chain_record( + episode_index=frame.episode_index, + frame_idx=frame.frame_idx, + subtask_text=phase.subtask_label, + status="refused", + output=None, + extra={"narration": narration}, + ) + ) + continue + image = Image.open(io.BytesIO(out_bytes)).convert("RGB") + image.save(out_path) + records.append( + _chain_record( + episode_index=frame.episode_index, + frame_idx=frame.frame_idx, + subtask_text=phase.subtask_label, + status="generated", + output=relative_out, + ) + ) + + manifest_path = output_dir / "foresight_manifest.json" + with manifest_path.open("w") as f: + json.dump( + { + "source_dataset": str(v2_root), + "model": model, + "chain_phases": [phase._asdict() for phase in CHAIN_PHASES], + "records": records, + }, + f, + indent=2, + ) + logger.info("Wrote manifest -> %s", manifest_path) + counts: dict[str, int] = {} + for record in records: + counts[record["status"]] = counts.get(record["status"], 0) + 1 + logger.info("Status counts: %s", counts) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Nano Banana foresight generator (generic scene-rules + 2-image conditioning)" + ) + parser.add_argument( + "--v2_root", + type=Path, + default=Path(".experiments_cache/foreact_eval/foresight_chain_eggplant_v2"), + help="Directory containing the v2 golden chain episodes with actual/ frames.", + ) + parser.add_argument( + "--output_dir", + type=Path, + default=Path(".experiments_cache/foreact_eval/foresight_nano_banana_chain"), + ) + parser.add_argument( + "--reference_frame", + type=Path, + default=DEFAULT_REFERENCE_FRAME, + help="Identity-anchor frame shown alongside each current observation. Default: ep0 f00.", + ) + parser.add_argument("--model", type=str, default=DEFAULT_MODEL) + parser.add_argument( + "--force", + action="store_true", + help="Regenerate even if an output PNG already exists.", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + if not os.environ.get("GEMINI_API_KEY") and not os.environ.get("GOOGLE_API_KEY"): + raise SystemExit("Set GEMINI_API_KEY (or GOOGLE_API_KEY) before running.") + + args.output_dir.mkdir(parents=True, exist_ok=True) + client = genai.Client() + _run_chain( + client, + args.model, + args.v2_root, + args.output_dir, + args.reference_frame, + args.force, + ) + logger.info("Done. Outputs in %s", args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/generate_subtasks.py b/experiments/subtask_probe/droid_eval/foreact_eval/generate_subtasks.py new file mode 100644 index 0000000..5cea030 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/generate_subtasks.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +"""Run the ForeAct VLM planner across DROID cache frames. + +Produces ``subtasks_foreact_*.json`` with the same per-frame schema other +subtask generators in this project emit, so ``visualize_foreact.py`` and any +future tooling consume it uniformly. + +Usage (OpenAI-compatible backend, paper's Qwen3-VL-8B-Instruct on local vLLM):: + + uv run python -m experiments.subtask_probe.droid_eval.foreact_eval.generate_subtasks \\ + --samples_dir ./.experiments_cache/droid_eval_ah15 \\ + --output ./.experiments_cache/droid_eval_ah15/subtasks_foreact_qwen8b.json \\ + --backend openai_compat \\ + --base_url http://localhost:8000/v1 \\ + --model Qwen/Qwen3-VL-8B-Instruct +""" + +from __future__ import annotations + +import argparse +import json +import logging +import time +from pathlib import Path +from typing import Any, Literal, assert_never, cast, get_args + +import numpy as np +from dotenv import load_dotenv + +from experiments.subtask_probe.droid_eval.foreact_eval.planner import ( + DEFAULT_BASE_URL, + DEFAULT_GEMINI_MODEL, + DEFAULT_OPENAI_MODEL, + BasePlanner, + GeminiPlanner, + OpenAICompatPlanner, +) +from experiments.subtask_probe.droid_eval.utils import load_manifest + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +Backend = Literal["openai_compat", "gemini"] +BACKEND_CHOICES: tuple[Backend, ...] = get_args(Backend) + + +def _parse_backend(raw: str) -> Backend: + if raw in BACKEND_CHOICES: + return cast(Backend, raw) + raise ValueError(f"Unknown backend: {raw!r} (expected one of {BACKEND_CHOICES})") + + +def _build_planner(backend: Backend, args: argparse.Namespace) -> tuple[BasePlanner, str]: + """Return (planner, backend_label) where backend_label goes into the output JSON.""" + match backend: + case "openai_compat": + base_url = args.base_url or DEFAULT_BASE_URL + model = args.model or DEFAULT_OPENAI_MODEL + planner: BasePlanner = OpenAICompatPlanner( + base_url=base_url, + model=model, + api_key=args.api_key, + ) + return planner, f"{model}@{base_url}" + case "gemini": + model = args.model or DEFAULT_GEMINI_MODEL + planner = GeminiPlanner( + model=model, + thinking_budget=args.thinking_budget, + max_retries=args.max_retries, + ) + return planner, model + case _ as unreachable: + assert_never(unreachable) + + +def _process_episode( + planner: BasePlanner, + samples_dir: Path, + episode: dict[str, Any], + replan_every: int, +) -> list[dict[str, Any]]: + """Run the planner across one episode, reset()-ing state at the start. + + On replan frames, one VLM call is issued. On non-replan frames, the last + subtask is reused without a VLM call. Non-replan latency is recorded as 0s + so the output JSON clearly distinguishes replan frames from reused frames. + """ + planner.reset() + records: list[dict[str, Any]] = [] + episode_id = episode["episode_id"] + instruction = episode["instruction"] + last_subtask = "" + last_previous_finished = False + last_prompt_phase = "" + + for step_idx, frame_info in enumerate(episode["frames"]): + frame_idx = frame_info["frame_idx"] + frame_path = samples_dir / frame_info["file"] + frame_data = np.load(frame_path) + exterior_image = np.asarray(frame_data["exterior_image"], dtype=np.uint8) + + is_replan = step_idx % replan_every == 0 + elapsed = 0.0 + if is_replan: + start = time.time() + try: + result = planner.generate_subtask(instruction, exterior_image) + last_subtask = result["subtask"] + last_previous_finished = result["previous_finished"] + last_prompt_phase = result["prompt_phase"] + except Exception as exc: + logger.warning( + "Planner call failed for %s frame %d: %s", + episode_id, + frame_idx, + exc, + ) + last_subtask = "" + last_previous_finished = False + last_prompt_phase = "error" + elapsed = time.time() - start + + records.append( + { + "episode_id": episode_id, + "frame_idx": frame_idx, + "instruction": instruction, + "subtask_text": last_subtask, + "generation_time_s": round(elapsed, 2), + "server_subtask_ms": round(elapsed * 1000, 1), + "previous_finished": last_previous_finished, + "prompt_phase": last_prompt_phase, + "is_replan": is_replan, + } + ) + + if step_idx % 5 == 0 or step_idx == len(episode["frames"]) - 1 or is_replan: + logger.info( + "[%s] frame %d/%d phase=%s finished=%s: %r (%.1fs)", + episode_id, + step_idx + 1, + len(episode["frames"]), + last_prompt_phase or "reuse", + last_previous_finished, + last_subtask, + elapsed, + ) + + return records + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="ForeAct VLM planner over DROID cache frames") + parser.add_argument("--samples_dir", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + parser.add_argument( + "--backend", + choices=list(BACKEND_CHOICES), + default="openai_compat", + help="VLM backend. Default matches the paper's Qwen3-VL-8B via local vLLM.", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model ID. Defaults: Qwen/Qwen3-VL-8B-Instruct (openai_compat) / " + "gemini-robotics-er-1.6-preview (gemini).", + ) + parser.add_argument( + "--base_url", + type=str, + default=None, + help="OpenAI-compatible server URL. Default: http://localhost:8000/v1", + ) + parser.add_argument( + "--api_key", + type=str, + default="none", + help="API key for openai_compat backend (local vLLM ignores this).", + ) + parser.add_argument( + "--thinking_budget", + type=int, + default=0, + help="Gemini thinking budget (gemini backend only; 0 disables).", + ) + parser.add_argument( + "--max_retries", + type=int, + default=10, + help="Per-call retry budget for 429 / transient network errors (gemini backend).", + ) + parser.add_argument( + "--replan_every", + type=int, + default=1, + help=( + "Issue a new planner call every N cached frames; intermediate frames " + "reuse the previous subtask text. Default 1 — the paper implies " + "per-observation cadence." + ), + ) + parser.add_argument( + "--max_episodes", + type=int, + default=None, + help="Only process the first N episodes (for smoke tests).", + ) + return parser.parse_args() + + +def main() -> None: + load_dotenv() + args = _parse_args() + + samples_dir = Path(args.samples_dir) + manifest = load_manifest(samples_dir) + if args.max_episodes is not None: + manifest = manifest[: args.max_episodes] + logger.info("Loaded manifest: %d episodes", len(manifest)) + + backend = _parse_backend(args.backend) + planner, backend_label = _build_planner(backend, args) + logger.info( + "Backend=%s label=%s replan_every=%d", + backend, + backend_label, + args.replan_every, + ) + + all_records: list[dict[str, Any]] = [] + for episode in manifest: + logger.info( + "Starting episode %s (%d frames): %r", + episode["episode_id"], + len(episode["frames"]), + episode["instruction"], + ) + all_records.extend( + _process_episode(planner, samples_dir, episode, replan_every=args.replan_every) + ) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump( + { + "prompt_format": "foreact_two_turn", + "backend": backend_label, + "results": all_records, + }, + f, + indent=2, + ) + logger.info("Saved %d records to %s (backend=%s)", len(all_records), output_path, backend_label) + + gen_times = [r["generation_time_s"] for r in all_records if r["generation_time_s"] > 0] + if gen_times: + logger.info( + "Planner latency: mean=%.2fs min=%.2fs max=%.2fs (replan calls only, n=%d)", + float(np.mean(gen_times)), + float(np.min(gen_times)), + float(np.max(gen_times)), + len(gen_times), + ) + unique_subtasks = {r["subtask_text"] for r in all_records if r["subtask_text"]} + logger.info("Unique subtask texts: %d / %d frames", len(unique_subtasks), len(all_records)) + empty = sum(1 for r in all_records if not r["subtask_text"]) + if empty: + logger.warning("Empty subtasks: %d / %d frames", empty, len(all_records)) + advances = sum(1 for r in all_records if r["previous_finished"]) + logger.info("Planner reported previous_finished=True on %d frames", advances) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/generate_subtasks_v2chain.py b/experiments/subtask_probe/droid_eval/foreact_eval/generate_subtasks_v2chain.py new file mode 100644 index 0000000..40aadc9 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/generate_subtasks_v2chain.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +"""Run the ForeAct planner (Qwen3-VL) over the v2 eggplant chain. + +Walks the v2 chain's ``episode_*/actual/frame_*.png`` frames in temporal +order, keeps the planner stateful across frames (no reset between linked +episodes since they form one continuous task), and writes per-frame subtask +predictions to JSON. + +This is the "use the paper's VLM correctly" companion to +``generate_foresight_nano_banana.py`` — we've been hardcoding subtask +labels in ``CHAIN_PHASES`` for both foresight generators; this script +replaces that with what a real Qwen3-VL-8B planner would say frame by +frame. +""" + +from __future__ import annotations + +import argparse +import json +import logging +from pathlib import Path +from typing import Any + +import numpy as np +from PIL import Image + +from experiments.subtask_probe.droid_eval.foreact_eval._io import iter_source_frames +from experiments.subtask_probe.droid_eval.foreact_eval.planner import OpenAICompatPlanner + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +DEFAULT_TASK = "Pick up the eggplant and place it into the blue plate." +DEFAULT_MODEL = "Qwen/Qwen3-VL-8B-Instruct" +DEFAULT_BASE_URL = "http://localhost:8000/v1" + + +def _load_rgb(path: Path) -> np.ndarray: + with Image.open(path) as img: + return np.asarray(img.convert("RGB"), dtype=np.uint8) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Qwen3-VL planner over the v2 chain") + parser.add_argument( + "--v2_root", + type=Path, + default=Path(".experiments_cache/foreact_eval/foresight_chain_eggplant_v2"), + ) + parser.add_argument( + "--output", + type=Path, + default=Path(".experiments_cache/foreact_eval/subtasks_qwen3_v2chain.json"), + ) + parser.add_argument("--task", type=str, default=DEFAULT_TASK) + parser.add_argument("--model", type=str, default=DEFAULT_MODEL) + parser.add_argument("--base_url", type=str, default=DEFAULT_BASE_URL) + parser.add_argument( + "--no_schema", + action="store_true", + help="Disable JSON schema enforcement on the VLM response (closer to Table 5).", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + frames = iter_source_frames(args.v2_root) + logger.info( + "Running Qwen3-VL planner over %d frames (task=%r, base_url=%s)", + len(frames), + args.task, + args.base_url, + ) + + planner = OpenAICompatPlanner( + base_url=args.base_url, model=args.model, use_schema=not args.no_schema + ) + logger.info("use_schema=%s", not args.no_schema) + + records: list[dict[str, Any]] = [] + for i, frame in enumerate(frames): + image = _load_rgb(frame.actual_path) + try: + result = planner.generate_subtask(args.task, image) + except Exception as exc: + logger.warning("ep%d f%05d failed: %s", frame.episode_index, frame.frame_idx, exc) + result = { + "subtask": "", + "previous_finished": False, + "prompt_phase": "error", + "error": str(exc), + } + records.append( + { + "episode_index": frame.episode_index, + "frame_idx": frame.frame_idx, + **result, + } + ) + logger.info( + "[%d/%d] ep%d f%05d (%s): previous_finished=%s subtask=%r", + i + 1, + len(frames), + frame.episode_index, + frame.frame_idx, + result.get("prompt_phase", "?"), + result.get("previous_finished"), + result.get("subtask"), + ) + + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text( + json.dumps( + { + "task": args.task, + "model": args.model, + "base_url": args.base_url, + "results": records, + }, + indent=2, + ) + ) + logger.info("Wrote %s", args.output) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/planner.py b/experiments/subtask_probe/droid_eval/foreact_eval/planner.py new file mode 100644 index 0000000..e83551b --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/planner.py @@ -0,0 +1,327 @@ +"""ForeAct VLM planner with the paper's Table 5 two-turn prompt. + +Ported verbatim from Appendix A.5 of the ForeAct paper (arxiv 2602.12322). +The prompt strings are part of the method — do not reword them. + +Per-episode state is minimal: just ``previous_subtask: str | None``. The +"reason-execute-monitor" cycle relies on the VLM re-deriving the plan +latently each turn; there is no explicit plan list or status marker. + +We additionally enforce a JSON schema on the response. The paper only says +"concise and deterministic", but our Comet experience showed free-form output +adds 500ms-3s of latency and causes semantic drift. The ``subtask`` field is +what the paper actually uses; the ``previous_finished`` bool is ours, added +purely for observability (so we can log when the planner advances). +""" + +from __future__ import annotations + +import base64 +import json +import logging +from abc import ABC, abstractmethod +from typing import Any, Literal, cast + +import numpy as np +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from experiments.subtask_probe.droid_eval.comet_style._gemini_utils import ( + call_with_retry, + encode_png, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Prompts (Table 5, Appendix A.5 of foreact.pdf). Only ``{task}`` substitutes. +# --------------------------------------------------------------------------- + +INITIAL_PROMPT_TEMPLATE = """\ +You are a robot controller. Please plan to finish the task in several steps. \ +And give instruction for each step in a concise way. +The task is to "{task}". + +RULES: +\u2022 During the job, I will continuously give you an observation image of the current state. +\u2022 Based on the observation, please judge if the last instruction has been finished. + - If yes, give me the instruction for the next step. + - If no, repeat the instruction of the ongoing subtask. +\u2022 You're not required to describe the observation. Only output the instruction for each subtask. + +Now, you are only required to output instruction for the first step.""" + + +FOLLOW_UP_PROMPT_TEMPLATE = """\ +Pay attention to the latest observation. Firstly, judge if the last instruction \ +has been finished. Secondly, if yes, give me the instruction for the next step; \ +if no, repeat the instruction of the ongoing subtask. +Your answer should be concise and deterministic. +Remember, your Overall Task is "{task}".""" + + +SUBTASK_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "subtask": {"type": "string"}, + "previous_finished": {"type": "boolean"}, + }, + "required": ["subtask", "previous_finished"], + "additionalProperties": False, +} + + +PromptPhase = Literal["initial", "follow_up"] +ImagePosition = Literal["start", "end"] + + +# --------------------------------------------------------------------------- +# Base planner with the per-episode state + dispatch logic +# --------------------------------------------------------------------------- + + +class BasePlanner(ABC): + """Stateful per-episode ForeAct planner; subclasses implement the VLM call. + + ``use_schema`` controls whether we force ``{"subtask": str, + "previous_finished": bool}`` JSON output via the backend's response-format + hook. The paper (Table 5, §3.3) doesn't mention any schema — it only asks + for "concise and deterministic" free-form text. Our schema is an addition + that helped the Comet-style reasoner but may be hurting step-level + decomposition here (see FINDINGS). Set ``False`` to reproduce the paper + more literally. + """ + + def __init__(self, use_schema: bool = True) -> None: + self.previous_subtask: str | None = None + self._use_schema = use_schema + + def reset(self) -> None: + self.previous_subtask = None + + def generate_subtask(self, task: str, current_image: np.ndarray) -> dict[str, Any]: + """Run one planner turn and update ``previous_subtask`` in place. + + Returns ``{"subtask": str, "previous_finished": bool, "prompt_phase": str}``. + On parse failure returns empty subtask_text + previous_finished=False so + the outer loop can keep marching without crashing the episode. + """ + if self.previous_subtask is None: + prompt = INITIAL_PROMPT_TEMPLATE.format(task=task) + # Table 5: "VISUAL INPUT: [Initial Observation Image]" appears at + # the END of the initial prompt block. + image_position: ImagePosition = "end" + prompt_phase: PromptPhase = "initial" + else: + prompt = FOLLOW_UP_PROMPT_TEMPLATE.format(task=task) + # Table 5: "VISUAL INPUT: [Current Observation Image]" appears at + # the START of the follow-up prompt block. + image_position = "start" + prompt_phase = "follow_up" + + raw = self._chat( + prompt=prompt, + image=current_image, + image_position=image_position, + response_schema=SUBTASK_SCHEMA if self._use_schema else None, + ) + parsed = ( + _parse_response_json(raw) + if self._use_schema + else _parse_response_freeform(raw, self.previous_subtask) + ) + if parsed["subtask"]: + self.previous_subtask = parsed["subtask"] + return {**parsed, "prompt_phase": prompt_phase} + + @abstractmethod + def _chat( + self, + prompt: str, + image: np.ndarray, + image_position: ImagePosition, + response_schema: dict[str, Any] | None, + ) -> str: + """Return the raw VLM output — JSON-shape when ``response_schema`` is set, + free-form text when it is ``None``.""" + + +def _parse_response_json(raw: str) -> dict[str, Any]: + try: + payload = json.loads(raw) + except json.JSONDecodeError: + logger.warning("Planner response was not valid JSON: %r", raw[:200]) + return {"subtask": "", "previous_finished": False} + subtask = str(payload.get("subtask") or "").strip() + previous_finished = bool(payload.get("previous_finished", False)) + return {"subtask": subtask, "previous_finished": previous_finished} + + +def _parse_response_freeform(raw: str, previous_subtask: str | None) -> dict[str, Any]: + """Parse a free-form VLM response. + + The paper's prompt asks for a concise instruction with no other text, so + the common case is a single short sentence. We strip surrounding quotes / + whitespace and, as a safety net, take the last non-empty line if the model + was chatty. ``previous_finished`` is inferred from whether the new subtask + string differs from ``previous_subtask`` — this matches the paper's + semantics where "if finished, give the next step; if not, repeat the + current one". + """ + cleaned = raw.strip().strip("\"'`") + lines = [line.strip().strip("\"'`") for line in cleaned.splitlines() if line.strip()] + subtask = lines[-1] if lines else "" + previous_finished = bool( + previous_subtask is not None and subtask and subtask != previous_subtask + ) + return {"subtask": subtask, "previous_finished": previous_finished} + + +# --------------------------------------------------------------------------- +# Backend A — OpenAI-compatible (targets local vLLM hosting Qwen3-VL-8B-Instruct, +# which is the exact model the paper uses for both VLM+\u03c0_0 and ForeAct \u00a74.3) +# --------------------------------------------------------------------------- + + +DEFAULT_BASE_URL = "http://localhost:8000/v1" +DEFAULT_OPENAI_MODEL = "Qwen/Qwen3-VL-8B-Instruct" + + +class OpenAICompatPlanner(BasePlanner): + def __init__( + self, + base_url: str = DEFAULT_BASE_URL, + model: str = DEFAULT_OPENAI_MODEL, + api_key: str = "none", + temperature: float = 1.0, + timeout_s: float = 600.0, + use_schema: bool = True, + ) -> None: + super().__init__(use_schema=use_schema) + self._client = OpenAI(base_url=base_url, api_key=api_key, timeout=timeout_s) + self._model = model + self._temperature = temperature + # Conversation history for the reason-execute-monitor cycle. Each + # turn appends (user_message_text_only, assistant_response). We strip + # image parts from the stored user message — keeping 57 base64- + # encoded images in context would blow past Qwen3-VL's 32k window + # within ~20 turns. The model's plan persistence relies on its own + # earlier assistant text, not on re-observing previous frames. + self._conversation: list[ChatCompletionMessageParam] = [] + + def reset(self) -> None: + super().reset() + self._conversation = [] + + def _chat( + self, + prompt: str, + image: np.ndarray, + image_position: ImagePosition, + response_schema: dict[str, Any] | None, + ) -> str: + png_bytes = encode_png(image) + data_url = f"data:image/png;base64,{base64.b64encode(png_bytes).decode('utf-8')}" + image_part: dict[str, Any] = {"type": "image_url", "image_url": {"url": data_url}} + text_part: dict[str, Any] = {"type": "text", "text": prompt} + content = [text_part, image_part] if image_position == "end" else [image_part, text_part] + # Active call sees: full conversation history + the current turn + # (with the CURRENT image). Previous turns are text-only in history. + current_user_msg = cast( + ChatCompletionMessageParam, {"role": "user", "content": content} + ) + messages: list[ChatCompletionMessageParam] = [*self._conversation, current_user_msg] + kwargs: dict[str, Any] = { + "model": self._model, + "messages": messages, + "temperature": self._temperature, + } + if response_schema is not None: + kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "foreact_subtask", + "schema": response_schema, + "strict": True, + }, + } + response = self._client.chat.completions.create(**kwargs) + assistant_text = (response.choices[0].message.content or "").strip() + + # Persist the turn to history. Strip the image part to keep context + # cheap; the assistant's plan lives in the text it emits. + user_text_only = cast( + ChatCompletionMessageParam, {"role": "user", "content": prompt} + ) + assistant_msg = cast( + ChatCompletionMessageParam, {"role": "assistant", "content": assistant_text} + ) + self._conversation.append(user_text_only) + self._conversation.append(assistant_msg) + + return assistant_text + + +# --------------------------------------------------------------------------- +# Backend B — Gemini (optional; the paper doesn't use it, but we have keys set +# up from the Comet work and it's a cheap apples-to-apples vs. our prior runs) +# --------------------------------------------------------------------------- + + +DEFAULT_GEMINI_MODEL = "gemini-robotics-er-1.6-preview" + + +class GeminiPlanner(BasePlanner): + def __init__( + self, + model: str = DEFAULT_GEMINI_MODEL, + thinking_budget: int = 0, + max_retries: int = 10, + request_timeout_s: float = 120.0, + use_schema: bool = True, + ) -> None: + super().__init__(use_schema=use_schema) + # google-genai is an optional dependency for this backend only — + # import lazily so OpenAI-only users don't need it installed. + from google import genai + from google.genai import types + + self._types = types + self._client = genai.Client( + http_options=types.HttpOptions(timeout=int(request_timeout_s * 1000)) + ) + self._model = model + self._thinking_budget = thinking_budget + self._max_retries = max_retries + + def _chat( + self, + prompt: str, + image: np.ndarray, + image_position: ImagePosition, + response_schema: dict[str, Any] | None, + ) -> str: + types = self._types + image_part = types.Part.from_bytes(data=encode_png(image), mime_type="image/png") + contents: list[Any] = ( + [prompt, image_part] if image_position == "end" else [image_part, prompt] + ) + config_kwargs: dict[str, Any] = { + "temperature": 1.0, + "thinking_config": types.ThinkingConfig(thinking_budget=self._thinking_budget), + } + if response_schema is not None: + config_kwargs["response_mime_type"] = "application/json" + config_kwargs["response_schema"] = response_schema + config = types.GenerateContentConfig(**config_kwargs) + + def _call() -> str: + response = self._client.models.generate_content( + model=self._model, + contents=contents, + config=config, + ) + return (response.text or "").strip() + + return call_with_retry(_call, max_retries=self._max_retries) diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/visualize_foreact.py b/experiments/subtask_probe/droid_eval/foreact_eval/visualize_foreact.py new file mode 100644 index 0000000..a47e7ba --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/visualize_foreact.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +"""Render exterior | wrist | predicted foresight | subtask per DROID frame. + +Extends ``visualize_subtasks.py`` with a third image column that pulls the +foresight PNG generated by ``generate_foresight.py``. If a foresight image is +missing for a given frame (e.g. the generator skipped it), a grey placeholder +with the text "(no foresight)" is shown so the row still aligns. + +Usage:: + + uv run python -m experiments.subtask_probe.droid_eval.foreact_eval.visualize_foreact \\ + --samples_dir ./.experiments_cache/droid_eval_ah15 \\ + --subtasks ./.experiments_cache/droid_eval_ah15/subtasks_foreact_qwen8b.json \\ + --foresight_dir ./.experiments_cache/droid_eval_ah15/foresight_foreact \\ + --output_dir ./.experiments_cache/droid_eval_ah15/foreact_report \\ + [--video --fps 2] +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +import imageio.v3 as iio3 +import numpy as np +from PIL import Image + +from experiments.subtask_probe.droid_eval.utils import load_manifest, load_subtask_index +from experiments.subtask_probe.droid_eval.visualize_subtasks import ( + _composite_side_by_side, + _draw_caption, + _encode_png_base64, + _load_frame_images, +) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def _load_foresight_image( + foresight_dir: Path, episode_id: str, frame_idx: int, target_h: int, target_w: int +) -> np.ndarray: + """Load the foresight PNG for a given frame, resizing to (target_h, target_w). + + Returns a grey placeholder if the PNG is missing so the HTML/video grid + stays aligned even when the generator skipped frames. + """ + path = foresight_dir / episode_id / f"frame_{frame_idx:05d}.png" + if not path.exists(): + return np.full((target_h, target_w, 3), 40, dtype=np.uint8) + with Image.open(path) as img: + resized = img.convert("RGB").resize((target_w, target_h), Image.Resampling.LANCZOS) + return np.asarray(resized, dtype=np.uint8) + + +def _write_html( + out_path: Path, + episode_id: str, + instruction: str, + items: list[tuple[int, str, dict[str, np.ndarray]]], +) -> None: + """Emit a self-contained HTML with exterior | wrist | foresight | subtask.""" + rows = [] + for frame_idx, subtask, views in items: + ext_b64 = _encode_png_base64(views["exterior"]) + wrist_b64 = _encode_png_base64(views["wrist"]) + fore_b64 = _encode_png_base64(views["foresight"]) + rows.append( + f"" + f"{frame_idx}" + f"
exterior
" + f"" + f"
wrist
" + f"" + f"
predicted foresight
" + f"" + f"{subtask or '(empty)'}" + f"" + ) + html = ( + "" + f"ForeAct \u2014 {episode_id}" + "" + f"

{episode_id}

" + f"

Instruction: {instruction} · {len(items)} frames

" + "" + "" + "" + "" + "".join(rows) + "
frameexteriorwristpredicted foresightsubtask
" + "" + ) + out_path.write_text(html) + + +def _write_mp4( + out_path: Path, + frames_with_captions: list[np.ndarray], + fps: int, +) -> None: + iio3.imwrite( + out_path, + np.stack(frames_with_captions), + fps=fps, + codec="libx264", + macro_block_size=1, + ) + + +def _process_episode( + samples_dir: Path, + foresight_dir: Path, + episode: dict, + subtask_index: dict[tuple[str, int], str], + output_dir: Path, + video: bool, + fps: int, +) -> None: + episode_id = episode["episode_id"] + instruction = episode["instruction"] + + items: list[tuple[int, str, dict[str, np.ndarray]]] = [] + captioned: list[np.ndarray] = [] + + for frame_info in episode["frames"]: + frame_idx = frame_info["frame_idx"] + views = _load_frame_images(samples_dir / frame_info["file"]) + ext_h, ext_w = views["exterior"].shape[:2] + foresight = _load_foresight_image(foresight_dir, episode_id, frame_idx, ext_h, ext_w) + views["foresight"] = foresight + subtask = subtask_index.get((episode_id, frame_idx), "") + items.append((frame_idx, subtask, views)) + if video: + composite = _composite_side_by_side(views) + ext_w = views["exterior"].shape[1] + wrist_w = views["wrist"].shape[1] + separator = 4 + captioned.append( + _draw_caption( + composite, + caption=subtask, + footer=f"{episode_id} frame {frame_idx} | task: {instruction}", + camera_labels=[ + (4, "EXTERIOR"), + (ext_w + separator + 4, "WRIST"), + (ext_w + wrist_w + 2 * separator + 4, "FORESIGHT"), + ], + ) + ) + + if video: + mp4_path = output_dir / f"{episode_id}.mp4" + _write_mp4(mp4_path, captioned, fps=fps) + logger.info("%s -> %s (%d frames, %d fps)", episode_id, mp4_path, len(captioned), fps) + else: + html_path = output_dir / f"{episode_id}.html" + _write_html(html_path, episode_id, instruction, items) + logger.info("%s -> %s (%d frames)", episode_id, html_path, len(items)) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Render ForeAct foresight + subtask grid") + parser.add_argument("--samples_dir", type=str, required=True) + parser.add_argument("--subtasks", type=str, required=True) + parser.add_argument( + "--foresight_dir", + type=str, + required=True, + help="Directory containing foresight/{episode_id}/frame_{idx:05d}.png files.", + ) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--video", action="store_true") + parser.add_argument("--fps", type=int, default=2) + args = parser.parse_args() + + samples_dir = Path(args.samples_dir) + foresight_dir = Path(args.foresight_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + manifest = load_manifest(samples_dir) + subtask_index = load_subtask_index(Path(args.subtasks)) + logger.info( + "Loaded %d episodes, %d subtask records, foresight_dir=%s", + len(manifest), + len(subtask_index), + foresight_dir, + ) + + for episode in manifest: + _process_episode( + samples_dir=samples_dir, + foresight_dir=foresight_dir, + episode=episode, + subtask_index=subtask_index, + output_dir=output_dir, + video=args.video, + fps=args.fps, + ) + + logger.info("Done. Output: %s", output_dir) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/foreact_eval/visualize_nano_banana.py b/experiments/subtask_probe/droid_eval/foreact_eval/visualize_nano_banana.py new file mode 100644 index 0000000..d437aaa --- /dev/null +++ b/experiments/subtask_probe/droid_eval/foreact_eval/visualize_nano_banana.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +"""Render the nano-banana foresight chain as a side-by-side mp4. + +``actual | foresight`` per frame, subtask caption overlaid on the bottom, +episode/frame footer under it. Runs at 2 fps by default to match the +ForeAct v2 golden chain mp4. + +Caption/composite drawing is shared with ``visualize_subtasks.py`` so this +visualizer and ``visualize_foreact.py`` produce stylistically matched +videos. Episode-phase subtask labels come from +``generate_foresight_nano_banana.EPISODE_PHASES`` so the label shown in the +mp4 is literally the same string the generator wrote into its manifest. +""" + +from __future__ import annotations + +import argparse +import json +import logging +from pathlib import Path + +import imageio.v3 as iio3 +import numpy as np +from PIL import Image, ImageDraw + +from experiments.subtask_probe.droid_eval.foreact_eval._io import ( + foresight_path, + iter_source_frames, +) +from experiments.subtask_probe.droid_eval.foreact_eval.generate_foresight_nano_banana import ( + lookup_phase, +) +from experiments.subtask_probe.droid_eval.visualize_subtasks import ( + _composite_side_by_side, + _draw_caption, +) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +FRAME_W = 480 +FRAME_H = 270 +SEPARATOR_PX = 4 + + +def _load_resized(path: Path, *, w: int, h: int) -> np.ndarray: + with Image.open(path) as img: + resized = img.convert("RGB").resize((w, h), Image.Resampling.LANCZOS) + return np.asarray(resized, dtype=np.uint8) + + +def _missing_placeholder(w: int, h: int, text: str) -> np.ndarray: + img = Image.new("RGB", (w, h), color=(40, 40, 40)) + draw = ImageDraw.Draw(img) + tw = draw.textlength(text) + draw.text(((w - tw) / 2, h / 2 - 8), text, fill=(200, 200, 200)) + return np.asarray(img, dtype=np.uint8) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Render nano-banana chain as mp4") + parser.add_argument( + "--v2_root", + type=Path, + default=Path(".experiments_cache/foreact_eval/foresight_chain_eggplant_v2"), + help="Source 'actual' frames dir (episode_*/actual/frame_*.png).", + ) + parser.add_argument( + "--foresight_dir", + type=Path, + default=Path(".experiments_cache/foreact_eval/foresight_nano_banana_chain"), + help="Nano-banana foresight dir (episode_*/frame_*.png).", + ) + parser.add_argument( + "--output_path", + type=Path, + default=Path( + ".experiments_cache/foreact_eval/foresight_nano_banana_chain/chain_nano_banana.mp4" + ), + ) + parser.add_argument("--fps", type=int, default=2) + parser.add_argument( + "--subtasks_json", + type=Path, + default=None, + help=( + "Optional per-frame subtasks JSON (shape: {'results': [{episode_index, " + "frame_idx, subtask}, ...]}). When provided, overrides CHAIN_PHASES: " + "captions come from this file and every frame with an entry is rendered " + "(no phase-based trimming)." + ), + ) + return parser.parse_args() + + +def _load_subtasks_by_frame(path: Path) -> dict[tuple[int, int], str]: + payload = json.loads(path.read_text()) + records = payload.get("results") if isinstance(payload, dict) else payload + by_frame: dict[tuple[int, int], str] = {} + for record in records or []: + subtask = (record.get("subtask") or "").strip() + if not subtask: + continue + by_frame[(record["episode_index"], record["frame_idx"])] = subtask + return by_frame + + +def main() -> None: + args = _parse_args() + all_frames = iter_source_frames(args.v2_root) + + if args.subtasks_json is not None: + per_frame_subtasks = _load_subtasks_by_frame(args.subtasks_json) + captioned = [ + (f, per_frame_subtasks[(f.episode_index, f.frame_idx)]) + for f in all_frames + if (f.episode_index, f.frame_idx) in per_frame_subtasks + ] + logger.info( + "Rendering %d frames from %s (@ %d fps)", len(captioned), args.subtasks_json, args.fps + ) + else: + phased = [(f, lookup_phase(f.episode_index, f.frame_idx)) for f in all_frames] + captioned = [(f, p.subtask_label) for (f, p) in phased if p is not None] + logger.info( + "Rendering %d in-phase frames (trimmed %d) @ %d fps", + len(captioned), + len(all_frames) - len(captioned), + args.fps, + ) + if not captioned: + raise SystemExit(f"No frames to render from {args.v2_root}") + + composites: list[np.ndarray] = [] + for frame, subtask in captioned: + actual = _load_resized(frame.actual_path, w=FRAME_W, h=FRAME_H) + foresight_file = foresight_path(args.foresight_dir, frame.episode_index, frame.frame_idx) + foresight = ( + _load_resized(foresight_file, w=FRAME_W, h=FRAME_H) + if foresight_file.exists() + else _missing_placeholder(FRAME_W, FRAME_H, "(no foresight)") + ) + composite = _composite_side_by_side( + {"actual": actual, "foresight": foresight}, separator_px=SEPARATOR_PX + ) + footer = f"episode {frame.episode_index:03d} \u00b7 frame {frame.frame_idx:05d}" + composites.append( + _draw_caption( + composite, + caption=subtask, + footer=footer, + camera_labels=[ + (4, "ACTUAL"), + (FRAME_W + SEPARATOR_PX + 4, "FORESIGHT"), + ], + ) + ) + + args.output_path.parent.mkdir(parents=True, exist_ok=True) + iio3.imwrite( + args.output_path, + np.stack(composites), + fps=args.fps, + codec="libx264", + macro_block_size=1, + ) + logger.info("Wrote %s", args.output_path) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/generate_subtasks.py b/experiments/subtask_probe/droid_eval/generate_subtasks.py new file mode 100644 index 0000000..aedbedc --- /dev/null +++ b/experiments/subtask_probe/droid_eval/generate_subtasks.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +"""Phase 1: Generate subtask text for DROID frames via deployed server. + +Sends each cached DROID frame to the deployed server with mode="subtask_only" +to get subtask text, then caches the results for Phase 2. + +Optionally configures the subtask prompt format on the server via the admin +HTTP endpoint before generation, so the same script can drive prompt-format +A/B tests against a single deployment. + +Usage: + # Use whichever subtask prompt format the server is currently configured with: + uv run python experiments/subtask_probe/droid_eval/generate_subtasks.py \ + --samples_dir ./.experiments_cache/droid_eval \ + --output ./.experiments_cache/droid_eval/subtasks.json \ + --server 43.200.36.250 + + # Override the subtask prompt format on the server before running: + uv run python experiments/subtask_probe/droid_eval/generate_subtasks.py \ + --samples_dir ./.experiments_cache/droid_eval \ + --prompt_format '{task}' \ + --output ./.experiments_cache/droid_eval/subtasks_raw.json \ + --server 43.200.36.250 +""" + +from __future__ import annotations + +import argparse +import json +import logging +import time +from pathlib import Path +from typing import Any + +import httpx +import numpy as np + +from hosting.admin_server import DEFAULT_ADMIN_PORT +from hosting.flash_transport_policy import FlashTransportPolicy + +from .constants import DEFAULT_QUIC_PORT +from .utils import build_subtask_observation, build_warmup_observation, load_manifest + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", force=True) +logger = logging.getLogger(__name__) + + +def _admin_request( + server: str, + admin_port: int, + method: str, + body: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Send a request to the server's admin HTTP endpoint and return the parsed JSON response.""" + url = f"http://{server}:{admin_port}/config" + try: + response = httpx.request(method, url, json=body, timeout=10.0) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + raise RuntimeError( + f"Admin endpoint returned {exc.response.status_code} for {method} {url}: " + f"{exc.response.text}" + ) from exc + except httpx.HTTPError as exc: + raise RuntimeError(f"Could not reach admin endpoint at {url}: {exc}") from exc + return response.json() + + +def _set_server_prompt_format(server: str, admin_port: int, prompt_format: str) -> str: + """PATCH the server's subtask prompt format and verify it was applied.""" + payload = _admin_request( + server, admin_port, method="PATCH", body={"generation_prompt_format": prompt_format} + ) + actual = payload.get("generation_prompt_format") + if actual != prompt_format: + raise RuntimeError( + f"Admin endpoint did not apply prompt format. Requested {prompt_format!r}, " + f"server reports {actual!r}." + ) + return prompt_format + + +def _get_server_prompt_format(server: str, admin_port: int) -> str: + """Fetch the server's currently-configured subtask prompt format.""" + payload = _admin_request(server, admin_port, method="GET") + prompt_format = payload.get("generation_prompt_format") + if not isinstance(prompt_format, str): + raise RuntimeError( + f"Admin endpoint returned no generation_prompt_format field: {payload!r}" + ) + return prompt_format + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate subtasks for DROID frames via server") + parser.add_argument( + "--samples_dir", + type=str, + required=True, + help="Directory with extracted DROID samples (from extract_droid_samples.py)", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output JSON file for subtask cache", + ) + parser.add_argument( + "--server", + type=str, + required=True, + help="Server address (e.g., 43.200.36.250)", + ) + parser.add_argument("--port", type=int, default=DEFAULT_QUIC_PORT, help="Server QUIC port") + parser.add_argument( + "--admin_host", + type=str, + default=None, + help=( + "Host for the admin HTTP endpoint. Defaults to --server. Use 127.0.0.1 with " + "an SSH tunnel when the deployed server binds admin to localhost." + ), + ) + parser.add_argument( + "--admin_port", + type=int, + default=DEFAULT_ADMIN_PORT, + help="Admin HTTP port for runtime config (used to set/read --prompt_format)", + ) + parser.add_argument( + "--prompt_format", + type=str, + default=None, + help=( + "Subtask prompt format to install on the server before generation, " + "e.g. 'Task: {task}. Subtask: ' or '{task}'. Must contain the literal " + "'{{task}}' placeholder. If omitted, the server's current format is used." + ), + ) + args = parser.parse_args() + + samples_dir = Path(args.samples_dir) + manifest = load_manifest(samples_dir) + logger.info("Loaded manifest: %d episodes", len(manifest)) + + # Set or read the active subtask prompt format on the server before any inference. + # The runtime config is read on every generate() call server-side, so this takes + # effect on the next request. + admin_host = args.admin_host or args.server + if args.prompt_format is not None: + active_prompt_format = _set_server_prompt_format( + admin_host, args.admin_port, args.prompt_format + ) + logger.info("Set server subtask prompt format to %r", active_prompt_format) + else: + active_prompt_format = _get_server_prompt_format(admin_host, args.admin_port) + logger.info("Using server's current subtask prompt format: %r", active_prompt_format) + + # Connect to server via QUIC + policy = FlashTransportPolicy(args.server, port=args.port) + logger.info("Connected to server at %s:%d via QUIC", args.server, args.port) + + # Warm up connection + policy.infer(build_warmup_observation(mode="subtask_only")) + logger.info("Server warmup complete") + + # Process all frames + subtask_results = [] + total_frames = sum(ep["num_frames"] for ep in manifest) + processed = 0 + + for episode in manifest: + episode_id = episode["episode_id"] + instruction = episode["instruction"] + + for frame_info in episode["frames"]: + frame_file = samples_dir / frame_info["file"] + frame_data = np.load(frame_file) + + # Send raw uint8 images — the server's _normalize_image() handles + # conversion to float32 [-1, 1] and resize_with_pad to 224x224. + obs = build_subtask_observation( + exterior_image=frame_data["exterior_image"], + wrist_image=frame_data["wrist_image"], + prompt=instruction, + ) + + start_time = time.time() + result = policy.infer(obs) + elapsed = time.time() - start_time + + subtask_info = result.get("subtask", {}) + subtask_text = subtask_info.get("text", "") or result.get("subtask_text", "") + subtask_ms = subtask_info.get("ms", elapsed * 1000) + + subtask_results.append( + { + "episode_id": episode_id, + "frame_idx": frame_info["frame_idx"], + "instruction": instruction, + "subtask_text": subtask_text, + "generation_time_s": round(elapsed, 2), + "server_subtask_ms": round(subtask_ms, 1), + } + ) + + processed += 1 + if processed % 5 == 0 or processed == total_frames: + logger.info( + "[%d/%d] %s frame %d: '%s' (%.1fs)", + processed, + total_frames, + episode_id, + frame_info["frame_idx"], + subtask_text, + elapsed, + ) + + # Save results — wrap in a self-describing dict so consumers know which prompt + # format produced these subtasks. load_subtask_records() handles both this shape + # and the legacy bare-list shape for backward compatibility. + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump( + {"prompt_format": active_prompt_format, "results": subtask_results}, + f, + indent=2, + ) + + logger.info( + "Subtask generation complete: %d results saved to %s (prompt_format=%r)", + len(subtask_results), + output_path, + active_prompt_format, + ) + + # Summary stats + gen_times = [r["generation_time_s"] for r in subtask_results] + logger.info( + "Latency: mean=%.2fs, min=%.2fs, max=%.2fs", + np.mean(gen_times), + np.min(gen_times), + np.max(gen_times), + ) + unique_subtasks = {r["subtask_text"] for r in subtask_results} + logger.info( + "Unique subtask texts: %d out of %d frames", len(unique_subtasks), len(subtask_results) + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/generate_subtasks_gemini.py b/experiments/subtask_probe/droid_eval/generate_subtasks_gemini.py new file mode 100644 index 0000000..a9a695e --- /dev/null +++ b/experiments/subtask_probe/droid_eval/generate_subtasks_gemini.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +"""Phase 1 (alt): Generate subtask text for DROID frames via Gemini Robotics-ER. + +Drop-in alternative to generate_subtasks.py that swaps the deployed pi0.5 server +for Google's Gemini Robotics-ER 1.6 Preview model, called directly via the +google-genai SDK. The on-disk JSON schema is identical so compare_subtask_outputs.py +and run_action_eval.py work unchanged. + +Requires GEMINI_API_KEY in the environment (or in a .env file loaded via +python-dotenv, which is already a project dep). + +Usage: + uv run python experiments/subtask_probe/droid_eval/generate_subtasks_gemini.py \\ + --samples_dir ./.experiments_cache/droid_eval \\ + --output ./.experiments_cache/droid_eval/subtasks_gemini.json + + # Override the prompt template: + uv run python experiments/subtask_probe/droid_eval/generate_subtasks_gemini.py \\ + --samples_dir ./.experiments_cache/droid_eval \\ + --prompt_format 'Task: {task}. What is the robot doing right now? Reply in 4 words.' \\ + --output ./.experiments_cache/droid_eval/subtasks_gemini_terse.json +""" + +from __future__ import annotations + +import argparse +import json +import logging +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any + +import numpy as np +from dotenv import load_dotenv +from google import genai +from google.genai import types + +from .comet_style._gemini_utils import call_with_retry, encode_png +from .utils import load_manifest + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = "gemini-robotics-er-1.6-preview" + +DEFAULT_PROMPT_FORMAT = ( + 'You are observing a robot performing: "{task}". Looking at the exterior and ' + "wrist camera views of the current moment, describe the immediate next subtask " + "the robot should do in 3 to 6 words, as a lowercase imperative phrase with no " + "trailing period. Respond with only that phrase." +) + + +def _generate_one( + client: genai.Client, + model: str, + prompt_text: str, + exterior_png: bytes, + wrist_png: bytes, + thinking_budget: int, + max_retries: int, +) -> tuple[str, float]: + """Call Gemini with retry-on-429. Returns (subtask_text, elapsed_seconds).""" + start_time = time.time() + + def _call() -> str: + response = client.models.generate_content( + model=model, + contents=[ + "Exterior camera view:", + types.Part.from_bytes(data=exterior_png, mime_type="image/png"), + "Wrist camera view:", + types.Part.from_bytes(data=wrist_png, mime_type="image/png"), + prompt_text, + ], + config=types.GenerateContentConfig( + temperature=1.0, + thinking_config=types.ThinkingConfig(thinking_budget=thinking_budget), + ), + ) + return (response.text or "").strip() + + text = call_with_retry(_call, max_retries=max_retries) + elapsed = time.time() - start_time + return text, elapsed + + +def _validate_prompt_format(prompt_format: str) -> None: + """Reject prompt formats that do not contain the required {task} placeholder. + + Mirrors the check in admin_server.py so CLI errors surface early. + """ + if "{task}" not in prompt_format: + raise ValueError( + f"--prompt_format must contain the literal '{{task}}' placeholder, got: " + f"{prompt_format!r}" + ) + + +def main() -> None: + load_dotenv() + + parser = argparse.ArgumentParser( + description="Generate subtasks for DROID frames via Gemini Robotics-ER" + ) + parser.add_argument( + "--samples_dir", + type=str, + required=True, + help="Directory with extracted DROID samples (from extract_droid_samples.py)", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output JSON file for subtask cache", + ) + parser.add_argument( + "--model", + type=str, + default=DEFAULT_MODEL, + help=f"Gemini model ID (default: {DEFAULT_MODEL})", + ) + parser.add_argument( + "--prompt_format", + type=str, + default=DEFAULT_PROMPT_FORMAT, + help=( + "Prompt template sent to Gemini. Must contain the literal '{task}' " + "placeholder, which is replaced with the episode's language instruction." + ), + ) + parser.add_argument( + "--thinking_budget", + type=int, + default=0, + help="Gemini thinking budget (0 disables thinking, higher values allow more reasoning tokens)", + ) + parser.add_argument( + "--max_workers", + type=int, + default=1, + help=( + "Max concurrent Gemini API requests. Default 1 keeps you under the " + "free-tier 5 RPM cap; raise it on a paid tier for throughput." + ), + ) + parser.add_argument( + "--max_retries", + type=int, + default=10, + help="Per-frame retry budget for 429 (RESOURCE_EXHAUSTED) responses", + ) + args = parser.parse_args() + + _validate_prompt_format(args.prompt_format) + + samples_dir = Path(args.samples_dir) + manifest = load_manifest(samples_dir) + logger.info("Loaded manifest: %d episodes", len(manifest)) + + client = genai.Client() + + # Build the flat task list up-front so results can be collected in manifest order + # regardless of which worker finishes first. + tasks: list[dict[str, Any]] = [] + for episode in manifest: + episode_id = episode["episode_id"] + instruction = episode["instruction"] + for frame_info in episode["frames"]: + tasks.append( + { + "episode_id": episode_id, + "instruction": instruction, + "frame_idx": frame_info["frame_idx"], + "frame_path": samples_dir / frame_info["file"], + } + ) + total_frames = len(tasks) + logger.info( + "Dispatching %d frames to %s (max_workers=%d)", total_frames, args.model, args.max_workers + ) + + progress_lock = threading.Lock() + progress = {"done": 0} + + def process(task: dict[str, Any]) -> dict[str, Any]: + frame_data = np.load(task["frame_path"]) + prompt_text = args.prompt_format.format(task=task["instruction"]) + try: + exterior_png = encode_png(frame_data["exterior_image"]) + wrist_png = encode_png(frame_data["wrist_image"]) + subtask_text, elapsed = _generate_one( + client=client, + model=args.model, + prompt_text=prompt_text, + exterior_png=exterior_png, + wrist_png=wrist_png, + thinking_budget=args.thinking_budget, + max_retries=args.max_retries, + ) + except Exception as exc: + logger.warning( + "Gemini call failed for %s frame %d: %s", + task["episode_id"], + task["frame_idx"], + exc, + ) + subtask_text = "" + elapsed = 0.0 + + with progress_lock: + progress["done"] += 1 + done = progress["done"] + if done % 5 == 0 or done == total_frames: + logger.info( + "[%d/%d] %s frame %d: %r (%.1fs)", + done, + total_frames, + task["episode_id"], + task["frame_idx"], + subtask_text, + elapsed, + ) + + return { + "episode_id": task["episode_id"], + "frame_idx": task["frame_idx"], + "instruction": task["instruction"], + "subtask_text": subtask_text, + "generation_time_s": round(elapsed, 2), + "server_subtask_ms": round(elapsed * 1000, 1), + } + + if args.max_workers <= 1: + subtask_results = [process(task) for task in tasks] + else: + with ThreadPoolExecutor(max_workers=args.max_workers) as executor: + # executor.map preserves input order, which matches the manifest. + subtask_results = list(executor.map(process, tasks)) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump( + { + "prompt_format": args.prompt_format, + "backend": args.model, + "results": subtask_results, + }, + f, + indent=2, + ) + + logger.info( + "Subtask generation complete: %d results saved to %s (backend=%s)", + len(subtask_results), + output_path, + args.model, + ) + + # Summary stats — mirrors generate_subtasks.py so output looks familiar. + gen_times = [r["generation_time_s"] for r in subtask_results if r["generation_time_s"] > 0] + if gen_times: + logger.info( + "Latency: mean=%.2fs, min=%.2fs, max=%.2fs", + float(np.mean(gen_times)), + float(np.min(gen_times)), + float(np.max(gen_times)), + ) + unique_subtasks = {r["subtask_text"] for r in subtask_results if r["subtask_text"]} + logger.info( + "Unique subtask texts: %d out of %d frames", len(unique_subtasks), len(subtask_results) + ) + failed = sum(1 for r in subtask_results if not r["subtask_text"]) + if failed: + logger.warning("Failed frames: %d / %d", failed, len(subtask_results)) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/run_action_eval.py b/experiments/subtask_probe/droid_eval/run_action_eval.py new file mode 100644 index 0000000..0517634 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/run_action_eval.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +"""Phase 2: Run action generation under 2 prompt conditions via deployed server. + +Sends each cached DROID frame to the deployed server with mode="action_only" +under two prompt conditions: + 1. Baseline: original task instruction only + 2. Subtask: "{instruction}. Subtask: {subtask}" (with generated subtask) + +Requires the server to be running with pi05_droid config and the DROID +checkpoint, since the action policy's transforms and normalization must +match the DROID embodiment. + +Usage: + uv run python experiments/subtask_probe/droid_eval/run_action_eval.py \ + --samples_dir ./.experiments_cache/droid_eval \ + --subtasks ./.experiments_cache/droid_eval/subtasks.json \ + --output_dir ./.experiments_cache/droid_eval/predictions \ + --server 43.200.36.250 +""" + +from __future__ import annotations + +import argparse +import json +import logging +from pathlib import Path + +import numpy as np + +from hosting.flash_transport_policy import FlashTransportPolicy + +from .constants import DEFAULT_QUIC_PORT, DROID_ACTION_DIM +from .utils import ( + build_action_observation, + build_warmup_observation, + generate_frame_noise, + load_manifest, + load_subtask_index, +) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run action eval with 2 prompt conditions") + parser.add_argument("--samples_dir", type=str, required=True) + parser.add_argument("--subtasks", type=str, required=True) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--server", + type=str, + required=True, + help="Server address (e.g., 43.200.36.250)", + ) + parser.add_argument("--port", type=int, default=DEFAULT_QUIC_PORT, help="Server QUIC port") + args = parser.parse_args() + + samples_dir = Path(args.samples_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + manifest = load_manifest(samples_dir) + subtask_index = load_subtask_index(Path(args.subtasks)) + + # Connect to server + policy = FlashTransportPolicy(args.server, port=args.port) + logger.info("Connected to server at %s:%d via QUIC", args.server, args.port) + + policy.infer(build_warmup_observation(mode="action_only")) + logger.info("Server warmup complete") + + # Process all frames + all_predictions = [] + total_frames = sum(ep["num_frames"] for ep in manifest) + processed = 0 + + for episode in manifest: + episode_id = episode["episode_id"] + instruction = episode["instruction"] + + for frame_info in episode["frames"]: + frame_idx = frame_info["frame_idx"] + frame_file = samples_dir / frame_info["file"] + frame_data = np.load(frame_file) + + exterior_image = frame_data["exterior_image"] + wrist_image = frame_data["wrist_image"] + raw_state = frame_data["state"] + + subtask_text = subtask_index.get((episode_id, frame_idx), "") + if not subtask_text: + logger.warning( + "No subtask found for %s frame %d, using empty string", episode_id, frame_idx + ) + + # Same noise for all conditions so the only difference is the prompt + frame_noise = generate_frame_noise(episode_id, frame_idx) + + # Run all conditions + conditions = { + "baseline": instruction, + "subtask": f"{instruction}. Subtask: {subtask_text}", + } + frame_predictions = {} + for condition_name, prompt in conditions.items(): + obs = build_action_observation( + exterior_image, wrist_image, raw_state, prompt, noise=frame_noise + ) + result = policy.infer(obs) + frame_predictions[condition_name] = np.array(result["actions"])[ + :, :DROID_ACTION_DIM + ] + + # Save predictions + pred_file = output_dir / f"{episode_id}_frame_{frame_idx:05d}.npz" + np.savez_compressed(pred_file, **frame_predictions) + + all_predictions.append( + { + "episode_id": episode_id, + "frame_idx": frame_idx, + "instruction": instruction, + "subtask_text": subtask_text, + "prediction_file": str(pred_file.relative_to(output_dir)), + } + ) + + processed += 1 + if processed % 5 == 0 or processed == total_frames: + logger.info( + "[%d/%d] Processed %s frame %d", processed, total_frames, episode_id, frame_idx + ) + + # Save prediction manifest + pred_manifest_path = output_dir / "prediction_manifest.json" + with pred_manifest_path.open("w") as f: + json.dump(all_predictions, f, indent=2) + + logger.info( + "Action evaluation complete: %d predictions saved to %s", len(all_predictions), output_dir + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/utils.py b/experiments/subtask_probe/droid_eval/utils.py new file mode 100644 index 0000000..13c3cec --- /dev/null +++ b/experiments/subtask_probe/droid_eval/utils.py @@ -0,0 +1,151 @@ +"""Shared utilities for the DROID subtask evaluation pipeline.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import numpy as np + +from .constants import DROID_ACTION_DIM, MODEL_ACTION_DIM, InferenceMode + + +def load_manifest(samples_dir: Path) -> list[dict[str, Any]]: + """Load the episode manifest from a samples directory.""" + with (samples_dir / "manifest.json").open() as f: + return json.load(f) + + +def load_subtask_records(path: Path) -> list[dict[str, Any]]: + """Load the raw list of subtask records from a subtasks JSON file. + + Accepts two on-disk shapes: + * Legacy: a bare JSON list of records. + * Current: ``{"prompt_format": "...", "results": [...]}`` — written by + the prompt-format-aware ``generate_subtasks.py``. + """ + with path.open() as f: + payload = json.load(f) + if isinstance(payload, dict) and "results" in payload: + return payload["results"] + if isinstance(payload, list): + return payload + raise ValueError( + f"Unrecognized subtask JSON shape in {path}: expected list or dict with 'results' key" + ) + + +def load_subtask_index(path: Path) -> dict[tuple[str, int], str]: + """Load subtask results and index by (episode_id, frame_idx) -> subtask_text.""" + return { + (entry["episode_id"], entry["frame_idx"]): entry["subtask_text"] + for entry in load_subtask_records(path) + } + + +def load_subtask_entries(path: Path) -> dict[tuple[str, int], dict[str, Any]]: + """Load subtask results and index by (episode_id, frame_idx) -> full entry dict.""" + return { + (entry["episode_id"], entry["frame_idx"]): entry for entry in load_subtask_records(path) + } + + +def build_subtask_observation( + exterior_image: np.ndarray, + wrist_image: np.ndarray, + prompt: str, +) -> dict[str, Any]: + """Build an observation dict for subtask generation (mode="subtask_only"). + + Images are sent as raw uint8 — the server's _normalize_image() handles + conversion to float32 [-1, 1] and the camera name mapping. + """ + return { + "images": { + "base_0_rgb": exterior_image, + "left_wrist_0_rgb": wrist_image, + }, + "state": np.zeros(14, dtype=np.float32), + "prompt": prompt, + "mode": "subtask_only", + } + + +def build_action_observation( + exterior_image: np.ndarray, + wrist_image: np.ndarray, + state: np.ndarray, + prompt: str, + noise: np.ndarray | None = None, +) -> dict[str, Any]: + """Build an observation dict for action generation (mode="action_only"). + + The server's pi05_droid policy transforms handle normalization, + tokenization, and image preprocessing internally. + + Args: + noise: Pre-generated noise tensor for flow matching denoising, + shape (ACTION_HORIZON, MODEL_ACTION_DIM). When the same noise is + passed for multiple prompt conditions, actions differ only due to + the prompt, not random noise. + """ + joint_position = state[:7] + gripper_position = state[7:8] + + obs: dict[str, Any] = { + "observation/exterior_image_1_left": exterior_image, + "observation/wrist_image_left": wrist_image, + "observation/joint_position": joint_position, + "observation/gripper_position": gripper_position, + "prompt": prompt, + "mode": "action_only", + } + if noise is not None: + obs["noise"] = noise + return obs + + +def build_warmup_observation(mode: InferenceMode = "action_only") -> dict[str, Any]: + """Build a dummy observation for server warmup.""" + if mode == "subtask_only": + return build_subtask_observation( + exterior_image=np.zeros((224, 224, 3), dtype=np.uint8), + wrist_image=np.zeros((224, 224, 3), dtype=np.uint8), + prompt="warmup", + ) + return build_action_observation( + exterior_image=np.zeros((224, 224, 3), dtype=np.uint8), + wrist_image=np.zeros((224, 224, 3), dtype=np.uint8), + state=np.zeros(DROID_ACTION_DIM, dtype=np.float32), + prompt="warmup", + ) + + +def generate_frame_noise(episode_id: str, frame_idx: int) -> np.ndarray: + """Generate deterministic noise for flow matching denoising. + + Uses a hash of (episode_id, frame_idx) as the seed so that multiple + prompt conditions on the same frame get identical noise, making the + comparison fair. + """ + from .constants import ACTION_HORIZON + + rng = np.random.RandomState(hash((episode_id, frame_idx)) % (2**31)) + return rng.randn(ACTION_HORIZON, MODEL_ACTION_DIM).astype(np.float32) + + +def decode_droid_image(img_bytes: bytes | str | np.ndarray) -> np.ndarray: + """Decode a DROID image from RLDS format. + + Handles both encoded (JPEG bytes) and pre-decoded (ndarray) formats + that appear in different DROID dataset versions. + """ + import tensorflow as tf # ty: ignore[unresolved-import] + + if isinstance(img_bytes, (bytes, str)) or ( + hasattr(img_bytes, "dtype") + and (img_bytes.dtype == np.object_ or img_bytes.dtype.kind in ("S", "U")) + ): + return tf.io.decode_image(img_bytes, expand_animations=False, dtype=tf.uint8).numpy() + return np.asarray(img_bytes, dtype=np.uint8) diff --git a/experiments/subtask_probe/droid_eval/visualize_results.py b/experiments/subtask_probe/droid_eval/visualize_results.py new file mode 100644 index 0000000..c0ac065 --- /dev/null +++ b/experiments/subtask_probe/droid_eval/visualize_results.py @@ -0,0 +1,502 @@ +#!/usr/bin/env python3 +"""Generate a self-contained HTML report for DROID evaluation results. + +Produces a single HTML file with embedded images (base64) that can be +opened in any browser and shared without dependencies. + +Sections: + 1. Summary metrics — overall L2, cosine sim, gripper accuracy + 2. Pairwise comparisons — Wilcoxon tests, % better + 3. Subtask gallery — every frame with images + generated subtask text + 4. Action trajectories — per-dimension action plots for sample frames + +Usage: + uv run python experiments/subtask_probe/droid_eval/visualize_results.py \ + --samples_dir ./.experiments_cache/droid_eval \ + --output ./.experiments_cache/droid_eval/report.html +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import json +import logging +from pathlib import Path +from typing import Literal + +import numpy as np + +from .constants import CONDITION_COLORS, JOINT_NAMES +from .utils import load_manifest, load_subtask_entries, load_subtask_records + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +ImageFormat = Literal["jpeg", "png"] + + +def _is_coherent_subtask(text: str) -> bool: + """Check if a subtask string is coherent (non-empty ASCII text, not Unicode garbage).""" + return bool(text.strip()) and text.isascii() + + +def image_to_base64(image_array: np.ndarray, fmt: ImageFormat = "jpeg", quality: int = 80) -> str: + """Convert a numpy image array (HWC, uint8) to a base64-encoded data URI.""" + from PIL import Image + + img = Image.fromarray(image_array) + buffer = io.BytesIO() + if fmt == "jpeg": + img.save(buffer, format="JPEG", quality=quality) + else: + img.save(buffer, format="PNG") + encoded = base64.b64encode(buffer.getvalue()).decode("ascii") + mime = "image/jpeg" if fmt == "jpeg" else "image/png" + return f"data:{mime};base64,{encoded}" + + +def make_action_svg( + ground_truth: np.ndarray, + predictions: dict[str, np.ndarray], + dim_idx: int, + dim_name: str, + width: int = 280, + height: int = 120, +) -> str: + """Generate an inline SVG showing action trajectories for one dimension.""" + margin_left = 35 + margin_right = 10 + margin_top = 20 + margin_bottom = 25 + plot_width = width - margin_left - margin_right + plot_height = height - margin_top - margin_bottom + + # Collect all values for axis scaling + all_values = [ground_truth[:, dim_idx]] + for pred in predictions.values(): + all_values.append(pred[:, dim_idx]) + all_flat = np.concatenate(all_values) + y_min = float(np.min(all_flat)) + y_max = float(np.max(all_flat)) + y_range = y_max - y_min + if y_range < 1e-6: + y_range = 1.0 + y_min -= y_range * 0.1 + y_max += y_range * 0.1 + y_range = y_max - y_min + + num_steps = ground_truth.shape[0] + + def to_svg_coords(step: int, value: float) -> tuple[float, float]: + x = margin_left + (step / max(num_steps - 1, 1)) * plot_width + y = margin_top + (1 - (value - y_min) / y_range) * plot_height + return x, y + + def make_polyline(values: np.ndarray, color: str, dashed: bool = False) -> str: + points = " ".join( + f"{to_svg_coords(i, v)[0]:.1f},{to_svg_coords(i, v)[1]:.1f}" + for i, v in enumerate(values) + ) + dash = ' stroke-dasharray="4,3"' if dashed else "" + return ( + f'' + ) + + lines = [f''] + lines.append(f'') + + # Title + lines.append( + f'{dim_name}' + ) + + # Y-axis labels + for frac in [0, 0.5, 1.0]: + y_val = y_min + frac * y_range + _, svg_y = to_svg_coords(0, y_val) + lines.append( + f'{y_val:.2f}' + ) + lines.append( + f'' + ) + + # Ground truth (dashed gray) + lines.append( + make_polyline(ground_truth[:, dim_idx], CONDITION_COLORS["ground_truth"], dashed=True) + ) + + # Predictions + for condition_name, pred in predictions.items(): + lines.append(make_polyline(pred[:, dim_idx], CONDITION_COLORS[condition_name])) + + lines.append("") + return "\n".join(lines) + + +def generate_html( + samples_dir: Path, + max_gallery_frames: int, + sample_action_frames: int, + subtasks_path: Path | None = None, +) -> str: + """Generate the full HTML report. + + ``subtasks_path`` overrides the default ``samples_dir/subtasks.json`` so the + same report template can visualize subtasks from different backends (e.g. + pi0.5 vs Gemini) pointing at the same cache. + """ + # Load all data + manifest = load_manifest(samples_dir) + resolved_subtasks_path = subtasks_path or (samples_dir / "subtasks.json") + subtask_index = load_subtask_entries(resolved_subtasks_path) + + # Load raw subtask list for unique count in header + subtasks = load_subtask_records(resolved_subtasks_path) + + results_path = samples_dir / "results.json" + has_results = results_path.exists() + results = {} + if has_results: + with results_path.open() as f: + results = json.load(f) + + predictions_dir = samples_dir / "predictions" + has_predictions = predictions_dir.exists() + + total_frames = sum(ep["num_frames"] for ep in manifest) + + # --- Start HTML --- + html_parts: list[str] = [] + html_parts.append(""" + + + + +DROID Subtask Evaluation Report + + + +
+""") + + # --- Header --- + html_parts.append(f""" +

DROID Subtask Evaluation Report

+

{len(manifest)} episodes · {total_frames} frames · {len({s["subtask_text"] for s in subtasks})} unique subtasks

+""") + + # --- Summary stats --- + if has_results: + html_parts.append("

Summary Metrics

") + + # Stat boxes + baseline_l2 = results["baseline"]["overall"]["l2_distance"]["mean"] + subtask_l2 = results["subtask"]["overall"]["l2_distance"]["mean"] + p_val = results.get("pairwise", {}).get("baseline_vs_subtask", {}).get("wilcoxon_p", 1.0) + + html_parts.append('
') + for label, l2_val in [("Baseline L2", baseline_l2), ("Subtask L2", subtask_l2)]: + html_parts.append( + f'
{l2_val:.3f}
{label}
' + ) + cls = "sig" if p_val < 0.05 else "not-sig" + sig_label = "significant" if p_val < 0.05 else "not significant" + html_parts.append( + f'
p={p_val:.4f}
Baseline vs Subtask ({sig_label})
' + ) + html_parts.append("
") + + # Overall metrics table + html_parts.append(""" + + +""") + for condition in ["baseline", "subtask"]: + o = results[condition]["overall"] + l2 = o["l2_distance"] + cos = o["cosine_similarity"] + grip = o["gripper_accuracy"] + html_parts.append( + f"" + f"" + f"" + f"" + f"" + ) + html_parts.append("
ConditionL2 DistanceCosine SimilarityGripper AccuracyN
{condition}{l2['mean']:.4f} ± {l2['std']:.4f}{cos['mean']:.4f} ± {cos['std']:.4f}{grip['mean']:.4f} ± {grip['std']:.4f}{o['n_frames']}
") + + # Per-dimension MAE table + html_parts.append("

Per-Dimension MAE

") + html_parts.append("") + for name in JOINT_NAMES: + html_parts.append(f"") + html_parts.append("") + for condition in ["baseline", "subtask"]: + per_dim = results[condition]["overall"]["per_dim_mae"] + html_parts.append(f"") + for name in JOINT_NAMES: + val = per_dim.get(name, 0) + html_parts.append(f"") + html_parts.append("") + html_parts.append("
Condition{name}
{condition}{val:.4f}
") + + # Pairwise comparisons + html_parts.append("

Pairwise Comparisons

") + html_parts.append( + "" + ) + for key, data in results.get("pairwise", {}).items(): + p_val = data.get("wilcoxon_p", None) + p_str = f"{p_val:.4f}" if p_val is not None else "N/A" + cls = "sig" if p_val is not None and p_val < 0.05 else "not-sig" + html_parts.append( + f"" + f"" + f"" + f'' + ) + html_parts.append("
ComparisonL2 Diff (mean)% B BetterWilcoxon p
{key}{data['l2_diff_mean']:.4f}{data['pct_b_better']:.1f}%{p_str}
") + + # --- Subtask Gallery --- + html_parts.append("

Subtask Gallery

") + html_parts.append( + '

Click an episode to expand. Coherent/garbage counts shown in the header.

' + ) + + for episode in manifest: + episode_id = episode["episode_id"] + task_type_cls = "multi" if episode["task_type"] == "multi_step" else "" + + # Count coherent vs garbage subtasks for this episode + ep_subtasks = [subtask_index.get((episode_id, f["frame_idx"])) for f in episode["frames"]] + ep_subtasks = [s for s in ep_subtasks if s is not None] + coherent_count = sum(1 for s in ep_subtasks if _is_coherent_subtask(s["subtask_text"])) + garbage_count = len(ep_subtasks) - coherent_count + quality_label = f"{coherent_count}/{len(ep_subtasks)} coherent" + if garbage_count > 0: + quality_label += f", {garbage_count} garbage" + quality_color = ( + "#4caf50" + if garbage_count == 0 + else ("#ff9800" if coherent_count > garbage_count else "#f44336") + ) + + html_parts.append(f""" +
garbage_count else ""}> + +
+ {episode_id} + {episode["task_type"].replace("_", " ")} + {quality_label} + {episode["num_frames"]} frames +
+
“{episode["instruction"]}”
+
+
+""") + + for frame_info in episode["frames"]: + frame_idx = frame_info["frame_idx"] + subtask_entry = subtask_index.get((episode_id, frame_idx)) + + # Load images + frame_path = samples_dir / frame_info["file"] + if frame_path.exists(): + frame_data = np.load(frame_path) + ext_uri = image_to_base64(frame_data["exterior_image"]) + wrist_uri = image_to_base64(frame_data["wrist_image"]) + else: + ext_uri = "" + wrist_uri = "" + + subtask_text = subtask_entry["subtask_text"] if subtask_entry else "(no subtask)" + gen_time = subtask_entry.get("generation_time_s", 0) if subtask_entry else 0 + is_garbage = not _is_coherent_subtask(subtask_text) + border_color = "#f44336" if is_garbage else "#eee" + + html_parts.append(f""" +
+
+ exterior + wrist +
+
Subtask: {subtask_text}
+
Frame {frame_idx} · {gen_time:.2f}s
+
+""") + + html_parts.append("
") # frame-grid, details + + # --- Action Trajectories --- + if has_predictions: + html_parts.append("

Action Trajectories (Sample Frames)

") + html_parts.append(""" +
+
Ground truth
+
Baseline
+
Subtask
+
+""") + + action_frame_count = 0 + for episode in manifest: + episode_id = episode["episode_id"] + # Pick evenly spaced frames from this episode + frames = episode["frames"] + if len(frames) <= 2: + selected_frames = frames + else: + step = max(1, len(frames) // 2) + selected_frames = frames[::step][:3] + + for frame_info in selected_frames: + if action_frame_count >= sample_action_frames: + break + + frame_idx = frame_info["frame_idx"] + pred_file = predictions_dir / f"{episode_id}_frame_{frame_idx:05d}.npz" + frame_file = samples_dir / frame_info["file"] + + if not pred_file.exists() or not frame_file.exists(): + continue + + pred_data = np.load(pred_file) + frame_data = np.load(frame_file) + ground_truth = frame_data["ground_truth_actions"] + + predictions = { + "baseline": pred_data["baseline"], + "subtask": pred_data["subtask"], + } + + # Trim to common horizon + min_h = min(ground_truth.shape[0], *(p.shape[0] for p in predictions.values())) + ground_truth_trimmed = ground_truth[:min_h] + predictions_trimmed = {k: v[:min_h] for k, v in predictions.items()} + + subtask_entry = subtask_index.get((episode_id, frame_idx)) + subtask_text = subtask_entry["subtask_text"] if subtask_entry else "" + ext_uri = image_to_base64(frame_data["exterior_image"]) + + html_parts.append(f""" +
+
+ {episode_id} / frame {frame_idx} + Subtask: {subtask_text} +
+
+ +
+""") + num_dims = min(ground_truth_trimmed.shape[1], len(JOINT_NAMES)) + for dim_idx in range(num_dims): + svg = make_action_svg( + ground_truth_trimmed, predictions_trimmed, dim_idx, JOINT_NAMES[dim_idx] + ) + html_parts.append(svg) + + html_parts.append("
") + action_frame_count += 1 + + if action_frame_count >= sample_action_frames: + break + + # --- Footer --- + html_parts.append(""" +
+ Generated by experiments/subtask_probe/droid_eval/visualize_results.py +
+
+ + +""") + + return "".join(html_parts) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate HTML evaluation report") + parser.add_argument("--samples_dir", type=str, required=True) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output HTML file (default: samples_dir/report.html)", + ) + parser.add_argument( + "--max_gallery_frames", type=int, default=100, help="Max frames to show in subtask gallery" + ) + parser.add_argument( + "--sample_action_frames", + type=int, + default=15, + help="Number of frames to show action trajectories for", + ) + parser.add_argument( + "--subtasks", + type=str, + default=None, + help="Path to the subtasks JSON (defaults to samples_dir/subtasks.json)", + ) + args = parser.parse_args() + + samples_dir = Path(args.samples_dir) + output_path = Path(args.output) if args.output else samples_dir / "report.html" + subtasks_path = Path(args.subtasks) if args.subtasks else None + + logger.info("Generating report from %s", samples_dir) + html = generate_html( + samples_dir, + max_gallery_frames=args.max_gallery_frames, + sample_action_frames=args.sample_action_frames, + subtasks_path=subtasks_path, + ) + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(html) + logger.info("Report saved to %s (%.1f MB)", output_path, output_path.stat().st_size / 1e6) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/droid_eval/visualize_subtasks.py b/experiments/subtask_probe/droid_eval/visualize_subtasks.py new file mode 100644 index 0000000..0198e2a --- /dev/null +++ b/experiments/subtask_probe/droid_eval/visualize_subtasks.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +"""Render frames + per-frame subtask text without running the action eval. + +Two output modes: + + * **HTML** (default): one self-contained HTML file per episode with the + exterior frame and its subtask caption inlined side-by-side. Good for + scrolling through a full episode. + + * **Video** (``--video``): one mp4 per episode with the subtask string + drawn onto the exterior frame, played back at ``--fps`` (default 2 Hz, + matching the DROID cache subsample rate so wall-clock ≈ real time). + +Usage:: + + uv run python -m experiments.subtask_probe.droid_eval.visualize_subtasks \\ + --samples_dir ./.experiments_cache/droid_eval_ah15 \\ + --subtasks ./.experiments_cache/droid_eval_ah15/subtasks_comet_qwen30b.json \\ + --output_dir ./.experiments_cache/droid_eval_ah15/subtask_videos \\ + --video --fps 2 +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import logging +from pathlib import Path + +import imageio.v3 as iio3 +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from experiments.subtask_probe.droid_eval.utils import load_manifest, load_subtask_index + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def _load_frame_images(frame_path: Path) -> dict[str, np.ndarray]: + """Load both camera views from a cached .npz as uint8 HxWx3 arrays. + + Returns a dict keyed by camera name (``exterior``, ``wrist``). Both views + exist on every cached DROID frame; surfacing both makes plan-tracking + failures easier to diagnose (e.g. gripper is clearly closed in the wrist + view but the reasoner still emits "grasp the cube"). + """ + data = np.load(frame_path) + return { + "exterior": np.asarray(data["exterior_image"], dtype=np.uint8), + "wrist": np.asarray(data["wrist_image"], dtype=np.uint8), + } + + +def _composite_side_by_side(views: dict[str, np.ndarray], separator_px: int = 4) -> np.ndarray: + """Concatenate multiple same-height views horizontally with a dark divider. + + Returns a single uint8 HxWx3 array so downstream caption/encode code can + treat the multi-view frame as one image. + """ + images = list(views.values()) + heights = {img.shape[0] for img in images} + if len(heights) != 1: + raise ValueError(f"camera views must share a height; got {heights}") + h = images[0].shape[0] + divider = np.full((h, separator_px, 3), 24, dtype=np.uint8) + pieces: list[np.ndarray] = [] + for i, img in enumerate(images): + if i > 0: + pieces.append(divider) + pieces.append(img) + return np.concatenate(pieces, axis=1) + + +def _draw_caption( + image: np.ndarray, + caption: str, + footer: str | None = None, + camera_labels: list[tuple[int, str]] | None = None, +) -> np.ndarray: + """Overlay a subtask caption on the bottom of the frame plus a dark banner. + + ``camera_labels`` is an optional list of ``(x_offset, label)`` pairs drawn + in the top-left of each view, useful when the image is a composite of + multiple camera feeds stitched together by ``_composite_side_by_side``. + """ + img = Image.fromarray(image).convert("RGB") + draw = ImageDraw.Draw(img, "RGBA") + w, h = img.size + + # Try a bundled TrueType font; fall back to default if unavailable. + try: + font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", size=max(16, h // 18)) + except OSError: + font = ImageFont.load_default() + label_font = ImageFont.load_default() + + if camera_labels: + for x_offset, label in camera_labels: + draw.rectangle( + [(x_offset, 0), (x_offset + 8 + 6 * len(label), 18)], + fill=(0, 0, 0, 180), + ) + draw.text((x_offset + 4, 3), label, fill=(220, 220, 220, 255), font=label_font) + + banner_h = int(h * 0.22) + draw.rectangle([(0, h - banner_h), (w, h)], fill=(0, 0, 0, 180)) + + text = caption or "" + pad = 10 + draw.text((pad, h - banner_h + pad), text, fill=(255, 255, 255, 255), font=font) + + if footer: + draw.text((pad, h - 14), footer, fill=(200, 200, 200, 255), font=label_font) + + return np.asarray(img) + + +def _encode_png_base64(image: np.ndarray) -> str: + buf = io.BytesIO() + Image.fromarray(image).save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + +def _write_html( + out_path: Path, + episode_id: str, + instruction: str, + items: list[tuple[int, str, dict[str, np.ndarray]]], +) -> None: + """Emit a self-contained HTML file with both camera views + subtask per step.""" + rows = [] + for frame_idx, subtask, views in items: + ext_b64 = _encode_png_base64(views["exterior"]) + wrist_b64 = _encode_png_base64(views["wrist"]) + rows.append( + f"" + f"{frame_idx}" + f"" + f"
exterior
" + f"" + f"" + f"" + f"
wrist
" + f"" + f"" + f"{subtask or '(empty)'}" + f"" + ) + html = ( + "" + f"Subtasks — {episode_id}" + "" + f"

{episode_id}

" + f"

Instruction: {instruction} · {len(items)} frames

" + "" + "" + "" + "".join(rows) + "
frameexteriorwristsubtask
" + "" + ) + out_path.write_text(html) + + +def _write_mp4( + out_path: Path, + frames_with_captions: list[np.ndarray], + fps: int, +) -> None: + """Write an mp4 of the given captioned frames using imageio-ffmpeg.""" + iio3.imwrite( + out_path, + np.stack(frames_with_captions), + fps=fps, + codec="libx264", + macro_block_size=1, # Allow odd dims without forced rescaling + ) + + +def _process_episode( + samples_dir: Path, + episode: dict, + subtask_index: dict[tuple[str, int], str], + output_dir: Path, + video: bool, + fps: int, +) -> None: + episode_id = episode["episode_id"] + instruction = episode["instruction"] + + items: list[tuple[int, str, dict[str, np.ndarray]]] = [] + captioned: list[np.ndarray] = [] + + for frame_info in episode["frames"]: + frame_idx = frame_info["frame_idx"] + views = _load_frame_images(samples_dir / frame_info["file"]) + subtask = subtask_index.get((episode_id, frame_idx), "") + items.append((frame_idx, subtask, views)) + if video: + composite = _composite_side_by_side(views) + ext_w = views["exterior"].shape[1] + captioned.append( + _draw_caption( + composite, + caption=subtask, + footer=f"{episode_id} frame {frame_idx} | task: {instruction}", + camera_labels=[(4, "EXTERIOR"), (ext_w + 8, "WRIST")], + ) + ) + + if video: + mp4_path = output_dir / f"{episode_id}.mp4" + _write_mp4(mp4_path, captioned, fps=fps) + logger.info("%s -> %s (%d frames, %d fps)", episode_id, mp4_path, len(captioned), fps) + else: + html_path = output_dir / f"{episode_id}.html" + _write_html(html_path, episode_id, instruction, items) + logger.info("%s -> %s (%d frames)", episode_id, html_path, len(items)) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Render subtasks over DROID frames") + parser.add_argument("--samples_dir", type=str, required=True) + parser.add_argument("--subtasks", type=str, required=True) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Directory to write per-episode HTML or mp4 files into.", + ) + parser.add_argument( + "--video", + action="store_true", + help="Emit mp4 per episode (caption drawn on each frame) instead of HTML.", + ) + parser.add_argument( + "--fps", + type=int, + default=2, + help="Frames-per-second for mp4 output. Default 2 (~DROID cache subsample rate).", + ) + args = parser.parse_args() + + samples_dir = Path(args.samples_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + manifest = load_manifest(samples_dir) + subtask_index = load_subtask_index(Path(args.subtasks)) + logger.info("Loaded %d episodes, %d subtask records", len(manifest), len(subtask_index)) + + for episode in manifest: + _process_episode( + samples_dir=samples_dir, + episode=episode, + subtask_index=subtask_index, + output_dir=output_dir, + video=args.video, + fps=args.fps, + ) + + logger.info("Done. Output: %s", output_dir) + + +if __name__ == "__main__": + main() diff --git a/experiments/subtask_probe/dual_runtime_benchmark.py b/experiments/subtask_probe/dual_runtime_benchmark.py new file mode 100644 index 0000000..949f1f3 --- /dev/null +++ b/experiments/subtask_probe/dual_runtime_benchmark.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 +"""Test running JAX subtask generation + PyTorch action generation simultaneously. + +Both models loaded on the same L40S GPU at the same time. +Requires: XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 (or similar) to prevent JAX from +grabbing all VRAM. + +Usage: + XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 uv run python experiments/subtask_probe/dual_runtime_test.py +""" + +from __future__ import annotations + +import string +import sys +import time +from pathlib import Path + +import numpy as np + +OPENPI_SRC = Path(__file__).resolve().parents[2] / "src" +HOSTING_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(HOSTING_ROOT)) +sys.path.insert(0, str(OPENPI_SRC)) + + +def check_jax_memory_fraction() -> None: + """Verify JAX memory fraction is set before importing JAX.""" + import os + + fraction = os.environ.get("XLA_PYTHON_CLIENT_MEM_FRACTION") + preallocate = os.environ.get("XLA_PYTHON_CLIENT_PREALLOCATE") + if fraction is None and preallocate != "false": + print( + "ERROR: Set XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 (or XLA_PYTHON_CLIENT_PREALLOCATE=false)" + ) + print(" Without this, JAX will grab all GPU memory and PyTorch won't load.") + sys.exit(1) + print(f"JAX memory config: MEM_FRACTION={fraction}, PREALLOCATE={preallocate}") + + +check_jax_memory_fraction() + +import jax # noqa: E402 +import jax.numpy as jnp # noqa: E402 +import openpi.shared.download as download # noqa: E402 +import safetensors.torch # noqa: E402 +import sentencepiece # noqa: E402 +import torch # noqa: E402 +from openpi.models import model as _model # noqa: E402 +from openpi.models.pi0 import Pi0, make_attn_mask # noqa: E402 +from openpi.models.pi0_config import Pi0Config # noqa: E402 +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402 + + +def load_tokenizer() -> sentencepiece.SentencePieceProcessor: + path = download.maybe_download( + "gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"} + ) + with path.open("rb") as f: + return sentencepiece.SentencePieceProcessor(model_proto=f.read()) # ty: ignore[unknown-argument] + + +def report_gpu_memory(label: str) -> None: + """Print GPU memory usage from both JAX and PyTorch perspectives.""" + # PyTorch view + if torch.cuda.is_available(): + allocated_mb = torch.cuda.memory_allocated() / 1024 / 1024 + reserved_mb = torch.cuda.memory_reserved() / 1024 / 1024 + print( + f" [{label}] PyTorch: {allocated_mb:.0f} MB allocated, {reserved_mb:.0f} MB reserved" + ) + + # JAX view + try: + for device in jax.devices(): + stats = device.memory_stats() + if stats: + used_mb = stats.get("bytes_in_use", 0) / 1024 / 1024 + limit_mb = stats.get("bytes_limit", 0) / 1024 / 1024 + print( + f" [{label}] JAX ({device}): {used_mb:.0f} MB used / {limit_mb:.0f} MB limit" + ) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# JAX subtask generation +# --------------------------------------------------------------------------- + + +def generate_subtask_jax( + jax_model: Pi0, + tokenizer: sentencepiece.SentencePieceProcessor, + task_prompt: str, + max_tokens: int = 20, +) -> tuple[str, float]: + """Generate subtask text using JAX. Returns (subtask_text, elapsed_seconds).""" + cleaned = task_prompt.lower().strip().replace("_", " ").replace("\n", " ") + if cleaned and cleaned[-1] in string.punctuation: + cleaned = cleaned[:-1] + prefix_str = f"Task: {cleaned}. Subtask: " + tokens = tokenizer.encode(prefix_str, add_bos=True) # ty: ignore[unresolved-attribute] + num_real = len(tokens) + max_len = 200 + if num_real < max_len: + mask = [True] * num_real + [False] * (max_len - num_real) + tokens = tokens + [0] * (max_len - num_real) + else: + tokens = tokens[:max_len] + mask = [True] * max_len + num_real = max_len + + tokens_np = np.asarray(tokens, dtype=np.int32) + mask_np = np.asarray(mask, dtype=np.bool_) + + action_dim = 32 + zero_img = np.zeros((224, 224, 3), dtype=np.float32) + obs = _model.Observation( + images={ + "base_0_rgb": jnp.array(zero_img[None]), + "left_wrist_0_rgb": jnp.array(zero_img[None]), + "right_wrist_0_rgb": jnp.array(zero_img[None]), + }, + image_masks={ + "base_0_rgb": jnp.array([True]), + "left_wrist_0_rgb": jnp.array([True]), + "right_wrist_0_rgb": jnp.array([True]), + }, + state=jnp.array(np.zeros(action_dim, dtype=np.float32)[None]), + tokenized_prompt=jnp.array(tokens_np[None]), + tokenized_prompt_mask=jnp.array(mask_np[None]), + ) + + start = time.monotonic() + + obs = _model.preprocess_observation(None, obs, train=False) + prefix_tokens, prefix_mask, prefix_ar_mask = jax_model.embed_prefix(obs) + prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) + positions = jnp.cumsum(prefix_mask, axis=1) - 1 # ty: ignore[invalid-argument-type] + + (prefix_out, _), kv_cache = jax_model.PaliGemma.llm( + [prefix_tokens, None], + mask=prefix_attn_mask, + positions=positions, + adarms_cond=[None, None], + ) + + B, prefix_S = prefix_tokens.shape[:2] + seq_indices = jnp.arange(prefix_S)[None, :] + last_pos = jnp.max(jnp.where(prefix_mask, seq_indices, -1), axis=1).astype(jnp.int32) # ty: ignore[no-matching-overload] + last_hidden = prefix_out[jnp.arange(B), last_pos, :] + + embed_table = jax_model.PaliGemma.llm.embedder["input_embedding"].value # ty: ignore[unresolved-attribute] + logits = jnp.dot(last_hidden, embed_table.T) + + generated_ids = [] + next_pos = jnp.array([num_real], dtype=jnp.int32) + + for step in range(max_tokens): + token_id = int(jnp.argmax(logits[0])) + generated_ids.append(token_id) + if token_id in (0, 1): + break + + token_emb = jax_model.PaliGemma.llm(jnp.array([[token_id]]), method="embed") + gen_count = step + 1 + gen_mask = jnp.ones((1, gen_count), dtype=jnp.bool_) + full_mask = jnp.concatenate([prefix_mask, gen_mask], axis=1) # ty: ignore[invalid-argument-type] + attn_mask = full_mask[:, None, :] + + (new_out, _), kv_cache = jax_model.PaliGemma.llm( + [token_emb, None], + mask=attn_mask, + positions=next_pos[:, None], + kv_cache=kv_cache, + adarms_cond=[None, None], + ) + logits = jnp.dot(new_out[:, -1, :], embed_table.T) + next_pos = next_pos + 1 + + elapsed = time.monotonic() - start + + if 1 in generated_ids: + generated_ids = generated_ids[: generated_ids.index(1)] + return tokenizer.decode(generated_ids), elapsed # ty: ignore[unresolved-attribute] + + +# --------------------------------------------------------------------------- +# PyTorch action generation +# --------------------------------------------------------------------------- + + +class FakeObservation: + def __init__( + self, + tokenized_prompt: torch.Tensor, + tokenized_prompt_mask: torch.Tensor, + device: str = "cuda", + action_dim: int = 32, + ) -> None: + self.images = { + "base_0_rgb": torch.zeros(1, 3, 224, 224, device=device), + "left_wrist_0_rgb": torch.zeros(1, 3, 224, 224, device=device), + "right_wrist_0_rgb": torch.zeros(1, 3, 224, 224, device=device), + } + self.image_masks = {k: torch.ones(1, dtype=torch.bool, device=device) for k in self.images} + self.state = torch.zeros(1, action_dim, device=device) + self.tokenized_prompt = tokenized_prompt + self.tokenized_prompt_mask = tokenized_prompt_mask + self.token_ar_mask = None + self.token_loss_mask = None + + +def make_action_observation( + task_prompt: str, + tokenizer: sentencepiece.SentencePieceProcessor, + device: str = "cuda", + action_dim: int = 32, +) -> FakeObservation: + """Build observation with standard pi0.5 action prompt format.""" + cleaned = task_prompt.strip().replace("_", " ").replace("\n", " ") + state = np.zeros(action_dim, dtype=np.float32) + discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + state_str = " ".join(map(str, discretized_state)) + full_prompt = f"Task: {cleaned}, State: {state_str};\nAction: " + + tokens = tokenizer.encode(full_prompt, add_bos=True) # ty: ignore[unresolved-attribute] + num_real = len(tokens) + max_len = 200 + if num_real < max_len: + mask = [True] * num_real + [False] * (max_len - num_real) + tokens = tokens + [0] * (max_len - num_real) + else: + tokens = tokens[:max_len] + mask = [True] * max_len + + tok_t = torch.tensor([tokens], dtype=torch.long, device=device) + mask_t = torch.tensor([mask], dtype=torch.bool, device=device) + return FakeObservation(tok_t, mask_t, device, action_dim) + + +def generate_actions_pytorch( + pt_model: PI0Pytorch, + obs: FakeObservation, + device: str = "cuda", +) -> tuple[np.ndarray, float]: + """Run action inference. Returns (actions, elapsed_seconds).""" + start = time.monotonic() + with torch.no_grad(): + actions = pt_model.sample_actions(device, obs, num_steps=10) # ty: ignore[missing-argument, invalid-argument-type] + elapsed = time.monotonic() - start + return actions[0].detach().cpu().numpy(), elapsed + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser(description="Dual runtime test: JAX + PyTorch on same GPU") + parser.add_argument( + "--jax_checkpoint", + type=str, + default=str(Path.home() / ".cache/openpi/openpi-assets/checkpoints/pi05_base"), + ) + parser.add_argument( + "--pytorch_checkpoint", + type=str, + default=str(Path.home() / "checkpoints/pi05_base_pytorch/model.safetensors"), + ) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + + tokenizer = load_tokenizer() + + report_gpu_memory("before loading") + + # ============================================= + # Load BOTH models simultaneously + # ============================================= + print("\n=== Loading JAX model ===") + jax_config = Pi0Config(pi05=True) + jax_model = jax_config.create(jax.random.key(0)) + params = _model.restore_params(f"{args.jax_checkpoint}/params", dtype=jnp.bfloat16) + import flax.nnx as nnx + + nnx.update(jax_model, nnx.State(params)) + jax_model.eval() + print("JAX model loaded.") + report_gpu_memory("after JAX load") + + print("\n=== Loading PyTorch model ===") + pt_config = Pi0Config(pi05=True, pytorch_compile_mode=None) + pt_model = PI0Pytorch(pt_config) + safetensors.torch.load_model(pt_model, args.pytorch_checkpoint) + pt_model = pt_model.to(args.device).eval() + pt_model.requires_grad_(False) + print("PyTorch model loaded.") + report_gpu_memory("after both loaded") + + # ============================================= + # Two-phase inference loop + # ============================================= + test_cases = [ + "pick up the red cup and place it on the shelf", + "fold the towel neatly", + "open the drawer and put the block inside", + "wipe the table with the sponge", + ] + + print(f"\n{'=' * 70}") + print(" TWO-PHASE INFERENCE: JAX subtask → PyTorch actions") + print(f"{'=' * 70}") + + for task_prompt in test_cases: + print(f'\n Task: "{task_prompt}"') + + # Phase 1: JAX subtask generation + subtask_text, subtask_time = generate_subtask_jax(jax_model, tokenizer, task_prompt) + print(f' Phase 1 (JAX subtask): "{subtask_text}" [{subtask_time * 1000:.0f}ms]') + + # Phase 2: PyTorch action generation (with subtask in prompt) + hybrid_prompt = f"{task_prompt}. Subtask: {subtask_text}" + obs = make_action_observation(hybrid_prompt, tokenizer, args.device) + actions, action_time = generate_actions_pytorch(pt_model, obs, args.device) + action_norm = np.linalg.norm(actions) + print(f" Phase 2 (PT actions): norm={action_norm:.4f} [{action_time * 1000:.0f}ms]") + print(f" Total latency: {(subtask_time + action_time) * 1000:.0f}ms") + + # ============================================= + # Latency benchmark (5 rounds) + # ============================================= + print(f"\n{'=' * 70}") + print(" LATENCY BENCHMARK (5 rounds, same prompt)") + print(f"{'=' * 70}") + + benchmark_prompt = "pick up the red cup and place it on the shelf" + subtask_times = [] + action_times = [] + + for i in range(5): + subtask_text, st = generate_subtask_jax(jax_model, tokenizer, benchmark_prompt) + hybrid_prompt = f"{benchmark_prompt}. Subtask: {subtask_text}" + obs = make_action_observation(hybrid_prompt, tokenizer, args.device) + _, at = generate_actions_pytorch(pt_model, obs, args.device) + subtask_times.append(st) + action_times.append(at) + print( + f" Round {i + 1}: subtask={st * 1000:.0f}ms action={at * 1000:.0f}ms total={(st + at) * 1000:.0f}ms" + ) + + avg_subtask = np.mean(subtask_times) * 1000 + avg_action = np.mean(action_times) * 1000 + print( + f"\n Average: subtask={avg_subtask:.0f}ms action={avg_action:.0f}ms total={avg_subtask + avg_action:.0f}ms" + ) + + report_gpu_memory("end of benchmark") + + print(f"\n{'=' * 70}") + print(" DONE — Both models coexisted on the same GPU successfully.") + print(f"{'=' * 70}") + + +if __name__ == "__main__": + main() diff --git a/experiments/transport_bench/benchmark.py b/experiments/transport_bench/benchmark.py new file mode 100644 index 0000000..e6a5237 --- /dev/null +++ b/experiments/transport_bench/benchmark.py @@ -0,0 +1,382 @@ +"""Same-host transport benchmark for flash-transport vs. openpi WebSocket. + +Runs serial inference calls at a target pacing against a running server and +records per-call latency. Designed to be invoked from inside a Docker +container that already has the openpi-flash wheel + its dependencies, so the +client dependency on ``openpi_client`` and ``hosting.flash_transport_policy`` +is satisfied without a separate install step. + +See ``run_matrix.sh`` for the wrapper that runs the full profile matrix with +bidirectional netem shaping (ifb-mirrored ingress). For a one-shot invocation +the shape is: + + python benchmark.py --transport quic --host 127.0.0.1 --port 5555 \\ + --target-rate-hz 20 --min-samples 200 --min-duration-s 30 \\ + --max-duration-s 600 --output /tmp/quic-clean.json + +The server is assumed to be running locally (action slot on 8000/TCP + 5555/UDP). +""" + +from __future__ import annotations + +import argparse +import json +import statistics +import sys +import time +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Literal, Protocol + +import numpy as np + +# Reach into the openpi-flash install tree. These import paths assume we are +# running inside the openpi-flash container image (or a local dev checkout +# with ``uv sync``). +from hosting.warmup import make_droid_observation # noqa: E402 +from openpi_client import websocket_client_policy as _websocket_client_policy # noqa: E402 + +Transport = Literal["ws", "quic"] + + +class InferablePolicy(Protocol): + def infer(self, obs: dict[str, Any]) -> dict[str, Any]: ... + def get_server_metadata(self) -> dict[str, Any]: ... + + +@dataclass +class CallSample: + """Per-call timing row.""" + + iteration_index: int + send_unix_s: float + client_round_trip_ms: float + server_infer_ms: float + policy_forward_ms: float + failed: bool = False + error_message: str | None = None + + @property + def network_overhead_ms(self) -> float: + if self.failed: + return 0.0 + return max(0.0, self.client_round_trip_ms - self.server_infer_ms) + + +@dataclass +class RunSummary: + transport: Transport + host: str + port: int + target_rate_hz: float + min_samples: int + min_duration_s: float + max_duration_s: float + warmup_iterations: int + total_iterations: int + successful_iterations: int + failure_count: int + wall_clock_duration_s: float + stop_reason: str + effective_rate_hz: float + client_round_trip_ms_p50: float + client_round_trip_ms_p95: float + client_round_trip_ms_p99: float + client_round_trip_ms_mean: float + client_round_trip_ms_stddev: float + client_round_trip_ms_max: float + server_infer_ms_mean: float + network_overhead_ms_mean: float + network_overhead_ms_p95: float + network_overhead_ms_p99: float + samples: list[dict[str, Any]] = field(default_factory=list) + + +def _percentile(values: list[float], percentile: float) -> float: + if not values: + return 0.0 + if len(values) == 1: + return values[0] + sorted_values = sorted(values) + rank = (percentile / 100.0) * (len(sorted_values) - 1) + lower_index = int(rank) + upper_index = min(lower_index + 1, len(sorted_values) - 1) + fractional = rank - lower_index + return sorted_values[lower_index] + fractional * ( + sorted_values[upper_index] - sorted_values[lower_index] + ) + + +def build_policy(transport: Transport, host: str, port: int) -> InferablePolicy: + if transport == "ws": + return _websocket_client_policy.WebsocketClientPolicy(host=host, port=port) + if transport == "quic": + from hosting.flash_transport_policy import FlashTransportPolicy + + return FlashTransportPolicy(host=host, port=port) + raise ValueError(f"Unknown transport: {transport}") + + +def run_benchmark( + *, + transport: Transport, + host: str, + port: int, + target_rate_hz: float, + min_samples: int, + min_duration_s: float, + max_duration_s: float, + warmup_iterations: int, + prompt: str, + seed: int, +) -> RunSummary: + rng = np.random.default_rng(seed) + + def make_observation() -> dict[str, Any]: + # Fresh random frames per call so server-side preprocessing can't + # short-circuit. + observation = make_droid_observation(prompt=prompt) + observation["observation/joint_position"] = rng.random(7) + return observation + + print(f"Building {transport} policy → {host}:{port}", flush=True) + policy = build_policy(transport, host, port) + try: + metadata = policy.get_server_metadata() + print(f"Server metadata: {metadata}", flush=True) + + print(f"Warmup: {warmup_iterations} iteration(s) ...", flush=True) + warmup_start = time.monotonic() + for _ in range(warmup_iterations): + policy.infer(make_observation()) + print( + f"Warmup done in {1000 * (time.monotonic() - warmup_start):.0f}ms", + flush=True, + ) + + inter_call_interval_s = 1.0 / target_rate_hz if target_rate_hz > 0 else 0.0 + wall_clock_start = time.monotonic() + min_deadline = wall_clock_start + min_duration_s + max_deadline = wall_clock_start + max_duration_s + + samples: list[CallSample] = [] + iteration_index = 0 + next_send_monotonic = wall_clock_start + stop_reason = "max_duration_s_reached" + + while True: + now_for_stop_check = time.monotonic() + successful_so_far = sum(1 for sample in samples if not sample.failed) + # Stop when we have enough samples AND the minimum runtime has + # elapsed (so target-rate pacing gets enough wall-clock time to + # stabilize). Hard cap on max_duration_s regardless. + if now_for_stop_check >= max_deadline: + stop_reason = "max_duration_s_reached" + break + if successful_so_far >= min_samples and now_for_stop_check >= min_deadline: + stop_reason = "min_samples_reached" + break + # Sleep until next scheduled send (serial, paced). + now_monotonic = time.monotonic() + if now_monotonic < next_send_monotonic: + time.sleep(next_send_monotonic - now_monotonic) + + send_unix = time.time() + call_start_monotonic = time.monotonic() + try: + action = policy.infer(make_observation()) + client_round_trip_ms = 1000 * (time.monotonic() - call_start_monotonic) + server_infer_ms = float( + action.get("server_timing", {}).get("infer_ms", 0.0) + ) + policy_forward_ms = float( + action.get("policy_timing", {}).get("infer_ms", 0.0) + ) + samples.append( + CallSample( + iteration_index=iteration_index, + send_unix_s=send_unix, + client_round_trip_ms=client_round_trip_ms, + server_infer_ms=server_infer_ms, + policy_forward_ms=policy_forward_ms, + ) + ) + except Exception as error: # noqa: BLE001 + client_round_trip_ms = 1000 * (time.monotonic() - call_start_monotonic) + samples.append( + CallSample( + iteration_index=iteration_index, + send_unix_s=send_unix, + client_round_trip_ms=client_round_trip_ms, + server_infer_ms=0.0, + policy_forward_ms=0.0, + failed=True, + error_message=f"{type(error).__name__}: {error}", + ) + ) + + iteration_index += 1 + next_send_monotonic += inter_call_interval_s + # If the call itself ran over the interval, skip ahead so we + # don't queue up a burst of back-to-back sends. + if time.monotonic() > next_send_monotonic: + next_send_monotonic = time.monotonic() + + wall_clock_duration_s = time.monotonic() - wall_clock_start + finally: + close_fn = getattr(policy, "close", None) + if callable(close_fn): + try: + close_fn() + except Exception: # noqa: BLE001 + pass + + successful_samples = [s for s in samples if not s.failed] + failure_count = len(samples) - len(successful_samples) + client_round_trip_values = [s.client_round_trip_ms for s in successful_samples] + server_infer_values = [s.server_infer_ms for s in successful_samples] + network_overhead_values = [s.network_overhead_ms for s in successful_samples] + + summary = RunSummary( + transport=transport, + host=host, + port=port, + target_rate_hz=target_rate_hz, + min_samples=min_samples, + min_duration_s=min_duration_s, + max_duration_s=max_duration_s, + warmup_iterations=warmup_iterations, + total_iterations=len(samples), + successful_iterations=len(successful_samples), + failure_count=failure_count, + wall_clock_duration_s=wall_clock_duration_s, + stop_reason=stop_reason, + effective_rate_hz=( + len(successful_samples) / wall_clock_duration_s if wall_clock_duration_s else 0.0 + ), + client_round_trip_ms_p50=_percentile(client_round_trip_values, 50), + client_round_trip_ms_p95=_percentile(client_round_trip_values, 95), + client_round_trip_ms_p99=_percentile(client_round_trip_values, 99), + client_round_trip_ms_mean=( + statistics.fmean(client_round_trip_values) if client_round_trip_values else 0.0 + ), + client_round_trip_ms_stddev=( + statistics.pstdev(client_round_trip_values) + if len(client_round_trip_values) > 1 + else 0.0 + ), + client_round_trip_ms_max=max(client_round_trip_values) if client_round_trip_values else 0.0, + server_infer_ms_mean=( + statistics.fmean(server_infer_values) if server_infer_values else 0.0 + ), + network_overhead_ms_mean=( + statistics.fmean(network_overhead_values) if network_overhead_values else 0.0 + ), + network_overhead_ms_p95=_percentile(network_overhead_values, 95), + network_overhead_ms_p99=_percentile(network_overhead_values, 99), + samples=[asdict(sample) for sample in samples], + ) + + return summary + + +def print_summary_table(summary: RunSummary) -> None: + print("\n--- Benchmark summary ---") + print( + f"transport={summary.transport} host={summary.host}:{summary.port} " + f"target_rate={summary.target_rate_hz:.1f}Hz " + f"min_samples={summary.min_samples} " + f"min_duration={summary.min_duration_s:.0f}s " + f"max_duration={summary.max_duration_s:.0f}s " + f"stop={summary.stop_reason} " + f"wall_clock={summary.wall_clock_duration_s:.1f}s" + ) + print( + f"iterations: total={summary.total_iterations} " + f"ok={summary.successful_iterations} fail={summary.failure_count} " + f"effective_rate={summary.effective_rate_hz:.2f}Hz" + ) + print("Client round-trip ms:") + print(f" mean={summary.client_round_trip_ms_mean:.1f} stddev={summary.client_round_trip_ms_stddev:.1f}") + print( + f" p50={summary.client_round_trip_ms_p50:.1f} " + f"p95={summary.client_round_trip_ms_p95:.1f} " + f"p99={summary.client_round_trip_ms_p99:.1f} " + f"max={summary.client_round_trip_ms_max:.1f}" + ) + print(f"Server infer ms mean: {summary.server_infer_ms_mean:.1f}") + print( + f"Network overhead ms: mean={summary.network_overhead_ms_mean:.1f} " + f"p95={summary.network_overhead_ms_p95:.1f} " + f"p99={summary.network_overhead_ms_p99:.1f}" + ) + + +def parse_cli_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Transport benchmark harness") + parser.add_argument("--transport", choices=["ws", "quic"], required=True) + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, required=True) + parser.add_argument( + "--min-samples", + type=int, + default=200, + help="Stop once this many successful calls have been collected (and --min-duration-s has elapsed).", + ) + parser.add_argument( + "--min-duration-s", + type=float, + default=30.0, + help="Minimum wall-clock run time before early termination is allowed.", + ) + parser.add_argument( + "--max-duration-s", + type=float, + default=600.0, + help="Hard wall-clock cap. Run terminates at this deadline even if min_samples is not reached.", + ) + parser.add_argument("--target-rate-hz", type=float, default=20.0) + parser.add_argument("--warmup-iterations", type=int, default=3) + parser.add_argument("--prompt", default="pick up the red cup") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--output", + type=Path, + default=None, + help="If set, write raw samples + summary JSON to this path.", + ) + parser.add_argument( + "--tag", + default=None, + help="Optional free-form label written to the output JSON (e.g. 'quic-50ms-1pct').", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = parse_cli_args(argv) + summary = run_benchmark( + transport=args.transport, + host=args.host, + port=args.port, + target_rate_hz=args.target_rate_hz, + min_samples=args.min_samples, + min_duration_s=args.min_duration_s, + max_duration_s=args.max_duration_s, + warmup_iterations=args.warmup_iterations, + prompt=args.prompt, + seed=args.seed, + ) + print_summary_table(summary) + if args.output is not None: + payload = asdict(summary) + if args.tag is not None: + payload["tag"] = args.tag + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(json.dumps(payload, indent=2)) + print(f"\nWrote {args.output}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/experiments/transport_bench/run_matrix.sh b/experiments/transport_bench/run_matrix.sh new file mode 100755 index 0000000..f4dc04c --- /dev/null +++ b/experiments/transport_bench/run_matrix.sh @@ -0,0 +1,134 @@ +#!/usr/bin/env bash +# Run the transport benchmark across a matrix of netem profiles × transports, +# with **bidirectional** shaping (egress + ifb-redirected ingress) and a +# **sample-target** stop condition so heavy-loss cells still gather enough +# samples for stable tail percentiles. +# +# Requires: +# - openpi-flash action slot reachable on the docker host, with ports +# published to 0.0.0.0 ({8000,5555}) — the client container reaches the +# server via the docker bridge gateway ("host.docker.internal" on Linux +# with --add-host=host-gateway). +# - benchmark.py mounted into the client container +# - iproute2 available in the client container (installed on demand below) +# +# IMPORTANT: the client container uses bridge networking (NOT host networking) +# so that `tc qdisc add dev eth0 ...` inside the container shapes only the +# container's veth, not the host's real network interface. Using host +# networking here would disrupt SSH. +# +# Why bidi shaping: shaping only the container's egress leaves the return +# path (server→client) unshaped, which is asymmetric and biases the +# comparison for response-heavy workloads. We fold container ingress into an +# IFB device and apply the same netem profile there so both directions see +# the same delay/loss. +# +# Why no jitter: the first-pass matrix (2026-04-21) used `delay 25ms 2ms +# distribution normal`, which applies per-packet jitter that can reorder +# packets on the wire. QUIC's reorder detection triggered spurious +# retransmits on ~300-packet DROID observation messages, producing a +# misleading 940ms p50 at only +25ms delay while TCP absorbed it cleanly. +# Zero-jitter profiles isolate the transport comparison from this artifact. +# A "jitter study" profile set can be added separately if we want to +# characterize reorder sensitivity explicitly. +# +# Usage (on the EC2 host): +# export IMAGE=438136598620.dkr.ecr.us-west-2.amazonaws.com/openpi-flash:latest +# bash run_matrix.sh /tmp/bench/results +set -euo pipefail + +OUTPUT_DIR="${1:-/tmp/bench/results}" +IMAGE="${IMAGE:?IMAGE env var must be set to the client container image}" +# Reach the server on the docker bridge gateway. host-gateway resolves to the +# host's IP on the docker bridge (typically 172.17.0.1). +HOST="${HOST:-host.docker.internal}" +TARGET_RATE_HZ="${TARGET_RATE_HZ:-20}" +WARMUP_ITERATIONS="${WARMUP_ITERATIONS:-3}" +# Sample-target termination — run each cell until MIN_SAMPLES successful +# calls have been collected AND MIN_DURATION_S has elapsed, capped at +# MAX_DURATION_S. See benchmark.py for details. +MIN_SAMPLES="${MIN_SAMPLES:-200}" +MIN_DURATION_S="${MIN_DURATION_S:-30}" +MAX_DURATION_S="${MAX_DURATION_S:-600}" +WS_PORT="${WS_PORT:-8000}" +QUIC_PORT="${QUIC_PORT:-5555}" + +mkdir -p "$OUTPUT_DIR" + +# Each entry is: profile_name|tc-netem-args (empty args => no shaping). +# Args are passed to `netem` on BOTH the container's egress (eth0) and the +# ifb device carrying its mirrored ingress, so both request and response +# experience the profile. Zero jitter — see header comment. +PROFILES=( + "clean|" + "delay25ms|delay 25ms" + "delay75ms_loss0_1pct|delay 75ms loss 0.1%" + "delay150ms_loss0_5pct|delay 150ms loss 0.5%" +) + +TRANSPORTS=( + "ws ${WS_PORT}" + "quic ${QUIC_PORT}" +) + +run_one() { + local profile_name="$1" + local tc_args="$2" + local transport="$3" + local port="$4" + local output_path="${OUTPUT_DIR}/${transport}_${profile_name}.json" + + echo "=== ${transport} @ ${profile_name} → ${output_path} ===" + + # Build the tc setup inside-container. For non-clean profiles we: + # 1. Apply netem on eth0 root → shapes container egress (client → server) + # 2. Create ifb0 in the container's netns + # 3. Add an ingress qdisc on eth0 and mirror ingress packets to ifb0 + # 4. Apply the same netem on ifb0 root → shapes container ingress (server → client) + # Net result: both directions see the same delay/loss profile. + local tc_setup="" + if [[ -n "${tc_args}" ]]; then + tc_setup="tc qdisc add dev eth0 root netem ${tc_args}; + ip link add ifb0 type ifb; + ip link set ifb0 up; + tc qdisc add dev eth0 handle ffff: ingress; + tc filter add dev eth0 parent ffff: protocol all u32 match u32 0 0 action mirred egress redirect dev ifb0; + tc qdisc add dev ifb0 root netem ${tc_args}; + echo [tc] eth0:; tc qdisc show dev eth0; + echo [tc] ifb0:; tc qdisc show dev ifb0;" + fi + + docker run --rm --cap-add=NET_ADMIN \ + --add-host=host.docker.internal:host-gateway \ + -v "$(pwd)/benchmark.py:/tmp/benchmark.py:ro" \ + -v "${OUTPUT_DIR}:/out" \ + "${IMAGE}" \ + bash -lc "set -euo pipefail; + if ! command -v tc >/dev/null 2>&1; then + apt-get update -qq && apt-get install -y --no-install-recommends iproute2 >/dev/null + fi; + ${tc_setup} + python /tmp/benchmark.py \ + --transport ${transport} \ + --host ${HOST} \ + --port ${port} \ + --target-rate-hz ${TARGET_RATE_HZ} \ + --min-samples ${MIN_SAMPLES} \ + --min-duration-s ${MIN_DURATION_S} \ + --max-duration-s ${MAX_DURATION_S} \ + --warmup-iterations ${WARMUP_ITERATIONS} \ + --tag ${transport}_${profile_name} \ + --output /out/${transport}_${profile_name}.json" +} + +for profile_entry in "${PROFILES[@]}"; do + profile_name="${profile_entry%%|*}" + tc_args="${profile_entry#*|}" + for transport_entry in "${TRANSPORTS[@]}"; do + read -r transport port <<<"${transport_entry}" + run_one "${profile_name}" "${tc_args}" "${transport}" "${port}" + done +done + +echo "All runs complete. Results in ${OUTPUT_DIR}:" +ls -l "${OUTPUT_DIR}"