custom neural network error( About dlfeval )
11 views (last 30 days)
Show older comments
jaehong kim
on 9 Feb 2021
Commented: jaehong kim
on 14 Feb 2021
I am running a custom neural network using a low level api ( dlnetwork, dlfeval, adamupdate, dlgradient).
However, while running dlfeval, I get an error.
error is here.
Error using dlfeval (line 43)
First input argument must be a formatted dlarray.
Error in deep (line 639)
[gradient,loss]=dlfeval(@modelGradients,dlnet,dlX);
I think it's an error about dlarray, but if you look at my code, I declare dlarray in the input.
my code is here. (Input feature:8 // Target:1)
clear,clc,close all
data=readmatrix('train.csv');
inputs=data(:,1:8);
targets=data(:,9);
input2=transpose(inputs);
target2=transpose(targets);
inputs2=normalize(input2,2,'range');
layers= [sequenceInputLayer([8],'Name','input')
fullyConnectedLayer(64,'Name','fc1')
tanhLayer('Name','tanh1')
fullyConnectedLayer(32,'Name','fc2')
tanhLayer('Name','tanh2')
fullyConnectedLayer(16,'Name','fc3')
tanhLayer('Name','tanh3')
fullyConnectedLayer(8,'Name','fc4')
tanhLayer('Name','tanh4')
fullyConnectedLayer(1,'Name','fc5')
];
lgraph=layerGraph(layers);
dlnet=dlnetwork(lgraph);
for it=1:5000
dlX=dlarray(inputs2);
[gradient,loss]=dlfeval(@modelGradients,dlnet,dlX);
dlnet=adamupdate(dlnet,gradient);
end
function [gradient,loss]=modelGradients(dlnet,dlx,t)
out=forward(dlnet,dlx);
loss=immse(out,t);
loss=dlarray(loss);
gradient=dlgradient(loss,dlnet.Learnables);
end
Thanks for reading my question!
0 Comments
Accepted Answer
Srivardhan Gadila
on 13 Feb 2021
The input dlX for the forward(dlnet,dlX) function should be a formatted dlarray. Refer to the documentation of forward (specifically dlX under Input Arguments) for more information.
Also from the above code, the modelGradients takes in dlnet, dlx and t as input arguments
function [gradient,loss]=modelGradients(dlnet,dlx,t)
but in the for loop to compute the gradients you are not providing the target data i.e., target2 as dlarray.
[gradient,loss]=dlfeval(@modelGradients,dlnet,dlX);
due to which you may get another error w.r.t
loss=immse(out,t);
More Answers (0)
See Also
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!