View source on GitHub |
A running R-hat diagnostic.
Inherits From: AutoCompositeTensor
tfp.experimental.stats.RunningPotentialScaleReduction( chain_variances, independent_chain_ndims ) RunningPotentialScaleReduction uses Gelman and Rubin (1992)'s potential scale reduction (also known as R-hat) for chain convergence [1].
If multiple independent R-hat computations are desired across a latent state, one should use a (possibly nested) collection for initialization parameters independent_chain_ndims and shape. Subsequent chain states used to update the streaming R-hat should mimic their identical structure.
RunningPotentialScaleReduction also assumes that incoming samples have shape [Ci1, Ci2,...,CiD] + A. Dimensions 0 through D - 1 index the Ci1 x ... x CiD independent chains to be tested for convergence to the same target. The remaining dimensions, A, represent the event shape and hence, can have any shape (even empty, which implies scalar samples). The number of independent chain dimensions is defined by the independent_chain_ndims parameter at initialization.
RunningPotentialScaleReduction is meant to serve general streaming R-hat. For a specialized version that fits streaming over MCMC samples, see PotentialScaleReductionReducer in tfp.experimental.mcmc.
References
[1]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation Using Multiple Sequences. Statistical Science, 7(4):457-472, 1992.
Methods
from_example
@classmethodfrom_example( example, independent_chain_ndims=1 )
Starts an empty RunningPotentialScaleReduction from metadata.
| Args | |
|---|---|
example | A Tensor. The RunningPotentialScaleReduction will accept samples of the same dtype and broadcast-compatible shape as the example. |
independent_chain_ndims | Integer or Integer type Tensor with value >= 1 giving the number of leading dimensions holding independent chain results to be tested for convergence. Using a collection implies that future samples will mimic that exact structure. |
| Returns | |
|---|---|
state | RunningPotentialScaleReduction representing a stream of no inputs. Note that by convention, the supplied example is used only for initialization, but not counted as a sample. |
from_shape
@classmethodfrom_shape( shape=(), independent_chain_ndims=1, dtype=tf.float32 )
Starts an empty RunningPotentialScaleReduction from metadata.
| Args | |
|---|---|
shape | Python Tuple or TensorShape representing the shape of incoming samples. Using a collection implies that future samples will mimic that exact structure. This is useful to supply if the RunningPotentialScaleReduction will be carried by a tf.while_loop, so that broadcasting does not change the shape across loop iterations. |
independent_chain_ndims | Integer or Integer type Tensor with value >= 1 giving the number of leading dimensions holding independent chain results to be tested for convergence. Using a collection implies that future samples will mimic that exact structure. |
dtype | Dtype of incoming samples and the resulting statistics. By default, the dtype is tf.float32. Any integer dtypes will be cast to corresponding floats (i.e. tf.int32 will be cast to tf.float32), as intermediate calculations should be performing floating-point division. |
| Returns | |
|---|---|
state | RunningPotentialScaleReduction representing a stream of no inputs. |
potential_scale_reduction
potential_scale_reduction() Computes the potential scale reduction for samples accumulated so far.
| Returns | |
|---|---|
rhat | An estimate of the R-hat. |
tree_flatten
tree_flatten() tree_unflatten
@classmethodtree_unflatten( metadata, tensors )
update
update( new_sample ) Update the RunningPotentialScaleReduction with a new sample.
| Args | |
|---|---|
new_sample | Incoming Tensor sample or (possibly nested) collection of Tensors with shape and dtype compatible with those used to form the RunningPotentialScaleReduction. |
| Returns | |
|---|---|
state | RunningPotentialScaleReduction updated to include the new sample. |
View source on GitHub