Clear Filters
Clear Filters

calculate predictions with weights and bias which extracted from LSTM model

35 views (last 30 days)
Hello,
I'am trying to calculate outputs using parameters from LSTM model (recurrentweight, inputweight, bias)
but output is different between "below codes" and "output from code Y=predict(net,X)".
please help me if you know the problems.
Thank you.
My network structure: (simple network)
layers = [
sequenceInputLayer(9,"Normalization","none") % number of input parameters are 9
lstmLayer(256)
fullyConnectedLayer(1)];
options = trainingOptions("adam", ...
MaxEpochs=2000, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
Plots="training-progress", ...
Verbose=false);
net = trainnet(X,Y,layers,"mse",options);
My code to extract the weights and bias:
R=net.Layers(2,1).RecurrentWeights;
W=net.Layers(2,1).InputWeights;
b=net.Layers(2,1).Bias;
Fc_W=net.Layers(3,1).Weights;
Fc_B=net.Layers(3,1).Bias;
Code for extract parameters of LSTM Layer (input, forget, cell, output)
HiddenLayersNum = 256;
W.Wi=W(1:HiddenLayersNum,:);
W.Wf=W(HiddenLayersNum+1:2*HiddenLayersNum,:);
W.Wc=W(2*HiddenLayersNum+1:3*HiddenLayersNum,:);
W.Wo=W(3*HiddenLayersNum+1:4*HiddenLayersNum,:);
R.Ri=R(1:HiddenLayersNum,:);
R.Rf=R(HiddenLayersNum+1:2*HiddenLayersNum,:);
R.Rc=R(2*HiddenLayersNum+1:3*HiddenLayersNum,:);
R.Ro=R(3*HiddenLayersNum+1:4*HiddenLayersNum,:);
B.bi=b(1,:);
B.bf=b(HiddenLayersNum+1:2*HiddenLayersNum,:);
B.bc=b(2*HiddenLayersNum+1:3*HiddenLayersNum,:);
B.bo=b(3*HiddenLayersNum+1:4*HiddenLayersNum,:);
h=net.State.Value{1,1}; % Hiddenstate
c=net.State.Value{2,1}; % Cellstate
Code for calculate LSTM Layer output:
% Input Gate
Z = W.Wi*X+R.Ri*h+B.bi; % x is new intput value for prediction (ex: x=[1 5 20 1 2];)
I = 1.0 ./ (1.0 + exp(-Z)); % Input gate
% Forget Gate
f =W.Wf*X+R.Rf*h+B.bf;
F = 1.0 ./ (1.0 + exp(-f)); % Forget gate
% Layer Input
g=W.Wc*X+R.Rc*h+B.bc; % Layer input
G=tanh(g);
% Output Layer
output = W.Wo*X+R.Ro*h_prev+B.bo;
output = 1.0 ./ (1.0 + exp(-output)); % Output Gate
% Cell State
cellgate=F.*c+I.*G; % Cell Gate
cellgate=cellgate;
% Output (Hidden) State
hidden=O.*tanh(cellgate); % Output State
hidden=dlarray(hidden);
L1 = relu(hidden);
Code for calculate output in fullyconnected Layer:
Fc=Fc_W*L1+Fc_B

Accepted Answer

Paras Gupta
Paras Gupta on 15 Jul 2024 at 19:46
Hi James,
I understand that the provided code for your model, which includes "LSTM" and a "Fully Connected" layers is giving incorrect inference results than the trained model's "predict" function.
The provided code seems to have the following three issues:
  • Incorrect Bias Initialization for Input Gate - The bias for the input gate shoud index the first "HiddenLayersNum" number of elements
  • Incorrect Variable Used for Output State Calculation - The variable "ouput" should be used instead of "O"
  • Incorrect Computation and Usage of "L1" - The computation of "L1" variable using "relu" function and its subsequent usage is incorrect.
You can refer to the following modfied code to obtain the correct results:
HiddenLayersNum = 256;
W.Wi=W(1:HiddenLayersNum,:);
W.Wf=W(HiddenLayersNum+1:2*HiddenLayersNum,:);
W.Wc=W(2*HiddenLayersNum+1:3*HiddenLayersNum,:);
W.Wo=W(3*HiddenLayersNum+1:4*HiddenLayersNum,:);
R.Ri=R(1:HiddenLayersNum,:);
R.Rf=R(HiddenLayersNum+1:2*HiddenLayersNum,:);
R.Rc=R(2*HiddenLayersNum+1:3*HiddenLayersNum,:);
R.Ro=R(3*HiddenLayersNum+1:4*HiddenLayersNum,:);
B.bi=b(1:HiddenLayersNum,:); % corrected
B.bf=b(HiddenLayersNum+1:2*HiddenLayersNum,:);
B.bc=b(2*HiddenLayersNum+1:3*HiddenLayersNum,:);
B.bo=b(3*HiddenLayersNum+1:4*HiddenLayersNum,:);
h=net.State.Value{1,1}; % Hiddenstate
c=net.State.Value{2,1}; % Cellstate
% Input Gate
Z = W.Wi*X+R.Ri*h+B.bi; % x is new intput value for prediction (ex: x=[1 5 20 1 2];)
I = 1.0 ./ (1.0 + exp(-Z)); % Input gate
% Forget Gate
f =W.Wf*X+R.Rf*h+B.bf;
F = 1.0 ./ (1.0 + exp(-f)); % Forget gate
% Layer Input
g=W.Wc*X+R.Rc*h+B.bc; % Layer input
G=tanh(g);
% Output Layer
output = W.Wo*X+R.Ro*h_prev+B.bo;
output = 1.0 ./ (1.0 + exp(-output)); % Output Gate
% Cell State
cellgate=F.*c+I.*G; % Cell Gate
cellgate=cellgate;
% Output (Hidden) State
hidden=output.*tanh(cellgate); % corrected
hidden=dlarray(hidden);
% removed L1 computation as it is not required
Fc=Fc_W*hidden+Fc_B % corrected
The following documentation links might be helpful:
Hope this helps.
  1 Comment
James
James on 19 Jul 2024 at 12:04
Thank you for your help!
I tried your code and the result showed same value as the value of "Y=predict(net,X)"
I really appreciate!

Sign in to comment.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!