Skip to content

annahedstroem/MERA-steering

Repository files navigation



Inference-time Steering for Language Models

PyTorch



This repository contains the code and experiments for the paper "To Steer or Not to Steer? Mechanistic Error Reduction with Abstention for Language Models" by Hedström et al., (2025).

Getting started! Python version Code style: black

Please note that this repository is under active development!

Citation

If you find this work interesting or useful in your research, use the following Bibtex annotation to cite us:

@inproceedings{
  hedstrom2025to,
  title={To Steer or Not to Steer? Mechanistic Error Reduction with Abstention for Language Models},
  author={Anna Hedstr{\"o}m and Salim I. Amoukou and Tom Bewley and Saumitra Mishra and Manuela Veloso},
  booktitle={Forty-second International Conference on Machine Learning},
  year={2025},
  url={https://openreview.net/forum?id=fUCPq5RvmH}
}

This work has been published in International Conference on Machine Learning (ICML) 2025.

Repository overview

The repository is organised as follows:

  • The src/ folder contains all necessary functions.
  • The nbs/ folder includes notebooks for generating the plots in the paper and for benchmarking experiments.
  • The tests/ folder contains the tests.

Paper highlights 📚

Our approach consists of three main steps. First, we use linear probes to obtain an effective direction for minimising the predicted error. Second, this direction is then scaled using the closed-form solution at both the token and layer levels. Third, we calibrate the steering threshold against the probe's error on a calibration dataset, informed by the user's tolerance for uncertainty.

The main benefits of MERA are:

  • Selective steering — steer only if the probe's estimated error is larger than a calibrated threshold α
  • Adaptive strength — the steering intensity λ scales with the probe's estimated error
  • Global abstention — steer only if when confident, such that performance is at least ε with probability 1-δ

Installation

Install the necessary packages using the provided requirements.txt:

conda create --name mera python==3.10
conda activate mera 
pip install torch --extra-index-url https://download.pytorch.org/whl/cu122  
pip install -r requirements.txt

If you want to run SAE experiments, also install:

pip install -e git+https://github.com/jbloomAus/SAELens.git#egg=sae-lens
pip install --force-reinstall --no-cache-dir cffi

Package requirements

Required packages are:

python
torch
transformers
datasets
huggingface_hub
accelerate
wandb

Getting started

If you want to try using MERA with your own dataset and model, go to the following notebook nbs/getting_started.py.

To steer with MERA on any of the existing datasets and models (see supported datasets and models here), run the following script:

python mera.py --dataset_names yes_no_question --model_name google/gemma-2-2b

How to reproduce experimental results

In the following, we describe how to reproduce the results in the paper. It requires

  • access to wandb (and that you pass your API key
  • that datasets are downloaded and saved to a MERA-steering/hf_cache folder here)
  • that you have a runs/ folder to save results

Create a runs/ folder at the root and go to src/ folder.

mkdir runs/
cd src
Step 1. Prepare datasets for probe training For each model, to prepare datasets for probe training (see supported datasets and models [here](#supported-models-and-datasets)) run the following script:
python -m cache.cache_run --dataset_names sentiment_analysis yes_no_question mmlu_high_school sms_spam --nr_samples 3000 --model_name meta-llama/Llama-3.2-1B-Instruct --hf_token INSERT_KEY device--cuda:1
python -m cache.cache_run --dataset_names mmlu_professional --nr_samples 2601 --model_name Qwen/Qwen2.5-3B-Instruct --hf_token INSERT_KEY device--cuda:4

Just rerun with the different models (see supported datasets and models here).

Next, post-processes the cache data (i.e., subselect activation values based on token positions ("last" of the prompt and "exact" of the answer)), making the cached files significantly smaller in size in preparation for probe training.

python -m cache.cache_postprocess --save_cache_key 2601 --model_names "Qwen/Qwen2.5-3B-Instruct" --dataset_names mmlu_professional 
Step 2. Train linear probes For each model, to train linear probes (error estimators), run the following script:
python -m probes.probes_train --dataset_names sms_spam --save_name trans --transform_targets True --save_cache_key 3000

if you want to change any of the hyperparameters, please edit the script probes_train.py directly.

To analyse the performance of the probes, go to the following notebook nbs/evaluate_probes.py.

Step 3. Benchmark steering methods For each model, to benchmark steering methods, run the following script:
python -m steering.steering_run --steering_methods optimal_probe --dataset_names sms_spam --model_names "meta-llama/Llama-3.2-1B-Instruct" --fname custom_experiment --probe_token_pos exact --wandb_key INSERT_KEY

python -m steering.steering_run --steering_methods no_steering additive_probe additive_logistic_probe vanilla_contrastive prompt_steering optimal_probe optimal_logistic_probe optimal_contrastive --dataset_names sentiment_analysis  --fname final --probe_token_pos exact --wandb_key INSERT_KEY --probe_file_name df_probes_trans --nr_test_samples 250 --nr_ref_samples 250 --device cuda:6

To analyse the performance of the steering methods, go to the following notebook nbs/evaluate_steering.py.

Dataset and models

How to download datasets

To download the datasets, please follow the instructions here. If you want to include additional datasets, please follow the guide here).

Supported datasets and models

Currently, we support these datasets:

  • sentiment_analysis
  • yes_no_question
  • mmlu_high_school
  • sms_spam

Our experiments work with these models:

  • google/gemma-2-2b
  • google/gemma-2-2b-it
  • Qwen/Qwen2.5-3B
  • Qwen/Qwen2.5-3B-Instruct
  • meta-llama/Llama-3.2-1B
  • meta-llama/Llama-3.2-1B-Instruct ... but other HuggingFace decoder-only LM models that contain blocks with a residual stream would be compatible with our current implementation (see register_hooks in here).

Thank you

We hope our repository is beneficial to your work and research. If you have any feedback, questions, or ideas, please feel free to raise an issue in this repository. Alternatively, you can reach out to us directly via email for discussions or suggestions.

📧 Contact us:

Thank you for your interest and support!

About

Code and notebooks for Mechanistic Error Reduction with Abstention (MERA) — for effective and safe language model steering 🔥

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors