feat: attention visualization script for player encoder transformer steps#94
Draft
Copilot wants to merge 3 commits into
Draft
feat: attention visualization script for player encoder transformer steps#94Copilot wants to merge 3 commits into
Copilot wants to merge 3 commits into
Conversation
…tiHeadAttention Agent-Logs-Url: https://github.com/spktrm/porygon2/sessions/167e84cb-6d94-461d-960e-70f4dda5197f Co-authored-by: spktrm <72776130+spktrm@users.noreply.github.com>
Agent-Logs-Url: https://github.com/spktrm/porygon2/sessions/167e84cb-6d94-461d-960e-70f4dda5197f Co-authored-by: spktrm <72776130+spktrm@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a script (
scripts/viz_attention.py) that visualises the per-head attention scores at every transformer step inside the player encoder, making attention bottlenecks immediately visible.Changes
rl/model/modules.pyAdded two
sowcalls insideMultiHeadAttention.__call__after computingattn_probsandattn_entropy:These are zero-overhead no-ops during normal training/inference — Flax only writes them when
mutable=['intermediates']is explicitly passed toapply.scripts/viz_attention.pyNew standalone visualisation script that:
--ckpt,--generation,--output,--traj-stepCLI argumentsget_ex_player_step()Encoder._embed_local_timestepdirectly (bypassing vmap) to capturelocal_timestep_decoderattentionEncoder._batched_forwarddirectly (bypassing the outer vmap) to captureinput_decoder,history_decoder, and allstate_transformerattention stepslocal_timestep_decoder– field tokens → relevant entity/edge tokensinput_decoder– latent queries → current game-state tokenshistory_decoder– latent queries → past timestep embeddingsstate_transformer_kv_enc(×2 layers) – self-attention on the latent KV sequencestate_transformer_q_self(×2 layers) – self-attention on output-state tokensstate_transformer_q_cross(×2 layers) – cross-attention output-state → latentattn_entropy_summary.png) — bars close to 0 indicate heads with highly concentrated (bottlenecked) attentionUsage