2
$\begingroup$

I am training a VAE architecture on microscopy images. Dataset of 1000 training images, 253 testing images. Images are resized to 128x128 input or 256x256 input from original resolution which is around 1024x720. The implementation here is a 256x256 input. Then, I put them through my VAE:

class SEMVAE(nn.Module): def __init__(self, latent_dim=64): super().__init__() # ----------------- Encoder ----------------- self.enc_conv1 = nn.Conv2d(1, 32, 3, stride=2, padding=1) # 256 -> 128 self.enc_bn1 = nn.BatchNorm2d(32) self.enc_conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1) # 128 -> 64 self.enc_bn2 = nn.BatchNorm2d(64) self.enc_conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1) # 64 -> 32 self.enc_bn3 = nn.BatchNorm2d(128) self.enc_conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)# 32 -> 16 self.enc_bn4 = nn.BatchNorm2d(128) self.dropout = nn.Dropout(0.05) self.flatten = nn.Flatten() self.fc_mu = nn.Linear(128*16*16, latent_dim) self.fc_logvar= nn.Linear(128*16*16, latent_dim) self.fc_dec = nn.Linear(latent_dim, 128*16*16) # ----------------- Decoder ----------------- self.dec_deconv1 = nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1) # 16 -> 32 self.dec_bn1 = nn.BatchNorm2d(128) self.dec_deconv2 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1) # 32 -> 64 self.dec_bn2 = nn.BatchNorm2d(64) self.dec_deconv3 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1) # 64 -> 128 self.dec_bn3 = nn.BatchNorm2d(32) self.dec_deconv4 = nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1) # 128 -> 256 def encode(self, x): x = F.leaky_relu(self.enc_bn1(self.enc_conv1(x)), 0.1) x = self.dropout(x) x = F.leaky_relu(self.enc_bn2(self.enc_conv2(x)), 0.1) x = self.dropout(x) x = F.leaky_relu(self.enc_bn3(self.enc_conv3(x)), 0.1) x = self.dropout(x) x = F.leaky_relu(self.enc_bn4(self.enc_conv4(x)), 0.1) x = self.dropout(x) x = self.flatten(x) mu = self.fc_mu(x) logvar = self.fc_logvar(x) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): x = self.fc_dec(z).view(-1, 128, 16, 16) x = F.leaky_relu(self.dec_bn1(self.dec_deconv1(x)), 0.1) x = F.leaky_relu(self.dec_bn2(self.dec_deconv2(x)), 0.1) x = F.leaky_relu(self.dec_bn3(self.dec_deconv3(x)), 0.1) x = torch.sigmoid(self.dec_deconv4(x)) # [0,1] for BCE return x def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) recon = self.decode(z) return recon, mu, logvar 

I am using basic KL Loss and Reconstruction loss in my loss function shown below:

class VAELoss(nn.Module): def __init__(self, recon_type = 'mse', beta = 1.0): super().__init__() self.recon_type = recon_type self.beta = beta def forward(self, recon, x, mu, logvar): ## Reconstruction Loss ## if self.recon_type == 'mse': recon_loss = F.mse_loss(recon, x, reduction='sum') elif self.recon_type == 'bce': recon_loss = F.binary_cross_entropy(recon, x, reduction='sum') elif self.recon_type == 'l1': recon_loss = F.smooth_l1_loss(recon, x, reduction='sum') else: raise ValueError("recon_type must be bce, mse, or l1") ## KL Divergence ## kl_loss = - 0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp()) ## Loss ## loss = recon_loss + self.beta * kl_loss return loss, recon_loss, kl_loss, self.beta * kl_loss 

'''

Here is the training loop I use:

def train_vae(model, train_loader, test_loader, epochs=50, lr=1e-3, recon_type='mse', beta = 1, device = 'cuda'): # Loading Model and Optimizer model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # VAELoss VAELoss_Calculator = VAELoss(recon_type = recon_type, beta = beta) # Tracking Loss train_loss_history, test_loss_history = [], [] train_recon_loss, train_kl_loss, train_weighted_kl_loss = [], [], [] # Training Loop for epoch in range(epochs): model.train() total_epoch_loss = 0 total_epoch_kl_loss = 0 total_epoch_recon_loss = 0 total_epoch_weighted_kl_loss = 0 for imgs in train_loader: batch_size = imgs.size(0) imgs = imgs.to(device) optimizer.zero_grad() recon, mu, logvar = model(imgs) loss, recon_loss, kl_loss, weighted_kl_loss = VAELoss_Calculator(recon, imgs, mu, logvar) total_epoch_loss += loss.item() total_epoch_recon_loss += recon_loss.item() total_epoch_kl_loss += kl_loss.item() total_epoch_weighted_kl_loss += weighted_kl_loss.item() loss.backward() optimizer.step() n = len(train_loader.dataset) avg_train_loss = total_epoch_loss / n avg_recon_loss = total_epoch_recon_loss / n avg_kl_loss = total_epoch_kl_loss / n avg_weighted_kl_loss = total_epoch_weighted_kl_loss / n train_loss_history.append(avg_train_loss) train_recon_loss.append(avg_recon_loss) train_kl_loss.append(avg_kl_loss) train_weighted_kl_loss.append(avg_weighted_kl_loss) # Model Evaluation model.eval() test_loss = 0 with torch.no_grad(): for test_imgs in test_loader: test_batch_size = test_imgs.size(0) test_imgs = test_imgs.to(device) test_recon, mu, logvar = model(test_imgs) loss, _, _, _ = VAELoss_Calculator(test_recon, test_imgs, mu, logvar) test_loss += loss.item() n_test = len(test_loader.dataset) avg_test_loss = test_loss / n_test test_loss_history.append(avg_test_loss) print(f"Epoch [{epoch+1}/{epochs}] " f"Train ELBO: {avg_train_loss:.4f} | Recon: {avg_recon_loss:.4f} | " f"KL: {avg_kl_loss:.4f} | KL (Weighted): {avg_weighted_kl_loss:.4f} | Test ELBO: {avg_test_loss:.4f}") # ---- Plot all losses ---- plt.figure(figsize=(9, 6)) plt.plot(train_loss_history, label='Train Total (ELBO)') plt.plot(train_recon_loss, label='Train Reconstruction') plt.plot(train_kl_loss, label='Train KL Divergence') plt.plot(train_weighted_kl_loss, label='Train KL Divergence (Weighted)') plt.plot(test_loss_history, label='Test Total (ELBO)') # plt.plot(train_grad_losses, label='Gradient Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.title('VAE Training Loss Components') plt.show() 

After using "BCE" reconstruction, batch size of 32, learning rate of 1e-4, beta value of 1, latent dimension of 64 and training for 150 epochs, I obtain the following results:

Epoch [1/100] Train ELBO: 43814.1424 | Recon: 43757.6568 | KL: 564.8580 | KL (Weighted): 56.4858 | Test ELBO: 40342.6853 Epoch [2/100] Train ELBO: 38939.3434 | Recon: 38859.7775 | KL: 795.6566 | KL (Weighted): 79.5657 | Test ELBO: 37637.1522 Epoch [3/100] Train ELBO: 37398.1756 | Recon: 37349.1395 | KL: 490.3624 | KL (Weighted): 49.0362 | Test ELBO: 37155.4560 Epoch [4/100] Train ELBO: 36790.2883 | Recon: 36751.9760 | KL: 383.1231 | KL (Weighted): 38.3123 | Test ELBO: 37103.1546 Epoch [5/100] Train ELBO: 36300.4348 | Recon: 36265.6324 | KL: 348.0279 | KL (Weighted): 34.8028 | Test ELBO: 36329.0702 Epoch [6/100] Train ELBO: 35811.1598 | Recon: 35778.0838 | KL: 330.7587 | KL (Weighted): 33.0759 | Test ELBO: 36197.2446 Epoch [7/100] Train ELBO: 35541.3350 | Recon: 35506.4722 | KL: 348.6268 | KL (Weighted): 34.8627 | Test ELBO: 35781.9881 Epoch [8/100] Train ELBO: 35109.2664 | Recon: 35070.1327 | KL: 391.3363 | KL (Weighted): 39.1336 | Test ELBO: 35419.9568 Epoch [9/100] Train ELBO: 34904.9988 | Recon: 34865.2861 | KL: 397.1254 | KL (Weighted): 39.7125 | Test ELBO: 35368.0089 Epoch [10/100] Train ELBO: 34677.8689 | Recon: 34636.5064 | KL: 413.6217 | KL (Weighted): 41.3622 | Test ELBO: 35279.6201 Epoch [11/100] Train ELBO: 34567.9672 | Recon: 34524.4573 | KL: 435.0997 | KL (Weighted): 43.5100 | Test ELBO: 35029.4027 Epoch [12/100] Train ELBO: 34412.6462 | Recon: 34365.4724 | KL: 471.7382 | KL (Weighted): 47.1738 | Test ELBO: 34983.6670 Epoch [13/100] Train ELBO: 34222.2194 | Recon: 34172.9446 | KL: 492.7466 | KL (Weighted): 49.2747 | Test ELBO: 34911.6089 Epoch [14/100] Train ELBO: 34152.5446 | Recon: 34101.2380 | KL: 513.0639 | KL (Weighted): 51.3064 | Test ELBO: 34925.7342 Epoch [15/100] Train ELBO: 34086.8908 | Recon: 34034.6688 | KL: 522.2204 | KL (Weighted): 52.2220 | Test ELBO: 34980.5862 Epoch [16/100] Train ELBO: 33968.7971 | Recon: 33914.5703 | KL: 542.2691 | KL (Weighted): 54.2269 | Test ELBO: 34592.5254 Epoch [17/100] Train ELBO: 33861.9486 | Recon: 33812.2085 | KL: 497.4020 | KL (Weighted): 49.7402 | Test ELBO: 34556.8503 Epoch [18/100] Train ELBO: 33706.1630 | Recon: 33656.8590 | KL: 493.0414 | KL (Weighted): 49.3041 | Test ELBO: 34511.0188 Epoch [19/100] Train ELBO: 33802.2230 | Recon: 33750.2045 | KL: 520.1836 | KL (Weighted): 52.0184 | Test ELBO: 34526.2105 Epoch [20/100] Train ELBO: 33730.4231 | Recon: 33680.2796 | KL: 501.4360 | KL (Weighted): .... Epoch [85/100] Train ELBO: 32994.7386 | Recon: 32951.0619 | KL: 436.7674 | KL (Weighted): 43.6767 | Test ELBO: 34010.8639 Epoch [86/100] Train ELBO: 32933.0684 | Recon: 32889.8569 | KL: 432.1167 | KL (Weighted): 43.2117 | Test ELBO: 34042.3177 Epoch [87/100] Train ELBO: 32934.8738 | Recon: 32893.2703 | KL: 416.0365 | KL (Weighted): 41.6036 | Test ELBO: 33952.0976 Epoch [88/100] Train ELBO: 32993.6404 | Recon: 32950.2150 | KL: 434.2497 | KL (Weighted): 43.4250 | Test ELBO: 33978.0052 Epoch [89/100] Train ELBO: 32914.6798 | Recon: 32871.9780 | KL: 427.0166 | KL (Weighted): 42.7017 | Test ELBO: 33968.7740 Epoch [90/100] Train ELBO: 32886.1937 | Recon: 32844.8323 | KL: 413.6142 | KL (Weighted): 41.3614 | Test ELBO: 34007.8426 Epoch [91/100] Train ELBO: 32932.2369 | Recon: 32890.9989 | KL: 412.3811 | KL (Weighted): 41.2381 | Test ELBO: 33970.2540 Epoch [92/100] Train ELBO: 32997.2572 | Recon: 32955.8671 | KL: 413.9017 | KL (Weighted): 41.3902 | Test ELBO: 33961.6801 Epoch [93/100] Train ELBO: 32911.2220 | Recon: 32868.1441 | KL: 430.7813 | KL (Weighted): 43.0781 | Test ELBO: 34012.0193 Epoch [94/100] Train ELBO: 32875.5987 | Recon: 32833.4682 | KL: 421.3034 | KL (Weighted): 42.1303 | Test ELBO: 33971.3513 Epoch [95/100] Train ELBO: 32991.2952 | Recon: 32950.2401 | KL: 410.5496 | KL (Weighted): 41.0550 | Test ELBO: 33978.4424 Epoch [96/100] Train ELBO: 32938.0615 | Recon: 32893.7618 | KL: 442.9997 | KL (Weighted): 44.3000 | Test ELBO: 33971.5015 Epoch [97/100] Train ELBO: 32930.5121 | Recon: 32888.0903 | KL: 424.2162 | KL (Weighted): 42.4216 | Test ELBO: 34008.6855 Epoch [98/100] Train ELBO: 32919.0241 | Recon: 32876.8566 | KL: 421.6755 | KL (Weighted): 42.1675 | Test ELBO: 34012.3174 Epoch [99/100] Train ELBO: 32906.9627 | Recon: 32866.2580 | KL: 407.0465 | KL (Weighted): 40.7046 | Test ELBO: 34066.2537 Epoch [100/100] Train ELBO: 32921.0352 | Recon: 32879.3460 | KL: 416.8920 | KL (Weighted): 41.6892 | Test ELBO: 34039.4429 

And the final reconstruction quality is as follows: VAE Reconstruction Results (bottom) compared to original images in the dataset (top)

The top row shows the original image while the bottom shows the reconstruction. Reconstruction accuracy is poor and does not capture final details. Rather gets the larger objects in the images but the quality is very poor and blurry. Reconstructions seem noisy and don't always get the correct contrast and level of pixel intensity correct in the images.

I understand that VAEs reconstructions and image generation is inherently blurry but from implementations online, I see the that the accuracy (picking up small features in an image) and relative quality is not as bad as the reconstruction results for my dataset.

How do I improve the reconstruction quality and accuracy strong enough for finer-level details of my images once my VAE is trained? What would be the culprit for my poor reconstruction results - image input resolution, architecture, loss function, training parameters or a combination of such?

$\endgroup$

0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.