Clear Filters
Clear Filters

Using transformer neural network for classification task

18 views (last 30 days)
numChannels = inputSize;
maxPosition = 256;
numHeads = 4;
numKeyChannels = numHeads*32;
layers = [
sequenceInputLayer(numChannels,Name="input")
positionEmbeddingLayer(numChannels, maxPosition, Name="pos-emb");
additionLayer(2, Name="add")
selfAttentionLayer(numHeads,numKeyChannels,'AttentionMask','causal')
selfAttentionLayer(numHeads,numKeyChannels)
indexing1dLayer("last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph, "input", "add/in2");
maxEpochs = 100;
miniBatchSize = 32;
learningRate = 0.001;
solver = 'adam';
shuffle = 'every-epoch';
gradientThreshold = 10;
executionEnvironment = "auto"; % chooses local GPU if available, otherwise CPU
options = trainingOptions(solver, ...
'Plots','training-progress', ...
'MaxEpochs', maxEpochs, ...
'MiniBatchSize', miniBatchSize, ...
'Shuffle', shuffle, ...
'InitialLearnRate', learningRate, ...
'GradientThreshold', gradientThreshold, ...
'ExecutionEnvironment', executionEnvironment);
The input size is 12, so there are 12 features.
numClasses is 4, so I am classifying it into 4 class.
But it gives the following error when I try to run it
"
Error in test123_20240727 (line 195)
net=trainNetwork(XTrain, YTrain, layers, options);
Caused by:
Layer 'add': Unconnected input. Each layer input must be connected to the output of another layer.
"
line 195 is "net=trainNetwork(XTrain, YTrain, layers, options);"
Can anyone help me with this?
  7 Comments
Umar
Umar on 29 Jul 2024
Hi @ haohaoxuexi1,
If you are still having issues with modifying your code, please let us know. We will be happy to help you out.
haohaoxuexi1
haohaoxuexi1 on 29 Jul 2024
@Umar Hi Umar, I am good at the moment. Will let u know if I have further question.

Sign in to comment.

Accepted Answer

Joss Knight
Joss Knight on 29 Jul 2024

You've passed layers instead of lgraph to trainNetwork.

  2 Comments
Umar
Umar on 29 Jul 2024
@Joss Knight, Thanks for jumping in. Please advice how to use lgraph to trainNetwork by providing code snippet. Again, thanks for your cooperation.
Joss Knight
Joss Knight on 13 Aug 2024
net=trainNetwork(XTrain, YTrain, lgraph, options);
instead of
net=trainNetwork(XTrain, YTrain, layers, options);

Sign in to comment.

More Answers (0)

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Products


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!