Main Content

Deep Learning Tips and Tricks

This page describes various training options and techniques for improving the accuracy of deep learning networks.

Choose Network Architecture

The appropriate network architecture depends on the task and the data available. Consider these suggestions when deciding which architecture to use and whether to use a pretrained network or to train from scratch.

DataDescription of TaskLearn More
ImagesClassification of natural images

Try different pretrained networks. For a list of pretrained deep learning networks, see Pretrained Deep Neural Networks.

To learn how to interactively prepare a network for transfer learning using Deep Network Designer, see Transfer Learning with Deep Network Designer.

Regression of natural imagesTry different pretrained networks. For an example showing how to convert a pretrained classification network into a regression network, see Convert Classification Network into Regression Network.
Classification and regression of non-natural images (for example, tiny images and spectrograms

For an example showing how to classify tiny images, see Train Residual Network for Image Classification.

For an example showing how to classify spectrograms, see Train Speech Command Recognition Model Using Deep Learning.

Semantic segmentationComputer Vision Toolbox™ provides tools to create deep learning networks for semantic segmentation. For more information, see Getting Started with Semantic Segmentation Using Deep Learning (Computer Vision Toolbox).
Sequences, time series, and signalsSequence-to-label classificationFor an example, see Sequence Classification Using Deep Learning.
Sequence-to-sequence classification and regressionTo learn more, see Sequence-to-Sequence Classification Using Deep Learning and Sequence-to-Sequence Regression Using Deep Learning.
Sequence-to-one regressionFor an example, see Sequence-to-One Regression Using Deep Learning.
Time series forecastingFor an example, see Time Series Forecasting Using Deep Learning.
TextClassification and regressionText Analytics Toolbox™ provides tools to create deep learning networks for text data. For an example, see Classify Text Data Using Deep Learning.
Text generationFor an example, see Generate Text Using Deep Learning.
AudioAudio classification and regression

Try different pretrained networks. For a list of pretrained deep learning networks, see Pretrained Models (Audio Toolbox).

To learn how to programmatically prepare a network for transfer learning, see Transfer Learning with Pretrained Audio Networks (Audio Toolbox). To learn how to interactively prepare a network for transfer learning using Deep Network Designer, see Transfer Learning with Pretrained Audio Networks in Deep Network Designer.

For an example showing how to classify sounds using deep learning, see Classify Sound Using Deep Learning (Audio Toolbox).

Choose Training Options

The trainingOptions function provides a variety of options to train your deep learning network.

TipMore Information
Monitor training progressTo turn on the training progress plot, set the 'Plots' option in trainingOptions to 'training-progress'.
Use validation data

To specify validation data, use the 'ValidationData' option in trainingOptions.

Note

If your validation data set is too small and does not sufficiently represent the data, then the reported metrics might not help you. Using a too large validation data set can result in slower training.

For transfer learning, speed up the learning of new layers and slow down the learning in the transferred layers

Specify higher learning rate factors for new layers by using, for example, the WeightLearnRateFactor property of convolution2dLayer.

Decrease the initial learning rate using the 'InitialLearnRate' option of trainingOptions.

When transfer learning, you do not need to train for as many epochs. Decrease the number of epochs using the 'MaxEpochs' option in trainingOptions.

To learn how to interactively prepare a network for transfer learning using Deep Network Designer, see Transfer Learning with Deep Network Designer.

Shuffle your data every epoch

To shuffle your data every epoch (one full pass of the data), set the 'Shuffle' option in trainingOptions to 'every-epoch'.

Note

For sequence data, shuffling can have a negative impact on the accuracy as it can increase the amount of padding or truncated data. If you have sequence data, then sorting the data by sequence length can help. To learn more, see Sequence Padding, Truncation, and Splitting.

Try different optimizers

To specify different optimizers, use the solverName argument in trainingOptions.

For more information, see Set Up Parameters and Train Convolutional Neural Network.

Improve Training Accuracy

If you notice problems during training, then consider these possible solutions.

ProblemPossible Solution
NaNs or large spikes in the loss

Decrease the initial learning rate using the 'InitialLearnRate' option of trainingOptions.

If decreasing the learning rate does not help, then try using gradient clipping. To set the gradient threshold, use the 'GradientThreshold' option in trainingOptions.

Loss is still decreasing at the end of trainingTrain for longer by increasing the number of epochs using the 'MaxEpochs' option in trainingOptions.
Loss plateaus

If the loss plateaus at an unexpectedly high value, then drop the learning rate at the plateau. To change the learning rate schedule, use the 'LearnRateSchedule' option in trainingOptions.

If dropping the learning rate does not help, then the model might be underfitting. Try increasing the number of parameters or layers. You can check if the model is underfitting by monitoring the validation loss.

Validation loss is much higher than the training loss

To prevent overfitting, try one or more of the following:

Loss decreases very slowly

Increase the initial learning rate using the 'InitialLearnRate' option of trainingOptions.

For image data, try including batch normalization layers in your network. For more information, see batchNormalizationLayer.

For more information, see Set Up Parameters and Train Convolutional Neural Network.

Fix Errors in Training

If your network does not train at all, then consider the possible solutions.

ErrorDescriptionPossible Solution
Out-of-memory error when trainingThe available hardware is unable to store the current mini-batch, the network weights, and the computed activations.

Try reducing the mini-batch size using the 'MiniBatchSize' option of trainingOptions.

If reducing the mini-batch size does not work, then try using a smaller network, reducing the number of layers, or reducing the number of parameters or filters in the layers.

Custom layer errorsThere could be an issue with the implementation of the custom layer.

Check the validity of the custom layer and find potential issues using checkLayer.

If a test fails when you use checkLayer, then the function provides a test diagnostic and a framework diagnostic. The test diagnostic highlights any layer issues, whereas the framework diagnostic provides more detailed information. To learn more about the test diagnostics and get suggestions for possible solutions, see Diagnostics.

Training throws the error 'CUDA_ERROR_UNKNOWN'Sometimes, the GPU throws this error when it is being used for both compute and display requests from the OS.

Try reducing the mini-batch size using the 'MiniBatchSize' option of trainingOptions.

If reducing the mini-batch size does not work, then in Windows®, try adjusting the Timeout Detection and Recovery (TDR) settings. For example, change the TdrDelay from 2 seconds (default) to 4 seconds (requires registry edit).

You can analyze your deep learning network using analyzeNetwork. The analyzeNetwork function displays an interactive visualization of the network architecture, detects errors and issues with the network, and provides detailed information about the network layers. Use the network analyzer to visualize and understand the network architecture, check that you have defined the architecture correctly, and detect problems before training. Problems that analyzeNetwork detects include missing or disconnected layers, mismatched or incorrect sizes of layer inputs, an incorrect number of layer inputs, and invalid graph structures.

Prepare and Preprocess Data

You can improve the accuracy by preprocessing your data.

Weight or Balance Classes

Ideally, all classes have an equal number of observations. However, for some tasks, classes can be imbalanced. For example, automotive datasets of street scenes tend to have more sky, building, and road pixels than pedestrian and bicyclist pixels because the sky, buildings, and roads cover more image area. If not handled correctly, this imbalance can be detrimental to the learning process because the learning is biased in favor of the dominant classes.

For classification tasks, you can specify class weights using the 'ClassWeights' option of classificationLayer. For an example, see Train Sequence Classification Network Using Data With Imbalanced Classes. For semantic segmentation tasks, you can specify class weights using the ClassWeights (Computer Vision Toolbox) property of pixelClassificationLayer (Computer Vision Toolbox).

Alternatively, you can balance the classes by doing one or more of the following:

  • Add new observations from the least frequent classes.

  • Remove observations from the most frequent classes.

  • Group similar classes. For example, group the classes "car" and "truck" into the single class "vehicle".

Preprocess Image Data

For more information about preprocessing image data, see Preprocess Images for Deep Learning.

TaskMore Information
Resize images

To use a pretrained network, you must resize images to the input size of the network. To resize images, use augmentedImageDatastore. For example, this syntax resizes images in the image datastore imds:

auimds = augmentedImageDatastore(inputSize,imds);

Tip

Use augmentedImageDatastore for efficient preprocessing of images for deep learning, including image resizing. Do not use the ReadFcn option of ImageDatastore objects.

ImageDatastore allows batch reading of JPG or PNG image files using prefetching. If you set the ReadFcn option to a custom function, then ImageDatastore does not prefetch and is usually significantly slower.

Image augmentation

To avoid overfitting, use image transformation. To learn more, see Train Network with Augmented Images.

Normalize regression targets

Normalize the predictors before you input them to the network. If you normalize the responses before training, then you must transform the predictions of the trained network to obtain the predictions of the original responses.

For more information, see Train Convolutional Neural Network for Regression.

Preprocess Sequence Data

For more information about working with LSTM networks, see Long Short-Term Memory Neural Networks.

TaskMore Information
Normalize sequence data

To normalize sequence data, first calculate the per-feature mean and standard deviation for all the sequences. Then, for each training observation, subtract the mean value and divide by the standard deviation.

To learn more, see Normalize Sequence Data.

Reduce sequence padding and truncation

To reduce the amount of padding or discarded data when padding or truncating sequences, try sorting your data by sequence length.

To learn more, see Sequence Padding, Truncation, and Splitting.

Specify mini-batch size and padding options for prediction

When you make predictions with sequences of different lengths, the mini-batch size can impact the amount of padding added to the input data, which can result in different predicted values. Try using different values to see which works best with your network.

To specify mini-batch size and padding options, use the 'MiniBatchSize' and 'SequenceLength' options of the classify, predict, classifyAndUpdateState, and predictAndUpdateState functions.

Use Available Hardware

To specify the execution environment, use the 'ExecutionEnvironment' option in trainingOptions.

ProblemMore Information
Training on CPU is slowIf training is too slow on a single CPU, try using a pretrained deep learning network as a feature extractor and train a machine learning model. For an example, see Extract Image Features Using Pretrained Network.
Training LSTM on GPU is slow

The CPU is better suited for training an LSTM network using mini-batches with short sequences. To use the CPU, set the 'ExecutionEnvironment' option in trainingOptions to 'cpu'.

Software does not use all available GPUsIf you have access to a machine with multiple GPUs, simply set the 'ExecutionEnvironment' option in trainingOptions to 'multi-gpu'. For more information, see Deep Learning with MATLAB on Multiple GPUs.

For more information, see Scale Up Deep Learning in Parallel, on GPUs, and in the Cloud.

Fix Errors With Loading from MAT-Files

If you are unable to load layers or a network from a MAT-file and get a warning of the form

Warning: Unable to load instances of class layerType into a 
heterogeneous array.  The definition of layerType could be
missing or contain an error.  Default objects will be
substituted. 
Warning: While loading an object of class 'SeriesNetwork':
Error using 'forward' in Layer nnet.cnn.layer.MissingLayer. The
function threw an error and could not be executed. 
then the network in the MAT-file may contain unavailable layers. This could be due to the following:

  • The file contains a custom layer not on the path – To load networks containing custom layers, add the custom layer files to the MATLAB® path.

  • The file contains a custom layer from a support package – To load networks using layers from support packages, install the required support package at the command line by using the corresponding function (for example, resnet18) or using the Add-On Explorer.

  • The file contains a custom layer from a documentation example that is not on the path – To load networks containing custom layers from documentation examples, open the example as a live script and copy the layer from the example folder to your working directory.

  • The file contains a layer from a toolbox that is not installed – To access layers from other toolboxes, for example, Computer Vision Toolbox or Text Analytics Toolbox, install the corresponding toolbox.

After trying the suggested solutions, reload the MAT-file.

See Also

| | |

Related Topics