-
Notifications
You must be signed in to change notification settings - Fork 67
feat: attribute-based architecture detection with better errors #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,7 +1,7 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Tools for finding and modifying components in a transformer model.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from contextlib import contextmanager | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Generator, TypeVar, Union | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Generator, Optional, TypeVar, Union | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import transformer_lens as tl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -70,6 +70,93 @@ def assign_key_path(model: T, key_path: str, value: Any) -> Generator[T, None, N | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nn.Module, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Ordered list of (attribute_name, model_classes) for final norm detection. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Attribute-based lookup is tried first; isinstance is the fallback. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _NORM_PATHS = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ("norm", ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.llama.modeling_llama.LlamaModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.mistral.modeling_mistral.MistralModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gemma.modeling_gemma.GemmaModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ("ln_f", ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.bloom.modeling_bloom.BloomModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gpt2.modeling_gpt2.GPT2Model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gpt_neo.modeling_gpt_neo.GPTNeoModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gptj.modeling_gptj.GPTJModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ("final_layer_norm", ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gpt_neox.modeling_gpt_neox.GPTNeoXModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Ordered list of (attribute_name, model_classes) for layer detection. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _LAYER_PATHS = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ("layers", ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.llama.modeling_llama.LlamaModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.mistral.modeling_mistral.MistralModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gemma.modeling_gemma.GemmaModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gpt_neox.modeling_gpt_neox.GPTNeoXModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ("h", ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.bloom.modeling_bloom.BloomModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gpt2.modeling_gpt2.GPT2Model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gpt_neo.modeling_gpt_neo.GPTNeoModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| models.gptj.modeling_gptj.GPTJModel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| )), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+73
to
+108
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Ordered list of (attribute_name, model_classes) for final norm detection. | |
| # Attribute-based lookup is tried first; isinstance is the fallback. | |
| _NORM_PATHS = [ | |
| ("norm", ( | |
| models.llama.modeling_llama.LlamaModel, | |
| models.mistral.modeling_mistral.MistralModel, | |
| models.gemma.modeling_gemma.GemmaModel, | |
| )), | |
| ("ln_f", ( | |
| models.bloom.modeling_bloom.BloomModel, | |
| models.gpt2.modeling_gpt2.GPT2Model, | |
| models.gpt_neo.modeling_gpt_neo.GPTNeoModel, | |
| models.gptj.modeling_gptj.GPTJModel, | |
| )), | |
| ("final_layer_norm", ( | |
| models.gpt_neox.modeling_gpt_neox.GPTNeoXModel, | |
| )), | |
| ] | |
| # Ordered list of (attribute_name, model_classes) for layer detection. | |
| _LAYER_PATHS = [ | |
| ("layers", ( | |
| models.llama.modeling_llama.LlamaModel, | |
| models.mistral.modeling_mistral.MistralModel, | |
| models.gemma.modeling_gemma.GemmaModel, | |
| models.gpt_neox.modeling_gpt_neox.GPTNeoXModel, | |
| )), | |
| ("h", ( | |
| models.bloom.modeling_bloom.BloomModel, | |
| models.gpt2.modeling_gpt2.GPT2Model, | |
| models.gpt_neo.modeling_gpt_neo.GPTNeoModel, | |
| models.gptj.modeling_gptj.GPTJModel, | |
| )), | |
| ] |
Copilot
AI
Feb 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.
| raise ValueError( | |
| f"Model {type(model).__name__} does not have a `base_model` attribute. " | |
| f"Available attributes: {available[:15]}. " | |
| shown_attributes = available[:15] + (["..."] if len(available) > 15 else []) | |
| raise ValueError( | |
| f"Model {type(model).__name__} does not have a `base_model` attribute. " | |
| f"Available attributes (showing up to 15): {shown_attributes}. " |
Copilot
AI
Feb 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.
| f"Available attributes: {available[:15]}. " | |
| f"Available attributes (first 15): {available[:15]}. " |
Copilot
AI
Feb 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The codebase has inconsistent usage of type hints for tuples. Some files use Tuple from the typing module (e.g., trajectory_plotting.py:3), while this change uses the modern tuple[...] syntax. Although both are valid for Python 3.9+, consider standardizing on one approach for consistency. The modern lowercase syntax is preferred in PEP 585 and works with Python 3.9+, which matches the project's minimum version requirement.
Copilot
AI
Feb 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring states that the function returns "A tuple of (path_components, module_list)" but the actual return type shows tuple[list[str], th.nn.ModuleList]. The term "path_components" in the docstring should be clarified to indicate it's a list of strings, not a single string path. Consider changing the Returns section to: "A tuple of (path_components, module_list) where path_components is a list of attribute names, or None."
| A tuple of ``(path_components, module_list)`` or ``None``. | |
| A tuple of ``(path_components, module_list)`` where ``path_components`` | |
| is a list of attribute names, or ``None``. |
Copilot
AI
Feb 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error messages display up to 15 available attributes but don't clarify that the list may be truncated. Consider adding an indication when the list is truncated, such as changing available[:15] to available[:15] + (['...'] if len(available) > 15 else []) or adding text like "first 15 attributes" to make it clear that not all attributes are shown.
| raise NotImplementedError( | |
| f"Unsupported model architecture: {type(base_model).__name__}. " | |
| f"Could not auto-detect transformer layers via attribute probing. " | |
| f"Available attributes: {available[:15]}. " | |
| displayed_available = available[:15] + (["..."] if len(available) > 15 else []) | |
| raise NotImplementedError( | |
| f"Unsupported model architecture: {type(base_model).__name__}. " | |
| f"Could not auto-detect transformer layers via attribute probing. " | |
| f"Available attributes (showing up to 15): {displayed_available}. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new tests only verify that attribute-based probing succeeds for known models in the test fixtures, but don't test the fallback behavior to isinstance checks. Consider adding a test case where attribute-based detection returns None (e.g., by mocking a model with non-standard attribute names) to verify that the isinstance fallback is correctly invoked. This would ensure both code paths work as intended.