Is it possible to share common weights and bias among different LSTM layers?

37 views (last 30 days)
I am building a network looks like the figure below.
There are three LSTM layers, namely LSTM_common_1, LSTM_common_2 and LSTM_common_3.
Can I retrict their weights and bias so that all of the LSTM_common_x shares the same set of weights and bias?
2020-02-05 17_50_15-Clipboard.png

Answers (1)

Conor Daly
Conor Daly on 17 Feb 2023
One way to share weights like this is to use nested layers -- layers which have learnable parameters defined by neural networks. The general idea is to create a layer which uses the shared sub-network (which in this case is just a single LSTM layer) as appropriate.
Here's an example for the case above:
classdef commonLSTMLayer < nnet.layer.Layer ...
& nnet.layer.Formattable ...
& nnet.layer.Acceleratable
properties (Learnable)
Network
end
methods
function this = commonLSTMLayer(numHiddenUnits, numOutputs, args)
arguments
numHiddenUnits (1,1) {mustBePositive, mustBeInteger}
numOutputs (1,1) {mustBePositive, mustBeInteger}
args.OutputMode {mustBeMember(args.OutputMode, ["last","sequence"])}= "sequence"
args.Name {mustBeTextScalar}
end
this.Name = args.Name;
layer = lstmLayer(numHiddenUnits, OutputMode=args.OutputMode);
this.Network = dlnetwork(layer, Initialize=false);
this.NumOutputs = numOutputs;
this.OutputNames = "out" + (1:numOutputs);
end
function varargout = predict(this, X)
varargout = cell(1,this.NumOutputs);
for n = 1:this.NumOutputs
varargout{n} = predict(this.Network, X(n,:,:));
end
end
end
end
Using this layer we can construct the network as follows:
numInputChannels = 3;
numHiddenUnits = 64;
layers = [ sequenceInputLayer(numInputChannels)
commonLSTMLayer(numHiddenUnits, numInputChannels, OutputMode="last", Name="lstm")
fullyConnectedLayer(2, Name="fc1")
concatenationLayer(1, 3, Name="cat")
regressionLayer() ];
lg = layerGraph(layers);
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc2"));
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc3"));
lg = connectLayers(lg, "lstm/out2", "fc2");
lg = connectLayers(lg, "lstm/out3", "fc3");
lg = connectLayers(lg, "fc2", "cat/in2");
lg = connectLayers(lg, "fc3", "cat/in3");
analyzeNetwork(lg)

Products


Release

R2019a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!