Skip to content

feat: negative layer indexing and forward_all() for lenses#146

Open
anthony-maio wants to merge 2 commits intoAlignmentResearch:mainfrom
anthony-maio:pr/negative-indexing
Open

feat: negative layer indexing and forward_all() for lenses#146
anthony-maio wants to merge 2 commits intoAlignmentResearch:mainfrom
anthony-maio:pr/negative-indexing

Conversation

@anthony-maio
Copy link

Add Python-style negative indexing to TunedLens.getitem, forward(), and transform_hidden() so that lens[-1] returns the last layer translator. Out-of-range indices now raise IndexError with a clear message.

Also add forward_all() to the Lens base class, which decodes hidden states from all layers in a single call under torch.inference_mode().

Add Python-style negative indexing to TunedLens.__getitem__,
forward(), and transform_hidden() so that lens[-1] returns the last
layer translator. Out-of-range indices now raise IndexError with a
clear message.

Also add forward_all() to the Lens base class, which decodes hidden
states from all layers in a single call under torch.inference_mode().
Copilot AI review requested due to automatic review settings February 15, 2026 09:09
@anthony-maio
Copy link
Author

@codex review

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enhances the lens API by adding Python-style negative layer indexing for TunedLens layer access and decode paths, and introduces a Lens.forward_all() convenience method to decode multiple layers’ hidden states in one call under torch.inference_mode().

Changes:

  • Add Lens.forward_all() to batch-decode a sequence of hidden states via per-layer forward().
  • Add negative indexing support (and clearer bounds errors) for TunedLens.__getitem__, transform_hidden(), and forward().
  • Add unit tests covering negative indexing behavior and forward_all() behavior.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
tuned_lens/nn/lenses.py Implements forward_all() and adds negative-index normalization + bounds checking for TunedLens layer indices.
tests/test_lenses.py Adds tests validating negative indexing semantics and forward_all() outputs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +50 to +67
@th.inference_mode()
def forward_all(
self, hidden_states: Sequence[th.Tensor]
) -> list[th.Tensor]:
"""Decode hidden states from all layers into logits.

Convenience method that applies :meth:`forward` to each layer's
hidden states in a single call under ``torch.inference_mode``.

Args:
hidden_states: Sequence of hidden state tensors, one per layer.
Each tensor should have shape ``(batch, seq_len, d_model)``.

Returns:
List of logit tensors, one per layer, each of shape
``(batch, seq_len, vocab_size)``.
"""
return [self.forward(h, idx=i) for i, h in enumerate(hidden_states)]
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward_all() enumerates indices starting at 0, which will raise IndexError for TunedLens if callers pass a full Hugging Face hidden_states tuple (commonly includes an extra final element and is typically used as hidden_states[:-1] elsewhere in the repo). Please clarify this expectation in the docstring (e.g., mention passing outputs.hidden_states[:-1]) or adjust the API to accept explicit layer indices/offset so "all layers" is unambiguous.

Copilot uses AI. Check for mistakes.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants