Clear Filters
Clear Filters

A saved GAN trained model for image generation does not generate the same accurate images when GPU is reset

2 views (last 30 days)
When I train the flower image generation example, everything seems to go well as long as the GPU memory keeps the parameters loaded. I obtain images of easily recognizable flowers, as shown in the example. However, if I save the complete training workspace using the 'save' command (for example, save('GANWorkspacefile.mat')), which also includes netG, then clear the GPU memory (reset), and subsequently load the previous workspace (load('GANWorkspacefile.mat')), the images generated with 'predict' end up blurry—no flowers at all—resembling the ones generated at the beginning of training. The same issue occurs when I transfer the saved workspace and load it on another machine with the same version of MATLAB (R2022b). It seems that something is missing when loading the workspace variables that prevents generating the images in the same way as they are generated just at the end of training. I would appreciate it if someone has any idea of what I'm doing wrong could comment on it.
Thank you.

Accepted Answer

Ben on 8 Apr 2024
Moved: Walter Roberson on 23 Apr 2024
I believe this is due to a bug in the R2022b version of the custom projectAndReshapeLayer attached to the example. In particular in the initialize method the layer.Weights and layer.Bias are replaced with their initial values even if they already have trained values. The initialize method is called when you load the saved generator network.
You can update the initialize method in the custom layer to the following:
function layer = initialize(layer,layout)
% layer = initialize(layer,layout) initializes the layer
% learnable parameters.
% Inputs:
% layer - Layer to initialize
% layout - Data layout, specified as a
% networkDataLayout object
% Outputs:
% layer - Initialized layer
% Layer output size.
outputSize = layer.OutputSize;
% Initialize fully connect weights.
if isempty(layer.Weights)
% Find number of channels.
idx = finddim(layout,"C");
numChannels = layout.Size(idx);
% Initialize using Glorot.
sz = [prod(outputSize) numChannels];
numOut = prod(outputSize);
numIn = numChannels;
layer.Weights = initializeGlorot(sz,numOut,numIn);
% Initialize fully connect bias.
if isempty(layer.Bias)
% Initialize with zeros.
layer.Bias = initializeZeros([prod(outputSize) 1]);
The if isempty(layer.Weights) and if isempty(layer.Bias) checks ensure that the trained projection is not lost on load.

More Answers (0)




Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!