Setting
I am training a Variational Autoencoder (VAE) on the CIFAR10 dataset, which has RGB colors. The VAE uses convolution and transposed convolution layers as well as linear layers to encoder and decode the data. The final prediction is done with a tanh, so the output is in the range (-1,1).
The loss is calculated as
def elbo_loss( x: torch.Tensor, x_hat: torch.Tensor, mus: torch.Tensor, sigmas: torch.Tensor ) -> torch.Tensor: """Calculates the ELBO loss. Parameters ---------- x : torch.Tensor Z-score normalized input images. x_hat : torch.Tensor Reconstruction of the input images. mus : torch.Tensor Mu values of the latent space. sigmas : torch.Tensor Sigma values of the latent space. """ mse = torch.nn.functional.mse_loss(x_hat, x, reduction='none') # We want the distribution of the latent space to be as close as possible to a standard normal distribution. z_dist = torch.distributions.Normal(mus, sigmas) target_dist = torch.distributions.Normal(torch.zeros_like(mus), torch.ones_like(sigmas)) d_kl = torch.distributions.kl_divergence(z_dist, target_dist) beta = 1 elbo = mse + beta * d_kl return elbo.mean() So the loss penalizes bad reconstructions as well as deviations of the latent distributions from the standard normal distribution. It is based on the original VAE paper.
Problem
I find the pre-processing of the images worrysome. So far, I almost always used z-score normalization for image processing, i.e., x = (x - mean) /std. However, in case of this VAE part of the loss calculation is the MSE between the reconstructed image and the (normalized) input. Due to the tanh, the output is in range (-1, 1), while the input is centered around zero (same as tanh) with an different range of values. I imagine this is problematic for the optimization and the quality of the training.
Another problem is that for the second part of the loss, we calculate the KL-Divergence between the latent distribution and the standard normal distribution, which has zero unit standard deviation. If the input images have a different range of pixel values, I imagine this to be suboptimal regarding the weight that the KL-Divergence has within the overall loss.
Questions
- Is the difference between the range of values of the pre-processed input images and the output of the tanh layer really problematic?
- If so, how to pre-process the images (or post-process the output?) to fit the tanh output layer?
- Should I change the output layer from tanh to something else?
- Does the weight that the two loss terms have change depending on the range of values of the input images? How to account for that?
I guess parts of the answers generalize to other neural networks or use cases as well. I have found this topic difficult when working with Transformers on semantic segmentation as well.