Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/test_model_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,36 @@ def test_get_layers_from_model(random_small_model: PreTrainedModel):
assert isinstance(layers, th.nn.ModuleList)
assert isinstance(path, str)
assert len(layers) == random_small_model.config.num_hidden_layers


# --- Tests for attribute-based detection ---


def test_attribute_norm_detection(random_small_model: PreTrainedModel):
"""Attribute-based probing should find the same norm as isinstance checks."""
norm = model_surgery._try_attribute_norm(random_small_model.base_model)
assert norm is not None
assert isinstance(norm, th.nn.Module)


def test_attribute_layer_detection(random_small_model: PreTrainedModel):
"""Attribute-based probing should find the same layers as isinstance checks."""
result = model_surgery._try_attribute_layers(random_small_model.base_model)
assert result is not None
path_components, layers = result
assert isinstance(layers, th.nn.ModuleList)
assert len(layers) == random_small_model.config.num_hidden_layers

Comment on lines +29 to +43
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tests only verify that attribute-based probing succeeds for known models in the test fixtures, but don't test the fallback behavior to isinstance checks. Consider adding a test case where attribute-based detection returns None (e.g., by mocking a model with non-standard attribute names) to verify that the isinstance fallback is correctly invoked. This would ensure both code paths work as intended.

Copilot uses AI. Check for mistakes.

def test_get_base_model_missing():
"""A model without base_model should raise a helpful error."""
fake_model = th.nn.Linear(10, 10)
with pytest.raises(ValueError, match="does not have a `base_model`"):
model_surgery._get_base_model(fake_model)


def test_error_message_includes_attributes():
"""Error messages should list available model attributes."""
fake_model = th.nn.Linear(10, 10)
with pytest.raises(ValueError, match="Available attributes"):
model_surgery._get_base_model(fake_model)
224 changes: 191 additions & 33 deletions tuned_lens/model_surgery.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tools for finding and modifying components in a transformer model."""

from contextlib import contextmanager
from typing import Any, Generator, TypeVar, Union
from typing import Any, Generator, Optional, TypeVar, Union

try:
import transformer_lens as tl
Expand Down Expand Up @@ -70,6 +70,93 @@ def assign_key_path(model: T, key_path: str, value: Any) -> Generator[T, None, N
nn.Module,
]

# Ordered list of (attribute_name, model_classes) for final norm detection.
# Attribute-based lookup is tried first; isinstance is the fallback.
_NORM_PATHS = [
("norm", (
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
)),
("ln_f", (
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
)),
("final_layer_norm", (
models.gpt_neox.modeling_gpt_neox.GPTNeoXModel,
)),
]

# Ordered list of (attribute_name, model_classes) for layer detection.
_LAYER_PATHS = [
("layers", (
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
models.gpt_neox.modeling_gpt_neox.GPTNeoXModel,
)),
("h", (
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
)),
]


Comment on lines +73 to +108
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unused _NORM_PATHS and _LAYER_PATHS constants are defined but never referenced in the code. These constants appear to document the mapping between attribute names and model classes, but since the implementation uses attribute-based probing that tries common attribute names regardless of model type, these constants serve no functional purpose. Consider either removing them or adding a comment explaining that they're kept for documentation purposes only.

Suggested change
# Ordered list of (attribute_name, model_classes) for final norm detection.
# Attribute-based lookup is tried first; isinstance is the fallback.
_NORM_PATHS = [
("norm", (
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
)),
("ln_f", (
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
)),
("final_layer_norm", (
models.gpt_neox.modeling_gpt_neox.GPTNeoXModel,
)),
]
# Ordered list of (attribute_name, model_classes) for layer detection.
_LAYER_PATHS = [
("layers", (
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
models.gpt_neox.modeling_gpt_neox.GPTNeoXModel,
)),
("h", (
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
)),
]

Copilot uses AI. Check for mistakes.
def _get_base_model(model: Model) -> th.nn.Module:
"""Get the base model, raising a helpful error if not found.

Args:
model: A pretrained model or HookedTransformer.

Returns:
The base model module.

Raises:
ValueError: If the model has no ``base_model`` attribute.
"""
if not hasattr(model, "base_model"):
available = [a for a in dir(model) if not a.startswith("_")]
raise ValueError(
f"Model {type(model).__name__} does not have a `base_model` attribute. "
f"Available attributes: {available[:15]}. "
Comment on lines +123 to +125
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.

Suggested change
raise ValueError(
f"Model {type(model).__name__} does not have a `base_model` attribute. "
f"Available attributes: {available[:15]}. "
shown_attributes = available[:15] + (["..."] if len(available) > 15 else [])
raise ValueError(
f"Model {type(model).__name__} does not have a `base_model` attribute. "
f"Available attributes (showing up to 15): {shown_attributes}. "

Copilot uses AI. Check for mistakes.
f"If this is a custom model, please open an issue at: "
f"https://github.com/AlignmentResearch/tuned-lens/issues"
)
return model.base_model


def _try_attribute_norm(base_model: th.nn.Module) -> Optional[nn.Module]:
"""Try to find the final norm via attribute-based probing.

Checks common attribute names on the base model and its ``decoder``
sub-module (for OPT-style architectures). Returns the norm module if
found and it is an instance of ``nn.Module``, otherwise ``None``.

Args:
base_model: The unwrapped base model to probe.

Returns:
The final norm module, or ``None`` if not found.
"""
# Direct attributes on base_model (covers Llama, Mistral, Gemma, GPT-2, etc.)
for attr in ("norm", "ln_f", "final_layer_norm"):
norm = getattr(base_model, attr, None)
if norm is not None and isinstance(norm, nn.Module):
return norm

# OPT-style: base_model.decoder.final_layer_norm
decoder = getattr(base_model, "decoder", None)
if decoder is not None:
norm = getattr(decoder, "final_layer_norm", None)
if norm is not None and isinstance(norm, nn.Module):
return norm

return None


def get_unembedding_matrix(model: Model) -> nn.Linear:
"""The final linear tranformation from the model hidden state to the output."""
Expand All @@ -93,38 +180,60 @@ def get_unembedding_matrix(model: Model) -> nn.Linear:
def get_final_norm(model: Model) -> Norm:
"""Get the final norm from a model.

This isn't standardized across models, so this will need to be updated as
we add new models.
Uses attribute-based probing to detect the final normalization layer,
which makes this function forward-compatible with new architectures that
follow standard naming conventions. Falls back to ``isinstance`` checks
for known architectures.

Args:
model: A pretrained model or HookedTransformer.

Returns:
The final normalization module.

Raises:
ValueError: If the model has no ``base_model`` or the norm is ``None``.
NotImplementedError: If the architecture is not recognized.
"""
if _transformer_lens_available and isinstance(model, tl.HookedTransformer):
return model.ln_final

if not hasattr(model, "base_model"):
raise ValueError("Model does not have a `base_model` attribute.")
base_model = _get_base_model(model)

base_model = model.base_model
if isinstance(base_model, models.opt.modeling_opt.OPTModel):
final_layer_norm = base_model.decoder.final_layer_norm
elif isinstance(base_model, models.gpt_neox.modeling_gpt_neox.GPTNeoXModel):
final_layer_norm = base_model.final_layer_norm
elif isinstance(
base_model,
(
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
),
):
final_layer_norm = base_model.ln_f
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
final_layer_norm = base_model.norm
elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel):
final_layer_norm = base_model.norm
elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel):
final_layer_norm = base_model.norm
else:
raise NotImplementedError(f"Unknown model type {type(base_model)}")
# Try attribute-based detection first (handles new architectures automatically)
final_layer_norm = _try_attribute_norm(base_model)

# Fall back to isinstance checks for known architectures
if final_layer_norm is None:
if isinstance(base_model, models.opt.modeling_opt.OPTModel):
final_layer_norm = base_model.decoder.final_layer_norm
elif isinstance(base_model, models.gpt_neox.modeling_gpt_neox.GPTNeoXModel):
final_layer_norm = base_model.final_layer_norm
elif isinstance(
base_model,
(
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
),
):
final_layer_norm = base_model.ln_f
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
final_layer_norm = base_model.norm
elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel):
final_layer_norm = base_model.norm
elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel):
final_layer_norm = base_model.norm
else:
available = [a for a in dir(base_model) if not a.startswith("_")]
raise NotImplementedError(
f"Unsupported model architecture: {type(base_model).__name__}. "
f"Could not auto-detect a final layer norm via attribute probing. "
f"Available attributes: {available[:15]}. "
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.

Suggested change
f"Available attributes: {available[:15]}. "
f"Available attributes (first 15): {available[:15]}. "

Copilot uses AI. Check for mistakes.
f"Please open an issue at: "
f"https://github.com/AlignmentResearch/tuned-lens/issues"
)

if final_layer_norm is None:
raise ValueError("Model does not have a final layer norm.")
Expand All @@ -134,24 +243,66 @@ def get_final_norm(model: Model) -> Norm:
return final_layer_norm


def _try_attribute_layers(
base_model: th.nn.Module,
) -> Optional[tuple[list[str], th.nn.ModuleList]]:
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codebase has inconsistent usage of type hints for tuples. Some files use Tuple from the typing module (e.g., trajectory_plotting.py:3), while this change uses the modern tuple[...] syntax. Although both are valid for Python 3.9+, consider standardizing on one approach for consistency. The modern lowercase syntax is preferred in PEP 585 and works with Python 3.9+, which matches the project's minimum version requirement.

Copilot uses AI. Check for mistakes.
"""Try to find transformer layers via attribute-based probing.

Checks common attribute names on the base model and its ``decoder``
sub-module (for OPT-style architectures). Returns the path components
and the ``ModuleList`` if found.

Args:
base_model: The unwrapped base model to probe.

Returns:
A tuple of ``(path_components, module_list)`` or ``None``.
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring states that the function returns "A tuple of (path_components, module_list)" but the actual return type shows tuple[list[str], th.nn.ModuleList]. The term "path_components" in the docstring should be clarified to indicate it's a list of strings, not a single string path. Consider changing the Returns section to: "A tuple of (path_components, module_list) where path_components is a list of attribute names, or None."

Suggested change
A tuple of ``(path_components, module_list)`` or ``None``.
A tuple of ``(path_components, module_list)`` where ``path_components``
is a list of attribute names, or ``None``.

Copilot uses AI. Check for mistakes.
"""
# Direct attributes on base_model (covers Llama, Mistral, Gemma, GPT-2, etc.)
for attr in ("layers", "h"):
layers = getattr(base_model, attr, None)
if isinstance(layers, th.nn.ModuleList):
return [attr], layers

# OPT-style: base_model.decoder.layers
decoder = getattr(base_model, "decoder", None)
if decoder is not None:
layers = getattr(decoder, "layers", None)
if isinstance(layers, th.nn.ModuleList):
return ["decoder", "layers"], layers

return None


def get_transformer_layers(model: Model) -> tuple[str, th.nn.ModuleList]:
"""Get the decoder layers from a model.

Uses attribute-based probing to detect transformer layers, which makes
this function forward-compatible with new architectures. Falls back to
``isinstance`` checks for known architectures.

Args:
model: The model to search.

Returns:
A tuple containing the key path to the layer list and the list itself.

Raises:
ValueError: If no such list exists.
ValueError: If the model has no ``base_model`` attribute.
NotImplementedError: If the architecture is not recognized.
"""
# TODO implement this so that we can do hooked transformer training.
if not hasattr(model, "base_model"):
raise ValueError("Model does not have a `base_model` attribute.")
base_model = _get_base_model(model)

# Try attribute-based detection first
result = _try_attribute_layers(base_model)
if result is not None:
path_components, layers = result
path = ".".join(["base_model"] + path_components)
return path, layers

# Fall back to isinstance checks for known architectures
path_to_layers = ["base_model"]
base_model = model.base_model
if isinstance(base_model, models.opt.modeling_opt.OPTModel):
path_to_layers += ["decoder", "layers"]
elif isinstance(base_model, models.gpt_neox.modeling_gpt_neox.GPTNeoXModel):
Expand All @@ -173,7 +324,14 @@ def get_transformer_layers(model: Model) -> tuple[str, th.nn.ModuleList]:
elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel):
path_to_layers += ["layers"]
else:
raise NotImplementedError(f"Unknown model type {type(base_model)}")
available = [a for a in dir(base_model) if not a.startswith("_")]
raise NotImplementedError(
f"Unsupported model architecture: {type(base_model).__name__}. "
f"Could not auto-detect transformer layers via attribute probing. "
f"Available attributes: {available[:15]}. "
Comment on lines +328 to +331
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.

Suggested change
raise NotImplementedError(
f"Unsupported model architecture: {type(base_model).__name__}. "
f"Could not auto-detect transformer layers via attribute probing. "
f"Available attributes: {available[:15]}. "
displayed_available = available[:15] + (["..."] if len(available) > 15 else [])
raise NotImplementedError(
f"Unsupported model architecture: {type(base_model).__name__}. "
f"Could not auto-detect transformer layers via attribute probing. "
f"Available attributes (showing up to 15): {displayed_available}. "

Copilot uses AI. Check for mistakes.
f"Please open an issue at: "
f"https://github.com/AlignmentResearch/tuned-lens/issues"
)

path_to_layers = ".".join(path_to_layers)
return path_to_layers, get_key_path(model, path_to_layers)
Expand Down
Loading