Choosing the Best Machine Learning Classification Model and Avoiding Overfitting

Chapter 4

Applying MATLAB Functions to Check for Overfitting

This chapter covers the following functions used in regularization and cross-validation:



Regularization techniques are used to prevent statistical overfitting in a predictive model. By introducing additional information into the model, regularization algorithms can handle multicollinearity and redundant predictors by making the model more parsimonious and accurate.


B = lasso(X,y)

These algorithms typically work by applying a penalty for complexity, such as adding the coefficients of the model into the minimization or including a roughness penalty.

Regularization for logistic regression can be performed simply in Statistics and Machine Learning Toolbox™ by using the lassoglm function. lassoglm is a model-fitting function that updates the weight and bias values according to coordinate descent optimization. It minimizes a combination of squared errors and parameters and then determines the correct combination to produce a model that generalizes well.

Additional classification models in Statistics and Machine Learning Toolbox provide optional regularization arguments that can be used to regularize a model during the training process:

  • fitclinear (binary linear classifier)
    • “Lambda”: regularization term strength
    • “Regularization”: lasso (L1), ridge (L2)
  • fitckernel (binary Gaussian kernel classifier)
    • “Regularization”: ridge (L2)

MATLAB Cross-Validation Functions

Statistics and Machine Learning Toolbox has two functions that are particularly useful when performing cross-validation: cvpartition and crossval.

You use the cvpartition function to create a cross-validation partition for data. Using a k-fold cross-validation example, the syntax would be:

c = cvpartition(n,'KFold',k) constructs an object c of the cvpartition class defining a random partition for k-fold cross-validation on n observations.

c = cvpartition(group,'KFold',k) creates a random partition for a stratified k-fold cross-validation.

Use crossval to perform a loss estimate using cross-validation. An example of the syntax for this is:

vals = crossval(fun,X) performs 10-fold cross-validation for the function fun, applied to the data in X.

fun is a function handle to a function with two inputs, the training subset of X, XTRAIN, and the test subset of X, XTEST, as follows:

testval = fun(XTRAIN,XTEST)