Deep probabilistic programming in JAX with GenJAX, made easier.
A template for configurable probabilistic-programming projects: Hydra configs, a PyTorch-Lightning-style learner abstraction, amortized variational inference, and static-graph-structure introspection — all on top of GenJAX's generative-function interface (GFI) and ADEV gradient estimators.
- Python >= 3.12
uvfor environment managementgenjax(and thereforejax >= 0.7.2, < 0.8— GenJAX's PJAX internals are not yet compatible withjax >= 0.9)optax,hydra-core,networkx,tensorboard,torch/torchvision(data), plus the rest declared inpyproject.toml
By default genjax is referenced as a local editable checkout at ../genjax
(see [tool.uv.sources] in pyproject.toml). To use a shared checkout instead,
switch that entry to the git source noted there.
uv sync # create the environment
uv run python -m src.train # train (Hydra entrypoint; see configs/train.yaml)
uv run pytest # run the testsHydra makes every component swappable from the command line, e.g.:
uv run python -m src.train learner=graphical model=vae guide=vae data=mnist- GenJAX models & guides written as
@gengenerative functions (src/model/model.py— a VAE:p(x, z)and an amortized guideq(z | x)). .yamlconfig files (Hydra) for models, guides, learners, data, and trainers, with command-line overrides.- Abstract base classes for fast iteration:
ParamLearner— the trainable-module interface (train_step/valid_step/test_step,save/load,parameters).DataModule— data shuffling and train/val/test splitting.Trainer— the training loop, checkpointing, and TensorBoard logging.
- Learners (
src/learner/):SviLearner— amortized stochastic variational inference. The ELBO is built fromguide.simulate+model.assesswrapped in an ADEV@expectation, sograd_estimateyields reparameterized gradients; parameters are optimized withoptax. Supports an IWAE bound (num_particles > 1).GraphicalModelLearner— trains exactly likeSviLearner, and additionally captures the model's static graph structure at setup (see below).
- Checkpoint saving / resuming and TensorBoard visualization.
GraphicalModelLearner introspects the generative model and exposes a
networkx digraph (.graph, .relations, .render_model()):
- Nodes — one per sample-site address, annotated with its distribution family, value shape, and whether it is latent or observed (observed = model sites the guide does not sample).
- Edges —
A -> Biff siteA's value influences siteB's distribution. Edges are recovered by dataflow provenance over the per-site log-density jaxpr (not autodiff), so dependencies routed through discrete parents are captured too — e.g. aCategoricalindex feeding a child's mean. This mirrors NumPyro'seval_provenanceand is sound acrosscond/scan/pjit.
See tests/test_structure.py for continuous, discrete, and independence cases.
genjax_template/
│
├── src/
│ ├── train.py - training entrypoint (Hydra)
│ ├── test.py - evaluation entrypoint
│ ├── data/ - DataModule + dataset implementations (e.g. MNIST)
│ ├── inference/ - GenJAX-native inference helpers
│ ├── learner/ - ParamLearner ABC, SviLearner, GraphicalModelLearner
│ ├── logger/ - logging / TensorBoard writer
│ ├── model/ - GenJAX @gen models and guides
│ ├── trainer/ - Trainer (loop, checkpointing, logging)
│ └── utils/ - small utilities
│
├── configs/ - Hydra configs
│ ├── data/ guide/ learner/ model/ trainer/
│ ├── experiment/ extras/ hydra/ paths/
│ └── train.yaml eval.yaml
│
├── notebooks/vae.ipynb - VAE example
├── tests/ - pytest tests (run with `uv run pytest`)
├── pyproject.toml - project + uv configuration
└── data/ - default input-data directory
This project is licensed under the MIT License. See LICENSE for more details.
Ported to GenJAX from a NumPyro template, itself inspired by Tensorflow-Project-Template by Mahmoud Gemy.