How do I debug a convolutional neural network with a custom training loop that is not learning?

11 views (last 30 days)
Hello! I have been trying to design a CNN for image analysis. The CNN is training on simulated images of size 132 x 132 x 6 (spatial, spatial, channel). The simulated images are computed using a bi-exponential equation of the form . In the CNN, the input images are forward passed through the network to generate four feature maps (, , , and ) which then are scaled and used to calculated the predicted image signals, . The predicted image signals are then compared to the input image signals S using the mean squared error loss function and the gradients are updated. The problem is the network is not learning. After some inspection I noticed that the gradients are all going to zero, however I'm not sure how to fix this problem. I have tried changing the learning rate, adam v. sgdm optimizers, and the mini-batch size, however I encounter the same problem. Any advice/feedback is greatly appreciated!
Also, I have removed parts of the code to make it as simple as possible for the time being, but will add in validation and testing loops.
% Image Parameters
rng(1);
imageSize = [132, 132];
bValue = [50 100 150 250 500 800]; % non-zero diffusion weightings
numbVal = length(bValue);
minDf = 0.0017;
maxDf = 0.107;
minf = 0.1;
maxf = 0.5;
minDs = 0.0003;
maxDs = 0.0017;
DfSim = minDf + (maxDf-minDf).*rand(10,1);
fSim = minf + (maxf-minf).*rand(10,1);
DsSim = minDs + (maxDs-minDs).*rand(10,1);
numIm = length(DfSim) * length(fSim) * length(DsSim); % number of 132 x 132 x 6 images
tissue = ones(imageSize);
bValue = reshape(bValue, [1,1,numbVal]); % Reshape bValue for matrix operation
% Prepare a directory to store the simulated images
outputDir = fullfile(tempdir, 'SimulatedDW-MRI');
if ~exist(outputDir, 'dir')
mkdir(outputDir);
end
% Initialize a table to store the image file paths and parameters
fprintf('Total simulated images: %d\n', numIm);
imageData = table('Size', [0 4],...
'VariableTypes', {'cell', 'double', 'double', 'double'},...
'VariableNames', {'imageFilePath', 'DfSim', 'fSim', 'DsSim'});
% Start the timer
tic;
% Loop through each combination of DfSim, fSim, and DsSim
imageIdx = 0;
S = zeros([imageSize length(bValue) numIm]);
for DfIdx = 1:length(DfSim)
for fIdx = 1:length(fSim)
for DsIdx = 1:length(DsSim)
imageIdx = imageIdx + 1;
% Calculate the diffusion signal for each b value for each channel
S(:,:,:,imageIdx) = tissue .* ((fSim(fIdx) .* exp(-bValue .* DfSim(DfIdx))) + ((1-fSim(fIdx)) .* exp(-bValue .* DsSim(DsIdx))));
% Track progress
fprintf('Processing image %d out of %d\n', imageIdx, numIm);
end
end
end
for imageIdx = 1:numIm
fileName = sprintf('%s/image%d.mat', outputDir, imageIdx); % Write the image to a .mat file
S_single = S(:,:,:,imageIdx);
save(fileName, 'S_single');
DfIdx = ceil(imageIdx / (length(fSim)*length(DsSim)));
fIdx = ceil((imageIdx - (DfIdx-1)*length(fSim)*length(DsSim)) / length(DsSim));
DsIdx = imageIdx - (DfIdx-1)*length(fSim)*length(DsSim) - (fIdx-1)*length(DsSim);
imageData(imageIdx, :) = {fileName, DfSim(DfIdx), fSim(fIdx), DsSim(DsIdx)};
fprintf('Saving image %d out of %d\n', imageIdx, numIm);
end
elapsedTime = toc;
fprintf('Computation time: %.2f seconds\n', elapsedTime);
%% Split data in training, validation, and testing sets
trainSplit = 0.8;
valSplit = 0.1;
testSplit = 0.1;
n = height(imageData);
idx = randperm(n);
trainIdx = idx(1:round(trainSplit*n));
valIdx = idx(round(trainSplit*n)+1:round((trainSplit+valSplit)*n));
testIdx = idx(round((trainSplit+valSplit)*n)+1:end);
imageDataTrain = imageData(trainIdx, :);
imageDataVal = imageData(valIdx, :);
imageDataTest = imageData(testIdx, :);
trainImds = fileDatastore(imageDataTrain.imageFilePath, ...
'ReadFcn' , @(filename) double(load(filename).S_single), ...
'FileExtensions', '.mat');
trainLabelsDatastore = arrayDatastore(imageDataTrain{:, {'DfSim', 'fSim', 'DsSim'}});
trainCombinedDatastore = combine(trainImds, trainLabelsDatastore);
valImds = fileDatastore(imageDataVal.imageFilePath, ...
'ReadFcn' , @(filename) double(load(filename).S_single), ...
'FileExtensions', '.mat');
valLabelsDatastore = arrayDatastore(imageDataVal{:, {'DfSim', 'fSim', 'DsSim'}});
valCombinedDatastore = combine(valImds, valLabelsDatastore);
testImds = fileDatastore(imageDataTest.imageFilePath, ...
'ReadFcn' , @(filename) double(load(filename).S_single), ...
'FileExtensions', '.mat');
testLabelsDatastore = arrayDatastore(imageDataTest{:, {'DfSim', 'fSim', 'DsSim'}});
testCombinedDatastore = combine(testImds, testLabelsDatastore);
%% Define the network layers
lgraph = layerGraph();
Layers = [
imageInputLayer([132 132 6],"Name","imageinput","Normalization","none")
convolution2dLayer([1 1],32,"Name","conv_1","Padding","same")
batchNormalizationLayer("Name","batchnorm_1")
leakyReluLayer("Name","relu_1")
dropoutLayer(0.02,"Name","dropout_1")
convolution2dLayer([3 3],32,"Name","conv_2","Padding","same")
leakyReluLayer("Name","relu_2")
dropoutLayer(0.02,"Name","dropout_2")
convolution2dLayer([1 1],64,"Name","conv_3","Padding","same")
batchNormalizationLayer("Name","batchnorm_2")
leakyReluLayer("Name","relu_3")
dropoutLayer(0.02,"Name","dropout_3")
convolution2dLayer([3 3],64,"Name","conv_4","Padding","same")
leakyReluLayer("Name","relu_4")
dropoutLayer(0.02,"Name","dropout_4")
convolution2dLayer([1 1],128,"Name","conv_5","Padding","same")
batchNormalizationLayer("Name","batchnorm_3")
leakyReluLayer("Name","relu_5")
dropoutLayer(0.02,"Name","dropout_5")
convolution2dLayer([3 3],128,"Name","conv_6","Padding","same")
leakyReluLayer("Name","relu_6")
dropoutLayer(0.02,"Name","dropout_6")
convolution2dLayer([1 1],64,"Name","conv_7","Padding","same")
batchNormalizationLayer("Name","batchnorm_4")
leakyReluLayer("Name","relu_7")
dropoutLayer(0.02,"Name","dropout_7")
convolution2dLayer([3 3],64,"Name","conv_8","Padding","same")
leakyReluLayer("Name","relu_8")
dropoutLayer(0.02,"Name","dropout_8")
convolution2dLayer([1 1],32,"Name","conv_9","Padding","same")
batchNormalizationLayer("Name","batchnorm_5")
leakyReluLayer("Name","relu_9")
dropoutLayer(0.02,"Name","dropout_9")
convolution2dLayer([3 3],32,"Name","conv_10","Padding","same")
leakyReluLayer("Name","relu_10")
dropoutLayer(0.02,"Name","dropout_10")
convolution2dLayer([1 1],4,"Name","conv_11","Padding","same")
sigmoidLayer("Name","sigmoid")];
lgraph = addLayers(lgraph,Layers);
dlnet = dlnetwork(lgraph);
plot(lgraph);
%% Training loop
numEpochs = 200;
miniBatchSize = 10;
initialLearnRate = 0.01;
decay = 0.00001;
gradDecay = 0.9;
sqGradDecay = 0.999;
mbq = minibatchqueue(trainCombinedDatastore,...
'MiniBatchSize', miniBatchSize,...
'MiniBatchFormat', {'SSCB', 'CB'}, ...
'OutputAsDlarray', [1, 1],...
'OutputEnvironment', 'auto');
averageGrad = [];
averageSqGrad = [];
numObservationsTrain = imageIdx;
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
plots = 'training-progress';
if strcmp(plots, 'training-progress')
figure
lineLossTrain = animatedline;
xlabel("Total Iterations")
ylabel("Loss")
end
epoch = 0;
iteration = 0;
start = tic;
% Loop over epochs.
while epoch < numEpochs
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq)
iteration = iteration + 1;
% Read mini-batch of data.
[dlX, dlT] = next(mbq);
[loss, gradients, state] = dlfeval(@modelLoss,dlnet,dlX);
dlnet.State = state;
% Determine learning rate for time-based decay learning rate schedule.
learnRate = initialLearnRate/(1 + decay*iteration);
% Update network parameters
[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,gradients,averageGrad,averageSqGrad,...
iteration, learnRate, gradDecay, sqGradDecay);
% Extract weights of first convolution layer
conv1Weights = dlnet.Layers(2).Weights;
% Print or save the weights
disp('Weights of conv_1 layer:');
disp(conv1Weights);
if strcmp(plots, 'training-progress')
D = duration(0,0,toc(start),'Format','hh:mm:ss');
addpoints(lineLossTrain, iteration, double(gather(extractdata(loss))));
title("Epoch: " + epoch + " , Elapsed: " + string(D));
drawnow
end
end
end
%% Custom loss function
function [loss, gradients,state] = modelLoss(dlnet, dlX)
% Forward data through network.
[dlY, state] = forward(dlnet, dlX);
% Calculate parameter maps
fMap = dlY(:,:,1,:).*0.5;
DfMap = dlY(:,:,2,:).*0.107;
S0Map = (dlY(:,:,3,:).*0.6) + 0.7;
DsMap = dlY(:,:,4,:).*0.0017;
% diffusion weightings
dlB = [50 100 150 250 500 800];
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX));
for b = 1:length(dlB)
Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
% Convert Spred to dlarray
Spred = dlarray(Spred, 'SSCB');
% Calculate the mse loss
loss = mse(Spred, dlX);
% Calculate gradients of loss with respect to learnable parameters.
gradients = dlgradient(loss, dlnet.Learnables);
end

Accepted Answer

Richard
Richard on 26 Jun 2023
These lines of code:
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX));
for b = 1:length(dlB)
Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
% Convert Spred to dlarray
Spred = dlarray(Spred, 'SSCB');
are creating a variable, Spred, that does not contain a traced dependency on the output of the network. This means that your mse() call is in fact only tracing a dependency on the original input dlX, therefore the gradients of the loss with respect to learnables is zeros.
Try this instead to create an Spred that incorporates the dependency on the network output:
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX), 'like', dlX);
for b = 1:length(dlB)
Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
The 'like' syntax for zeros() constructs a zeros dlarray that is tracing, like its input, and your indexing within the loop will then be captured. In your original version, because Spred is created as a plain double array, the indexing which places values into Spred(:,:,b,:) is casting the computed and traced right-hand side into a plain double value which loses the trace information that dlgradient depends on.
Incidentally I think you can also remove the loop entirely by reshaping dlB into a 3D vector and relying on implicit expansion, which should be faster:
dlB = reshape(dlB, 1,1,[]);
Spred = S0Map .* (fMap.*exp(-dlB.*DfMap) + (1 - fMap).*exp(-dlB.*DsMap));
  2 Comments
Marissa Brown
Marissa Brown on 26 Jun 2023
Hi Richard! Thank you for the response. I added the trace dependency to the output variable as you suggested, and it now appears that the network is learning! The issue now is that the loss function decreases to about 200, where it levels off. I assume there may be another bug somewhere in the code. I've tried adjusting the hyperparameters (learning rate, mini batch size) as well as using the sgdm and adam optimizers, but got roughly the same results. The data I'm using is all simulated, so I believe the bug is within the network structure itself or the training loop. I also noticed a lot of variation in the loss during training, however I am not sure if this is normal or not.
Richard
Richard on 26 Jun 2023
Hi Marissa,
10 samples is quite a small minibatchsize and I think this is causing you to see a lot of noise in the gradients. When I increase the minibatchsize to 64 I see a much smoother curve::

Sign in to comment.

More Answers (1)

Aniketh
Aniketh on 25 Jun 2023
A very probable cause for this, and what I have exeperienced myself a few times is Initialization, check the initialization of your network's weights. If the weights are initialized too small, it can lead to vanishing gradients. Consider using a suitable initialization method, such as Xavier or He initialization, which helps to maintain a reasonable range for the weights.
Another thing you could consider is your Network architecture, evaluate the depth and complexity of your network architecture. Very deep networks are more susceptible to vanishing gradients. If your network is too deep, consider reducing the number of layers or introducing skip connections (e.g., residual connections) to facilitate gradient flow.
  1 Comment
Marissa Brown
Marissa Brown on 26 Jun 2023
Hi Aniketh, thank you for your response! From what I've read, the convolution layers use Glorot/Xavier as the default initialization, however He initialization performs better for some networks, so I tried using a leaky He initialization function on the convolution layers.
%% He Initializer function
function weights = leakyHe(sz,scale)
% If not specified, then use default scale = 0.1
if nargin < 2
scale = 0.1;
end
filterSize = [sz(1) sz(2)];
numChannels = sz(3);
numIn = filterSize(1) * filterSize(2) * numChannels;
varWeights = 2 / ((1 + scale^2) * numIn);
weights = randn(sz) * sqrt(varWeights);
end
I didn't notice any big changes in the loss function between the two methods.
Loss function with Xavier:
Loss function with He:

Sign in to comment.

Categories

Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange

Products


Release

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!