How can I transfer the model parameters of a well-trained NN to another one?
9 views (last 30 days)
Show older comments
Wanli Wen
on 24 Nov 2019
Answered: Divya Gaddipati
on 5 Dec 2019
I have two NNs, i.e., net_1 and net_2, where net_1 is not trained and net_2 has been well trained. Now I want to transfer the knowledge of net_2 to net_1, such that net_1 can be used well as net_2. So I have got the following code. However, after setting the weights and bias of net_1 to those of net_2, I find that the net_1 behaves very very bad, e.g., net_2(-2) = 3.999, net_1(-2)=32.249. Here, net_1 is expected to output a value that is very similar with net_2. May anone please tell me that is there anything wrong with my code? Thanks.
(Please note that I do not want to use the operation net_1 = net_2 to achieve this purpose.)
clear all
%%
% Task: To fit a non-linear function f(x) = x.^2
%%
D=1e4; % no. of training sample
layers_neurons=[64];
%% Net 1: no training network
net_1 = feedforwardnet(layers_neurons);
[data1,target2] = gen_data_sample(10);
net_1 = configure(net_1, data1, target2);
%% Net 2: well training network
[data2,target2] = gen_data_sample(D);
net_2 = feedforwardnet(layers_neurons); % doc feedforwardnet for more details
net_2 = configure(net_2, data2, target2);
net_2 = train(net_2,data2, target2); % , 'useGPU', 'yes', 'useparallel', 'yes'
%% Transfer the knowledge of Net 2 to Net 1
net_1.IW = net_2.IW;
net_1.LW = net_2.LW;
net_1.b = net_2.b;
%% Test and Compare Net 1 and Net 2
net_1(-2)
net_2(-2)
%%
function [input,output] = gen_data_sample(D)
%%
input = -20+(20-(-20))*rand(1, D);
output = input.^2;
end
0 Comments
Accepted Answer
Divya Gaddipati
on 5 Dec 2019
Before you assign weights of “net_2” to “net_1”, initialize net_1 to net_2 using the init function
net_1 = init(net_2);
This would resolve your issue.
Additionally, you can also remove the configuring part of net_1 (i.e., line 10 in your code), which might not be required if you are using init.
For more information on configure and init, refer to the below link:https://www.mathworks.com/help/deeplearning/ug/create-configure-and-initialize-multilayer-neural-networks.html#bss330n-3
Hope this helps!
0 Comments
More Answers (0)
See Also
Categories
Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!