# kfoldPredict

Predict labels for observations not used for training

## Description

returns
cross-validated class labels predicted by the cross-validated, binary,
linear classification model `Label`

= kfoldPredict(`CVMdl`

)`CVMdl`

. That is, for
every fold, `kfoldPredict`

predicts class labels for
observations that it holds out when it trains using all other observations.

`Label`

contains predicted class labels for
each regularization strength in the linear classification models that
compose `CVMdl`

.

`[`

also returns
cross-validated classification
scores for both classes. `Label`

,`Score`

]
= kfoldPredict(`CVMdl`

)`Score`

contains
classification scores for each regularization strength in `CVMdl`

.

## Input Arguments

`CVMdl`

— Cross-validated, binary, linear classification model

`ClassificationPartitionedLinear`

model object

Cross-validated, binary, linear classification model, specified as a `ClassificationPartitionedLinear`

model object. You can create a
`ClassificationPartitionedLinear`

model using `fitclinear`

and specifying any one of the cross-validation, name-value
pair arguments, for example, `CrossVal`

.

To obtain estimates, kfoldPredict applies the same data used to cross-validate the linear
classification model (`X`

and `Y`

).

## Output Arguments

`Label`

— Cross-validated, predicted class labels

categorical array | character array | logical matrix | numeric matrix | cell array of character vectors

Cross-validated, predicted class labels, returned as a categorical or character array, logical or numeric matrix, or cell array of character vectors.

In most cases, `Label`

is an *n*-by-*L*
array of the same data type as the observed class labels (see `Y`

) used to create
`CVMdl`

. (The software treats string arrays as cell arrays of character
vectors.)
*n* is the number of observations in the predictor data
(see `X`

) and
*L* is the number of regularization strengths in
`CVMdl.Trained{1}.Lambda`

. That is,
`Label(`

is the predicted class label for observation * i*,

*)*

`j`

*using the linear classification model that has regularization strength*

`i`

`CVMdl.Trained{1}.Lambda(``j`

)

.If `Y`

is a character array and *L* >
1, then `Label`

is a cell array of class labels.

`Score`

— Cross-validated classification scores

numeric array

Cross-validated classification
scores, returned as an
*n*-by-2-by-*L* numeric array.
*n* is the number of observations in the predictor data
that created `CVMdl`

(see `X`

) and
*L* is the number of regularization strengths in
`CVMdl.Trained{1}.Lambda`

.
`Score(`

is the score for classifying observation * i*,

*,*

`k`

*)*

`j`

*into class*

`i`

*using the linear classification model that has regularization strength*

`k`

`CVMdl.Trained{1}.Lambda(``j`

)

.
`CVMdl.ClassNames`

stores the order of the
classes.If `CVMdl.Trained{1}.Learner`

is `'logistic'`

,
then classification scores are posterior probabilities.

## Examples

### Predict *k*-fold Cross-Validation Labels

Load the NLP data set.

`load nlpdata`

`X`

is a sparse matrix of predictor data, and `Y`

is a categorical vector of class labels. There are more than two classes in the data.

The models should identify whether the word counts in a web page are from the Statistics and Machine Learning Toolbox™ documentation. So, identify the labels that correspond to the Statistics and Machine Learning Toolbox™ documentation web pages.

`Ystats = Y == 'stats';`

Cross-validate a binary, linear classification model using the entire data set, which can identify whether the word counts in a documentation web page are from the Statistics and Machine Learning Toolbox™ documentation.

rng(1); % For reproducibility CVMdl = fitclinear(X,Ystats,'CrossVal','on'); Mdl1 = CVMdl.Trained{1}

Mdl1 = ClassificationLinear ResponseName: 'Y' ClassNames: [0 1] ScoreTransform: 'none' Beta: [34023x1 double] Bias: -1.0008 Lambda: 3.5193e-05 Learner: 'svm'

`CVMdl`

is a `ClassificationPartitionedLinear`

model. By default, the software implements 10-fold cross validation. You can alter the number of folds using the `'KFold'`

name-value pair argument.

Predict labels for the observations that `fitclinear`

did not use in training the folds.

label = kfoldPredict(CVMdl);

Because there is one regularization strength in `Mdl1`

, `label`

is a column vector of predictions containing as many rows as observations in `X`

.

Construct a confusion matrix.

ConfusionTrain = confusionchart(Ystats,label);

The model misclassifies 15 `'stats'`

documentation pages as being outside of the Statistics and Machine Learning Toolbox documentation, and misclassifies nine pages as `'stats'`

pages.

### Estimate *k*-fold Cross-Validation Posterior Class Probabilities

Linear classification models return posterior probabilities for logistic regression learners only.

Load the NLP data set and preprocess it as in Predict k-fold Cross-Validation Labels. Transpose the predictor data matrix.

load nlpdata Ystats = Y == 'stats'; X = X';

Cross-validate binary, linear classification models using 5-fold cross-validation. Optimize the objective function using SpaRSA. Lower the tolerance on the gradient of the objective function to `1e-8`

.

rng(10); % For reproducibility CVMdl = fitclinear(X,Ystats,'ObservationsIn','columns',... 'KFold',5,'Learner','logistic','Solver','sparsa',... 'Regularization','lasso','GradientTolerance',1e-8);

Predict the posterior class probabilities for observations not used to train each fold.

[~,posterior] = kfoldPredict(CVMdl); CVMdl.ClassNames

`ans = `*2x1 logical array*
0
1

Because there is one regularization strength in `CVMdl`

, `posterior`

is a matrix with 2 columns and rows equal to the number of observations. Column *i* contains posterior probabilities of `Mdl.ClassNames(i)`

given a particular observation.

Compute the performance metrics (true positive rates and false positive rates) for a ROC curve and find the area under the ROC curve (AUC) value by creating a `rocmetrics`

object.

rocObj = rocmetrics(Ystats,posterior,CVMdl.ClassNames);

Plot the ROC curve for the second class by using the `plot`

function of `rocmetrics`

.

plot(rocObj,ClassNames=CVMdl.ClassNames(2))

The ROC curve indicates that the model classifies the validation observations almost perfectly.

### Find Good Lasso Penalty Using Cross-Validated AUC

To determine a good lasso-penalty strength for a linear classification model that uses a logistic regression learner, compare cross-validated AUC values.

Load the NLP data set. Preprocess the data as in Estimate k-fold Cross-Validation Posterior Class Probabilities.

load nlpdata Ystats = Y == 'stats'; X = X';

There are 9471 observations in the test sample.

Create a set of 11 logarithmically-spaced regularization strengths from $$1{0}^{-6}$$ through $$1{0}^{-0.5}$$.

Lambda = logspace(-6,-0.5,11);

Cross-validate a binary, linear classification models that use each of the regularization strengths and 5-fold cross-validation. Optimize the objective function using SpaRSA. Lower the tolerance on the gradient of the objective function to `1e-8`

.

rng(10) % For reproducibility CVMdl = fitclinear(X,Ystats,'ObservationsIn','columns', ... 'KFold',5,'Learner','logistic','Solver','sparsa', ... 'Regularization','lasso','Lambda',Lambda,'GradientTolerance',1e-8)

CVMdl = ClassificationPartitionedLinear CrossValidatedModel: 'Linear' ResponseName: 'Y' NumObservations: 31572 KFold: 5 Partition: [1x1 cvpartition] ClassNames: [0 1] ScoreTransform: 'none'

Mdl1 = CVMdl.Trained{1}

Mdl1 = ClassificationLinear ResponseName: 'Y' ClassNames: [0 1] ScoreTransform: 'logit' Beta: [34023x11 double] Bias: [-13.2936 -13.2936 -13.2936 -13.2936 -13.2936 -6.8954 -5.4359 -4.7170 -3.4108 -3.1566 -2.9792] Lambda: [1.0000e-06 3.5481e-06 1.2589e-05 4.4668e-05 1.5849e-04 5.6234e-04 0.0020 0.0071 0.0251 0.0891 0.3162] Learner: 'logistic'

`Mdl1`

is a `ClassificationLinear`

model object. Because `Lambda`

is a sequence of regularization strengths, you can think of `Mdl1`

as 11 models, one for each regularization strength in `Lambda`

.

Predict the cross-validated labels and posterior class probabilities.

[label,posterior] = kfoldPredict(CVMdl); CVMdl.ClassNames; [n,K,L] = size(posterior)

n = 31572

K = 2

L = 11

posterior(3,1,5)

ans = 1.0000

`label`

is a 31572-by-11 matrix of predicted labels. Each column corresponds to the predicted labels of the model trained using the corresponding regularization strength. `posterior`

is a 31572-by-2-by-11 matrix of posterior class probabilities. Columns correspond to classes and pages correspond to regularization strengths. For example, `posterior(3,1,5)`

indicates that the posterior probability that the first class (label `0`

) is assigned to observation 3 by the model that uses `Lambda(5)`

as a regularization strength is 1.0000.

For each model, compute the AUC by using `rocmetrics`

.

auc = 1:numel(Lambda); % Preallocation for j = 1:numel(Lambda) rocObj = rocmetrics(Ystats,posterior(:,:,j),CVMdl.ClassNames); auc(j) = rocObj.AUC(1); end

Higher values of `Lambda`

lead to predictor variable sparsity, which is a good quality of a classifier. For each regularization strength, train a linear classification model using the entire data set and the same options as when you trained the model. Determine the number of nonzero coefficients per model.

Mdl = fitclinear(X,Ystats,'ObservationsIn','columns', ... 'Learner','logistic','Solver','sparsa','Regularization','lasso', ... 'Lambda',Lambda,'GradientTolerance',1e-8); numNZCoeff = sum(Mdl.Beta~=0);

In the same figure, plot the test-sample error rates and frequency of nonzero coefficients for each regularization strength. Plot all variables on the log scale.

figure yyaxis left plot(log10(Lambda),log10(auc),'o-') ylabel('log_{10} AUC') yyaxis right plot(log10(Lambda),log10(numNZCoeff + 1),'o-') ylabel('log_{10} nonzero-coefficient frequency') xlabel('log_{10} Lambda') title('Cross-Validated Statistics') hold off

Choose the index of the regularization strength that balances predictor variable sparsity and high AUC. In this case, a value between $$1{0}^{-3}$$ to $$1{0}^{-1}$$ should suffice.

idxFinal = 9;

Select the model from `Mdl`

with the chosen regularization strength.

MdlFinal = selectModels(Mdl,idxFinal);

`MdlFinal`

is a `ClassificationLinear`

model containing one regularization strength. To estimate labels for new observations, pass `MdlFinal`

and the new data to `predict`

.

## More About

### Classification Score

For linear classification models, the raw *classification
score* for classifying the observation *x*, a row vector,
into the positive class is defined by

$${f}_{j}(x)=x{\beta}_{j}+{b}_{j}.$$

For the model with regularization strength *j*, $${\beta}_{j}$$ is the estimated column vector of coefficients (the model property
`Beta(:,j)`

) and $${b}_{j}$$ is the estimated, scalar bias (the model property
`Bias(j)`

).

The raw classification score for classifying *x* into
the negative class is –*f*(*x*).
The software classifies observations into the class that yields the
positive score.

If the linear classification model consists of logistic regression learners, then the
software applies the `'logit'`

score transformation to the raw
classification scores (see `ScoreTransform`

).

## Extended Capabilities

### GPU Arrays

Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.

This function fully supports GPU arrays. For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).

## Version History

**Introduced in R2016a**

### R2024a: Specify GPU arrays (requires Parallel Computing Toolbox)

`kfoldPredict`

fully supports GPU arrays.

### R2023b: Observations with missing predictor values are used in resubstitution and cross-validation computations

Starting in R2023b, the following classification model object functions use observations with missing predictor values as part of resubstitution ("resub") and cross-validation ("kfold") computations for classification edges, losses, margins, and predictions.

In previous releases, the software omitted observations with missing predictor values from the resubstitution and cross-validation computations.

## MATLAB Command

You clicked a link that corresponds to this MATLAB command:

Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.

Select a Web Site

Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .

You can also select a web site from the following list:

## How to Get Best Site Performance

Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.

### Americas

- América Latina (Español)
- Canada (English)
- United States (English)

### Europe

- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)

- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)