View source on GitHub |
Runs one step of Uncalibrated Hamiltonian Monte Carlo.
Inherits From: TransitionKernel
tfp.mcmc.UncalibratedHamiltonianMonteCarlo( target_log_prob_fn, step_size, num_leapfrog_steps, state_gradients_are_stopped=False, store_parameters_in_results=False, experimental_shard_axis_names=None, name=None ) For more details on UncalibratedHamiltonianMonteCarlo, see HamiltonianMonteCarlo.
Args | |
|---|---|
target_log_prob_fn | Python callable which takes an argument like current_state (or *current_state if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. |
step_size | Tensor or Python list of Tensors representing the step size for the leapfrog integrator. Must broadcast with the shape of current_state. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. |
num_leapfrog_steps | Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to step_size * num_leapfrog_steps. |
state_gradients_are_stopped | Python bool indicating that the proposed new state be run through tf.stop_gradient. This is particularly useful when combining optimization over samples from the HMC chain. Default value: False (i.e., do not apply stop_gradient). |
store_parameters_in_results | If True, then step_size and num_leapfrog_steps are written to and read from eponymous fields in the kernel results objects returned from one_step and bootstrap_results. This allows wrapper kernels to adjust those parameters on the fly. |
experimental_shard_axis_names | A structure of string names indicating how members of the state are sharded. |
name | Python str name prefixed to Ops created by this function. Default value: None (i.e., 'hmc_kernel'). |
Attributes | |
|---|---|
experimental_shard_axis_names | The shard axis names for members of the state. |
is_calibrated | Returns True if Markov chain converges to specified distribution.
|
name | |
num_leapfrog_steps | Returns the num_leapfrog_steps parameter. If |
parameters | Return dict of __init__ arguments and their values. |
state_gradients_are_stopped | |
step_size | Returns the step_size parameter. If |
target_log_prob_fn | |
Methods
bootstrap_results
bootstrap_results( init_state ) Creates initial previous_kernel_results using a supplied state.
copy
copy( **override_parameter_kwargs ) Non-destructively creates a deep copy of the kernel.
| Args | |
|---|---|
**override_parameter_kwargs | Python String/value dictionary of initialization arguments to override with new values. |
| Returns | |
|---|---|
new_kernel | TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs). |
experimental_with_shard_axes
experimental_with_shard_axes( shard_axis_names ) Returns a copy of the kernel with the provided shard axis names.
| Args | |
|---|---|
shard_axis_names | a structure of strings indicating the shard axis names for each component of this kernel's state. |
| Returns | |
|---|---|
| A copy of the current kernel with the shard axis information. |
one_step
one_step( current_state, previous_kernel_results, seed=None ) Runs one iteration of Hamiltonian Monte Carlo.
| Args | |
|---|---|
current_state | Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s). The first r dimensions index independent chains, r = tf.rank(target_log_prob_fn(*current_state)). |
previous_kernel_results | collections.namedtuple containing Tensors representing values from previous calls to this function (or from the bootstrap_results function.) |
seed | PRNG seed; see tfp.random.sanitize_seed for details. |
| Returns | |
|---|---|
next_state | Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as current_state. |
kernel_results | collections.namedtuple of internal calculations used to advance the chain. |
| Raises | |
|---|---|
ValueError | if there isn't one step_size or a list with same length as current_state. |
View source on GitHub