Deep learning: predict and forward give very inconsistent result with batch normalisation

1 view (last 30 days)
Hi there,
I have a dnn with BatchNormalization layers. I did [Uf, statef] = forward(dnn_net, XT); "statef" returned from forward call should contain the learned mean and variance of the BN layers. I then update the dnn_net.State = statef so that the dnn_net.State is updated with the learned mean and var from the forward call. Then I did Up = predict(dnn_net,XT) [on the same data XT as the forward call]. Then I compare the mse(Up,Uf) and it turns out to be quite large (e.g 2.xx). This is totally not expected. Please help
I have created a simple test script to show the problem below:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'b-');
err = mse(Uf(1,:),Up(1,:))
err =
1x1 dlarray 0.3134

Answers (1)

Matt J
Matt J on 25 Aug 2024
Edited: Matt J on 25 Aug 2024
To get agreement, you need to allow the trained mean and variance to converge:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
for i=1:500
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
end
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'--b');
err = mse(Uf(1,:),Up(1,:))
err =
1x1 dlarray 2.6776e-12
  6 Comments
Tom
Tom on 26 Aug 2024
Hi Matt,
Thanks a lot for your explaination, Is there a way to set up BN layer so that it will be using a supplied mean and var for normalization when calling forward()? In my application, it is important for the BN to use the mean and var as constants for differentiation of U with respect to X and T. Thanks in advance.

Sign in to comment.

Categories

Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange

Products


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!