Skip to content

ShahryarBQ/Grokking

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Grokking with Toy Transformers

This project implements a minimal transformer framework to study the grokking phenomenon on modular arithmetic tasks. Grokking refers to delayed generalization behavior where a model first memorizes training data and only later achieves strong performance on unseen data.


📌 Overview

We train a small transformer model to learn modular addition:

(x + y) mod p

Each input is represented as a sequence: [x, y, =]. The model predicts the result at the final token position.

This repository is designed to:

  • Reproduce grokking behavior
  • Compare optimizers (Adam, AdamW, SGD, L-BFGS)
  • Study the effect of regularization and dataset size
  • Visualize training vs. test accuracy dynamics

📁 Project Structure

Grokking/
├── main.py           # Entry point for training
├── plot.py           # Plot accuracy curves from logs
├── logs/             # Saved experiment logs and plots
└── src/
    ├── attention.py  # Multi-head attention implementation
    ├── dataset.py    # Dataset generation and splitting
    ├── model.py      # Toy Transformer model
    ├── train.py      # Training loop and evaluation
    └── logger.py     # Logging utilities

⚙️ Installation

Requirements

  • Python 3.8+
  • PyTorch
  • Matplotlib

Install dependencies

pip install torch matplotlib


🚀 Running Experiments

Basic Training

Run the default experiment: python main.py

Key Default Settings:

  • Modulus: p = 97
  • Model dimension: d_model = 128
  • Attention heads: 4
  • Optimizer: AdamW
  • Learning rate: 1e-3
  • Weight decay: 1.0
  • Train fraction: 0.5
  • Training steps: 1e6

Custom Experiments

You can override parameters via command line:

python main.py \
    --optimizer adamw \
    --lr 1e-3 \
    --weight_decay 1.0 \
    --regularizer none \
    --train_frac 0.5

Available Options:

--optimizer: adam, adamw, sgd, lbfgs
--regularizer: l1, l2, none
--train_frac: fraction of dataset used for training
--steps: number of training iterations
--d_model: model dimension
--nheads: number of attention heads

📊 Plotting Results

After training, logs are saved as JSON files in the logs/ directory. To generate plots:

python plot.py logs/your_log.json

Optional arguments:

python plot.py logs/your_log.json [max_step] [smooth_window]

Example:

python plot.py logs/adamw_reg-none_frac-0.5.json 500000 40

This will produce a PDF showing training vs. test accuracy.


📈 Grokking Behavior

Typical training dynamics include:

  1. Training accuracy reaches ~100% quickly.
  2. Test accuracy remains low for many steps.
  3. A sudden transition occurs where test accuracy rapidly improves.

This delayed generalization is known as grokking.


📝 Logging

Each run produces a JSON log file containing:

  • Loss: Cross-entropy loss and Regularization loss
  • Accuracy: Training and Test accuracy

Filenames encode experiment settings for easy comparison.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages