Install with the appropriate command. If JAX isn't installed already, we recommend using one of the dl4bi[<jax-version>] installs.
pip install dl4bi # dl4bi
pip install dl4bi[cpu] # dl4bi + jax for CPU
pip install dl4bi[cuda12] # dl4bi + jax for CUDA-12
pip install dl4bi[cuda13] # dl4bi + jax for CUDA-13
pip install dl4bi[benchmarks,cpu] # benchmark deps + jax for CPU
pip install dl4bi[benchmarks,cuda12] # benchmark deps + jax for CUDA-12
pip install dl4bi[benchmarks,cuda13] # benchmark deps + jax for CUDA-13The benchmarks extra installs the additional packages used by the benchmark scripts, especially under benchmarks/meta_learning and benchmarks/vae.
git clone git@github.com:MLGlobalHealth/dl4bi.git
cd dl4bi
uv run --with pdoc pdoc --docformat google --math dl4biIf you're using this package or some of its code, please cite the relevant paper(s):
@inproceedings{jenson2026scalable,
title = {Scalable Spatiotemporal Inference with Biased Scan Attention Transformer Neural Processes},
author = {Jenson, Daniel and Navott, Jhonathan and Grynfelder, Piotr and Zhang, Mengyan and Sharma, Makkunda and Semenova, Elizaveta and Flaxman, Seth},
booktitle = {Proceedings of the 29th International Conference on Artificial Intelligence and Statistics},
series = {Proceedings of Machine Learning Research},
volume = {300},
address = {Tangier, Morocco},
year = {2026},
publisher = {PMLR}
}
@inproceedings{navott2026deeprv,
title = {{DeepRV}: Accelerating Spatiotemporal Inference with Pre-trained Neural Priors},
author = {Navott, Jhonathan and Jenson, Daniel and Flaxman, Seth and Semenova, Elizaveta},
booktitle = {Proceedings of the 29th International Conference on Artificial Intelligence and Statistics},
series = {Proceedings of Machine Learning Research},
volume = {300},
address = {Tangier, Morocco},
year = {2026},
publisher = {PMLR}
}Benchmarks are available for BSA-TNP here and for DeepRV here.
- Install
uv:curl -LsSf https://astral.sh/uv/install.sh | sh - Clone the repository and
cdinto it:git clone git@github.com:MLGlobalHealth/dl4bi.git && cd dl4bi - Install the latest Python with
uv:uv python install - Sync the project environment:
- CPU JAX:
uv sync --extra cpu - CUDA 12 JAX:
uv sync --extra cuda12 - CUDA 13 JAX:
uv sync --extra cuda13 - Benchmark deps + CPU JAX:
uv sync --extra benchmarks --extra cpu - Benchmark deps + CUDA 12 JAX:
uv sync --extra benchmarks --extra cuda12 - Benchmark deps + CUDA 13 JAX:
uv sync --extra benchmarks --extra cuda13
- CPU JAX:
uv synccreates.venv, installs the project in editable mode, includes the defaultdevdependency group, and picks a Python interpreter compatible with the project'srequires-python- Before making changes, install the shared development hooks:
uv run pre-commit install --install-hooks - Verify the hook setup once per clone with:
uv run pre-commit run --all-files - Keep the hooks installed for local development; commits on
mainrunpytest -q teststhrough the sharedpre-commitsetup - Run project commands through
uv, e.g.uv run pytestoruv run python gp.py - If you want to activate the virtualenv directly, use
source .venv/bin/activate
Create a local .env file with the publish tokens:
TEST_PYPI_TOKEN=pypi-...
PYPI_TOKEN=pypi-...Run the release helper from a clean main checkout:
uv run python scripts/release.py .env "AISTATS 2026"The helper bumps the patch version, commits and tags v<version> <message>,
rebuilds dist/, publishes to TestPyPI and PyPI, pushes main and the tag,
and smoke-tests the published install targets.