Trying to get 80% and greater accuracy from network. Can someone help in editing my code to reach to 80% or close too?
4 views (last 30 days)
Show older comments
clc; clear all; close all;
load generated_data.mat
% 2289*180
% 6 classes
X1_T = X1';
rand('seed', 0)
ind = randperm(size(X1_T, 1));
X1_T = X1_T(ind, :);
Y1 = categorical(Y1(ind));
% Split Data
X1_train = X1_T;
train_X1 = X1_train(1:120,:);
train_Y1 = Y1(1:120);
% Data Batch
XTrain=(reshape(train_X1', [2289,120]));
val_X1 = X1_train(121:150,:);
val_Y1 = Y1(121:150);
XVal=(reshape(val_X1', [2289,30]));
test_X1 = X1_train(151:180,:);
test_Y1 = Y1(151:180);
XTest=(reshape(test_X1', [2289,30]));
numFeatures = size(X1_T,2);
% number of hidden units represent the size of the data
numHiddenUnits = 180;
%number of classes represent different patients normal,LIS,type2....
numClasses = length(categories(categorical(Y1)));
layers = [ ...
sequenceInputLayer(numFeatures)
dropoutLayer(0.1)
bilstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numClasses)
instanceNormalizationLayer
softmaxLayer
classificationLayer];
options = trainingOptions('sgdm', ...
'MaxEpochs',150, ...
'GradientThreshold',1, ...
'Verbose',false, ...
'ValidationData',{XVal, val_Y1},...
'LearnRateDropFactor',0.2,...
'LearnRateDropPeriod',5,...
'Plots','training-progress');
% Train
net = trainNetwork(XTrain,train_Y1,layers,options);
% Test
miniBatchSize = 27;
YPred = classify(net,XTest, ...
'MiniBatchSize',miniBatchSize,...
'ExecutionEnvironment', 'cpu');
acc = mean(YPred(:) == categorical(test_Y1(:)))
figure
t = confusionchart(categorical(test_Y1(:)),YPred(:));
2 Comments
Accepted Answer
yanqi liu
on 16 Dec 2021
Edited: yanqi liu
on 16 Dec 2021
yes,sir,may be use
clc; clear all; close all;
load generated_data.mat
% 2289*180
% 6 classes
rand('seed', 0)
X1_T = X1';
YC1 = categorical(Y1);
CS = categories(YC1);
train_index = []; val_index = []; test_index = [];
for i = 1 : length(CS)
indi = find(YC1==CS{i});
% Shuffling data
indi = indi(randperm(length(indi)));
% 2/3---train, 1/6---val, 1/6---test
index1 = round(length(indi)*2/3);
index2 = round(length(indi)*(2/3+1/6));
train_index = [train_index indi(1:index1)];
val_index = [val_index indi(1+index1:index2)];
test_index = [test_index indi(1+index2:end)];
end
ind = [train_index val_index test_index];
X1_T = X1_T(ind, :);
Y1 = categorical(Y1(ind));
% Split Data
X1_train = X1_T;
train_X1 = X1_train(1:120,:);
train_Y1 = Y1(1:120);
% Data Batch
XTrain=(reshape(train_X1', [2289,120]));
val_X1 = X1_train(121:150,:);
val_Y1 = Y1(121:150);
XVal=(reshape(val_X1', [2289,30]));
test_X1 = X1_train(151:180,:);
test_Y1 = Y1(151:180);
XTest=(reshape(test_X1', [2289,30]));
numFeatures = size(X1_T,2);
% number of hidden units represent the size of the data
numHiddenUnits = 500;
%number of classes represent different patients normal,LIS,type2....
numClasses = length(categories(categorical(Y1)));
layers = [ ...
sequenceInputLayer(numFeatures)
%dropoutLayer(0.5)
instanceNormalizationLayer
bilstmLayer(numHiddenUnits,'OutputMode','sequence')
%dropoutLayer(0.5)
instanceNormalizationLayer
bilstmLayer(round(numHiddenUnits/2),'OutputMode','sequence')
fullyConnectedLayer(numClasses)
instanceNormalizationLayer
softmaxLayer
classificationLayer];
options = trainingOptions('sgdm', ...
'MaxEpochs',100, ...
'GradientThreshold',1, ...
'Verbose',false, ...
'ValidationData',{XVal, val_Y1},...
'LearnRateDropFactor',0.2,...
'LearnRateDropPeriod',5,...
'Plots','training-progress');
% Train
net = trainNetwork(XTrain,train_Y1,layers,options);
% Test
miniBatchSize = 27;
YPred = classify(net,XTest, ...
'MiniBatchSize',miniBatchSize,...
'ExecutionEnvironment', 'cpu');
acc = mean(YPred(:) == categorical(test_Y1(:)))
figure
t = confusionchart(categorical(test_Y1(:)),YPred(:));
acc =
0.9667
>>
2 Comments
yanqi liu
on 17 Dec 2021
yes,sir,we know it is 6 classes,so just for every class,we choose 2/3、1/6、1/6 as train、val、test data,through this method,we can get the data split index
More Answers (0)
See Also
Categories
Find more on Image Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!