View source on GitHub |
Runs one step of Metropolis-adjusted Langevin algorithm.
Inherits From: TransitionKernel
tfp.mcmc.MetropolisAdjustedLangevinAlgorithm( target_log_prob_fn, step_size, volatility_fn=None, parallel_iterations=10, experimental_shard_axis_names=None, name=None ) Metropolis-adjusted Langevin algorithm (MALA) is a Markov chain Monte Carlo (MCMC) algorithm that takes a step of a discretised Langevin diffusion as a proposal. This class implements one step of MALA using Euler-Maruyama method for a given current_state and diagonal preconditioning volatility matrix. Mathematical details and derivations can be found in [Roberts and Rosenthal (1998)][1] and [Xifara et al. (2013)][2].
See UncalibratedLangevin class description below for details on the proposal generating step of the algorithm.
The one_step function can update multiple chains in parallel. It assumes that all leftmost dimensions of current_state index independent chain states (and are therefore updated independently). The output of target_log_prob_fn(*current_state) should reduce log-probabilities across all event dimensions. Slices along the rightmost dimensions may have different target distributions; for example, current_state[0, :] could have a different target distribution from current_state[1, :]. These semantics are governed by target_log_prob_fn(*current_state). (The number of independent chains is tf.size(target_log_prob_fn(*current_state)).)
Examples:
Simple chain with warm-up.
In this example we sample from a standard univariate normal distribution using MALA with step_size equal to 0.75.
import tensorflow.compat.v2 as tf import tensorflow_probability as tfp import numpy as np import matplotlib.pyplot as plt tfd = tfp.distributions dtype = np.float32 # Target distribution is Standard Univariate Normal target = tfd.Normal(loc=dtype(0), scale=dtype(1)) def target_log_prob(x): return target.log_prob(x) # Define MALA sampler with `step_size` equal to 0.75 samples = tfp.mcmc.sample_chain( num_results=1000, current_state=dtype(1), kernel=tfp.mcmc.MetropolisAdjustedLangevinAlgorithm( target_log_prob_fn=target_log_prob, step_size=0.75), num_burnin_steps=500, trace_fn=None, seed=42) sample_mean = tf.reduce_mean(samples, axis=0) sample_std = tf.sqrt( tf.reduce_mean( tf.math.squared_difference(samples, sample_mean), axis=0)) print('sample mean', sample_mean) print('sample standard deviation', sample_std) plt.title('Traceplot') plt.plot(samples.numpy(), 'b') plt.xlabel('Iteration') plt.ylabel('Position') plt.show() Sample from a 3-D Multivariate Normal distribution.
In this example we also consider a non-constant volatility function.
import tensorflow.compat.v2 as tf import tensorflow_probability as tfp import numpy as np dtype = np.float32 true_mean = dtype([0, 0, 0]) true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]]) num_results = 500 num_chains = 500 # Target distribution is defined through the Cholesky decomposition chol = tf.linalg.cholesky(true_cov) target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol) # Here we define the volatility function to be non-constant def volatility_fn(x): # Stack the input tensors together return 1. / (0.5 + 0.1 * tf.math.abs(x)) # Initial state of the chain init_state = np.ones([num_chains, 3], dtype=dtype) # Run MALA with normal proposal for `num_results` iterations for # `num_chains` independent chains: states = tfp.mcmc.sample_chain( num_results=num_results, current_state=init_state, kernel=tfp.mcmc.MetropolisAdjustedLangevinAlgorithm( target_log_prob_fn=target.log_prob, step_size=.1, volatility_fn=volatility_fn), num_burnin_steps=200, num_steps_between_results=1, trace_fn=None, seed=42) sample_mean = tf.reduce_mean(states, axis=[0, 1]) x = (states - sample_mean)[..., tf.newaxis] sample_cov = tf.reduce_mean( tf.matmul(x, tf.transpose(x, [0, 1, 3, 2])), [0, 1]) print('sample mean', sample_mean.numpy()) print('sample covariance matrix', sample_cov.numpy()) References
[1]: Gareth Roberts and Jeffrey Rosenthal. Optimal Scaling of Discrete Approximations to Langevin Diffusions. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 60: 255-268, 1998. https://doi.org/10.1111/1467-9868.00123
[2]: T. Xifara et al. Langevin diffusions and the Metropolis-adjusted Langevin algorithm. arXiv preprint arXiv:1309.2983, 2013. https://arxiv.org/abs/1309.2983
Raises | |
|---|---|
ValueError | if there isn't one step_size or a list with same length as current_state. |
TypeError | if volatility_fn is not callable. |
Attributes | |
|---|---|
experimental_shard_axis_names | The shard axis names for members of the state. |
is_calibrated | Returns True if Markov chain converges to specified distribution.
|
name | |
parallel_iterations | |
parameters | Return dict of __init__ arguments and their values. |
step_size | |
target_log_prob_fn | |
volatility_fn | |
Methods
bootstrap_results
bootstrap_results( init_state ) Creates initial previous_kernel_results using a supplied state.
copy
copy( **override_parameter_kwargs ) Non-destructively creates a deep copy of the kernel.
| Args | |
|---|---|
**override_parameter_kwargs | Python String/value dictionary of initialization arguments to override with new values. |
| Returns | |
|---|---|
new_kernel | TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs). |
experimental_with_shard_axes
experimental_with_shard_axes( shard_axis_names ) Returns a copy of the kernel with the provided shard axis names.
| Args | |
|---|---|
shard_axis_names | a structure of strings indicating the shard axis names for each component of this kernel's state. |
| Returns | |
|---|---|
| A copy of the current kernel with the shard axis information. |
one_step
one_step( current_state, previous_kernel_results, seed=None ) Runs one iteration of MALA.
| Args | |
|---|---|
current_state | Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s). The first r dimensions index independent chains, r = tf.rank(target_log_prob_fn(*current_state)). |
previous_kernel_results | collections.namedtuple containing Tensors representing values from previous calls to this function (or from the bootstrap_results function.) |
seed | PRNG seed; see tfp.random.sanitize_seed for details. |
| Returns | |
|---|---|
next_state | Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as current_state. |
kernel_results | collections.namedtuple of internal calculations used to advance the chain. |
| Raises | |
|---|---|
ValueError | if there isn't one step_size or a list with same length as current_state or diffusion_drift. |
View source on GitHub