Main Content

Define Custom Training Loops, Loss Functions, and Networks

For most deep learning tasks, you can use a pretrained neural network and adapt it to your own data. For an example showing how to use transfer learning to retrain a convolutional neural network to classify a new set of images, see Train Deep Learning Network to Classify New Images. Alternatively, you can create and train neural networks from scratch using the trainnet, trainNetwork, and trainingOptions functions.

If the trainingOptions function does not provide the training options that you need for your task, then you can create a custom training loop using automatic differentiation. To learn more, see Define Deep Learning Network for Custom Training Loops.

If Deep Learning Toolbox™ does not provide the layers you need for your task (including output layers that specify loss functions), then you can create a custom layer. To learn more, see Define Custom Deep Learning Layers. For loss functions that cannot be specified using an output layer, you can specify the loss in a custom training loop. To learn more, see Specify Loss Functions. For networks that cannot be created using layer graphs, you can define custom networks as a function. To learn more, see Define Network as Model Function.

For more information about which training method to use for which task, see Train Deep Learning Model in MATLAB.

Define Deep Learning Network for Custom Training Loops

Define Network as dlnetwork Object

For most tasks, you can control the training algorithm details using the trainingOptions, trainnet, and trainNetwork functions. If the trainingOptions function does not provide the options you need for your task (for example, a custom learning rate schedule), then you can define your own custom training loop using a dlnetwork object. A dlnetwork object allows you to train a network specified as a layer graph using automatic differentiation.

For networks specified as a layer graph, you can create a dlnetwork object from the layer graph by using the dlnetwork function directly.

net = dlnetwork(lgraph);

For an example showing how to train a network with a custom learning rate schedule, see Train Network Using Custom Training Loop.

Define Network as Model Function

For architectures that cannot be created using layer graphs (for example, a twin neural network that requires shared weights), you can define the model as a function of the form [Y1,...,YM] = model(parameters,X1,...,XN), where parameters contains the network parameters, X1,...,XN corresponds to the input data for the N model inputs, and Y1,...,YM corresponds to the M model outputs. To train a deep learning model defined as a function, use a custom training loop. For an example, see Train Network Using Model Function.

When you define a deep learning model as a function, you must manually initialize the layer weights. For more information, see Initialize Learnable Parameters for Model Function.

If you define a custom network as a function, then the model function must support automatic differentiation. You can use the following deep learning operations. The functions listed here are only a subset. For a complete list of functions that support dlarray input, see List of Functions with dlarray Support.

FunctionDescription
attentionThe attention operation focuses on parts of the input using weighted multiplication operations.
avgpoolThe average pooling operation performs downsampling by dividing the input into pooling regions and computing the average value of each region.
batchnormThe batch normalization operation normalizes the input data across all observations for each channel independently. To speed up training of the convolutional neural network and reduce the sensitivity to network initialization, use batch normalization between convolution and nonlinear operations such as relu.
crossentropyThe cross-entropy operation computes the cross-entropy loss between network predictions and target values for single-label and multi-label classification tasks.
crosschannelnormThe cross-channel normalization operation uses local responses in different channels to normalize each activation. Cross-channel normalization typically follows a relu operation. Cross-channel normalization is also known as local response normalization.
ctcThe CTC operation computes the connectionist temporal classification (CTC) loss between unaligned sequences.
dlconvThe convolution operation applies sliding filters to the input data. Use the dlconv function for deep learning convolution, grouped convolution, and channel-wise separable convolution.
dlode45The neural ordinary differential equation (ODE) operation returns the solution of a specified ODE.
dltranspconvThe transposed convolution operation upsamples feature maps.
embedThe embed operation converts numeric indices to numeric vectors, where the indices correspond to discrete data. Use embeddings to map discrete data such as categorical values or words to numeric vectors.
fullyconnectThe fully connect operation multiplies the input by a weight matrix and then adds a bias vector.
geluThe Gaussian error linear unit (GELU) activation operation weights the input by its probability under a Gaussian distribution.
groupnormThe group normalization operation normalizes the input data across grouped subsets of channels for each observation independently. To speed up training of the convolutional neural network and reduce the sensitivity to network initialization, use group normalization between convolution and nonlinear operations such as relu.
gruThe gated recurrent unit (GRU) operation allows a network to learn dependencies between time steps in time series and sequence data.
huberThe Huber operation computes the Huber loss between network predictions and target values for regression tasks. When the 'TransitionPoint' option is 1, this is also known as smooth L1 loss.
instancenormThe instance normalization operation normalizes the input data across each channel for each observation independently. To improve the convergence of training the convolutional neural network and reduce the sensitivity to network hyperparameters, use instance normalization between convolution and nonlinear operations such as relu.
l1lossThe L1 loss operation computes the L1 loss given network predictions and target values. When the Reduction option is "sum" and the NormalizationFactor option is "batch-size", the computed value is known as the mean absolute error (MAE).
l2lossThe L2 loss operation computes the L2 loss (based on the squared L2 norm) given network predictions and target values. When the Reduction option is "sum" and the NormalizationFactor option is "batch-size", the computed value is known as the mean squared error (MSE).
layernormThe layer normalization operation normalizes the input data across all channels for each observation independently. To speed up training of recurrent and multilayer perceptron neural networks and reduce the sensitivity to network initialization, use layer normalization after the learnable operations, such as LSTM and fully connect operations.
leakyreluThe leaky rectified linear unit (ReLU) activation operation performs a nonlinear threshold operation, where any input value less than zero is multiplied by a fixed scale factor.
lstmThe long short-term memory (LSTM) operation allows a network to learn long-term dependencies between time steps in time series and sequence data.
maxpoolThe maximum pooling operation performs downsampling by dividing the input into pooling regions and computing the maximum value of each region.
maxunpoolThe maximum unpooling operation unpools the output of a maximum pooling operation by upsampling and padding with zeros.
mseThe half mean squared error operation computes the half mean squared error loss between network predictions and target values for regression tasks.
onehotdecode

The one-hot decode operation decodes probability vectors, such as the output of a classification network, into classification labels.

The input A can be a dlarray. If A is formatted, the function ignores the data format.

reluThe rectified linear unit (ReLU) activation operation performs a nonlinear threshold operation, where any input value less than zero is set to zero.
sigmoidThe sigmoid activation operation applies the sigmoid function to the input data.
softmaxThe softmax activation operation applies the softmax function to the channel dimension of the input data.

Specify Loss Functions

When you use a custom training loop, you must calculate the loss in the model gradients function. Use the loss value when computing gradients for updating the network weights. To compute the loss, you can use the following functions.

FunctionDescription
softmaxThe softmax activation operation applies the softmax function to the channel dimension of the input data.
sigmoidThe sigmoid activation operation applies the sigmoid function to the input data.
crossentropyThe cross-entropy operation computes the cross-entropy loss between network predictions and target values for single-label and multi-label classification tasks.
l1lossThe L1 loss operation computes the L1 loss given network predictions and target values. When the Reduction option is "sum" and the NormalizationFactor option is "batch-size", the computed value is known as the mean absolute error (MAE).
l2lossThe L2 loss operation computes the L2 loss (based on the squared L2 norm) given network predictions and target values. When the Reduction option is "sum" and the NormalizationFactor option is "batch-size", the computed value is known as the mean squared error (MSE).
huberThe Huber operation computes the Huber loss between network predictions and target values for regression tasks. When the 'TransitionPoint' option is 1, this is also known as smooth L1 loss.
mseThe half mean squared error operation computes the half mean squared error loss between network predictions and target values for regression tasks.
ctcThe CTC operation computes the connectionist temporal classification (CTC) loss between unaligned sequences.

Alternatively, you can use a custom loss function by creating a function of the form loss = myLoss(Y,T), where Y and T correspond to the network predictions and targets, respectively, and loss is the returned loss.

For an example showing how to train a generative adversarial network (GAN) that generates images using a custom loss function, see Train Generative Adversarial Network (GAN).

Update Learnable Parameters Using Automatic Differentiation

When you train a deep learning model with a custom training loop, the software minimizes the loss with respect to the learnable parameters. To minimize the loss, the software uses the gradients of the loss with respect to the learnable parameters. To calculate these gradients using automatic differentiation, you must define a model gradients function.

Define Model Loss Function

For a model specified as a dlnetwork object, create a function of the form [loss,gradients] = modelLoss(net,X,T), where net is the network, X is the network input, T contains the targets, and loss and gradients are the returned loss and gradients, respectively. Optionally, you can pass extra arguments to the gradients function (for example, if the loss function requires extra information), or return extra arguments (for example, the updated network state).

For a model specified as a function, create a function of the form [loss,gradients] = modelLoss(parameters,X,T), where parameters contains the learnable parameters, X is the model input, T contains the targets, and loss and gradients are the returned loss and gradients, respectively. Optionally, you can pass extra arguments to the gradients function (for example, if the loss function requires extra information), or return extra arguments (for example, the updated model state).

To learn more about defining model loss functions for custom training loops, see Define Model Loss Function for Custom Training Loop.

Update Learnable Parameters

To evaluate the model loss function using automatic differentiation, use the dlfeval function, which evaluates a function with automatic differentiation enabled. For the first input of dlfeval, pass the model loss function specified as a function handle. For the following inputs, pass the required variables for the model loss function. For the outputs of the dlfeval function, specify the same outputs as the model loss function.

To update the learnable parameters using the gradients, you can use the following functions.

FunctionDescription
adamupdateUpdate parameters using adaptive moment estimation (Adam)
rmspropupdateUpdate parameters using root mean squared propagation (RMSProp)
sgdmupdateUpdate parameters using stochastic gradient descent with momentum (SGDM)
lbfgsupdateUpdate parameters using limited-memory BFGS (L-BFGS)
dlupdateUpdate parameters using custom function

See Also

| | |

Related Topics