Applying MATLAB Functions to Check for Overfitting
This chapter covers the following functions used in regularization and cross-validation:
Regularization techniques are used to prevent statistical overfitting in a predictive model. By introducing additional information into the model, regularization algorithms can handle multicollinearity and redundant predictors by making the model more parsimonious and accurate.
B = lasso(X,y)
These algorithms typically work by applying a penalty for complexity, such as adding the coefficients of the model into the minimization or including a roughness penalty.
Regularization for logistic regression can be performed simply in Statistics and Machine Learning Toolbox™ by using the
lassoglm is a model-fitting function that updates the weight and bias values according to coordinate descent optimization. It minimizes a combination of squared errors and parameters and then determines the correct combination to produce a model that generalizes well.
Additional classification models in Statistics and Machine Learning Toolbox provide optional regularization arguments that can be used to regularize a model during the training process:
fitclinear(binary linear classifier)
- “Lambda”: regularization term strength
- “Regularization”: lasso (L1), ridge (L2)
fitckernel(binary Gaussian kernel classifier)
- “Regularization”: ridge (L2)
MATLAB Cross-Validation Functions
Statistics and Machine Learning Toolbox has two functions that are particularly useful when performing cross-validation:
You use the
cvpartition function to create a cross-validation partition for data. Using a k-fold cross-validation example, the syntax would be:
c = cvpartition(n,'KFold',k) constructs an object
c of the
cvpartition class defining a random partition for k-fold cross-validation on n observations.
c = cvpartition(group,'KFold',k) creates a random partition for a stratified k-fold cross-validation.
crossval to perform a loss estimate using cross-validation. An example of the syntax for this is:
vals = crossval(fun,X) performs 10-fold cross-validation for the function
fun, applied to the data in
fun is a function handle to a function with two inputs, the training subset of
X, XTRAIN, and the test subset of
X, XTEST, as follows:
testval = fun(XTRAIN,XTEST)