8
$\begingroup$

At train time, the KL divergence term drives $Q(z=\mu(X)+\epsilon \times\Sigma(X) | X)$ toward $N(0,I)$, where $\epsilon\sim N(0,I)$. It can't drive $Q(z|X)$ to exactly $N(0,I)$ because the reconstruction loss of the encoder/decoder pair would explode (the $Q(z|X)$ network would destroy all information about $X$).

Therefore when we run the system at "generator time" using only the decoder and sampling $z$ from $N(0,I)$, won't this poorly represent the training set because $Q(z|X)$ over the training set is too different from $N(0,I)$? For example $Q(z|X)$ might look like $N(0,2\times I)$, or it might even have some nonlinear hard-to-sample shape.

edit1: To clarify and ask a more well defined question: If the distribution of Q(z|X) is significantly different from N(0,I), why do we sample from N(0,I) when generating samples? Won't this yield samples that poorly represent the training set?

edit2: Even more clarification. This image shows the 10 MNIST digits mapped into a 2D latent space. You can see it does not match $\mathcal{N}(0,I)$. This image is based on 2 latent dimensions and 2 hidden layer encoder, each with 500 nodes. enter image description here

$\endgroup$

4 Answers 4

5
$\begingroup$

It's ok for $Q(z|X)$ to be different from $\mathcal{N}(0, I)$, because when we sample from the VAE, we're not trying to reconstruct $X$ anymore. Instead, we're trying to sample some $X \sim \mathcal{X}$ where $\mathcal{X}$ is the distribution of all images in the dataset.

Imagine of the latent space were actually a uniform distribution over the interval $(0,10)$, and we were autoencoding MNIST digits. Suppose that images with 1 in them happened to have $Q(z|X)$ distributed around $(0,1)$, images with 2 happened to be around $(1,2)$, etc.

Then for any particular $X$, $Q(z|X)$ is not close to matching the uniform distribution. However, as long as the mixture $\frac{1}{n} \sum_i Q(z|X_i)$ reasonably covers and matches the uniform distribution, it's reasonable to sample $z \sim U(0,10)$ and then run the decoder, because the $z$ you got is probably close to $\mu(X)$ for some $X$.

edit: To answer the question of why we might expect the mixture of $Q(z|X)$ to be approximately $\mathcal{N}(0,I)$, note that we can decompose $P(z) = \int P(z|X) p(X) dz = E\left[ P(z|X) \right]$. By definition, $z \sim \mathcal{N}(0,I)$. However, when we approximate $P(z|X)$ with the encoder $Q(z|X)$, we end up with something slightly different.

Minimizing the VAE loss is equivalent to maximizing $\log P(X) - \mathcal{D}_\text{KL}(Q(z|X) || P(z|X))$. So we're simultaneously maximizing the log likelihood of the data while also encouraging $Q(z|X)$ to be as close to $P(z|X)$ as possible. As a result, we should end up with very close to $\mathcal{N}(0,I)$.

$\endgroup$
8
  • 1
    $\begingroup$ But we can compute the KL divergence between N(0,I) and Q(z|X) in closed form if both are Gaussian, which is what we typically do. Once trained, the distribution of Q(z|X) will look like a quilt of 10 Gaussians for MNIST (analogous to the picture you described for MNIST on the uniform distribution), and the mixture will be bigger than N(0,I). I guess we have to manually inspect this mixture distribution in the latent space and try to come up with a way to sample it. On the other hand, this blog claims that we should sample from N(0,I) at generator time, not Q(z|X). $\endgroup$ Commented Mar 12, 2018 at 23:17
  • $\begingroup$ the blog: towardsdatascience.com/… $\endgroup$ Commented Mar 12, 2018 at 23:18
  • $\begingroup$ I edited my answer to answer your question on why you might expect the "quilt" to be approximately N(0,I) $\endgroup$ Commented Mar 13, 2018 at 18:25
  • $\begingroup$ @shimao I think it is misleading to say $KL(Q(z|x) \Vert P(z))$ is a regularizer term in some tutorials, it is simply a term in the lower bound. Moreover, for generative purposes, the only thing that we care about is $P(x) = \int P(x|z)P(z)$. $P(z)$ is the prior and we get $P(x|z)$ by training, so we don't actually care about $Q(z|x)$. $\endgroup$ Commented May 12, 2018 at 5:23
  • $\begingroup$ @me_Tchaikovsky i'm not exactly sure what you're getting at -- but there is indeed value in knowing how close $Q(z|x)$ comes to $P(z|x)$, because that determines whether it makes sense to sample from VAEs in the straightforward way (sampling the latent space and then running the decoder) $\endgroup$ Commented May 12, 2018 at 5:28
2
$\begingroup$

As a complement to shimao's answer, which boils down to "the mixture of $q(z|x)$ is reasonably close to the $N(0,I)$ distribution", it is worth noting that nowadays, at least in in the context of using VAEs for generative modeling of images, it is common practice to train another model (after the VAE training is done) which learns the actual $q(z)$ distribution rather than supposing it's "close enough" to the fixed prior.

This training is typically carried out via an autoregressive neural network, which allows in turn to easily sample from this "true" $q(z)$ once that network is trained. This is the process used in VQ-VAE, VQ-VAE-2 and DALL-E for example.

To quote from the VQ-VAE-2 paper (section 3.2) :

Fitting prior distributions using neural networks from training data has become common practice, as it can significantly improve the performance of latent variable models. This procedure also reduces the gap between the marginal posterior and the prior. Thus, latent variables sampled from the learned prior at test time are close to what the decoder network has observed during training which results in more coherent outputs.

$\endgroup$
1
$\begingroup$

The answer is two-fold:

(1) The encoder network needs to be expressive enough (wide enough and deep enough) to be able to map the nonlinear input space to something close to $\mathcal{N}(0,I)$.

(2) In addition to (1) (I added a 3rd hidden layer to the MNIST example I described in the question), when I increase the number of latent dimensions, I observe that the mapping of the training data into the latent space becomes closer to $\mathcal{N}(0,I)$. In hindsight this is not super surprising, because the system is able to store information across more dimensions, so each individual latent dimension can get closer to $\mathcal{N}(0,I)$.

$\endgroup$
0
$\begingroup$

Also, we sample $z$ from the prior and decode it because we made the assumption that $p_{\theta}(x) = \int p_{\theta}(x|z)p(z) dz$. Therefore, sampling from $p_{\theta}(x)$ is equivalent to sampling from the joint $p_{\theta}(x,z)$ then discarding $z$.

$\endgroup$

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.