This project implements a minimal transformer framework to study the grokking phenomenon on modular arithmetic tasks. Grokking refers to delayed generalization behavior where a model first memorizes training data and only later achieves strong performance on unseen data.
We train a small transformer model to learn modular addition:
(x + y) mod p
Each input is represented as a sequence: [x, y, =]. The model predicts the result at the final token position.
This repository is designed to:
- Reproduce grokking behavior
- Compare optimizers (Adam, AdamW, SGD, L-BFGS)
- Study the effect of regularization and dataset size
- Visualize training vs. test accuracy dynamics
Grokking/
├── main.py # Entry point for training
├── plot.py # Plot accuracy curves from logs
├── logs/ # Saved experiment logs and plots
└── src/
├── attention.py # Multi-head attention implementation
├── dataset.py # Dataset generation and splitting
├── model.py # Toy Transformer model
├── train.py # Training loop and evaluation
└── logger.py # Logging utilities
- Python 3.8+
- PyTorch
- Matplotlib
pip install torch matplotlib
Run the default experiment: python main.py
Key Default Settings:
- Modulus: p = 97
- Model dimension: d_model = 128
- Attention heads: 4
- Optimizer: AdamW
- Learning rate: 1e-3
- Weight decay: 1.0
- Train fraction: 0.5
- Training steps: 1e6
You can override parameters via command line:
python main.py \
--optimizer adamw \
--lr 1e-3 \
--weight_decay 1.0 \
--regularizer none \
--train_frac 0.5Available Options:
--optimizer: adam, adamw, sgd, lbfgs
--regularizer: l1, l2, none
--train_frac: fraction of dataset used for training
--steps: number of training iterations
--d_model: model dimension
--nheads: number of attention headsAfter training, logs are saved as JSON files in the logs/ directory. To generate plots:
python plot.py logs/your_log.jsonOptional arguments:
python plot.py logs/your_log.json [max_step] [smooth_window]Example:
python plot.py logs/adamw_reg-none_frac-0.5.json 500000 40This will produce a PDF showing training vs. test accuracy.
Typical training dynamics include:
- Training accuracy reaches ~100% quickly.
- Test accuracy remains low for many steps.
- A sudden transition occurs where test accuracy rapidly improves.
This delayed generalization is known as grokking.
Each run produces a JSON log file containing:
- Loss: Cross-entropy loss and Regularization loss
- Accuracy: Training and Test accuracy
Filenames encode experiment settings for easy comparison.