Skip to content

FFishy-git/TamingSAE_GBA

Repository files navigation

TamingSAE_GBA

Welcome to TamingSAE_GBA! This repository contains the source code to paper Taming Polysemanticity in LLMs: Provable Feature Recovery via Sparse Autoencoders. Inside, you will find the source code to the GBA implementation and the training script for a variety of experiments conducted in the paper.

Implementation and Usage

Setup

Before running any experiments, make sure to load the Simtransformer submodule:

git submodule add https://github.com/FFishy-git/Simtransformer.git
git submodule update --init --recursive

Important Note on Credentials: For security reasons, we have removed credential information from all job-creation notebooks (.ipynb files). Before running any experiments, you need to specify the following in the respective notebooks:

  • For Pile-Qwen experiments:
    • Model access credentials in preprocessing_jobs.ipynb
    • Training configuration in create_jobs.ipynb
  • For synthetic experiments:
    • Experiment parameters in create_jobs.ipynb
    • Evaluation settings in eval.ipynb

Look for fields marked with <> in these notebooks and fill them with your specific values. The implementation files SAE_model_v2.py and SAETran_model_v2.py are not affected.

GBA Algorithm Implementation

The core implementation of the Group-Based Autoencoder (GBA) algorithm can be found in the Group_SAE directory:

  • SAE_model_v2.py: Main implementation of the bias adaptation (BA) algorithm
  • SAETran_model_v2.py: Enhanced version with neuron grouping (GBA).

Running Pile-Qwen Experiments

To run experiments on the Pile-Qwen dataset:

  1. Data Preprocessing:

    • Use preprocessing_jobs.ipynb in the Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing directory
    • This will launch preprocessing jobs using data_preprocess.py
    • The script will save the model's MLP output at specific layers over the given dataset. This process takes roughly 4~5 hours on a single H100/A100 GPU.
  2. Training:

    • Navigate to Pile-Qwen2.5-1.5B-hook-mlp-out-SAE
    • Use create_jobs.ipynb to set up and launch training jobs
    • The training is executed via train_entry.py, which uses the GBA method. Necessary metrics like training/validation loss and sparsity are logged to wandb project for evaluation.

Running Synthetic Experiments

To run synthetic experiments (using synthetic_freq_vs_d with number of features $n=65536$ as an example):

  1. Dataset Creation and Training:

    • Navigate to synthetic_freq_vs_d/d_sweep_n=65536_10M
    • Create the synthetic dataset
    • Use create_jobs.ipynb to launch grid experiments using train_synthetic.py
  2. Evaluation:

    • Use eval.ipynb under the same directory to analyze the results and plot the feature recovery rate (FRR) heatmap.
    • The evaluation includes feature recovery analysis and visualization

Repository Structure

TamingSAE_GBA/
├── README.md
├── .gitignore
├── .gitmodules
├── create_jobs/                    # Job creation scripts
│   ├── cancel_job.ipynb           # Job cancellation notebook
│   └── prefix.py                  # Prefix configuration script
├── Group_SAE/                      # Main implementation directory
│   ├── multi_stage_sae_v2.py
│   ├── multi_stage_v3.py
│   ├── SAE_model_v2.py
│   ├── eval_model.py
│   ├── multi_stage_sae.py
│   ├── SAETran_model_topk.py
│   ├── SAETran_model_v2.py
│   └── SAETran_model_v3.py
├── sythetic_topk_vs_sched/        # Synthetic experiments
├── visualization/                  # Visualization tools
├── synthetic_rho2/                 # Synthetic experiments
├── synthetic_ind_feat_occ/         # Synthetic experiments
├── synthetic_freq_vs_d/            # Synthetic experiments
│   ├── train.py                   # Main training script
│   ├── train_synthetic.py         # Synthetic data training script
│   ├── freq_vs_d_vis.ipynb        # Frequency vs dimension visualization
│   ├── eval_results/             # Evaluation results
│   │   ├── scaling_TargetFreq_vs_d_*.pdf
│   │   ├── learned_feats_percentage_s=3-M=512.csv
│   │   ├── p_vs_d_s=3_n=128_M=512.pdf
│   │   └── eval_individual.ipynb
│   ├── d_sweep_n=65536_10M/      # Large scale experiments
│   └── d_sweep_n=128_1M/         # Minimal scale experiments
├── synthetic_M_vs_s/               # Synthetic experiments
├── Simtransformer/                # Transformer implementation
├── Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/  # Data preprocessing
│   ├── preprocessing_jobs.ipynb    # Preprocessing job management
│   ├── visualize_data_norm.ipynb   # Data normalization visualization
│   ├── data_preprocess.py         # Main preprocessing script
│   ├── gen_data_info.py          # Data information generation
│   ├── save_validation.py        # Validation data saving
│   └── data_peek.ipynb           # Data exploration notebook
└── Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/           # Training 
    ├── train_entry.py                            # Main training script
    ├── test_Z_score.ipynb                        # Z-score analysis notebook
    ├── create_group_ablation_jobs.ipynb          # Group ablation job creation
    ├── create_multiseed_jobs.ipynb               # Multi-seed job creation
    ├── create_jobs.ipynb                         # Job creation notebook
    ├── activation_compare.ipynb                  # Activation comparison
    ├── scatter_Z_score_vs_fraction_pre_act_ge_zero.png
    ├── scatter_Z_score_vs_fraction.png
    ├── gen_feat_dashboard.py                     # Feature dashboard generation
    ├── additional_stage.py                       # Additional training stage
    ├── understanding/                            # Model understanding analysis
    │   ├── scatter_Z_score_vs_max_cos_sim.pdf
    │   ├── scatter_Z_score_vs_max_proj.pdf
    │   └── scatter_Z_score_vs_fraction_pre_act_ge_zero.pdf
    ├── results/                                  # Experimental results
    │   ├── tail_prob_max_cos_sim_max_proj_*.pdf
    │   ├── tail_prob_max_cos_sim_Z_score_frac_*.pdf
    │   ├── Pile-Qwen-legend.pdf
    │   ├── get_first_page.ipynb
    │   ├── tail_prob_max_cos_sim.pdf
    │   └── Pile-Qwen-L*.pdf (various layer results)
    └── eval_group_ablation/                      # Group ablation studies
        ├── legend_only.pdf
        ├── sparsity_vs_num_groups.pdf
        ├── val_loss_vs_num_groups.pdf
        ├── val_loss_vs_sparsity.pdf
        └── eval.ipynb

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors