Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/test_lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,74 @@ def test_tuned_lens_generate_smoke(random_small_model: trf.PreTrainedModel):
)
assert tokens.shape[-1] <= 11
assert tokens.shape[-1] > 1


# --- Tests for negative indexing ---


def test_tuned_lens_negative_index(random_tuned_lens: TunedLens):
"""Negative index -1 should return the last translator."""
last = random_tuned_lens[-1]
explicit_last = random_tuned_lens[len(random_tuned_lens) - 1]
assert last is explicit_last


def test_tuned_lens_negative_index_minus_n(random_tuned_lens: TunedLens):
"""Negative index -N should return the first translator."""
n = len(random_tuned_lens)
first = random_tuned_lens[-n]
explicit_first = random_tuned_lens[0]
assert first is explicit_first


def test_tuned_lens_index_out_of_range(random_tuned_lens: TunedLens):
"""Out-of-range indices should raise IndexError."""
n = len(random_tuned_lens)
with pytest.raises(IndexError):
random_tuned_lens[n]
with pytest.raises(IndexError):
random_tuned_lens[-(n + 1)]


def test_tuned_lens_forward_negative_idx(random_tuned_lens: TunedLens):
"""forward() should accept negative layer indices."""
randn = th.randn(1, 10, 128)
logits_neg = random_tuned_lens.forward(randn, -1)
logits_pos = random_tuned_lens.forward(randn, len(random_tuned_lens) - 1)
assert th.allclose(logits_neg, logits_pos)


def test_tuned_lens_transform_hidden_negative_idx(random_tuned_lens: TunedLens):
"""transform_hidden() should accept negative layer indices."""
randn = th.randn(1, 10, 128)
h_neg = random_tuned_lens.transform_hidden(randn, -1)
h_pos = random_tuned_lens.transform_hidden(randn, len(random_tuned_lens) - 1)
assert th.allclose(h_neg, h_pos)


# --- Tests for forward_all ---


def test_logit_lens_forward_all(logit_lens):
"""forward_all() should return one logit tensor per layer."""
hidden_states = [th.randn(1, 10, 128) for _ in range(3)]
results = logit_lens.forward_all(hidden_states)
assert len(results) == 3
for r in results:
assert r.shape == (1, 10, 100)


def test_tuned_lens_forward_all(random_tuned_lens: TunedLens):
"""forward_all() should return one logit tensor per layer."""
hidden_states = [th.randn(1, 10, 128) for _ in range(3)]
results = random_tuned_lens.forward_all(hidden_states)
assert len(results) == 3


def test_forward_all_matches_sequential(random_tuned_lens: TunedLens):
"""forward_all() results should match sequential forward() calls."""
hidden_states = [th.randn(1, 10, 128) for _ in range(3)]
batch_results = random_tuned_lens.forward_all(hidden_states)
for i, h in enumerate(hidden_states):
single = random_tuned_lens.forward(h, i)
assert th.allclose(batch_results[i], single)
85 changes: 80 additions & 5 deletions tuned_lens/nn/lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Dict, Generator, Optional, Union
from typing import Dict, Generator, Optional, Sequence, Union

import torch as th
from transformers import PreTrainedModel
Expand Down Expand Up @@ -47,6 +47,25 @@ def forward(self, h: th.Tensor, idx: int) -> th.Tensor:
"""Decode hidden states into logits."""
...

@th.inference_mode()
def forward_all(
self, hidden_states: Sequence[th.Tensor]
) -> list[th.Tensor]:
"""Decode hidden states from all layers into logits.

Convenience method that applies :meth:`forward` to each layer's
hidden states in a single call under ``torch.inference_mode``.

Args:
hidden_states: Sequence of hidden state tensors, one per layer.
Each tensor should have shape ``(batch, seq_len, d_model)``.

Returns:
List of logit tensors, one per layer, each of shape
``(batch, seq_len, vocab_size)``.
"""
return [self.forward(h, idx=i) for i, h in enumerate(hidden_states)]
Comment on lines +50 to +67
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward_all() enumerates indices starting at 0, which will raise IndexError for TunedLens if callers pass a full Hugging Face hidden_states tuple (commonly includes an extra final element and is typically used as hidden_states[:-1] elsewhere in the repo). Please clarify this expectation in the docstring (e.g., mention passing outputs.hidden_states[:-1]) or adjust the API to accept explicit layer indices/offset so "all layers" is unambiguous.

Copilot uses AI. Check for mistakes.


class LogitLens(Lens):
"""Unembeds the residual stream into logits."""
Expand Down Expand Up @@ -169,8 +188,22 @@ def __init__(
)

def __getitem__(self, item: int) -> th.nn.Module:
"""Get the probe module at the given index."""
return self.layer_translators[item]
"""Get the probe module at the given index.

Supports Python-style negative indexing (e.g., ``-1`` for the last
layer translator).

Args:
item: Layer index. Negative values count from the end.

Returns:
The translator module for the requested layer.

Raises:
IndexError: If the index is out of range.
"""
resolved_idx = self._resolve_idx(item)
return self.layer_translators[resolved_idx]

def __iter__(self) -> Generator[th.nn.Module, None, None]:
"""Get iterator over the translators within the lens."""
Expand Down Expand Up @@ -303,15 +336,57 @@ def save(
with open(path / config, "w") as f:
json.dump(self.config.to_dict(), f)

def _resolve_idx(self, idx: int) -> int:
"""Normalize a possibly-negative layer index.

Args:
idx: Layer index. Negative values count from the end.

Returns:
A non-negative layer index.

Raises:
IndexError: If the resolved index is out of range.
"""
num_layers = len(self.layer_translators)
resolved = idx if idx >= 0 else num_layers + idx
if resolved < 0 or resolved >= num_layers:
raise IndexError(
f"Layer index {idx} out of range for lens with "
f"{num_layers} translators."
)
return resolved

def transform_hidden(self, h: th.Tensor, idx: int) -> th.Tensor:
"""Transform hidden state from layer `idx`."""
"""Transform hidden state from layer ``idx``.

Supports negative indexing (e.g., ``-1`` for the last layer).

Args:
h: Hidden state tensor of shape ``(batch, seq_len, d_model)``.
idx: Layer index. Negative values count from the end.

Returns:
Transformed hidden state, same shape as input.
"""
# Note that we add the translator output residually, in contrast to the formula
# in the paper. By parametrizing it this way we ensure that weight decay
# regularizes the transform toward the identity, not the zero transformation.
idx = self._resolve_idx(idx)
return h + self[idx](h)

def forward(self, h: th.Tensor, idx: int) -> th.Tensor:
"""Transform and then decode the hidden states into logits."""
"""Transform and then decode the hidden states into logits.

Supports negative indexing (e.g., ``-1`` for the last layer).

Args:
h: Hidden state tensor of shape ``(batch, seq_len, d_model)``.
idx: Layer index. Negative values count from the end.

Returns:
Logit tensor of shape ``(batch, seq_len, vocab_size)``.
"""
h = self.transform_hidden(h, idx)
return self.unembed.forward(h)

Expand Down