feat: attribute-based architecture detection with better errors#145
feat: attribute-based architecture detection with better errors#145anthony-maio wants to merge 1 commit intoAlignmentResearch:mainfrom
Conversation
Replace isinstance checks in get_final_norm() and get_transformer_layers() with attribute-based probing that tries common attribute paths (model.norm, model.layers, model.h, etc.) before falling back to the original isinstance chain. This automatically supports new architectures (Phi, Qwen, StableLM, etc.) that follow standard naming conventions. Error messages now include the model class name and available attributes to help users debug unsupported architectures.
|
@codex review |
There was a problem hiding this comment.
Pull request overview
This PR refactors the model architecture detection in model_surgery.py to use attribute-based probing as the primary method for finding transformer components, with isinstance checks as a fallback. The goal is to automatically support new transformer architectures (like Phi, Qwen, StableLM) that follow standard naming conventions without requiring explicit code changes.
Changes:
- Replaced isinstance-first approach with attribute-based probing that checks common attribute names (norm, ln_f, layers, h, etc.)
- Added helper functions
_get_base_model,_try_attribute_norm, and_try_attribute_layersto encapsulate the detection logic - Enhanced error messages to include model class names and available attributes for easier debugging
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| tuned_lens/model_surgery.py | Refactored get_final_norm and get_transformer_layers to use attribute-based detection with isinstance fallback; added new helper functions and improved error messages |
| tests/test_model_surgery.py | Added tests for the new attribute-based detection functions and error handling behavior |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| base_model: The unwrapped base model to probe. | ||
|
|
||
| Returns: | ||
| A tuple of ``(path_components, module_list)`` or ``None``. |
There was a problem hiding this comment.
The docstring states that the function returns "A tuple of (path_components, module_list)" but the actual return type shows tuple[list[str], th.nn.ModuleList]. The term "path_components" in the docstring should be clarified to indicate it's a list of strings, not a single string path. Consider changing the Returns section to: "A tuple of (path_components, module_list) where path_components is a list of attribute names, or None."
| A tuple of ``(path_components, module_list)`` or ``None``. | |
| A tuple of ``(path_components, module_list)`` where ``path_components`` | |
| is a list of attribute names, or ``None``. |
| # Ordered list of (attribute_name, model_classes) for final norm detection. | ||
| # Attribute-based lookup is tried first; isinstance is the fallback. | ||
| _NORM_PATHS = [ | ||
| ("norm", ( | ||
| models.llama.modeling_llama.LlamaModel, | ||
| models.mistral.modeling_mistral.MistralModel, | ||
| models.gemma.modeling_gemma.GemmaModel, | ||
| )), | ||
| ("ln_f", ( | ||
| models.bloom.modeling_bloom.BloomModel, | ||
| models.gpt2.modeling_gpt2.GPT2Model, | ||
| models.gpt_neo.modeling_gpt_neo.GPTNeoModel, | ||
| models.gptj.modeling_gptj.GPTJModel, | ||
| )), | ||
| ("final_layer_norm", ( | ||
| models.gpt_neox.modeling_gpt_neox.GPTNeoXModel, | ||
| )), | ||
| ] | ||
|
|
||
| # Ordered list of (attribute_name, model_classes) for layer detection. | ||
| _LAYER_PATHS = [ | ||
| ("layers", ( | ||
| models.llama.modeling_llama.LlamaModel, | ||
| models.mistral.modeling_mistral.MistralModel, | ||
| models.gemma.modeling_gemma.GemmaModel, | ||
| models.gpt_neox.modeling_gpt_neox.GPTNeoXModel, | ||
| )), | ||
| ("h", ( | ||
| models.bloom.modeling_bloom.BloomModel, | ||
| models.gpt2.modeling_gpt2.GPT2Model, | ||
| models.gpt_neo.modeling_gpt_neo.GPTNeoModel, | ||
| models.gptj.modeling_gptj.GPTJModel, | ||
| )), | ||
| ] | ||
|
|
||
|
|
There was a problem hiding this comment.
The unused _NORM_PATHS and _LAYER_PATHS constants are defined but never referenced in the code. These constants appear to document the mapping between attribute names and model classes, but since the implementation uses attribute-based probing that tries common attribute names regardless of model type, these constants serve no functional purpose. Consider either removing them or adding a comment explaining that they're kept for documentation purposes only.
| # Ordered list of (attribute_name, model_classes) for final norm detection. | |
| # Attribute-based lookup is tried first; isinstance is the fallback. | |
| _NORM_PATHS = [ | |
| ("norm", ( | |
| models.llama.modeling_llama.LlamaModel, | |
| models.mistral.modeling_mistral.MistralModel, | |
| models.gemma.modeling_gemma.GemmaModel, | |
| )), | |
| ("ln_f", ( | |
| models.bloom.modeling_bloom.BloomModel, | |
| models.gpt2.modeling_gpt2.GPT2Model, | |
| models.gpt_neo.modeling_gpt_neo.GPTNeoModel, | |
| models.gptj.modeling_gptj.GPTJModel, | |
| )), | |
| ("final_layer_norm", ( | |
| models.gpt_neox.modeling_gpt_neox.GPTNeoXModel, | |
| )), | |
| ] | |
| # Ordered list of (attribute_name, model_classes) for layer detection. | |
| _LAYER_PATHS = [ | |
| ("layers", ( | |
| models.llama.modeling_llama.LlamaModel, | |
| models.mistral.modeling_mistral.MistralModel, | |
| models.gemma.modeling_gemma.GemmaModel, | |
| models.gpt_neox.modeling_gpt_neox.GPTNeoXModel, | |
| )), | |
| ("h", ( | |
| models.bloom.modeling_bloom.BloomModel, | |
| models.gpt2.modeling_gpt2.GPT2Model, | |
| models.gpt_neo.modeling_gpt_neo.GPTNeoModel, | |
| models.gptj.modeling_gptj.GPTJModel, | |
| )), | |
| ] |
| def test_attribute_norm_detection(random_small_model: PreTrainedModel): | ||
| """Attribute-based probing should find the same norm as isinstance checks.""" | ||
| norm = model_surgery._try_attribute_norm(random_small_model.base_model) | ||
| assert norm is not None | ||
| assert isinstance(norm, th.nn.Module) | ||
|
|
||
|
|
||
| def test_attribute_layer_detection(random_small_model: PreTrainedModel): | ||
| """Attribute-based probing should find the same layers as isinstance checks.""" | ||
| result = model_surgery._try_attribute_layers(random_small_model.base_model) | ||
| assert result is not None | ||
| path_components, layers = result | ||
| assert isinstance(layers, th.nn.ModuleList) | ||
| assert len(layers) == random_small_model.config.num_hidden_layers | ||
|
|
There was a problem hiding this comment.
The new tests only verify that attribute-based probing succeeds for known models in the test fixtures, but don't test the fallback behavior to isinstance checks. Consider adding a test case where attribute-based detection returns None (e.g., by mocking a model with non-standard attribute names) to verify that the isinstance fallback is correctly invoked. This would ensure both code paths work as intended.
| raise ValueError( | ||
| f"Model {type(model).__name__} does not have a `base_model` attribute. " | ||
| f"Available attributes: {available[:15]}. " |
There was a problem hiding this comment.
The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.
| raise ValueError( | |
| f"Model {type(model).__name__} does not have a `base_model` attribute. " | |
| f"Available attributes: {available[:15]}. " | |
| shown_attributes = available[:15] + (["..."] if len(available) > 15 else []) | |
| raise ValueError( | |
| f"Model {type(model).__name__} does not have a `base_model` attribute. " | |
| f"Available attributes (showing up to 15): {shown_attributes}. " |
| raise NotImplementedError( | ||
| f"Unsupported model architecture: {type(base_model).__name__}. " | ||
| f"Could not auto-detect a final layer norm via attribute probing. " | ||
| f"Available attributes: {available[:15]}. " |
There was a problem hiding this comment.
The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.
| f"Available attributes: {available[:15]}. " | |
| f"Available attributes (first 15): {available[:15]}. " |
| raise NotImplementedError( | ||
| f"Unsupported model architecture: {type(base_model).__name__}. " | ||
| f"Could not auto-detect transformer layers via attribute probing. " | ||
| f"Available attributes: {available[:15]}. " |
There was a problem hiding this comment.
The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.
| raise NotImplementedError( | |
| f"Unsupported model architecture: {type(base_model).__name__}. " | |
| f"Could not auto-detect transformer layers via attribute probing. " | |
| f"Available attributes: {available[:15]}. " | |
| displayed_available = available[:15] + (["..."] if len(available) > 15 else []) | |
| raise NotImplementedError( | |
| f"Unsupported model architecture: {type(base_model).__name__}. " | |
| f"Could not auto-detect transformer layers via attribute probing. " | |
| f"Available attributes (showing up to 15): {displayed_available}. " |
|
|
||
| def _try_attribute_layers( | ||
| base_model: th.nn.Module, | ||
| ) -> Optional[tuple[list[str], th.nn.ModuleList]]: |
There was a problem hiding this comment.
The codebase has inconsistent usage of type hints for tuples. Some files use Tuple from the typing module (e.g., trajectory_plotting.py:3), while this change uses the modern tuple[...] syntax. Although both are valid for Python 3.9+, consider standardizing on one approach for consistency. The modern lowercase syntax is preferred in PEP 585 and works with Python 3.9+, which matches the project's minimum version requirement.
Replace isinstance checks in get_final_norm() and
get_transformer_layers() with attribute-based probing that tries common attribute paths (model.norm, model.layers, model.h, etc.) before falling back to the original isinstance chain.
This automatically supports new architectures (Phi, Qwen, StableLM, etc.) that follow standard naming conventions. Error messages now include the model class name and available attributes to help users debug unsupported architectures.