adamupdate
Syntax
Description
Update the network learnable parameters in a custom training loop using the adaptive moment estimation (Adam) algorithm.
Note
This function applies the Adam optimization algorithm to update network parameters in
custom training loops. To train a neural network using the trainnet
function
using the Adam solver, use the trainingOptions
function and set the solver to
"adam"
.
[
updates the learnable parameters of the network netUpdated
,averageGrad
,averageSqGrad
] = adamupdate(net
,grad
,averageGrad
,averageSqGrad
,iteration
)net
using the Adam
algorithm. Use this syntax in a training loop to iteratively update a network defined as a
dlnetwork
object.
[
updates the learnable parameters in params
,averageGrad
,averageSqGrad
] = adamupdate(params
,grad
,averageGrad
,averageSqGrad
,iteration
)params
using the Adam algorithm. Use
this syntax in a training loop to iteratively update the learnable parameters of a network
defined using functions.
[___] = adamupdate(___
also specifies values to use for the global learning rate, gradient decay, square gradient
decay, and small constant epsilon, in addition to the input arguments in previous syntaxes. learnRate
,gradDecay
,sqGradDecay
,epsilon
)
Examples
Update Learnable Parameters Using adamupdate
Perform a single adaptive moment estimation update step with a global learning rate of 0.05
, gradient decay factor of 0.75
, and squared gradient decay factor of 0.95
.
Create the parameters and parameter gradients as numeric arrays.
params = rand(3,3,4); grad = ones(3,3,4);
Initialize the iteration counter, average gradient, and average squared gradient for the first iteration.
iteration = 1; averageGrad = []; averageSqGrad = [];
Specify custom values for the global learning rate, gradient decay factor, and squared gradient decay factor.
learnRate = 0.05; gradDecay = 0.75; sqGradDecay = 0.95;
Update the learnable parameters using adamupdate
.
[params,averageGrad,averageSqGrad] = adamupdate(params,grad,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay);
Update the iteration counter.
iteration = iteration + 1;
Train Network Using adamupdate
Use adamupdate
to train a network using the Adam algorithm.
Load Training Data
Load the digits training data.
[XTrain,TTrain] = digitTrain4DArrayData; classes = categories(TTrain); numClasses = numel(classes);
Define Network
Define the network and specify the average image value using the Mean
option in the image input layer.
layers = [ imageInputLayer([28 28 1],'Mean',mean(XTrain,4)) convolution2dLayer(5,20) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer fullyConnectedLayer(numClasses) softmaxLayer];
Create a dlnetwork
object from the layer array.
net = dlnetwork(layers);
Define Model Loss Function
Create the helper function modelLoss
, listed at the end of the example. The function takes a dlnetwork
object and a mini-batch of input data with corresponding labels, and returns the loss and the gradients of the loss with respect to the learnable parameters.
Specify Training Options
Specify the options to use during training.
miniBatchSize = 128; numEpochs = 20; numObservations = numel(TTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Train Network
Initialize the average gradients and squared average gradients.
averageGrad = []; averageSqGrad = [];
Calculate the total number of iterations for the training progress monitor.
numIterations = numEpochs * numIterationsPerEpoch;
Initialize the TrainingProgressMonitor
object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.
monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");
Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters using the adamupdate
function. At the end of each iteration, display the training progress.
Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
iteration = 0; epoch = 0; while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Shuffle data. idx = randperm(numel(TTrain)); XTrain = XTrain(:,:,:,idx); TTrain = TTrain(idx); i = 0; while i < numIterationsPerEpoch && ~monitor.Stop i = i + 1; iteration = iteration + 1; % Read mini-batch of data and convert the labels to dummy % variables. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); T = zeros(numClasses, miniBatchSize,"single"); for c = 1:numClasses T(c,TTrain(idx)==classes(c)) = 1; end % Convert mini-batch of data to a dlarray. X = dlarray(single(X),"SSCB"); % If training on a GPU, then convert data to a gpuArray. if canUseGPU X = gpuArray(X); end % Evaluate the model loss and gradients using dlfeval and the % modelLoss function. [loss,gradients] = dlfeval(@modelLoss,net,X,T); % Update the network parameters using the Adam optimizer. [net,averageGrad,averageSqGrad] = adamupdate(net,gradients,averageGrad,averageSqGrad,iteration); % Update the training progress monitor. recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=epoch + " of " + numEpochs); monitor.Progress = 100 * iteration/numIterations; end end
Test Network
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.
[XTest,TTest] = digitTest4DArrayData;
Convert the data to a dlarray
with the dimension format "SSCB"
(spatial, spatial, channel, batch). For GPU prediction, also convert the data to a gpuArray
.
XTest = dlarray(XTest,"SSCB"); if canUseGPU XTest = gpuArray(XTest); end
To classify images using a dlnetwork
object, use the predict
function and find the classes with the highest scores.
YTest = predict(net,XTest); [~,idx] = max(extractdata(YTest),[],1); YTest = classes(idx);
Evaluate the classification accuracy.
accuracy = mean(YTest==TTest)
accuracy = 0.9908
Model Loss Function
The modelLoss
helper function takes a dlnetwork
object net
and a mini-batch of input data X
with corresponding labels T
, and returns the loss and the gradients of the loss with respect to the learnable parameters in net
. To compute the gradients automatically, use the dlgradient
function.
function [loss,gradients] = modelLoss(net,X,T) Y = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables); end
Input Arguments
net
— Network
dlnetwork
object
Network, specified as a dlnetwork
object.
The function updates the Learnables
property of the
dlnetwork
object. net.Learnables
is a table with
three variables:
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Value of parameter, specified as a cell array containing adlarray
.
The input argument grad
must be a table of the same
form as net.Learnables
.
params
— Network learnable parameters
dlarray
| numeric array | cell array | structure | table
Network learnable parameters, specified as a dlarray
, a numeric
array, a cell array, a structure, or a table.
If you specify params
as a table, it must contain the following
three variables:
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Value of parameter, specified as a cell array containing adlarray
.
You can specify params
as a container of learnable parameters for
your network using a cell array, structure, or table, or nested cell arrays or
structures. The learnable parameters inside the cell array, structure, or table must be
dlarray
or numeric values of data type double
or
single
.
The input argument grad
must be provided with exactly the same
data type, ordering, and fields (for structures) or variables (for tables) as
params
.
The learnables can be complex-valued. (since R2024a) Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The learnables must not be complex-valued. If your model involves complex learnables, then convert the learnables to real values before calculating the gradients.
grad
— Gradients of the loss
dlarray
| numeric array | cell array | structure | table
Gradients of the loss, specified as a dlarray
, a numeric array, a
cell array, a structure, or a table.
The exact form of grad
depends on the input network or learnable
parameters. The following table shows the required format for grad
for possible inputs to adamupdate
.
Input | Learnable Parameters | Gradients |
---|---|---|
net | Table net.Learnables containing
Layer , Parameter , and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data type, variables, and ordering as
net.Learnables . grad must have a
Value variable consisting of cell arrays that contain the
gradient of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
Numeric array | Numeric array with the same data type and ordering as
params
| |
Cell array | Cell array with the same data types, structure, and ordering as
params | |
Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer , Parameter , and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data types, variables, and ordering as
params . grad must have a
Value variable consisting of cell arrays that contain the
gradient of each learnable parameter. |
You can obtain grad
from a call to dlfeval
that
evaluates a function that contains a call to dlgradient
.
For more information, see Use Automatic Differentiation In Deep Learning Toolbox.
The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.
averageGrad
— Moving average of parameter gradients
[]
| dlarray
| numeric array | cell array | structure | table
Moving average of parameter gradients, specified as an empty array, a
dlarray
, a numeric array, a cell array, a structure, or a table.
The exact form of averageGrad
depends on the input network or
learnable parameters. The following table shows the required format for
averageGrad
for possible inputs to
adamupdate
.
Input | Learnable Parameters | Average Gradients |
---|---|---|
net | Table net.Learnables containing
Layer , Parameter , and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data type, variables, and ordering as
net.Learnables . averageGrad must have a
Value variable consisting of cell arrays that contain the
average gradient of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
Numeric array | Numeric array with the same data type and ordering as
params
| |
Cell array | Cell array with the same data types, structure, and ordering as
params | |
Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer , Parameter , and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data types, variables, and ordering as
params . averageGrad must have a
Value variable consisting of cell arrays that contain the
average gradient of each learnable parameter. |
If you specify averageGrad
and averageSqGrad
as empty arrays, the function assumes no previous gradients and runs in the same way as
for the first update in a series of iterations. To update the learnable parameters
iteratively, use the averageGrad
output of a previous call to
adamupdate
as the averageGrad
input.
The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.
averageSqGrad
— Moving average of squared parameter gradients
[]
| dlarray
| numeric array | cell array | structure | table
Moving average of squared parameter gradients, specified as an empty array, a
dlarray
, a numeric array, a cell array, a structure, or a table.
The exact form of averageSqGrad
depends on the input network or
learnable parameters. The following table shows the required format for
averageSqGrad
for possible inputs to
adamupdate
.
Input | Learnable parameters | Average Squared Gradients |
---|---|---|
net | Table net.Learnables containing
Layer , Parameter , and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data type, variables, and ordering as
net.Learnables . averageSqGrad must have
a Value variable consisting of cell arrays that contain the
average squared gradient of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
Numeric array | Numeric array with the same data type and ordering as
params
| |
Cell array | Cell array with the same data types, structure, and ordering as
params | |
Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer , Parameter , and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data types, variables and ordering as
params . averageSqGrad must have a
Value variable consisting of cell arrays that contain the
average squared gradient of each learnable parameter. |
If you specify averageGrad
and averageSqGrad
as empty arrays, the function assumes no previous gradients and runs in the same way as
for the first update in a series of iterations. To update the learnable parameters
iteratively, use the averageSqGrad
output of a previous call to
adamupdate
as the averageSqGrad
input.
The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.
iteration
— Iteration number
positive integer
Iteration number, specified as a positive integer. For the first call to
adamupdate
, use a value of 1
. You must increment
iteration
by 1
for each subsequent call in a
series of calls to adamupdate
. The Adam algorithm uses this value to
correct for bias in the moving averages at the beginning of a set of iterations.
learnRate
— Global learning rate
0.001
(default) | positive scalar
Global learning rate, specified as a positive scalar. The default value of
learnRate
is 0.001
.
If you specify the network parameters as a dlnetwork
, the
learning rate for each parameter is the global learning rate multiplied by the
corresponding learning rate factor property defined in the network layers.
gradDecay
— Gradient decay factor
0.9
(default) | positive scalar between 0
and 1
Gradient decay factor, specified as a positive scalar between 0
and 1
. The default value of gradDecay
is
0.9
.
sqGradDecay
— Squared gradient decay factor
0.999
(default) | positive scalar between 0
and 1
Squared gradient decay factor, specified as a positive scalar between
0
and 1
. The default value of
sqGradDecay
is 0.999
.
epsilon
— Small constant
1e-8
(default) | positive scalar
Small constant for preventing divide-by-zero errors, specified as a positive scalar.
The default value of epsilon
is 1e-8
.
Output Arguments
netUpdated
— Updated network
dlnetwork
object
Updated network, returned as a dlnetwork
object.
The function updates the Learnables
property of the
dlnetwork
object.
params
— Updated network learnable parameters
dlarray
| numeric array | cell array | structure | table
Updated network learnable parameters, returned as a dlarray
, a
numeric array, a cell array, a structure, or a table with a Value
variable containing the updated learnable parameters of the network.
The learnables can be complex-valued. (since R2024a) Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The learnables must not be complex-valued. If your model involves complex learnables, then convert the learnables to real values before calculating the gradients.
averageGrad
— Updated moving average of parameter gradients
dlarray
| numeric array | cell array | structure | table
Updated moving average of parameter gradients, returned as a
dlarray
, a numeric array, a cell array, a structure, or a table.
The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.
averageSqGrad
— Updated moving average of squared parameter gradients
dlarray
| numeric array | cell array | structure | table
Updated moving average of squared parameter gradients, returned as a
dlarray
, a numeric array, a cell array, a structure, or a table.
The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.
Algorithms
Adaptive Moment Estimation
Adaptive moment estimation (Adam) [1] uses a parameter update that is similar to RMSProp, but with an added momentum term. It keeps an element-wise moving average of both the parameter gradients and their squared values,
The β1 and β2 decay rates are the gradient decay and squared gradient decay factors, respectively. Adam uses the moving averages to update the network parameters as
The value α is the learning rate. If gradients over many iterations are similar, then using a moving average of the gradient enables the parameter updates to pick up momentum in a certain direction. If the gradients contain mostly noise, then the moving average of the gradient becomes smaller, and so the parameter updates become smaller too. The full Adam update also includes a mechanism to correct a bias the appears in the beginning of training. For more information, see [1].
References
[1] Kingma, Diederik, and Jimmy Ba. "Adam: A method for stochastic optimization." arXiv preprint arXiv:1412.6980 (2014).
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
Usage notes and limitations:
When at least one of the following input arguments is a
gpuArray
or adlarray
with underlying data of typegpuArray
, this function runs on the GPU.grad
averageGrad
averageSqGrad
params
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2019bR2024a: Complex-valued learnable parameters and gradients
The learnable parameters, gradients, moving average of gradients, and moving average of squared gradients can be complex-valued. When the updated learnable parameters are complex-valued, ensure that the corresponding operations support complex-valued parameters.
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)