Assess Neural Network Classifier Performance
Create a feedforward neural network classifier with fully connected layers using fitcnet
. Use validation data for early stopping of the training process to prevent overfitting the model. Then, use the object functions of the classifier to assess the performance of the model on test data.
Load and Preprocess Sample Data
This example uses the 1994 census data stored in census1994.mat
. The data set consists of demographic information from the US Census Bureau that you can use to predict whether an individual makes over $50,000 per year.
Load the sample data census1994
, which contains the training data adultdata
and the test data adulttest
. Preview the first few rows of the training data set.
load census1994
head(adultdata)
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary ___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______ 39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K 50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K 38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K 53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K 28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K 37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K 49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K 52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Each row contains the demographic information for one adult. The last column, salary
, shows whether a person has a salary less than or equal to $50,000 per year or greater than $50,000 per year.
Delete the rows of adultdata
and adulttest
in which the tables have missing values.
adultdata = rmmissing(adultdata); adulttest = rmmissing(adulttest);
Combine the education_num
and education
variables in both the training and test data to create a single ordered categorical variable that shows the highest level of education a person has achieved.
edOrder = unique(adultdata.education_num,"stable"); edCats = unique(adultdata.education,"stable"); [~,edIdx] = sort(edOrder); adultdata.education = categorical(adultdata.education, ... edCats(edIdx),"Ordinal",true); adultdata.education_num = []; adulttest.education = categorical(adulttest.education, ... edCats(edIdx),"Ordinal",true); adulttest.education_num = [];
Partition Training Data
Split the training data further using a stratified holdout partition. Create a separate validation data set to stop the model training process early. Reserve approximately 30% of the observations for the validation data set and use the rest of the observations to train the neural network classifier.
rng("default") % For reproducibility of the partition c = cvpartition(adultdata.salary,"Holdout",0.30); trainingIndices = training(c); validationIndices = test(c); tblTrain = adultdata(trainingIndices,:); tblValidation = adultdata(validationIndices,:);
Train Neural Network
Train a neural network classifier by using the training set. Specify the salary
column of tblTrain
as the response and the fnlwgt
column as the observation weights, and standardize the numeric predictors. Evaluate the model at each iteration by using the validation set. Specify to display the training information at each iteration by using the Verbose
name-value argument. By default, the training process ends early if the validation cross-entropy loss is greater than or equal to the minimum validation cross-entropy loss computed so far, six times in a row. To change the number of times the validation loss is allowed to be greater than or equal to the minimum, specify the ValidationPatience
name-value argument.
Mdl = fitcnet(tblTrain,"salary","Weights","fnlwgt", ... "Standardize",true,"ValidationData",tblValidation, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 0.326435| 0.105391| 1.174862| 0.023566| 0.325292| 0| | 2| 0.275413| 0.024249| 0.259219| 0.086448| 0.275310| 0| | 3| 0.258430| 0.027390| 0.173985| 0.048263| 0.258820| 0| | 4| 0.218429| 0.024172| 0.617121| 0.033140| 0.220265| 0| | 5| 0.194545| 0.022570| 0.717853| 0.024639| 0.197881| 0| | 6| 0.187702| 0.030800| 0.706053| 0.013569| 0.192706| 0| | 7| 0.182328| 0.016970| 0.175624| 0.015521| 0.187243| 0| | 8| 0.180458| 0.007389| 0.241016| 0.022679| 0.184689| 0| | 9| 0.179364| 0.007194| 0.112335| 0.017717| 0.183928| 0| | 10| 0.175531| 0.008233| 0.271539| 0.014772| 0.180789| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 0.167236| 0.014633| 0.941927| 0.014300| 0.172918| 0| | 12| 0.164107| 0.007069| 0.186935| 0.013451| 0.169584| 0| | 13| 0.162421| 0.005973| 0.226712| 0.017844| 0.167040| 0| | 14| 0.161055| 0.004590| 0.142162| 0.024692| 0.165982| 0| | 15| 0.159318| 0.007807| 0.438498| 0.020057| 0.164524| 0| | 16| 0.158856| 0.003321| 0.054253| 0.024557| 0.164177| 0| | 17| 0.158481| 0.004336| 0.125983| 0.015360| 0.163746| 0| | 18| 0.158042| 0.004697| 0.160583| 0.015980| 0.163042| 0| | 19| 0.157412| 0.007637| 0.304204| 0.022741| 0.162194| 0| | 20| 0.156931| 0.003145| 0.182916| 0.030745| 0.161804| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 0.156666| 0.003791| 0.089101| 0.017161| 0.161714| 0| | 22| 0.156457| 0.003157| 0.039609| 0.015312| 0.161592| 0| | 23| 0.156210| 0.002608| 0.081463| 0.017402| 0.161511| 0| | 24| 0.155981| 0.003497| 0.088109| 0.023004| 0.161557| 1| | 25| 0.155520| 0.004131| 0.181666| 0.018876| 0.161433| 0| | 26| 0.154899| 0.002309| 0.327281| 0.013928| 0.161065| 0| | 27| 0.154703| 0.001210| 0.055537| 0.013516| 0.160733| 0| | 28| 0.154503| 0.002407| 0.089433| 0.015329| 0.160449| 0| | 29| 0.154304| 0.003212| 0.118986| 0.016101| 0.160163| 0| | 30| 0.154026| 0.002823| 0.183600| 0.029703| 0.159885| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 31| 0.153738| 0.004477| 0.405824| 0.020427| 0.159378| 0| | 32| 0.153538| 0.003659| 0.065795| 0.013910| 0.159333| 0| | 33| 0.153491| 0.001184| 0.017043| 0.013195| 0.159377| 1| | 34| 0.153460| 0.000988| 0.017456| 0.017540| 0.159446| 2| | 35| 0.153420| 0.002433| 0.032119| 0.030212| 0.159463| 3| | 36| 0.153329| 0.003517| 0.058506| 0.019751| 0.159478| 4| | 37| 0.153181| 0.002436| 0.116169| 0.015466| 0.159453| 5| | 38| 0.153025| 0.001577| 0.177446| 0.024381| 0.159377| 6| |==========================================================================================|
Use the information inside the TrainingHistory
property of the object Mdl
to check the iteration that corresponds to the minimum validation cross-entropy loss. The final returned model Mdl
is the model trained at this iteration.
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 32
Evaluate Test Set Performance
Evaluate the performance of the trained classifier Mdl
on the test set adulttest
by using the predict
, loss
, margin
, and edge
object functions.
Find the predicted labels and classification scores for the observations in the test set.
[labels,Scores] = predict(Mdl,adulttest);
Create a confusion matrix from the test set results. The diagonal elements indicate the number of correctly classified instances of a given class. The off-diagonal elements are instances of misclassified observations.
confusionchart(adulttest.salary,labels)
Compute the test set classification accuracy.
error = loss(Mdl,adulttest,"salary");
accuracy = (1-error)*100
accuracy = 85.0172
The neural network classifier correctly classifies approximately 85% of the test set observations.
Compute the test set classification margins for the trained neural network. Display a histogram of the margins.
The classification margins are the difference between the classification score for the true class and the classification score for the false class. Because neural network classifiers return scores that are posterior probabilities, classification margins close to 1 indicate confident classifications and negative margin values indicate misclassifications.
m = margin(Mdl,adulttest,"salary");
histogram(m)
Use the classification edge, or mean of the classification margins, to assess the overall performance of the classifier.
meanMargin = edge(Mdl,adulttest,"salary")
meanMargin = 0.5943
Alternatively, compute the weighted classification edge by using observation weights.
weightedMeanMargin = edge(Mdl,adulttest,"salary", ... "Weight","fnlwgt")
weightedMeanMargin = 0.6045
Visualize the predicted labels and classification scores using scatter plots, in which each point corresponds to an observation. Use the predicted labels to set the color of the points, and use the maximum scores to set the transparency of the points. Points with less transparency are labeled with greater confidence.
First, find the maximum classification score for each test set observation.
maxScores = max(Scores,[],2);
Create a scatter plot comparing maximum scores across the number of work hours per week and level of education. Because the education variable is categorical, randomly jitter (or space out) the points along the y-dimension.
Change the colormap so that maximum scores corresponding to salaries that are less than or equal to $50,000 per year appear as blue, and maximum scores corresponding to salaries greater than $50,000 per year appear as red.
scatter(adulttest.hours_per_week,adulttest.education,[],labels, ... "filled","MarkerFaceAlpha","flat","AlphaData",maxScores, ... "YJitter","rand"); xlabel("Number of Work Hours Per Week") ylabel("Education") Mdl.ClassNames
ans = 2x1 categorical
<=50K
>50K
colors = lines(2)
colors = 2×3
0 0.4470 0.7410
0.8500 0.3250 0.0980
colormap(colors);
The colors in the scatter plot indicate that, in general, the neural network predicts that people with lower levels of education (12th grade or below) have salaries less than or equal to $50,000 per year. The transparency of some of the points in the lower right of the plot indicates that the model is less confident in this prediction for people who work many hours per week (60 hours or more).
See Also
fitcnet
| margin
| edge
| loss
| predict
| ClassificationNeuralNetwork
| confusionchart
| scatter