Welcome to TamingSAE_GBA! This repository contains the source code to paper Taming Polysemanticity in LLMs: Provable Feature Recovery via Sparse Autoencoders. Inside, you will find the source code to the GBA implementation and the training script for a variety of experiments conducted in the paper.
Before running any experiments, make sure to load the Simtransformer submodule:
git submodule add https://github.com/FFishy-git/Simtransformer.git
git submodule update --init --recursiveImportant Note on Credentials: For security reasons, we have removed credential information from all job-creation notebooks (.ipynb files). Before running any experiments, you need to specify the following in the respective notebooks:
- For Pile-Qwen experiments:
- Model access credentials in
preprocessing_jobs.ipynb - Training configuration in
create_jobs.ipynb
- Model access credentials in
- For synthetic experiments:
- Experiment parameters in
create_jobs.ipynb - Evaluation settings in
eval.ipynb
- Experiment parameters in
Look for fields marked with <> in these notebooks and fill them with your specific values. The implementation files SAE_model_v2.py and SAETran_model_v2.py are not affected.
The core implementation of the Group-Based Autoencoder (GBA) algorithm can be found in the Group_SAE directory:
SAE_model_v2.py: Main implementation of the bias adaptation (BA) algorithmSAETran_model_v2.py: Enhanced version with neuron grouping (GBA).
To run experiments on the Pile-Qwen dataset:
-
Data Preprocessing:
- Use
preprocessing_jobs.ipynbin thePile-Qwen2.5-1.5B-hook-mlp-out-preprocessingdirectory - This will launch preprocessing jobs using
data_preprocess.py - The script will save the model's MLP output at specific layers over the given dataset. This process takes roughly 4~5 hours on a single H100/A100 GPU.
- Use
-
Training:
- Navigate to
Pile-Qwen2.5-1.5B-hook-mlp-out-SAE - Use
create_jobs.ipynbto set up and launch training jobs - The training is executed via
train_entry.py, which uses the GBA method. Necessary metrics like training/validation loss and sparsity are logged to wandb project for evaluation.
- Navigate to
To run synthetic experiments (using synthetic_freq_vs_d with number of features
-
Dataset Creation and Training:
- Navigate to
synthetic_freq_vs_d/d_sweep_n=65536_10M - Create the synthetic dataset
- Use
create_jobs.ipynbto launch grid experiments usingtrain_synthetic.py
- Navigate to
-
Evaluation:
- Use
eval.ipynbunder the same directory to analyze the results and plot the feature recovery rate (FRR) heatmap. - The evaluation includes feature recovery analysis and visualization
- Use
TamingSAE_GBA/
├── README.md
├── .gitignore
├── .gitmodules
├── create_jobs/ # Job creation scripts
│ ├── cancel_job.ipynb # Job cancellation notebook
│ └── prefix.py # Prefix configuration script
├── Group_SAE/ # Main implementation directory
│ ├── multi_stage_sae_v2.py
│ ├── multi_stage_v3.py
│ ├── SAE_model_v2.py
│ ├── eval_model.py
│ ├── multi_stage_sae.py
│ ├── SAETran_model_topk.py
│ ├── SAETran_model_v2.py
│ └── SAETran_model_v3.py
├── sythetic_topk_vs_sched/ # Synthetic experiments
├── visualization/ # Visualization tools
├── synthetic_rho2/ # Synthetic experiments
├── synthetic_ind_feat_occ/ # Synthetic experiments
├── synthetic_freq_vs_d/ # Synthetic experiments
│ ├── train.py # Main training script
│ ├── train_synthetic.py # Synthetic data training script
│ ├── freq_vs_d_vis.ipynb # Frequency vs dimension visualization
│ ├── eval_results/ # Evaluation results
│ │ ├── scaling_TargetFreq_vs_d_*.pdf
│ │ ├── learned_feats_percentage_s=3-M=512.csv
│ │ ├── p_vs_d_s=3_n=128_M=512.pdf
│ │ └── eval_individual.ipynb
│ ├── d_sweep_n=65536_10M/ # Large scale experiments
│ └── d_sweep_n=128_1M/ # Minimal scale experiments
├── synthetic_M_vs_s/ # Synthetic experiments
├── Simtransformer/ # Transformer implementation
├── Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/ # Data preprocessing
│ ├── preprocessing_jobs.ipynb # Preprocessing job management
│ ├── visualize_data_norm.ipynb # Data normalization visualization
│ ├── data_preprocess.py # Main preprocessing script
│ ├── gen_data_info.py # Data information generation
│ ├── save_validation.py # Validation data saving
│ └── data_peek.ipynb # Data exploration notebook
└── Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/ # Training
├── train_entry.py # Main training script
├── test_Z_score.ipynb # Z-score analysis notebook
├── create_group_ablation_jobs.ipynb # Group ablation job creation
├── create_multiseed_jobs.ipynb # Multi-seed job creation
├── create_jobs.ipynb # Job creation notebook
├── activation_compare.ipynb # Activation comparison
├── scatter_Z_score_vs_fraction_pre_act_ge_zero.png
├── scatter_Z_score_vs_fraction.png
├── gen_feat_dashboard.py # Feature dashboard generation
├── additional_stage.py # Additional training stage
├── understanding/ # Model understanding analysis
│ ├── scatter_Z_score_vs_max_cos_sim.pdf
│ ├── scatter_Z_score_vs_max_proj.pdf
│ └── scatter_Z_score_vs_fraction_pre_act_ge_zero.pdf
├── results/ # Experimental results
│ ├── tail_prob_max_cos_sim_max_proj_*.pdf
│ ├── tail_prob_max_cos_sim_Z_score_frac_*.pdf
│ ├── Pile-Qwen-legend.pdf
│ ├── get_first_page.ipynb
│ ├── tail_prob_max_cos_sim.pdf
│ └── Pile-Qwen-L*.pdf (various layer results)
└── eval_group_ablation/ # Group ablation studies
├── legend_only.pdf
├── sparsity_vs_num_groups.pdf
├── val_loss_vs_num_groups.pdf
├── val_loss_vs_sparsity.pdf
└── eval.ipynb