Deep learning: predict and forward give very inconsistent result with batch normalisation
1 view (last 30 days)
Show older comments
Hi there,
I have a dnn with BatchNormalization layers. I did [Uf, statef] = forward(dnn_net, XT); "statef" returned from forward call should contain the learned mean and variance of the BN layers. I then update the dnn_net.State = statef so that the dnn_net.State is updated with the learned mean and var from the forward call. Then I did Up = predict(dnn_net,XT) [on the same data XT as the forward call]. Then I compare the mse(Up,Uf) and it turns out to be quite large (e.g 2.xx). This is totally not expected. Please help
I have created a simple test script to show the problem below:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'b-');
err = mse(Uf(1,:),Up(1,:))
0 Comments
Answers (1)
Matt J
on 25 Aug 2024
Edited: Matt J
on 25 Aug 2024
To get agreement, you need to allow the trained mean and variance to converge:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
for i=1:500
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
end
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'--b');
err = mse(Uf(1,:),Up(1,:))
6 Comments
See Also
Categories
Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!