Conversation
adamimos
commented
Sep 30, 2025
- configs: interpolate vocab size from generator in penzai transformer/gruformer YAMLs
- run: inject vocab from generator; add BOS/EOS and seq_len consistency checks; route PyTorch models to torch trainer with n_ctx alignment
- torch: move batch to model device; ensure token dtype long in training and evaluation
- configs: add TransformerLens example (predictive_model, training, evaluation) and mess3_085 generator; include convenience top-level TL config
- docs: add TransformerLens example usage and notes
… checks; route PyTorch models to torch trainer with n_ctx alignment
…luation) and mess3_085 generator; include convenience top-level TL config
Code Review for PR #82: feat/context consistency and tlThank you for this PR! The implementation of TransformerLens support and context consistency checks adds valuable functionality. Here's my detailed review: ✅ Strengths
🔍 Suggestions for Improvement
|
|
@claude codex did this PR and it's ugly. You are so much better at it. Please do your work. |
|
Claude finished @adamimos's task —— View job Code Review for PR #82: feat/context consistency and tlI'll provide a comprehensive analysis of this PR's implementation of TransformerLens support and context consistency improvements. Todo List:
🎯 Overall AssessmentThis PR successfully integrates TransformerLens support into the codebase while adding robust consistency checks. The implementation is thoughtful and maintains backward compatibility while introducing new capabilities. 🏗️ Architecture & DesignStrengths
Areas for Improvement1. Exception Handling SpecificityIssue: The broad exception catching reduces debuggability and masks potential issues. # Current (run.py:97-98)
try:
import torch.nn as nn
except Exception:
nn = NoneRecommendation: Be specific about expected failures: try:
import torch.nn as nn
except ImportError:
nn = None # PyTorch not installed, use JAX pathThis pattern appears in multiple places (lines 54-56, 111-113) and should be consistently improved. 2. Configuration Validation ArchitectureIssue: The vocab size setting logic (lines 41-56) is complex with nested conditionals that are hard to test. Recommendation: Extract to a dedicated function: def inject_vocab_size(cfg: DictConfig, vocab_size: int) -> None:
"""Inject vocab size into model config, handling various config structures."""
inst = cfg.predictive_model.instance
# Direct attribute
if hasattr(inst, "vocab_size") or (isinstance(inst, DictConfig) and "vocab_size" in inst):
inst.vocab_size = vocab_size
# Penzai-style
if isinstance(inst, DictConfig) and "config" in inst and "vocab_size" in inst["config"]:
inst.config.vocab_size = vocab_size
# TransformerLens-style
if isinstance(inst, DictConfig) and "cfg" in inst and "d_vocab" in inst["cfg"]:
inst.cfg.d_vocab = vocab_size3. Device Handling RobustnessIssue: Device detection in evaluation (evaluate_pytorch_model.py:28-31) could fail on models without parameters. try:
device = next(model.parameters()).device
except Exception:
device = NoneRecommendation: Handle StopIteration explicitly: try:
device = next(model.parameters()).device
except StopIteration:
# Model has no parameters, default to CPU
device = torch.device('cpu')🔍 Detailed Code Analysisrun.py ChangesLines 41-56: The vocab size injection is well-thought-out but could benefit from logging when injection succeeds/fails for debugging. Line 42: Using model = typed_instantiate(cfg.predictive_model.instance, object) # Use object as base typeLines 65-74: Good defensive programming with assertions, but consider using if cfg.validation is not None and cfg.validation.sequence_len != cfg.training.sequence_len:
raise ValueError(
f"validation.sequence_len ({cfg.validation.sequence_len}) must match "
f"training.sequence_len ({cfg.training.sequence_len}) for consistent context"
)PyTorch Training/EvaluationGood Practices:
Potential Improvement: In both files, the device/dtype handling pattern is duplicated. Consider a utility function: def prepare_batch_for_model(model: torch.nn.Module, inputs: torch.Tensor, labels: torch.Tensor):
"""Prepare batch tensors for model processing."""
device = next(model.parameters()).device if any(model.parameters()) else torch.device('cpu')
inputs = inputs.to(device).long()
labels = labels.to(device).long()
return inputs, labelsConfiguration Filestransformer_lens_2L2H.yaml:
transformerlens.yaml:
|