feat: negative layer indexing and forward_all() for lenses#146
feat: negative layer indexing and forward_all() for lenses#146anthony-maio wants to merge 2 commits intoAlignmentResearch:mainfrom
Conversation
Add Python-style negative indexing to TunedLens.__getitem__, forward(), and transform_hidden() so that lens[-1] returns the last layer translator. Out-of-range indices now raise IndexError with a clear message. Also add forward_all() to the Lens base class, which decodes hidden states from all layers in a single call under torch.inference_mode().
|
@codex review |
There was a problem hiding this comment.
Pull request overview
This PR enhances the lens API by adding Python-style negative layer indexing for TunedLens layer access and decode paths, and introduces a Lens.forward_all() convenience method to decode multiple layers’ hidden states in one call under torch.inference_mode().
Changes:
- Add
Lens.forward_all()to batch-decode a sequence of hidden states via per-layerforward(). - Add negative indexing support (and clearer bounds errors) for
TunedLens.__getitem__,transform_hidden(), andforward(). - Add unit tests covering negative indexing behavior and
forward_all()behavior.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| tuned_lens/nn/lenses.py | Implements forward_all() and adds negative-index normalization + bounds checking for TunedLens layer indices. |
| tests/test_lenses.py | Adds tests validating negative indexing semantics and forward_all() outputs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @th.inference_mode() | ||
| def forward_all( | ||
| self, hidden_states: Sequence[th.Tensor] | ||
| ) -> list[th.Tensor]: | ||
| """Decode hidden states from all layers into logits. | ||
|
|
||
| Convenience method that applies :meth:`forward` to each layer's | ||
| hidden states in a single call under ``torch.inference_mode``. | ||
|
|
||
| Args: | ||
| hidden_states: Sequence of hidden state tensors, one per layer. | ||
| Each tensor should have shape ``(batch, seq_len, d_model)``. | ||
|
|
||
| Returns: | ||
| List of logit tensors, one per layer, each of shape | ||
| ``(batch, seq_len, vocab_size)``. | ||
| """ | ||
| return [self.forward(h, idx=i) for i, h in enumerate(hidden_states)] |
There was a problem hiding this comment.
forward_all() enumerates indices starting at 0, which will raise IndexError for TunedLens if callers pass a full Hugging Face hidden_states tuple (commonly includes an extra final element and is typically used as hidden_states[:-1] elsewhere in the repo). Please clarify this expectation in the docstring (e.g., mention passing outputs.hidden_states[:-1]) or adjust the API to accept explicit layer indices/offset so "all layers" is unambiguous.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Add Python-style negative indexing to TunedLens.getitem, forward(), and transform_hidden() so that lens[-1] returns the last layer translator. Out-of-range indices now raise IndexError with a clear message.
Also add forward_all() to the Lens base class, which decodes hidden states from all layers in a single call under torch.inference_mode().