Welcome to Blackjax!

Welcome to Blackjax!#

Warning

The documentation corresponds to the current state of the main branch. There may be differences with the latest released version.

Blackjax is a library of samplers for JAX that works on CPU as well as GPU. It is designed with two categories of users in mind:

  • People who just need state-of-the-art samplers that are fast, robust and well tested;

  • Researchers who can use the library’s building blocks to design new algorithms.

It integrates really well with PPLs as long as they can provide a (potentially unnormalized) log-probability density function compatible with JAX.

Hello World#

import jax import jax.numpy as jnp import jax.scipy.stats as stats import numpy as np import blackjax observed = np.random.normal(10, 20, size=1_000) def logdensity_fn(x): logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"]) return jnp.sum(logpdf) # Build the kernel step_size = 1e-3 inverse_mass_matrix = jnp.array([1., 1.]) nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix) # Initialize the state initial_position = {"loc": 1., "scale": 2.} state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) step = jax.jit(nuts.step) for i in range(1_000): nuts_key = jax.random.fold_in(rng_key, i) state, _ = step(nuts_key, state) 

Note

If you want to use Blackjax with a model implemented with a PPL, go to the related tutorials in the left menu.

Installation#

pip install blackjax 

Conda

conda install blackjax -c conda-forge 

GPU instructions

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the version of JAX that will be installed along with BlackJAX will make your code run on CPU only. If you want to use BlackJAX on GPU/TPU we recommend you follow these instructions to install JAX with the relevant hardware acceleration support.