I'm training a variational autoencoder on CelebA dataset using TensorFlow.keras
The problem I'm facing is that the generated images are not diverse enough and look kinda bad.
(new) Example:
What I think:
- it's bad because the reconstruction and KL loss are unbalanced.
- I read this question and followed its solution - read about KL annealing and tried to implement it myself but didn't work.
Note:
It's my first time working with autoencoders so maybe I missed something obvious.
it would be super appreciated if you could give a programmatic/technical solution and not a theoretical one with equations and complicated math
The loss function:
def r_loss(self, y_true, y_pred): return K.mean(K.square(y_true - y_pred), axis=[1, 2, 3]) def kl_loss(self, y_true, y_pred): return -0.5 * K.sum(1 + self.sd_layer - K.square(self.mean_layer) - K.exp(self.sd_layer), axis=1) def total_loss(self, y_true, y_pred): return K.mean(self.r_loss(y_true, y_pred) + self.kl_loss(y_true, y_pred)) The encoder:
def build_encoder(self): conv_filters = [32, 64, 64, 64] conv_kernel_size = [3, 3, 3, 3] conv_strides = [2, 2, 2, 2] # Number of Conv layers n_layers = len(conv_filters) # Define model input x = self.encoder_input # Add convolutional layers for i in range(n_layers): x = Conv2D(filters=conv_filters[i], kernel_size=conv_kernel_size[i], strides=conv_strides[i], padding='same', name='encoder_conv_' + str(i) )(x) if self.use_batch_norm: # True x = BatchNormalization()(x) x = LeakyReLU()(x) if self.use_dropout: # False x = Dropout(rate=0.25)(x) # Required for reshaping latent vector while building Decoder self.shape_before_flattening = K.int_shape(x)[1:] x = Flatten()(x) self.mean_layer = Dense(self.encoder_output_dim, name='mu')(x) self.sd_layer = Dense(self.encoder_output_dim, name='log_var')(x) # Defining a function for sampling def sampling(args): mean_mu, log_var = args epsilon = K.random_normal(shape=K.shape(mean_mu), mean=0., stddev=1.) return mean_mu + K.exp(log_var / 2) * epsilon # Using a Keras Lambda Layer to include the sampling function as a layer # in the model encoder_output = Lambda(sampling, name='encoder_output')([self.mean_layer, self.sd_layer]) return Model(self.encoder_input, encoder_output, name="VAE_Encoder") The decoder:
def build_decoder(self): conv_filters = [64, 64, 32, 3] conv_kernel_size = [3, 3, 3, 3] conv_strides = [2, 2, 2, 2] n_layers = len(conv_filters) # Define model input decoder_input = self.decoder_input # To get an exact mirror image of the encoder x = Dense(np.prod(self.shape_before_flattening))(decoder_input) x = Reshape(self.shape_before_flattening)(x) # Add convolutional layers for i in range(n_layers): x = Conv2DTranspose(filters=conv_filters[i], kernel_size=conv_kernel_size[i], strides=conv_strides[i], padding='same', name='decoder_conv_' + str(i) )(x) # Adding a sigmoid layer at the end to restrict the outputs # between 0 and 1 if i < n_layers - 1: x = LeakyReLU()(x) else: x = Activation('sigmoid')(x) # Define model output self.decoder_output = x return Model(decoder_input, self.decoder_output, name="VAE_Decoder") The combined model:
def build_autoencoder(self): self.encoder = self.build_encoder() self.decoder = self.build_decoder() # Input to the combined model will be the input to the encoder. # Output of the combined model will be the output of the decoder. self.autoencoder = Model(self.encoder_input, self.decoder(self.encoder(self.encoder_input)), name="Variational_Auto_Encoder") self.autoencoder.compile(optimizer=self.adam_optimizer, loss=self.total_loss, metrics=[self.total_loss], experimental_run_tf_function=False) self.autoencoder.summary() EDIT:
the latent size is 256 and the sample method is as follows;
def generate(self, image=None): if not os.path.exists(self.sample_dir): os.makedirs(self.sample_dir) if image is None: img = np.random.normal(size=(9, self.encoder_output_dim)) prediction = self.decoder.predict(img) op = np.vstack((np.hstack((prediction[0], prediction[1], prediction[2])), np.hstack((prediction[3], prediction[4], prediction[5])), np.hstack((prediction[6], prediction[7], prediction[8])))) print(op.shape) op = cv2.resize(op, (self.input_size * 9, self.input_size * 9), interpolation=cv2.INTER_AREA) op = cv2.cvtColor(op, cv2.COLOR_BGR2RGB) cv2.imshow("generated", op) cv2.imwrite(self.sample_dir + "generated" + str(r(0, 9999)) + ".jpg", (op * 255).astype("uint8")) else: img = cv2.imread(image, cv2.IMREAD_UNCHANGED) img = cv2.resize(img, (self.input_size, self.input_size), interpolation=cv2.INTER_AREA) img = img.astype("float32") img = img / 255 prediction = self.autoencoder.predict(img.reshape(1, self.input_size, self.input_size, 3)) img = cv2.resize(prediction[0][:, :, ::-1], (960, 960), interpolation=cv2.INTER_AREA) cv2.imshow("prediction", img) cv2.imwrite(self.sample_dir + "generated" + str(r(0, 9999)) + ".jpg", (img * 255).astype("uint8")) 