Difficulties in training ANNs with multiple outputs: always constant outputs
9 views (last 30 days)
Show older comments
Does anyone have experience with defining a neural network that has multiple outputs? I want to input a vector and output a vector as well as a matrix. Accordingly, I need a DAG network.
I realize that I need a custom training loop for this, compare with: https://de.mathworks.com/help/deeplearning/ug/train-network-with-multiple-outputs.html
The good news is that the training (code) can basically be performed:
I have chosen the following network architecture:
numNeurons = 10;
% Input
layers1 = [
featureInputLayer(size(XData,2),'Name','Param_Input',"Normalization","rescale-symmetric");
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
tanhLayer
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
tanhLayer
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
tanhLayer('Name','tanh_middle')
];
lgraph = layerGraph(layers1);
% Output 1
filterSize = dimOutput{1};
numFilters = 20;
strideSize = [1,1];
projectionSize = [1,1,size(XData,2)];
layers2 = [
fullyConnectedLayer(numNeurons,'Name','fcEF')
batchNormalizationLayer
tanhLayer
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
tanhLayer
projectAndReshapeLayerNew(projectionSize)
transposedConv2dLayer(filterSize,numFilters,'Stride',strideSize,Cropping="same")
batchNormalizationLayer
tanhLayer
transposedConv2dLayer(filterSize,1,'Stride',strideSize,'Name','Output1')
];
% Output 2
layers3 = [
fullyConnectedLayer(numNeurons,'Name','fcFreq')
batchNormalizationLayer
tanhLayer
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
tanhLayer
fullyConnectedLayer(dimOutput{2},'Name','Output2')
];
lgraph = addLayers(lgraph,layers2);
lgraph = addLayers(lgraph,layers3);
lgraph = connectLayers(lgraph,"tanh_middle","fcEF");
lgraph = connectLayers(lgraph,"tanh_middle","fcFreq");
end
% [...] Training
% "Assemble Multiple-Output Network for Prediction"
lgraphNew = layerGraph(trainedNet);
layerReg1 = regressionLayer(Name="regOutput1");
layerReg2 = regressionLayer(Name="regOutput2");
lgraphNew = addLayers(lgraphNew,layerReg1);
lgraphNew = addLayers(lgraphNew,layerReg2);
lgraphNew = connectLayers(lgraphNew,"Output1","regOutput1");
lgraphNew = connectLayers(lgraphNew,"Output2","regOutput2");
figure
plot(lgraphNew)
However, the problem is that all outputs (coefficients of the vector and the matrix) are the same. Apparently the network learns some average values and not the concrete training data as desired:
Output 1 (all matrices are the same):
Output 2 (all "vectors"/lines are the same):
Is the network architecture very unfavorable? What could be the reason? I would rule out the training data as a reason, as the training is successful if I train separate single-output ANNs.
Thank you and best regards.
0 Comments
Answers (1)
Venu
on 8 Jan 2024
In your case, I suspect factors and attributes of your custom layer 'ProjectAndReshapeLayer'. You can check with layer initialization aspect, consider applying regularization, check how the projection matrix is learned, verify that the reshaping operation is appropriate for the specific output type. It's important to verify that this layer is not inadvertently causing the network to learn average values rather than distinct representations for each output.
Try adding another FC layer at the end of layers2 to increase the complexity. This additional complexity can potentially help the network capture more nuanced representations for the first output, especially if the previous layers might not have been capturing the necessary complexity.
2 Comments
Udit06
on 9 Jan 2024
I would like to add one more point to the above answer. In a multi-output scenario, the total loss for the network is often a combination of the individual losses for each output. The model's training objective is to minimize this total loss. However, if one loss dominates the total loss, the network may focus on optimizing for that particular output at the expense of the others, leading to poor performance on the less weighted tasks. To handle this, you can assign weights to each loss component to balance their contributions to the total loss.
I hope this helps.
See Also
Categories
Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!