Clear Filters
Clear Filters

Can you provide me suggestions/critique my approach to this Neural Network fitting?

2 views (last 30 days)
I am hopping I can get some constructive suggestions on how to improve my code or if you guys think there is a better way to approach what I am trying to do or if there is any other area I should investigate. Much appreciated in advance.
So what I am doing is creating a neural network to fit a variable as a function of some other 7 variables. I have a gigantic tabulation of those 8 variables:
ANN=fitnet([50],'trainrp');
ANN=train(ANN,Input,Output);
Where Input is the tabulation of the 7 variables and Output is the tabulation of the variable. The end goal here is to give the ANN any 7 random combination of the Input and it will give me an accurate estimation of the Output (linear interpolation). However, I am doing this process iteratively. What I mean is that after training this neural network on fitting the data, I then generate a new and different table and use the neural network to verify if it can predict the output values within a 10% percent error. If it cannot, I take the rows of data where it didn't do well and I add them to the original table, and repeat the command:
ANN=train(ANN,Input,Output);
But now Input and Output are the original table plus the data from the new table where the neural network didn't do very well. And I keep repeating this process over and over and over (automated process, not manual).

Accepted Answer

Mrutyunjaya Hiremath
Mrutyunjaya Hiremath on 17 Aug 2023
Try this:
% Load your data
% Example: load('your_data.mat');
% Assuming you have Input and Output as your data matrices
% Normalize Inputs
meanInput = mean(Input,2);
stdInput = std(Input,0,2);
Input = (Input - meanInput) ./ stdInput;
% Split data into training and validation sets
[trainInd,~,valInd] = dividerand(size(Input,2),0.7,0,0.3);
trainInput = Input(:,trainInd);
trainOutput = Output(:,trainInd);
valInput = Input(:,valInd);
valOutput = Output(:,valInd);
% Initialize the neural network with multiple hidden layers
% For instance, one with 100 neurons and another with 50
hiddenLayers = [100, 50];
ANN = fitnet(hiddenLayers,'trainrp');
% Set early stopping parameters
ANN.divideParam.trainRatio = 0.7;
ANN.divideParam.valRatio = 0.3;
ANN.divideParam.testRatio = 0;
% Regularization (to prevent overfitting)
ANN.performParam.regularization = 0.1;
% Convergence criteria
threshold = 0.02; % stop if mean relative error is below this
max_iterations = 20; % maximum number of training iterations
prevError = inf; % initialize with a high value
tolerance = 0.001; % minimal change to consider convergence
for i = 1:max_iterations
% Training the neural network
[ANN, tr] = train(ANN, trainInput, trainOutput);
% Validate
predictions = ANN(valInput);
error = abs(predictions - valOutput) ./ valOutput; % relative error
meanError = mean(error);
% Check convergence
if meanError < threshold || abs(prevError - meanError) < tolerance
break; % convergence achieved
end
% If the error on validation data is high, add it to training set
highErrorIndices = find(error > 0.1);
trainInput = [trainInput, valInput(:,highErrorIndices)];
trainOutput = [trainOutput, valOutput(:,highErrorIndices)];
prevError = meanError;
end
Change according to your input and output data.
  1 Comment
Ali Almakhmari
Ali Almakhmari on 17 Aug 2023
Edited: Ali Almakhmari on 17 Aug 2023
I am already doing something similar, except few things. First, I never normalized my data, which is a good idea. But the thing is, I am not sure if I should do it because my trasnferFcn of the first hidden layer is the "tansig", which I believe already does the normalization (although not 100% sure). Second, I never considered multiple layers, so thats a good suggestion to investigate. Third, I never considered "ANN.performParam.regularization", I will try to read more about it. Also, why did you not scale the output?

Sign in to comment.

More Answers (0)

Categories

Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange

Products


Release

R2022b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!