Documentation | Quick Start | Example Notebooks | Change Log | Branches | Reference
This is the Python JAX implementation for DiBS (Lorch et al., 2021), a fully differentiable method for joint Bayesian inference of the DAG and parameters of general, causal Bayesian networks. In this implementation, DiBS inference is performed with SVGD (Liu and Wang, 2016). Since DiBS and SVGD operate on continuous tensors and solely rely on Monte Carlo estimation and gradient ascent-like updates, the inference code leverages efficient vectorized operations, automatic differentiation, just-in-time compilation, and hardware acceleration, fully implemented with JAX.
To install the latest stable release, run:
pip install dibs-lib
The documentation is linked here:
Documentation.
If you work on Apple Silicon, we recommend using the compatible JAX versions jax==0.2.10 jaxlib==0.1.60
The following code snippet demonstrates how to use the dibs
package.
In this example, we use DiBS to generate 10 DAG and parameter samples
from the joint posterior over Gaussian Bayes nets with means modeled
by neural networks.
from dibs.inference import JointDiBS
from dibs.target import make_nonlinear_gaussian_model
import jax.random as random
key = random.PRNGKey(0)
# simulate some data
key, subk = random.split(key)
data, model = make_nonlinear_gaussian_model(key=subk, n_vars=20)
# sample 10 DAG and parameter particles from the joint posterior
dibs = JointDiBS(x=data.x, interv_mask=None, inference_model=model)
key, subk = random.split(key)
gs, thetas = dibs.sample(key=subk, n_particles=10, steps=1000)
The argument x
for JointDiBS
is a matrix of shape [N, d]
and could
be any real-world data set. interv_mask
is a binary mask of the same shape that indicates
whether or not a variable was intervened upon in a given sample (interv_mask=None
indicates observational data and is
equivalent to interv_mask=jax.numpy.zeros_like(x)
).
Try out a working example notebook in Google Colab, which runs directly from your browser.
Whenever a GPU backend is available to JAX, dibs
will automatically leverage it to accelerate its computations,
so you can select the free GPU runtime available in Google Colab for speed-up.
Analogous notebooks can be found inside the examples/
folder.
Executing the code will generate samples from the joint posterior with DiBS and
simultaneously visualize the matrices of edge probabilities modeled by the individual particles
that are transported by SVGD during inference.
-
4 Jul 2022: Inference from interventional data via the interventional log (marginal) likelihood, assuming known, hard interventions. To model soft or random interventions, the likelihoods in the model classes in
dibs/models/
can be easily modified. -
14 Mar 2022: Published to PyPI
-
12 Mar 2022: Extended BGe marginal likelihood to be well-defined inside the probability simplex. The computation remains exact for binary entries but is well-behaved for soft relaxations of the graph. This allows reparameterization (Gumbel-softmax) gradient estimation for the BGe score.
-
14 Dec 2021: Documentation added
The repository consists of two branches:
master
(recommended, on PyPI): Lightweight and easy-to-use package for using DiBS in your research or applications.full
: Comprehensive code to reproduce the experimental results in (Lorch et al., 2021). The purpose of this branch is reproducibility; the branch is not updated anymore and may contain outdated notation and documentation.
The latest stable release is published on PyPI, so the best way to install dibs
is using pip
as shown above.
For custom installations, we recommend using conda
and generating a new environment via environment.yml
.
Next, clone the code repository:
git clone https://github.com/larslorch/dibs.git
Finally, install the dibs
package with
pip install -e .
If you find this code useful, please cite our paper:
@article{lorch2021dibs,
title={DiBS: Differentiable Bayesian Structure Learning},
author={Lorch, Lars and Rothfuss, Jonas and Sch{\"o}lkopf, Bernhard and Krause, Andreas},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}