Is it possible to share common weights and bias among different LSTM layers?
8 views (last 30 days)
Show older comments
I am building a network looks like the figure below.
There are three LSTM layers, namely LSTM_common_1, LSTM_common_2 and LSTM_common_3.
Can I retrict their weights and bias so that all of the LSTM_common_x shares the same set of weights and bias?
0 Comments
Answers (1)
Conor Daly
on 17 Feb 2023
One way to share weights like this is to use nested layers -- layers which have learnable parameters defined by neural networks. The general idea is to create a layer which uses the shared sub-network (which in this case is just a single LSTM layer) as appropriate.
Here's an example for the case above:
classdef commonLSTMLayer < nnet.layer.Layer ...
& nnet.layer.Formattable ...
& nnet.layer.Acceleratable
properties (Learnable)
Network
end
methods
function this = commonLSTMLayer(numHiddenUnits, numOutputs, args)
arguments
numHiddenUnits (1,1) {mustBePositive, mustBeInteger}
numOutputs (1,1) {mustBePositive, mustBeInteger}
args.OutputMode {mustBeMember(args.OutputMode, ["last","sequence"])}= "sequence"
args.Name {mustBeTextScalar}
end
this.Name = args.Name;
layer = lstmLayer(numHiddenUnits, OutputMode=args.OutputMode);
this.Network = dlnetwork(layer, Initialize=false);
this.NumOutputs = numOutputs;
this.OutputNames = "out" + (1:numOutputs);
end
function varargout = predict(this, X)
varargout = cell(1,this.NumOutputs);
for n = 1:this.NumOutputs
varargout{n} = predict(this.Network, X(n,:,:));
end
end
end
end
Using this layer we can construct the network as follows:
numInputChannels = 3;
numHiddenUnits = 64;
layers = [ sequenceInputLayer(numInputChannels)
commonLSTMLayer(numHiddenUnits, numInputChannels, OutputMode="last", Name="lstm")
fullyConnectedLayer(2, Name="fc1")
concatenationLayer(1, 3, Name="cat")
regressionLayer() ];
lg = layerGraph(layers);
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc2"));
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc3"));
lg = connectLayers(lg, "lstm/out2", "fc2");
lg = connectLayers(lg, "lstm/out3", "fc3");
lg = connectLayers(lg, "fc2", "cat/in2");
lg = connectLayers(lg, "fc3", "cat/in3");
analyzeNetwork(lg)
0 Comments
See Also
Categories
Find more on Image Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!