Clear Filters
Clear Filters

Train VAE for RGB image generation

38 views (last 30 days)
debojit sharma
debojit sharma on 17 Jun 2023
Commented: Ben on 26 Jun 2023
I am trying to implement the code to train VAE for image generation given in the following link using my own dataset of RGB images of size 200*200.
I am getting the following errors in the Train model part:
The code of VAE in the above link is using MNIST dataset images as input to encoder of VAE and it is being said that the decoder of VAE will output an image of size 28-by-28-by-1. But I am trying to generate RGB image of size 200*200 by training this VAE model given in the link. So, my input image is a RGB image of size 200*200. I am getting the above mentioned error in the train model part. I am not able to resolve these errors. So, somebody please kindly guide me regarding what changes I will have to make in this code so that I can train these VAE model to generate RGB image of size 200*200. I will be thankful to you.

Answers (1)

Ben on 23 Jun 2023
The error is stating that the VAE outputs Y and the training images T are different sizes when you try to compute the mean-squared error mse loss between them.
Note that the VAE output size is determined by both the input image sizes and the layers in the network. I think there are a few things to check first:
  1. Make sure the output of the VAE has the same number of channels as the target images - for the MNIST example this will be 1, for RGB images it would be 3.
  2. Make sure the VAE output has the same height and width as the target images, 200x200. The VAE in the example downsamples the spatial sizes by using Stride=2 in the two convolution layers of the encoder, then upsamples again using Stride=2 with the two transposed convolution layers in the decoder. You have to be careful to ensure the decoder upsamples back to the original image size.
  3. Ensure the custom projectAndReshapeLayer is configured for your encoder latent size - in the example the projectionSize is [7,7,64] but for the same network on 200x200 images I would expect this needs to be [50,50,64].
If you can't get this working could you let us know if you have modified the encoder or decoder layers at all? If not can you ensure that all the images input to the VAE have the same size?
Hope that helps,
Aniketh on 25 Jun 2023
Have you tried printing the dimensions of the arguments being passed to the loss calculator dlfeval(), the upsampling, downsampling and projection corrections pointed out by Ben should solve your issue, however the exact difference in the output dimensions of the layersE and layersD should point you to the correct direction.
Ben on 26 Jun 2023
@debojit sharma - I've written some code showing how this could work for 200x200x3 images. I noticed the main issue I had was that numInputChannels in the example is computed wrong, so perhaps that is the issue you are having. I fixed that in the below:
numLatentChannels = 16;
imageSize = [200 200 3]; % updated for 200x200x3 images
layersE = [
projectionSize = [50 50 64]; % recomputed manually
numInputChannels = imageSize(3); % fixed from the example.
layersD = [
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
% Test forward
batchSize = 5;
imageBatch = dlarray(randn([imageSize,batchSize]),"SSCB");
latentBatch = forward(netE,imageBatch);
generatedBatch = forward(netD,latentBatch);
% Test loss and gradients
if canUseGPU
netE = dlupdate(@gpuArray,netE);
netD = dlupdate(@gpuArray,netD);
imageBatch = gpuArray(imageBatch);
[loss,gradE,gradD] = dlfeval(@modelLoss,netE,netD,imageBatch);
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
[Z,mu,logSigmaSq] = forward(netE,X);
% Forward through decoder.
Y = forward(netD,Z);
% Calculate loss and gradients.
loss = elboLoss(Y,X,mu,logSigmaSq);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
function loss = elboLoss(Y,T,mu,logSigmaSq)
% Reconstruction loss.
reconstructionLoss = mse(Y,T);
% KL divergence.
KL = -0.5 * sum(1 + logSigmaSq - mu.^2 - exp(logSigmaSq),1);
KL = mean(KL);
% Combined loss.
loss = reconstructionLoss + KL;
Hope that helps.

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!