Skip to content

feat: attention visualization script for player encoder transformer steps#94

Draft
Copilot wants to merge 3 commits into
better_arch_v2from
copilot/visualize-attention-scores
Draft

feat: attention visualization script for player encoder transformer steps#94
Copilot wants to merge 3 commits into
better_arch_v2from
copilot/visualize-attention-scores

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented May 9, 2026

Summary

Adds a script (scripts/viz_attention.py) that visualises the per-head attention scores at every transformer step inside the player encoder, making attention bottlenecks immediately visible.

Changes

rl/model/modules.py

Added two sow calls inside MultiHeadAttention.__call__ after computing attn_probs and attn_entropy:

self.sow("intermediates", "attn_weights", attn_probs.astype(jnp.float32))
self.sow("intermediates", "attn_entropy", jnp.nan_to_num(attn_entropy, nan=0.0).astype(jnp.float32))

These are zero-overhead no-ops during normal training/inference — Flax only writes them when mutable=['intermediates'] is explicitly passed to apply.

scripts/viz_attention.py

New standalone visualisation script that:

  • Accepts --ckpt, --generation, --output, --traj-step CLI arguments
  • Loads a single game-state example via get_ex_player_step()
  • Calls Encoder._embed_local_timestep directly (bypassing vmap) to capture local_timestep_decoder attention
  • Calls Encoder._batched_forward directly (bypassing the outer vmap) to capture input_decoder, history_decoder, and all state_transformer attention steps
  • Produces per-head heatmaps (PNG) for all 9 MHA calls across 5 transformer components:
    • local_timestep_decoder – field tokens → relevant entity/edge tokens
    • input_decoder – latent queries → current game-state tokens
    • history_decoder – latent queries → past timestep embeddings
    • state_transformer_kv_enc (×2 layers) – self-attention on the latent KV sequence
    • state_transformer_q_self (×2 layers) – self-attention on output-state tokens
    • state_transformer_q_cross (×2 layers) – cross-attention output-state → latent
  • Produces an entropy summary bar chart (attn_entropy_summary.png) — bars close to 0 indicate heads with highly concentrated (bottlenecked) attention

Usage

# Random params (good for checking shapes)
python -m scripts.viz_attention

# With a checkpoint, custom output directory, specific trajectory step
python -m scripts.viz_attention \
    --ckpt ./ckpts/gen9/ckpt_0001.pkl \
    --output ./my_viz \
    --traj-step 2 \
    --generation 9

Copilot AI and others added 2 commits May 9, 2026 06:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants