Main Content

Train Model Using Custom Backward Function

This example shows how to train a deep learning model that contains an operation with a custom backward function.

When you define a custom loss function, custom layer forward function, or define a deep learning model as a function, if the software does not provide the deep learning operation that you require for your task, then you can define your own function using dlarray objects.

Most deep learning workflows use gradients to train the model. If the function only uses functions that support dlarray objects, then you can use the functions directly and the software determines the gradients automatically using automatic differentiation. For example, you can pass dlarray object functions like crossentropy to as a loss function to the trainnet function, or use dlarray object functions like dlconv in custom layer functions. For a list of functions that support dlarray objects, see List of Functions with dlarray Support.

If you want to use functions that do not support dlarray objects, or want to use a specific algorithm to compute the gradients, then you can define a custom deep learning operation as a differentiable function object. This example trains a simple classification neural network, defined as a function, which uses a custom SReLU [1] operation with a custom backward function.

For an example showing how to create the custom function, see Specify Custom Operation Backward Function.

Load Training Data

Load the digits data. The training set contains 5000 images of handwritten digits and their corresponding digit labels and angles of rotation.

load digitsDataTrain

View the class names of the data set.

classNames = categories(labelsTrain)
classNames = 10×1 cell

Model Parameters

Define the parameters for each of the operations and include them in a struct. Use the format parameters.OperationName.ParameterName where parameters is the structure, OperationName is the name of the operation (for example "conv") and ParameterName is the name of the parameter (for example, "Weights").

Create an empty structure for the learnable parameters.

parameters = struct;

Initialize the learnable weights and biases using the example functions like initializeGlorot and initializeHe. To access these functions, open the example as a live script.

Initialize the weights and biases for the convolution operation "conv" using initializeGlorot, and initializeZeros, respectively.

filterSize = [5 5];
numFilters = 20;

numChannels = size(XTrain,3);

numOut = numFilters*prod(filterSize);
numIn = numChannels*prod(filterSize);
sz = [filterSize(1) filterSize(2) numChannels numFilters];
parameters.conv.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv.Bias = initializeZeros([numFilters 1]);

Initialize the offset and scale for the layer normalization operation "layernorm" using initializeZeros, and initializeOnes, respectively.

parameters.layernorm.Offset = initializeZeros([numFilters 1]);
parameters.layernorm.Scale = initializeOnes([numFilters 1]);

Initialize the parameters for the SReLU operation "srelu" using initializeHe.

numIn = numFilters;
sz = [1 1 numIn];
parameters.srelu.LeftThreshold = initializeHe(sz,numIn);
parameters.srelu.LeftSlope = initializeHe(sz,numIn);
parameters.srelu.RightThreshold = initializeHe(sz,numIn);
parameters.srelu.RightSlope = initializeHe(sz,numIn);

Initialize the weights and biases for the fully connect operation using initializeGlorot, and initializeZeros, respectively.

numClasses = numel(classNames);
numOut = numClasses;
numIn = 15680;
sz = [numOut numIn];
parameters.fc.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc.Bias = initializeZeros([numOut 1]);

Create Custom SReLU Function

Create a custom function that applies the SReLU operation. To specify a custom backward function, create a sreluFunction object, using the class definition attached to this example as a supporting file. To access this file, open this example as a live script. Specify the data format using the first argument of sreluFunction.

function Y = srelu(X,tl,al,tr,ar)

format = dims(X);

fcn = sreluFunction(format);
Y = fcn(X,tl,al,tr,ar);

Y = dlarray(Y,format);


Define Model Function

Create the function model that takes the learnable parameters and input data as input and returns the model output. The model applies the convolution, layer normalization, SReLU, fully connect, and softmax operations to the input data.

function Y = model(parameters,X)

weights = parameters.conv.Weights;
bias = parameters.conv.Bias;
Y = dlconv(X,weights,bias,Padding="same");

offset = parameters.layernorm.Offset;
scaleFactor = parameters.layernorm.Scale;
Y = layernorm(Y,offset,scaleFactor);

tl = parameters.srelu.LeftThreshold;
al = parameters.srelu.LeftSlope;
tr = parameters.srelu.RightThreshold;
ar = parameters.srelu.RightSlope;

Y = srelu(Y,tl,al,tr,ar);

weights = parameters.fc.Weights;
bias = parameters.fc.Bias;
Y = fullyconnect(Y,weights,bias);

Y = softmax(Y);


Define Model Loss Function

Create the function modelLoss that takes the model parameters, a mini-batch of input data X with corresponding targets T, and returns the loss and the gradients of the loss with respect to the learnable parameters.

function [loss,gradients] = modelLoss(parameters,X,T)

Y = model(parameters,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss,parameters);


Specify Training Options

Specify the training options. Train for 20 epochs with a mini-batch size of 128.

numEpochs = 20;
miniBatchSize = 128;

Train Model

Use a minibatchqueue object to process and manage the mini-batches of images. For each mini-batch:

  • Use the custom mini-batch preprocessing function preprocessMiniBatch (defined at the end of this example) to one-hot encode the class labels.

  • Format the image data with the dimension labels "SSCB" (spatial, spatial, channel, batch). By default, the minibatchqueue object converts the data to dlarray objects with underlying type single. Do not add a format to the class labels or angles.

  • Train on a GPU if one is available. By default, the minibatchqueue object converts each output to a gpuArray if a GPU is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

adsXTrain = arrayDatastore(XTrain,IterationDimension=4);
adsTTrain = arrayDatastore(labelsTrain);
cdsTrain = combine(adsXTrain,adsTTrain);

mbq = minibatchqueue(cdsTrain,...
    MiniBatchFormat=["SSCB" ""]);

Create a mini-batch preprocessing function that concatenates the input data and one-hot encodes the targets.

function [X,T] = preprocessMiniBatch(dataX,dataT)

X = cat(4,dataX{:});
T = cat(2,dataT{:});
T = onehotencode(T,1);


Train using the Adam solver. Initialize the training parameters for Adam.

trailingAvg = [];
trailingAvgSq = [];

To monitor training, create a training progress monitor. To update the progress bar of the training progress monitor, calculate the total number of training iterations.

numObservations = size(XTrain,4);
numIterationsPerEpoch = ceil(numObservations/miniBatchSize);
numIterations = numIterationsPerEpoch * numEpochs;

monitor = trainingProgressMonitor( ...
    Metrics="Loss", ...
    Info="Epoch", ...

Train the model using a custom training loop.

epoch = 0;
iteration = 0;
while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;


    while hasdata(mbq) && ~monitor.Stop
        iteration = iteration + 1;

        [X,T] = next(mbq);
        [loss,gradients] = dlfeval(@modelLoss,parameters,X,T);
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
        updateInfo(monitor,Epoch=(epoch+" of "+numEpochs));
        monitor.Progress = 100 * iteration/numIterations;

Test Model

Test the model by evaluating the classification accuracy on the test data set.

load digitsDataTest
XTest = dlarray(XTest,"SSCB");
scoresTest = model(parameters,XTest);
YTest = scores2label(scoresTest,classNames);
acc = mean(labelsTest==YTest')
acc = 0.9882


  1. Hu, Xiaobin, Peifeng Niu, Jianmei Wang, and Xinxin Zhang. “A Dynamic Rectified Linear Activation Units.” IEEE Access 7 (2019): 180409–16.

See Also

| | | | | |

Related Topics