Main Content

predict

Predict responses using regression neural network

Since R2021a

    Description

    yfit = predict(Mdl,X) returns predicted response values for the predictor data in the table or matrix X using the trained regression neural network model Mdl.

    example

    yfit = predict(Mdl,X,Name=Value) specifies additional options using one or more name-value arguments. For example, specify that columns in the predictor data correspond to observations.

    example

    Examples

    collapse all

    Predict test set response values by using a trained regression neural network model.

    Load the patients data set. Create a table from the data set. Each row corresponds to one patient, and each column corresponds to a diagnostic variable. Use the Systolic variable as the response variable, and the rest of the variables as predictors.

    load patients
    tbl = table(Diastolic,Height,Smoker,Weight,Systolic);

    Separate the data into a training set tblTrain and a test set tblTest by using a nonstratified holdout partition. The software reserves approximately 30% of the observations for the test data set and uses the rest of the observations for the training data set.

    rng("default") % For reproducibility of the partition
    c = cvpartition(size(tbl,1),"Holdout",0.30);
    trainingIndices = training(c);
    testIndices = test(c);
    tblTrain = tbl(trainingIndices,:);
    tblTest = tbl(testIndices,:);

    Train a regression neural network model using the training set. Specify the Systolic column of tblTrain as the response variable. Specify to standardize the numeric predictors, and set the iteration limit to 50. By default, the neural network model has one fully connected layer with 10 outputs, excluding the final fully connected layer.

    Mdl = fitrnet(tblTrain,"Systolic", ...
        "Standardize",true,"IterationLimit",50);

    Predict the systolic blood pressure levels for patients in the test set.

    predictedY = predict(Mdl,tblTest);

    Visualize the results by using a scatter plot with a reference line. Plot the predicted values along the vertical axis and the true response values along the horizontal axis. Points on the reference line indicate correct predictions.

    plot(tblTest.Systolic,predictedY,".")
    hold on
    plot(tblTest.Systolic,tblTest.Systolic)
    hold off
    xlabel("True Systolic Blood Pressure Levels")
    ylabel("Predicted Systolic Blood Pressure Levels")

    Figure contains an axes object. The axes object contains 2 objects of type line.

    Because many of the points are far from the reference line, the default neural network model with a fully connected layer of size 10 does not seem to be a great predictor of systolic blood pressure levels.

    Perform feature selection by comparing test set losses and predictions. Compare the test set metrics for a regression neural network model trained using all the predictors to the test set metrics for a model trained using only a subset of the predictors.

    Load the sample file fisheriris.csv, which contains iris data including sepal length, sepal width, petal length, petal width, and species type. Read the file into a table.

    fishertable = readtable('fisheriris.csv');

    Separate the data into a training set trainTbl and a test set testTbl by using a nonstratified holdout partition. The software reserves approximately 30% of the observations for the test data set and uses the rest of the observations for the training data set.

    rng("default")
    c = cvpartition(size(fishertable,1),"Holdout",0.3);
    trainTbl = fishertable(training(c),:);
    testTbl = fishertable(test(c),:);

    Train one regression neural network model using all the predictors in the training set, and train another model using all the predictors except PetalWidth. For both models, specify PetalLength as the response variable, and standardize the predictors.

    allMdl = fitrnet(trainTbl,"PetalLength","Standardize",true);
    subsetMdl = fitrnet(trainTbl,"PetalLength ~ SepalLength + SepalWidth + Species", ...
        "Standardize",true);

    Compare the test set mean squared error (MSE) of the two models. Smaller MSE values indicate better performance.

    allMSE = loss(allMdl,testTbl)
    allMSE = 
    0.0834
    
    subsetMSE = loss(subsetMdl,testTbl)
    subsetMSE = 
    0.0884
    

    For each model, compare the test set predicted petal lengths to the true petal lengths. Plot the predicted petal lengths along the vertical axis and the true petal lengths along the horizontal axis. Points on the reference line indicate correct predictions.

    tiledlayout(2,1)
    
    % Top axes
    ax1 = nexttile;
    allPredictedY = predict(allMdl,testTbl);
    plot(ax1,testTbl.PetalLength,allPredictedY,".")
    hold on
    plot(ax1,testTbl.PetalLength,testTbl.PetalLength)
    hold off
    xlabel(ax1,"True Petal Length")
    ylabel(ax1,"Predicted Petal Length")
    title(ax1,"All Predictors")
    
    % Bottom axes
    ax2 = nexttile;
    subsetPredictedY = predict(subsetMdl,testTbl);
    plot(ax2,testTbl.PetalLength,subsetPredictedY,".")
    hold on
    plot(ax2,testTbl.PetalLength,testTbl.PetalLength)
    hold off
    xlabel(ax2,"True Petal Length")
    ylabel(ax2,"Predicted Petal Length")
    title(ax2,"Subset of Predictors")

    Figure contains 2 axes objects. Axes object 1 with title All Predictors, xlabel True Petal Length, ylabel Predicted Petal Length contains 2 objects of type line. One or more of the lines displays its values using only markers Axes object 2 with title Subset of Predictors, xlabel True Petal Length, ylabel Predicted Petal Length contains 2 objects of type line. One or more of the lines displays its values using only markers

    Because both models seems to perform well, with predictions scattered near the reference line, consider using the model trained using all predictors except PetalWidth.

    See how the layers of a regression neural network model work together to predict the response value for a single observation.

    Load the sample file fisheriris.csv, which contains iris data including sepal length, sepal width, petal length, petal width, and species type. Read the file into a table, and display the first few rows of the table.

    fishertable = readtable('fisheriris.csv');
    head(fishertable)
        SepalLength    SepalWidth    PetalLength    PetalWidth     Species  
        ___________    __________    ___________    __________    __________
    
            5.1           3.5            1.4           0.2        {'setosa'}
            4.9             3            1.4           0.2        {'setosa'}
            4.7           3.2            1.3           0.2        {'setosa'}
            4.6           3.1            1.5           0.2        {'setosa'}
              5           3.6            1.4           0.2        {'setosa'}
            5.4           3.9            1.7           0.4        {'setosa'}
            4.6           3.4            1.4           0.3        {'setosa'}
              5           3.4            1.5           0.2        {'setosa'}
    

    Train a regression neural network model using the data set. Specify the PetalLength variable as the response and use the other numeric variables as predictors.

    Mdl = fitrnet(fishertable,"PetalLength ~ SepalLength + SepalWidth + PetalWidth");

    Select the fifteenth observation from the data set. See how the layers of the neural network take the observation and return a predicted response value newPointResponse.

    newPoint = Mdl.X{15,:}
    newPoint = 1×3
    
        5.8000    4.0000    0.2000
    
    
    firstFCStep = (Mdl.LayerWeights{1})*newPoint' + Mdl.LayerBiases{1};
    reluStep = max(firstFCStep,0);
    
    finalFCStep = (Mdl.LayerWeights{end})*reluStep + Mdl.LayerBiases{end};
    
    newPointResponse = finalFCStep
    newPointResponse = 
    1.6716
    

    Check that the prediction matches the one returned by the predict object function.

    predictedY = predict(Mdl,newPoint)
    predictedY = 
    1.6716
    
    isequal(newPointResponse,predictedY)
    ans = logical
       1
    
    

    The two results match.

    Since R2024b

    Create a regression neural network with more than one response variable.

    Load the carbig data set, which contains measurements of cars made in the 1970s and early 1980s. Create a table containing the predictor variables Displacement, Horsepower, and so on, as well as the response variables Acceleration and MPG. Display the first eight rows of the table.

    load carbig
    cars = table(Displacement,Horsepower,Model_Year, ...
        Origin,Weight,Acceleration,MPG);
    head(cars)
        Displacement    Horsepower    Model_Year    Origin     Weight    Acceleration    MPG
        ____________    __________    __________    _______    ______    ____________    ___
    
            307            130            70        USA         3504           12        18 
            350            165            70        USA         3693         11.5        15 
            318            150            70        USA         3436           11        18 
            304            150            70        USA         3433           12        16 
            302            140            70        USA         3449         10.5        17 
            429            198            70        USA         4341           10        15 
            454            220            70        USA         4354            9        14 
            440            215            70        USA         4312          8.5        14 
    

    Remove rows of cars where the table has missing values.

    cars = rmmissing(cars);

    Categorize the cars based on whether they were made in the USA.

    cars.Origin = categorical(cellstr(cars.Origin));
    cars.Origin = mergecats(cars.Origin,["France","Japan",...
        "Germany","Sweden","Italy","England"],"NotUSA");

    Partition the data into training and test sets. Use approximately 85% of the observations to train a neural network model, and 15% of the observations to test the performance of the trained model on new data. Use cvpartition to partition the data.

    rng("default") % For reproducibility
    c = cvpartition(height(cars),"Holdout",0.15);
    carsTrain = cars(training(c),:);
    carsTest = cars(test(c),:);

    Train a multiresponse neural network regression model by passing the carsTrain training data to the fitrnet function. For better results, specify to standardize the predictor data.

    Mdl = fitrnet(carsTrain,["Acceleration","MPG"], ...
        Standardize=true)
    Mdl = 
      RegressionNeuralNetwork
               PredictorNames: {'Displacement'  'Horsepower'  'Model_Year'  'Origin'  'Weight'}
                 ResponseName: {'Acceleration'  'MPG'}
        CategoricalPredictors: 4
            ResponseTransform: 'none'
              NumObservations: 334
                   LayerSizes: 10
                  Activations: 'relu'
        OutputLayerActivation: 'none'
                       Solver: 'LBFGS'
              ConvergenceInfo: [1x1 struct]
              TrainingHistory: [1000x7 table]
    
    
    

    Mdl is a trained RegressionNeuralNetwork model. You can use dot notation to access the properties of Mdl. For example, you can specify Mdl.ConvergenceInfo to get more information about the model convergence.

    Evaluate the performance of the regression model on the test set by computing the test mean squared error (MSE). Smaller MSE values indicate better performance. Return the loss for each response variable separately by setting the OutputType name-value argument to "per-response".

    testMSE = loss(Mdl,carsTest,["Acceleration","MPG"], ...
        OutputType="per-response")
    testMSE = 1×2
    
        1.5341    4.8245
    
    

    Predict the response values for the observations in the test set. Return the predicted response values as a table.

    predictedY = predict(Mdl,carsTest,OutputType="table")
    predictedY=58×2 table
        Acceleration     MPG  
        ____________    ______
    
           9.3612       13.567
           15.655       21.406
           17.921       17.851
           11.139       13.433
           12.696        10.32
           16.498       17.977
           16.227       22.016
           12.165       12.926
           12.691       12.072
           12.424       14.481
           16.974       22.152
           15.504       24.955
           11.068       13.874
           11.978       12.664
           14.926       10.134
           15.638       24.839
          ⋮
    
    

    Input Arguments

    collapse all

    Trained regression neural network, specified as a RegressionNeuralNetwork model object or CompactRegressionNeuralNetwork model object returned by fitrnet or compact, respectively.

    Predictor data used to generate responses, specified as a numeric matrix or table.

    By default, each row of X corresponds to one observation, and each column corresponds to one variable.

    • For a numeric matrix:

      • The variables in the columns of X must have the same order as the predictor variables that trained Mdl.

      • If you train Mdl using a table (for example, Tbl) and Tbl contains only numeric predictor variables, then X can be a numeric matrix. To treat numeric predictors in Tbl as categorical during training, identify categorical predictors by using the CategoricalPredictors name-value argument of fitrnet. If Tbl contains heterogeneous predictor variables (for example, numeric and categorical data types) and X is a numeric matrix, then predict throws an error.

    • For a table:

      • predict does not support multicolumn variables or cell arrays other than cell arrays of character vectors.

      • If you train Mdl using a table (for example, Tbl), then all predictor variables in X must have the same variable names and data types as the variables that trained Mdl (stored in Mdl.PredictorNames). However, the column order of X does not need to correspond to the column order of Tbl. Also, Tbl and X can contain additional variables (response variables, observation weights, and so on), but predict ignores them.

      • If you train Mdl using a numeric matrix, then the predictor names in Mdl.PredictorNames must be the same as the corresponding predictor variable names in X. To specify predictor names during training, use the PredictorNames name-value argument of fitrnet. All predictor variables in X must be numeric vectors. X can contain additional variables (response variables, observation weights, and so on), but predict ignores them.

    If you set Standardize=true in fitrnet when training Mdl, then the software standardizes the numeric columns of the predictor data using the corresponding means (Mdl.Mu) and standard deviations (Mdl.Sigma).

    Note

    If you orient your predictor matrix so that observations correspond to columns and specify ObservationsIn="columns", then you might experience a significant reduction in computation time. You cannot specify ObservationsIn="columns" for predictor data in a table or for multiresponse regression.

    Data Types: single | double | table

    Name-Value Arguments

    Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

    Example: predict(Mdl,X,ObservationsIn="columns") indicates that columns in the predictor data correspond to observations.

    Predictor data observation dimension, specified as "rows" or "columns".

    Note

    If you orient your predictor matrix so that observations correspond to columns and specify ObservationsIn="columns", then you might experience a significant reduction in computation time. You cannot specify ObservationsIn="columns" for predictor data in a table or for multiresponse regression.

    Data Types: char | string

    Since R2024b

    Output type for the predicted responses yfit, specified as "matrix" or "table". This argument is valid only when Mdl is a model with multiple response variables.

    Example: OutputType="table"

    Data Types: char | string

    Since R2023b

    Predicted response value to use for observations with missing predictor values, specified as "median", "mean", or a numeric scalar.

    ValueDescription
    "median"predict uses the median of the observed response values in the training data as the predicted response value for observations with missing predictor values.
    "mean"predict uses the mean of the observed response values in the training data as the predicted response value for observations with missing predictor values.
    Numeric scalarpredict uses this value as the predicted response value for observations with missing predictor values.

    Example: PredictionForMissingValue="mean"

    Example: PredictionForMissingValue=NaN

    Data Types: single | double | char | string

    Output Arguments

    collapse all

    Predicted responses, returned as a numeric vector, matrix, or table.

    • If yfit is a numeric vector, then entry i in yfit corresponds to observation i in X.

    • If yfit is a numeric matrix or table, then row i in yfit corresponds to observation i in X.

    Alternative Functionality

    Simulink Block

    To integrate the prediction of a neural network regression model into Simulink®, you can use the RegressionNeuralNetwork Predict block in the Statistics and Machine Learning Toolbox™ library or a MATLAB® Function block with the predict function. For examples, see Predict Responses Using RegressionNeuralNetwork Predict Block and Predict Class Labels Using MATLAB Function Block.

    When deciding which approach to use, consider the following:

    • If you use the Statistics and Machine Learning Toolbox library block, you can use the Fixed-Point Tool (Fixed-Point Designer) to convert a floating-point model to fixed point.

    • Support for variable-size arrays must be enabled for a MATLAB Function block with the predict function.

    • If you use a MATLAB Function block, you can use MATLAB functions for preprocessing or post-processing before or after predictions in the same MATLAB Function block.

    Extended Capabilities

    Version History

    Introduced in R2021a

    expand all

    Go to top of page