Skip to content

feat: attribute-based architecture detection with better errors#145

Open
anthony-maio wants to merge 1 commit intoAlignmentResearch:mainfrom
anthony-maio:pr/attribute-detection
Open

feat: attribute-based architecture detection with better errors#145
anthony-maio wants to merge 1 commit intoAlignmentResearch:mainfrom
anthony-maio:pr/attribute-detection

Conversation

@anthony-maio
Copy link

Replace isinstance checks in get_final_norm() and
get_transformer_layers() with attribute-based probing that tries common attribute paths (model.norm, model.layers, model.h, etc.) before falling back to the original isinstance chain.

This automatically supports new architectures (Phi, Qwen, StableLM, etc.) that follow standard naming conventions. Error messages now include the model class name and available attributes to help users debug unsupported architectures.

Replace isinstance checks in get_final_norm() and
get_transformer_layers() with attribute-based probing that tries
common attribute paths (model.norm, model.layers, model.h, etc.)
before falling back to the original isinstance chain.

This automatically supports new architectures (Phi, Qwen, StableLM,
etc.) that follow standard naming conventions. Error messages now
include the model class name and available attributes to help users
debug unsupported architectures.
Copilot AI review requested due to automatic review settings February 15, 2026 09:08
@anthony-maio
Copy link
Author

@codex review

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the model architecture detection in model_surgery.py to use attribute-based probing as the primary method for finding transformer components, with isinstance checks as a fallback. The goal is to automatically support new transformer architectures (like Phi, Qwen, StableLM) that follow standard naming conventions without requiring explicit code changes.

Changes:

  • Replaced isinstance-first approach with attribute-based probing that checks common attribute names (norm, ln_f, layers, h, etc.)
  • Added helper functions _get_base_model, _try_attribute_norm, and _try_attribute_layers to encapsulate the detection logic
  • Enhanced error messages to include model class names and available attributes for easier debugging

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.

File Description
tuned_lens/model_surgery.py Refactored get_final_norm and get_transformer_layers to use attribute-based detection with isinstance fallback; added new helper functions and improved error messages
tests/test_model_surgery.py Added tests for the new attribute-based detection functions and error handling behavior

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

base_model: The unwrapped base model to probe.

Returns:
A tuple of ``(path_components, module_list)`` or ``None``.
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring states that the function returns "A tuple of (path_components, module_list)" but the actual return type shows tuple[list[str], th.nn.ModuleList]. The term "path_components" in the docstring should be clarified to indicate it's a list of strings, not a single string path. Consider changing the Returns section to: "A tuple of (path_components, module_list) where path_components is a list of attribute names, or None."

Suggested change
A tuple of ``(path_components, module_list)`` or ``None``.
A tuple of ``(path_components, module_list)`` where ``path_components``
is a list of attribute names, or ``None``.

Copilot uses AI. Check for mistakes.
Comment on lines +73 to +108
# Ordered list of (attribute_name, model_classes) for final norm detection.
# Attribute-based lookup is tried first; isinstance is the fallback.
_NORM_PATHS = [
("norm", (
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
)),
("ln_f", (
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
)),
("final_layer_norm", (
models.gpt_neox.modeling_gpt_neox.GPTNeoXModel,
)),
]

# Ordered list of (attribute_name, model_classes) for layer detection.
_LAYER_PATHS = [
("layers", (
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
models.gpt_neox.modeling_gpt_neox.GPTNeoXModel,
)),
("h", (
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
)),
]


Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unused _NORM_PATHS and _LAYER_PATHS constants are defined but never referenced in the code. These constants appear to document the mapping between attribute names and model classes, but since the implementation uses attribute-based probing that tries common attribute names regardless of model type, these constants serve no functional purpose. Consider either removing them or adding a comment explaining that they're kept for documentation purposes only.

Suggested change
# Ordered list of (attribute_name, model_classes) for final norm detection.
# Attribute-based lookup is tried first; isinstance is the fallback.
_NORM_PATHS = [
("norm", (
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
)),
("ln_f", (
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
)),
("final_layer_norm", (
models.gpt_neox.modeling_gpt_neox.GPTNeoXModel,
)),
]
# Ordered list of (attribute_name, model_classes) for layer detection.
_LAYER_PATHS = [
("layers", (
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
models.gpt_neox.modeling_gpt_neox.GPTNeoXModel,
)),
("h", (
models.bloom.modeling_bloom.BloomModel,
models.gpt2.modeling_gpt2.GPT2Model,
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
models.gptj.modeling_gptj.GPTJModel,
)),
]

Copilot uses AI. Check for mistakes.
Comment on lines +29 to +43
def test_attribute_norm_detection(random_small_model: PreTrainedModel):
"""Attribute-based probing should find the same norm as isinstance checks."""
norm = model_surgery._try_attribute_norm(random_small_model.base_model)
assert norm is not None
assert isinstance(norm, th.nn.Module)


def test_attribute_layer_detection(random_small_model: PreTrainedModel):
"""Attribute-based probing should find the same layers as isinstance checks."""
result = model_surgery._try_attribute_layers(random_small_model.base_model)
assert result is not None
path_components, layers = result
assert isinstance(layers, th.nn.ModuleList)
assert len(layers) == random_small_model.config.num_hidden_layers

Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tests only verify that attribute-based probing succeeds for known models in the test fixtures, but don't test the fallback behavior to isinstance checks. Consider adding a test case where attribute-based detection returns None (e.g., by mocking a model with non-standard attribute names) to verify that the isinstance fallback is correctly invoked. This would ensure both code paths work as intended.

Copilot uses AI. Check for mistakes.
Comment on lines +123 to +125
raise ValueError(
f"Model {type(model).__name__} does not have a `base_model` attribute. "
f"Available attributes: {available[:15]}. "
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.

Suggested change
raise ValueError(
f"Model {type(model).__name__} does not have a `base_model` attribute. "
f"Available attributes: {available[:15]}. "
shown_attributes = available[:15] + (["..."] if len(available) > 15 else [])
raise ValueError(
f"Model {type(model).__name__} does not have a `base_model` attribute. "
f"Available attributes (showing up to 15): {shown_attributes}. "

Copilot uses AI. Check for mistakes.
raise NotImplementedError(
f"Unsupported model architecture: {type(base_model).__name__}. "
f"Could not auto-detect a final layer norm via attribute probing. "
f"Available attributes: {available[:15]}. "
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.

Suggested change
f"Available attributes: {available[:15]}. "
f"Available attributes (first 15): {available[:15]}. "

Copilot uses AI. Check for mistakes.
Comment on lines +328 to +331
raise NotImplementedError(
f"Unsupported model architecture: {type(base_model).__name__}. "
f"Could not auto-detect transformer layers via attribute probing. "
f"Available attributes: {available[:15]}. "
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.

Suggested change
raise NotImplementedError(
f"Unsupported model architecture: {type(base_model).__name__}. "
f"Could not auto-detect transformer layers via attribute probing. "
f"Available attributes: {available[:15]}. "
displayed_available = available[:15] + (["..."] if len(available) > 15 else [])
raise NotImplementedError(
f"Unsupported model architecture: {type(base_model).__name__}. "
f"Could not auto-detect transformer layers via attribute probing. "
f"Available attributes (showing up to 15): {displayed_available}. "

Copilot uses AI. Check for mistakes.

def _try_attribute_layers(
base_model: th.nn.Module,
) -> Optional[tuple[list[str], th.nn.ModuleList]]:
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codebase has inconsistent usage of type hints for tuples. Some files use Tuple from the typing module (e.g., trajectory_plotting.py:3), while this change uses the modern tuple[...] syntax. Although both are valid for Python 3.9+, consider standardizing on one approach for consistency. The modern lowercase syntax is preferred in PEP 585 and works with Python 3.9+, which matches the project's minimum version requirement.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants