Skip to content

hc495/ICL_head_tuning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mechanistic Fine-tuning for In-context Learning

This repo contains the official code for the following paper published at EMNLP 2025 BlackBoxNLP Workshop:

Hakaze Cho, et al. "Mechanistic Fine-tuning for In-context Learning." EMNLP 2025 BlackBoxNLP Workshop, 2025.

Implemented by Hakaze Cho, the primary contributor of the paper.

Overview

Abstract

In-context Learning (ICL) utilizes structured demonstration-query inputs to induce few-shot learning on Language Models (LMs), which are not originally pre-trained on ICL-style data. To bridge the gap between ICL and pre-training, some approaches fine-tune LMs on large ICL-style datasets by an end-to-end paradigm with massive computational costs. To reduce such costs, in this paper, we propose Attention Behavior Fine-Tuning (ABFT), utilizing the previous findings on the inner mechanism of ICL, building training objectives on the attention scores instead of the final outputs, to force the attention scores to focus on the correct label tokens presented in the context and mitigate attention scores from the wrong label tokens. Our experiments on 9 modern LMs and 8 datasets empirically find that ABFT outperforms in performance, robustness, unbiasedness, and efficiency, with only around 0.01% data cost compared to the previous methods. Moreover, our subsequent analysis finds that the end-to-end training objective contains the ABFT objective, suggesting the implicit bias of ICL-style data to the emergence of induction heads. Our work demonstrates the possibility of controlling specific module sequences within LMs to improve their behavior, opening up the future application of mechanistic interpretability.

Summary figure

Diagram of ABFT framework. (A) An example of ICL-style inputs. We build datasets from such examples to fine-tune models. (B) Feed-forward inference of ICL. We collect the attention scores of every attention head in every layer to calculate the training objective. and we only enable the gradient of the $W_Q$ and $W_K$ matrices. (C) The criterion for induction head. Only attention heads producing attention scores with a significant focus on the label tokens can be identified as induction heads. (D) Loss calculation of ABFT. Only induction heads return a non-zero loss, and such loss contains a punishment on "wrong" attention scores to wrong label tokens, and a reward on ``correct'' attention scores to correct label tokens.

Set Up

0. Requirement

  1. A GPU with more than 22GB VRAM and CUDA (Ver. 12.4 recommended) are strongly required to run all the experiments.
  2. Network connection to huggingface is needed to download the pre-trained model. And a huggingface user token with access to the Llama Family model is recommended to run a part of the experiments.
  3. Anaconda or Miniconda is needed.

1. Clone the repository

git clone https://github.com/hc495/ICL_head_tuning.git
cd ICL_head_tuning

2. Environment Installation

conda env create -f environment.yaml
conda activate icl_head_ft

Experiments

Parameters

Parameter Type Default Description
--ICL_model_name str Name of the ICL model, from huggingface (please set your token in logs/default_config.py)
--ICL_dataset_index int Index of the ICL dataset, from the StaICC library
--method str "ABFT" Inference method, "ABFT" or "E2E" (end to end fine-tuning)
--ICL_k int default_config.ICL_k Demonstration numbers for ICL inputs
--quantized store_true default_config.quantized Whether to use a quantized model
--correct_label_award float default_config.correct_label_award Reward for correct label attention, $B$ in paper
--wrong_label_penalty float default_config.wrong_label_penalty Penalty for wrong label attention, $A$ in paper
--train_sample_num int default_config.train_sample_num Number of training samples
--lr float default_config.lr Learning rate
--pseudo_batch_size int default_config.pseudo_batch_size Pseudo batch size
--epoch int default_config.epoch Number of training epochs
--random_seed int default_config.random_seed Random seed
--no_fore_test store_true default_config.no_fore_test Skip accuracy test before training
--no_post_test store_true default_config.no_post_test Skip accuracy test after training
--test_type str default_config.test_type Test type (Normal, TempSensitivity, DemoSensitivity), refer to StaICC library
--generalization_test store_true default_config.generalization_test Test generalization on all 8 datasets ($\text{ACC}_\text{OD}$) after training
--dont_save_model store_true default_config.dont_save_model Do not save the model in the results
--restricted_gradient store_true default_config.restricted_gradient Restrict gradient updates to current induction head only, but not all the modules with gradient
--self_adaption_loss str default_config.self_adaption_loss Use adaptive loss and algorithm (P, PID). If not used, set to None, and utilize no self-adaption loss
--loss_adaption_factor_for_P float default_config.loss_adaption_factor Loss adaptation factor for P method
--PID_param_P float default_config.PID_param_P P factor in PID
--PID_param_I float default_config.PID_param_I I factor in PID
--PID_param_D float default_config.PID_param_D D factor in PID
--lora_r int default_config.lora_r LoRA rank
--MAUVE_test store_true default_config.MAUVE_test Whether to run MAUVE test (before and after fine-tuning, slow)
--MAUVE_prefix str default_config.MAUVE_test_prefix Prefix file for MAUVE test generation

Example

python prototype.py \
    --ICL_model_name "Qwen/Qwen2.5-32B" \
    --ICL_k 4 \
    --ICL_dataset_index 0 \
    --method "ABFT" \
    --correct_label_award 1.0 \
    --wrong_label_penalty 0.5 \
    --train_sample_num 512 \
    --lr 2e-5 \
    --pseudo_batch_size 32 \
    --epoch 2 \
    --quantized \
    --random_seed 43 \
    --dont_save_model \
    --generalization_test \
    --no_fore_test \
    --self_adaption_loss "PID"

Citation

If you find this work useful for your research, please cite our paper:

@inproceedings{cho2025mechanistic,
    title={Mechanistic Fine-tuning for In-context Learning},
    author={Cho, Hakaze and Luo, Peng and Kato, Mariko and Kaenbyou, Rin and Inoue, Naoya},
    booktitle={Proceedings of the 8th BlackboxNLP Workshop: Analyzing and Interpreting Neural Networks for NLP},
    year={2025}
}

About

[EMNLP 2025 BlackBox NLP workshop] Official code implementation of paper: "Mechanistic Fine-tuning for In-context Learning"

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages