View source on GitHub |
Computes diagonal of the Jacobian matrix of ys=fn(xs) wrt xs.
tfp.math.diag_jacobian( xs, ys=None, sample_shape=None, fn=None, parallel_iterations=10, name=None ) If ys is a tensor or a list of tensors of the form (ys_1, .., ys_n) and xs is of the form (xs_1, .., xs_n), the function jacobians_diag computes the diagonal of the Jacobian matrix, i.e., the partial derivatives (dys_1/dxs_1,.., dys_n/dxs_n). For definition details, see https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant
Example
Diagonal Hessian of the log-density of a 3D Gaussian distribution
In this example we sample from a standard univariate normal distribution using MALA with step_size equal to 0.75.
import tensorflow as tf import tensorflow_probability as tfp import numpy as np tfd = tfp.distributions dtype = np.float32 with tf.Session(graph=tf.Graph()) as sess: true_mean = dtype([0, 0, 0]) true_cov = dtype([[1, 0.25, 0.25], [0.25, 2, 0.25], [0.25, 0.25, 3]]) chol = tf.linalg.cholesky(true_cov) target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol) # Assume that the state is passed as a list of tensors `x` and `y`. # Then the target function is defined as follows: def target_fn(x, y): # Stack the input tensors together z = tf.concat([x, y], axis=-1) - true_mean return target.log_prob(z) sample_shape = [3, 5] state = [tf.ones(sample_shape + [2], dtype=dtype), tf.ones(sample_shape + [1], dtype=dtype)] fn_val, grads = tfp.math.value_and_gradient(target_fn, state) # We can either pass the `sample_shape` of the `state` or not, which impacts # computational speed of `diag_jacobian` _, diag_jacobian_shape_passed = diag_jacobian( xs=state, ys=grads, sample_shape=tf.shape(fn_val)) _, diag_jacobian_shape_none = diag_jacobian( xs=state, ys=grads) diag_jacobian_shape_passed_ = sess.run(diag_jacobian_shape_passed) diag_jacobian_shape_none_ = sess.run(diag_jacobian_shape_none) print('hessian computed through `diag_jacobian`, sample_shape passed: ', np.concatenate(diag_jacobian_shape_passed_, -1)) print('hessian computed through `diag_jacobian`, sample_shape skipped', np.concatenate(diag_jacobian_shape_none_, -1)) Raises | |
|---|---|
ValueError | if lists xs and ys have different length or both ys and fn are None, or fn is None in the eager execution mode. |
View source on GitHub