Make Predictions Using dlnetwork
Object
This example shows how to make predictions using a dlnetwork
object by splitting data into mini-batches.
For large data sets, or when predicting on hardware with limited memory, make predictions by splitting the data into mini-batches. When making predictions with SeriesNetwork
or DAGNetwork
objects, the predict
function automatically splits the input data into mini-batches. For dlnetwork
objects, you must split the data into mini-batches manually.
Load dlnetwork
Object
Load a trained dlnetwork
object and the corresponding classes.
s = load("digitsCustom.mat");
dlnet = s.dlnet;
classes = s.classes;
Load Data for Prediction
Load the digits data for prediction.
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true);
Make Predictions
Loop over the mini-batches of the test data and make predictions using a custom prediction loop.
Use minibatchqueue
to process and manage the mini-batches of images. Specify a mini-batch size of 128. Set the read size property of the image datastore to the mini-batch size.
For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch
(defined at the end of this example) to concatenate the data into a batch and normalize the images.Format the images with the dimensions
'SSCB'
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
.Make predictions on a GPU if one is available. By default, the
minibatchqueue
object converts the output to agpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... "MiniBatchSize",miniBatchSize,... "MiniBatchFcn", @preprocessMiniBatch,... "MiniBatchFormat","SSCB");
Loop over the minibatches of data and make predictions using the predict
function. Use the onehotdecode
function to determine the class labels. Store the predicted class labels.
numObservations = numel(imds.Files); YPred = strings(1,numObservations); predictions = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. dlX = next(mbq); % Make predictions using the predict function. dlYPred = predict(dlnet,dlX); % Determine corresponding classes. predBatch = onehotdecode(dlYPred,classes,1); predictions = [predictions predBatch]; end
Visualize some of the predictions.
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); label = predictions(idx(i)); imshow(I) title("Label: " + string(label)) end
Mini-Batch Preprocessing Function
The preprocessMiniBatch
function preprocesses the data using the following steps:
Extract the data from the incoming cell array and concatenate into a numeric array. Concatenating over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.
Normalize the pixel values between
0
and1
.
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
See Also
dlarray
| dlnetwork
| predict
| minibatchqueue
| onehotdecode
Related Topics
- Train Generative Adversarial Network (GAN)
- Train Network Using Custom Training Loop
- Define Model Loss Function for Custom Training Loop
- Update Batch Normalization Statistics in Custom Training Loop
- Define Custom Training Loops, Loss Functions, and Networks
- Make Predictions Using Model Function
- Specify Training Options in Custom Training Loop
- List of Deep Learning Layers
- Deep Learning Tips and Tricks
- Automatic Differentiation Background