This repository contains code for reproducing results reported in our preprint on RPL, a JEPA model of predictive processing.
We have provided the code as a python package. To setup your python environment and run the code, clone this repository and follow the following steps:
# Go to the directory
cd recurrent-predictive-learning
# Install the package including all the dependencies via pip
pip install .
We recommend to install the package in a separate project-specific virtual environment.
Note that you will need a CUDA-compatible hardware to run the code smoothly and
mps support for Mac users is currently unavailable.
You can train all models for the corresponding data using the provided bash scripts:
bash bash_scripts/run_moving_animals.shtrains and evaluates all models on the moving animal videos (cf. Fig. 2 and Supplementary Fig. S2)bash bash_scripts/run_mnist_triplets.shtrains and evaluates all models on the MNIST triplet sequences (cf. Fig. 4)bash bash_scripts/run_mouse.shtrains and evaluates all models on the videos of behaving mice (cf. Fig. 5a,b)bash bash_scripts/run_libri.shtrains all models on the Librispeech corpus (cf. Fig 5c,d)bash bash_scripts/run_oddballs.shtrains the model for the local-global oddball paradigm (cf. Fig. 7)bash bash_scripts/run_hRPL.shtrains and evaluates hRPL models and the frozen control models on the moving animal videos (cf. Fig. 8)
To evaluate each model individually, you can use:
# For RPL
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 repl/scripts/offline_eval.py [options]
# For hRPL
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 repl/scripts/offline_greedy_eval.py [options]
Use --help to see all available arguments and make sure to pass --model_path <PATH-TO-MODEL>/model_final.pt.
Use the provided jupyter notebooks in the notebooks folder to generate the figures
in the paper or do further analysis. The notebooks are named by referring to the
corresponding figure number in the manuscript.