Numpyro deep probabilistic programming made easier.
- Python >= 3.5 (3.6 recommended)
- numpyro >= 0.18.0
- tqdm (Optional for
test.py) - tensorboard >= 1.14 (see Tensorboard Visualization)
- Clear folder structure which is suitable for configurable probabilistic programming projects.
.yamlconfig file support for convenient parameter tuning.- Customizable command line options for more convenient parameter tuning.
- Checkpoint saving and resuming.
- Abstract base classes for faster development:
Trainerhandles training process logging and more.DataModulehandles data shuffling and validation data splitting.ParaMonadhandles checkpoint saving/resuming, updating of mutable parameters, and JAX RNG keys.
numpyro_template/
│
├── train.py - main script to start training
├── test.py - evaluation of trained model
│
├── configs/ - Hydra configuration files for , models, guides, parametric monads, and trainers
│ ├── data/ - data-module configurations
│ ├── guide/ - variational guide program or sampler configurations
│ ├── model/ - generative model program configurations
│ ├── monad/ - state monad configurations for RNG keys, mutable parameters and optimizer states, etc.
| └── trainer/ - trainer class configurations
│
├── data/ - default directory for storing input data
│
├── notebooks/ - Jupyter notebooks showing off results
│ └── vae.ipynb - an example with a trained Variational Autoencoder
|
├── src/ - core source code in Python
│ ├── data/ - data modules and the core `DataModule` class
│ ├── logger/ - logger source
│ ├── model/ - model source code in Numpyro
│ ├── trainer/ - source code to `ParaMonad`, `Trainer`, and their subclasses
| └── utils/ - small utility functions
├── util.py
└── ...
This project is licensed under the MIT License. See LICENSE for more details
This project is inspired by the project Tensorflow-Project-Template by Mahmoud Gemy