Main Content

Compress Image Classification Network for Deployment to Resource-Constrained Embedded Devices

This example shows how to reduce the memory footprint and computation requirements of an image classification network for deployment on resource constrained embedded devices such as the Raspberry Pi™.

In many applications where transfer learning is used to retrain an image classification network for a new task or where a new network is trained from scratch, the optimal network architecture is not known, and the network might be overparameterized. An overparameterized network has redundancies. Network pruning is a powerful model compression tool that helps identify redundancies that can be removed with little impact on the final network output. When you use pruning in combination with network quantization, you can reduce the inference time and memory footprint of the network making it easier to deploy to ARM® CPU platforms such as the Raspberry Pi.

This example shows how to:

  • Use transfer learning to retrain SqueezeNet, a pretrained convolutional neural network to classify a new set of images from the CIFAR-10 data set.

  • Prune filters from the convolutional layers of the network by using first-order Taylor approximation.

  • Retrain the network after pruning to regain any loss in accuracy.

  • Evaluate the impact of pruning on classification accuracy.

  • Quantize the weights, biases, and activations of the convolution layers to 8-bit scaled integer data type.

  • Generate and deploy optimized C++ code to a Raspberry Pi.

  • Evaluate the impact of quantization on the classification accuracy of the pruned network.

Third-Party Prerequisites

Prepare Data

Download the CIFAR-10 data set [1]. The data set contains 60,000 images. Each image is 32-by-32 pixels in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time.

datadir = tempdir; 

Load the CIFAR-10 training and test images as 4-D arrays. The training set contains 50,000 images and the test set contains 10,000 images. Use the CIFAR-10 test images for network validation.

[XTrain,TTrain,XValidation,TValidation] = loadCIFARData(datadir);

You can display a random sample of the training images using the following code.

idx = randperm(size(XTrain,4),20);
im = imtile(XTrain(:,:,:,idx),ThumbnailSize=[96,96]);

Create an augmentedImageDatastore object to use for network training. During training, the datastore randomly flips the training images along the vertical axis and randomly translates them up to four pixels horizontally and vertically. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

imageSize = [32,32,3];
pixelRange = [-4,4];
imageAugmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandXTranslation=pixelRange, ...
augimdsTrain = augmentedImageDatastore(imageSize,XTrain,TTrain, ...
augimdsValidation = augmentedImageDatastore(imageSize,XValidation, ...
classes = categories(TTrain);

Retrain Network on CIFAR-10 Data Using Transfer Learning

SqueezeNet has been trained on over a million images and can classify images into 1000 object categories (such as keyboard, coffee mug, pencil, and many animals). The pretrained SqueezeNet network is fine-tuned by using transfer learning. Fine-tuning a network with transfer learning is usually much faster and easier than training a network with randomly initialized weights from scratch.

Retrain Network

Training the network on a good GPU takes considerable amount of time. If you do not have a GPU, then training takes much longer. Training on a GPU or in parallel requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

To save time while running this example, load a pretrained network by setting doTraining to false. To train the network yourself, set doTraining to true.

doTraining = false;
if doTraining

    net = squeezenet;  %#ok<UNRCH> 
    lgraph = layerGraph(net);
    larray = [imageInputLayer(imageSize,'Name','data')];
    lgraph = replaceLayer(lgraph,'data',larray);
    [learnableLayer,classLayer] = findLayersToReplace(lgraph);
    numClasses = 10;
    newFirstConvLayer = convolution2dLayer([3,3], 64,'WeightLearnRateFactor', ...
    lgraph = replaceLayer(lgraph,'conv1',newFirstConvLayer);
    newConvLayer =  convolution2dLayer([1,1],numClasses, ...
    lgraph = replaceLayer(lgraph,'conv10',newConvLayer);
    newClassificatonLayer = classificationLayer('Name','new_classoutput');
    lgraph = replaceLayer(lgraph,'ClassificationLayer_predictions',newClassificatonLayer);

    options = trainingOptions('adam', ...
        'MiniBatchSize',100, ...
        'MaxEpochs',15, ...
        'InitialLearnRate',2e-4/3, ...
        'Shuffle','every-epoch', ...
        'ValidationData',augimdsValidation, ...
        'ValidationFrequency',25, ...
        'ValidationPatience',5, ...
        'Verbose',false, ...

    transferNet = trainNetwork(augimdsTrain,lgraph,options);

Save the trained network.


Evaluate Trained Network

Calculate the final accuracy of the network on the validation set (without data augmentation).

[YValPred,probs] = classify(transferNet,XValidation);
accuracyOfTrainedNet = mean(YValPred == TValidation) * 100;
disp("Validation accuracy of trained network: " + accuracyOfTrainedNet + "%")
Validation accuracy of trained network: 60.48%

Prune Network

Prune the network using the taylorPrunableNetwork function. The network computes an importance score for each convolution filter in the network based on Taylor expansion [2][3]. Pruning is iterative; each time the loop runs, until a stopping criterion is met, the function removes a small number of the least important convolution filters and updates the network architecture.

Specify Pruning and Fine-Tuning Options

Set the pruning options.

  • maxPruningIterations sets the maximum number of iterations to be used for pruning process.

  • maxToPrune is set as the maximum number of filters to be pruned in each iteration of the pruning cycle.

maxPruningIterations = 30;
maxToPrune = 32;

Set the fine-tuning options.

learnRate = 1e-2/3;
momentum = 0.9;
miniBatchSize = 256;
numMinibatchUpdates  = 50;
validationFrequency = 1;

Prune Network using Custom Pruning Loop

To implement a custom pruning loop, convert the network to a dlnetwork object.

layerG = layerGraph(transferNet);
layerG = removeLayers(layerG,layerG.OutputNames);
net = dlnetwork(layerG);

Print a summary of the dlnetwork object. The summary shows whether the network is initialized, the total number of learnable parameters, and information about the network inputs.

   Initialized: true

   Number of learnables: 727.6k

      1   'data'   32×32×3 images

Create a Taylor prunable network from the original network.

prunableNet = taylorPrunableNetwork(net);
maxPrunableFilters = prunableNet.NumPrunables;

Create a minibatchqueue object that processes and manages mini-batches of images during training. For each mini-batch:

  • Use the custom mini-batch preprocessing function preprocessMiniBatch (defined at the end of this example) to convert the labels to one-hot encoded variables.

  • Format the image data with the dimension labels 'SSCB' (spatial, spatial, channel, batch). By default, the minibatchqueue object converts the data to dlarray objects with underlying type single. Do not add a format to the class labels.

  • Train on a GPU if one is available. By default, the minibatchqueue object converts each output to a gpuArray if a GPU is available.

mbqTrain = minibatchqueue(augimdsTrain, ...
    MiniBatchSize = miniBatchSize, ...
    MiniBatchFcn = @preprocessMiniBatchTraining, ...
    OutputAsDlarray = [1 1], ...
    OutputEnvironment = ["auto","auto"], ...
    PartialMiniBatch = "return", ...
    MiniBatchFormat = ["SSCB",""]);

mbqTest = minibatchqueue(augimdsValidation,...
    MiniBatchSize = miniBatchSize,...
    MiniBatchFcn = @preprocessMiniBatchTraining, ...
    OutputAsDlarray = [1 1], ...
    OutputEnvironment = ["auto","auto"], ...
    PartialMiniBatch = "return", ...
    MiniBatchFormat = ["SSCB",""]);

Initialize the training progress plots.

tl = tiledlayout(3,1);
lossAx = nexttile;
lineLossFinetune = animatedline(Color=[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Fine-Tuning Iteration")
grid on
title("Mini-Batch Loss During Pruning")
xTickPos = [];

accuracyAx = nexttile;
lineAccuracyPruning = animatedline(Color=[0.098 0.325 0.85],LineWidth=2,Marker="o");
ylim([0 100])
xlabel("Pruning Iteration")
grid on
title("Validation Accuracy After Pruning")

numPrunablesAx = nexttile;
lineNumPrunables = animatedline(Color=[0.4660 0.6740 0.1880],LineWidth=2,Marker="^");
ylim([200 3000])
xlabel("Pruning Iteration")
ylabel("Prunable Filters")
grid on
title("Number of Prunable Convolution Filters After Pruning")

Prune the network by repeatedly fine-tuning the network and removing the low scoring filters.

For each pruning iteration. The following steps are used:

  • Fine-tune network and accumulate Taylor scores for convolution filters for numMinibatchUpdates

  • Prune the network using the updatePrunables function to remove maxToPrune number of convolution filters

  • Compute validation accuracy

To fine tune the network, loop over the mini-batches of the training data. For each mini-batch in the fine-tuning iteration the following steps are used:

  • Evaluate the pruning loss, gradients of the pruning activations, pruning activations, model gradients and the state using the dlfeval and modelLossPruning functions.

  • Update the network state.

  • Update the network parameters using the sgdmupdate function.

  • Update the Taylor scores of the prunable network using the updateScore function.

  • Display the training progress.

start = tic;
iteration = 0;

for pruningIteration = 1:maxPruningIterations

    velocity = [];

    % Loop over mini-batches.
    fineTuningIteration = 0;
    while hasdata(mbqTrain)
        iteration = iteration + 1;
        fineTuningIteration = fineTuningIteration + 1;

        [X, T] = next(mbqTrain);
        [loss,pruningActivations, pruningGradients, netGradients, state] = ...
            dlfeval(@modelLossPruning, prunableNet, X, T);
        prunableNet.State = state;
        [prunableNet, velocity] = sgdmupdate(prunableNet, netGradients, velocity, learnRate, momentum);
        prunableNet = updateScore(prunableNet, pruningActivations, pruningGradients);

        % Display the training progress.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        addpoints(lineLossFinetune, iteration, double(loss))
        title(tl,"Processing Pruning Iteration: " + pruningIteration + " of " + maxPruningIterations + ...
            ", Elapsed Time: " + string(D))
        % Synchronize the x-axis of the accuracy and numPrunables plots with the loss plot.

        % Stop the fine-tuning loop when numMinibatchUpdates is reached.
        if (fineTuningIteration > numMinibatchUpdates)

    % Prune filters based on previously computed Taylor scores.
    prunableNet = updatePrunables(prunableNet, MaxToPrune = maxToPrune);

    % Show results on the validation data set in a subset of pruning iterations.
    isLastPruningIteration = pruningIteration == maxPruningIterations;
    if (mod(pruningIteration, validationFrequency) == 0 || isLastPruningIteration)
        accuracy = modelAccuracy(prunableNet, mbqTest, classes, augimdsValidation.NumObservations);
        addpoints(lineAccuracyPruning, iteration, accuracy)

    xTickPos = [xTickPos, iteration]; %#ok<AGROW>


Figure contains 3 axes objects. Axes object 1 with title Mini-Batch Loss During Pruning, xlabel Fine-Tuning Iteration, ylabel Loss contains an object of type animatedline. Axes object 2 with title Validation Accuracy After Pruning, xlabel Pruning Iteration, ylabel Accuracy contains an object of type animatedline. Axes object 3 with title Number of Prunable Convolution Filters After Pruning, xlabel Pruning Iteration, ylabel Prunable Filters contains an object of type animatedline.

In contrast to typical training where the loss decreases with each iteration, pruning may increase the loss and reduce the validation accuracy due to the change of network structure when convolution filters are pruned. To further improve the accuracy of the network, you can retrain the network.

Once pruning is complete, convert the taylorPrunableNetwork back to a dlnetwork for retraining.

prunedNet = dlnetwork(prunableNet);

Retrain Network After Pruning

Retrain the network after pruning to regain any loss in accuracy. To retrain the network using the trainNetwork function,

  • Extract the layerGraph from the dlnetwork.

  • Add the removed classification layer from the original network to the layerGraph of the pruned network.

  • Train the layerGraph network.

prunedLayerGraph = layerGraph(prunedNet);
outputLayerName = string(transferNet.OutputNames{1});
outputLayerIdx = {transferNet.Layers.Name} == outputLayerName;
prunedLayerGraph = addLayers(prunedLayerGraph,transferNet.Layers(outputLayerIdx));
prunedLayerGraph = connectLayers(prunedLayerGraph,prunedNet.OutputNames{1},outputLayerName);

Set the options to the default settings for stochastic gradient descent with momentum. Set the maximum number of retraining epochs at 10 and start the training with an initial learning rate of 0.01.

options = trainingOptions("adam", ...
    MaxEpochs = 10, ...
    MiniBatchSize = 100, ...
    InitialLearnRate = 2e-4/3, ...
    LearnRateSchedule = "piecewise", ...
    LearnRateDropFactor = 0.1, ...
    LearnRateDropPeriod = 2, ...
    L2Regularization = 0.02, ...
    ValidationData = augimdsValidation, ...
    ValidationFrequency = 25, ...
    Verbose = false, ...
    Shuffle = "every-epoch", ...
    Plots = "training-progress");

Train the network.

prunedDAGNet = trainNetwork(augimdsTrain,prunedLayerGraph,options);

Figure Training Progress (26-Jan-2023 12:08:07) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 8 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 8 objects of type patch, text, line.

Save the pruned network.


Compare Original Network and Pruned Network

Determine the impact of pruning on each layer.

[originalNetFilters,layerNames] = numConvLayerFilters(transferNet);
prunedNetFilters = numConvLayerFilters(prunedDAGNet);

Visualize the number of filters in the original network and in the pruned network.

ylabel("Number of Filters")
title("Number of Filters Per Layer")
ax = gca;
ax.TickLabelInterpreter = "none";
legend("Original Network Filters","Pruned Network Filters","Location","southoutside")

Figure contains an axes object. The axes object with title Number of Filters Per Layer, xlabel Layer, ylabel Number of Filters contains 2 objects of type bar. These objects represent Original Network Filters, Pruned Network Filters.

Large differences between the number of filters of the two networks indicate where many of the less important filters have been pruned.

Next, compare the accuracy of the original network and the pruned network.

YPredOriginal = classify(transferNet,XValidation);
Elapsed time is 1.435466 seconds.
accuOriginal = mean(YPredOriginal == TValidation)
accuOriginal = 0.6048
YPredPruned = classify(prunedDAGNet,XValidation);
Elapsed time is 2.194408 seconds.
accuPruned = mean(YPredPruned == TValidation)
accuPruned = 0.7843

Pruning can unequally affect the classification of different classes and introduce bias into the model, which might not be apparent from the accuracy value. To assess the impact of pruning at a class level, use a confusion matrix chart.

confusionchart(TValidation,YPredOriginal,Normalization = "row-normalized");
title("Original Network")

Figure contains an object of type ConfusionMatrixChart. The chart of type ConfusionMatrixChart has title Original Network.

confusionchart(TValidation,YPredPruned,Normalization = "row-normalized");
title("Pruned Network")

Figure contains an object of type ConfusionMatrixChart. The chart of type ConfusionMatrixChart has title Pruned Network.

Next, estimate the model parameters for the original network and the pruned network to understand the impact of pruning on the overall network learnables and size.

ans=3×3 table
                         Network Learnables    Approx. Network Memory (MB)    Accuracy
                         __________________    ___________________________    ________

    Original Network         7.2763e+05                   2.7757               0.6048 
    Pruned Network           3.8997e+05                   1.4876               0.7843 
    Percentage Change           -46.406                  -46.406               29.679 

This table compares the size and classification accuracy of the original and the pruned network. A decrease in network memory and improves accuracy values indicate a successful pruning operation.

Quantize the Pruned Network

To quantize the pruned network using the dlquantizer function, specify the network you want to calibrate and the execution environment, and then calibrate with calibration data.

clear r
r = raspi;
quantOpts = dlquantizationOptions('Target',r);
quantObj = dlquantizer(prunedDAGNet,'ExecutionEnvironment','CPU'); 

Use the calibrate function to exercise the network with the calibration data and collect range statistics for the weights, biases, and activations at each layer.

calResults = calibrate(quantObj,augimdsTrain,'UseGPU','off')
### Host application produced the following standard output (stdout) and standard error (stderr) messages:
calResults=122×5 table
        Optimized Layer Name         Network Layer Name     Learnables / Activations    MinValue     MaxValue
    ____________________________    ____________________    ________________________    _________    ________

    {'new_firstconv_Weights'   }    {'new_firstconv'   }           "Weights"             -0.53081    0.50032 
    {'new_firstconv_Bias'      }    {'new_firstconv'   }           "Bias"                -0.13664     0.2061 
    {'fire2-squeeze1x1_Weights'}    {'fire2-squeeze1x1'}           "Weights"              -1.3348     1.1903 
    {'fire2-squeeze1x1_Bias'   }    {'fire2-squeeze1x1'}           "Bias"                -0.12888    0.25519 
    {'fire2-expand1x1_Weights' }    {'fire2-expand1x1' }           "Weights"             -0.71728    0.87709 
    {'fire2-expand1x1_Bias'    }    {'fire2-expand1x1' }           "Bias"               -0.065638    0.14888 
    {'fire2-expand3x3_Weights' }    {'fire2-expand3x3' }           "Weights"             -0.71899     0.6452 
    {'fire2-expand3x3_Bias'    }    {'fire2-expand3x3' }           "Bias"               -0.062058    0.08805 
    {'fire3-squeeze1x1_Weights'}    {'fire3-squeeze1x1'}           "Weights"             -0.72677    0.67948 
    {'fire3-squeeze1x1_Bias'   }    {'fire3-squeeze1x1'}           "Bias"                -0.11343    0.33745 
    {'fire3-expand1x1_Weights' }    {'fire3-expand1x1' }           "Weights"             -0.68734    0.93931 
    {'fire3-expand1x1_Bias'    }    {'fire3-expand1x1' }           "Bias"               -0.075568    0.31345 
    {'fire3-expand3x3_Weights' }    {'fire3-expand3x3' }           "Weights"              -0.5874    0.72577 
    {'fire3-expand3x3_Bias'    }    {'fire3-expand3x3' }           "Bias"               -0.066463    0.12058 
    {'fire4-squeeze1x1_Weights'}    {'fire4-squeeze1x1'}           "Weights"             -0.70607     1.0569 
    {'fire4-squeeze1x1_Bias'   }    {'fire4-squeeze1x1'}           "Bias"                -0.11843    0.14643 

Save the dlquantizer object containing the network to quantize.


We can use the Deep Network Quantizer app to further visualize the dynamic ranges of the calibrated layers:

Use the validate function to compare the results of the network before and after quantization using the validation data set. Examine the MetricResults.Result field of the validation output to see the accuracy of the quantized network.

validationMetricsC = validate(quantObj,augimdsValidation,quantOpts);
### Starting application: 'codegen/lib/validate_predict_int8/pil/validate_predict_int8.elf'
    To terminate execution: clear validate_predict_int8_pil
### Launching application validate_predict_int8.elf...
### Host application produced the following standard output (stdout) and standard error (stderr) messages:
ans=2×2 table
    NetworkImplementation    MetricOutput
    _____________________    ____________

     {'Floating-Point'}          0.765   
     {'Quantized'     }         0.7641   

Generate and Deploy INT8 C++ Code to Raspberry Pi

The predictResponses.m entry-point function takes an image input and runs prediction on the image using the specified network. The function uses a persistent object mynet to load the network object and reuses the persistent object for prediction on subsequent calls.

type predictResponses.m
function out = predictResponses(net,in)

persistent mynet;

if isempty(mynet)
    mynet = coder.loadDeepLearningNetwork(net);

out = predict(mynet, in);


To generate a PIL MEX function, create a code configuration object for a static library and set the verification mode to 'PIL'. Set the target language to C++. Create a coder.Hardware object for Raspberry Pi and attach it to the code generation configuration object.

cfg = coder.config('lib', 'ecoder', true);
cfg.VerificationMode = 'PIL';
cfg.TargetLang = 'C++';
cfg.Hardware = coder.hardware('Raspberry Pi');

Create a deep learning configuration object for the ARM Compute library. Specify the library version and arm architecture. For this example, suppose that the ARM Compute Library in the Raspberry Pi hardware is version 20.02.1.

dlcfg = coder.DeepLearningConfig('arm-compute');
dlcfg.ArmComputeVersion = '20.02.1';
dlcfg.ArmArchitecture = 'armv7';

Set the properties of dlcfg to generate code for INT8 inference.

dlcfg.CalibrationResultFile = 'squeezenetQuantObj.mat'; 
dlcfg.DataType = 'int8';
cfg.DeepLearningConfig = dlcfg;    
inputs = {coder.Constant('prunedDAGNet.mat'),ones(32,32,3,'uint8')};

Generate a PIL MEX function by using the codegen command.

codegen -config cfg predictResponses -args inputs
 Deploying code. This may take a few minutes. 
### Connectivity configuration for function 'predictResponses': 'Raspberry Pi'
Location of the generated elf : /home/pi/MATLAB_ws/R2023a/home/lnarasim/Documents/MATLAB/ExampleManager/lnarasim.Bdoc23a.j2174901/deeplearning_shared-ex40890309/codegen/lib/predictResponses/pil
Code generation successful.

Compare Classification Accuracy of the Transfer Learned, Pruned, and Quantized Networks

Evaluate the impact of quantization on the classification accuracy of the pruned network.

testImages = read(augimdsValidation);
testImage = table2array(testImages(4,1));
predictScores(:,1) = predictResponses('transferNet.mat', testImage{1}); 
predictScores(:,2) = predictResponses('prunedDAGNet.mat', testImage{1});
predictScores(:,3) = predictResponses_pil('prunedDAGNet.mat',testImage{1}); 
### Starting application: 'codegen/lib/predictResponses/pil/predictResponses.elf'
    To terminate execution: clear predictResponses_pil
### Launching application predictResponses.elf...
XLim = [0 1.1];
YAxisLocation = 'left';
legend('Trained Network (Single)','Pruned Network (Single)','ARM-Compute (8-bit integer)');
sgtitle('Network Predictions')

Figure contains an axes object and an object of type subplottext. The axes object with xlabel Probability contains 3 objects of type bar. These objects represent Trained Network (Single), Pruned Network (Single), ARM-Compute (8-bit integer).

Helper Functions

Download CIFAR-10 Dataset

The downloadCIFARData function downloads the CIFAR-10 dataset from the external website. The download is approximately 175MB in size.

function downloadCIFARData(destination)

url = '';

unpackedData = fullfile(destination,'cifar-10-batches-mat');
if ~exist(unpackedData,'dir')
    fprintf('Downloading CIFAR-10 dataset (175 MB). This can take a while...');


Process CIFAR-10 Dataset

Load the CIFAR-10 training and test images as 4-D arrays. The training set contains 50,000 images and the test set contains 10,000 images. Use the CIFAR-10 test images for network validation.

function [XTrain,YTrain,XTest,YTest] = loadCIFARData(location)

location = fullfile(location,'cifar-10-batches-mat');

[XTrain1,YTrain1] = loadBatchAsFourDimensionalArray(location,'data_batch_1.mat');
[XTrain2,YTrain2] = loadBatchAsFourDimensionalArray(location,'data_batch_2.mat');
[XTrain3,YTrain3] = loadBatchAsFourDimensionalArray(location,'data_batch_3.mat');
[XTrain4,YTrain4] = loadBatchAsFourDimensionalArray(location,'data_batch_4.mat');
[XTrain5,YTrain5] = loadBatchAsFourDimensionalArray(location,'data_batch_5.mat');
XTrain = cat(4,XTrain1,XTrain2,XTrain3,XTrain4,XTrain5);
YTrain = [YTrain1;YTrain2;YTrain3;YTrain4;YTrain5];

[XTest,YTest] = loadBatchAsFourDimensionalArray(location,'test_batch.mat');

function [XBatch,YBatch] = loadBatchAsFourDimensionalArray(location,batchFileName)
s = load(fullfile(location,batchFileName));
XBatch =';
XBatch = reshape(XBatch,32,32,3,[]);
XBatch = permute(XBatch,[2 1 3 4]);
YBatch = convertLabelsToCategorical(location,s.labels);

function categoricalLabels = convertLabelsToCategorical(location,integerLabels)
s = load(fullfile(location,'batches.meta.mat'));
categoricalLabels = categorical(integerLabels,0:9,s.label_names);

Mini-Batch Preprocessing Function

The preprocessMiniBatchTraining function preprocesses a mini-batch of predictors and labels for loss computation during training.

function [X,T] = preprocessMiniBatchTraining(XCell,TCell)
% Concatenate.
X = cat(4,XCell{1:end});

% Extract label data from cell and concatenate.
T = cat(2,TCell{1:end});

% One-hot encode labels.
T = onehotencode(T,1);

Model Gradients Function for Fine-Tuning and Pruning

The modelLossPruning function takes as input a deep.prune.TaylorPrunableNetwork object prunableNet, a mini-batch of input data X with corresponding labels T and returns the loss, gradients of the loss with respect to the pruning activations, pruning activations, gradients of the loss with respect to the learnable parameters in prunableNet and the network state. To compute the gradients automatically, use the dlgradient function.

function [loss,pruningGradient,pruningActivations,netGradients,state] = modelLossPruning(prunableNet, X, T)

[dlYPred,state,pruningActivations] = forward(prunableNet,X);
dlYPred = squeeze(dlYPred);

loss = crossentropy(dlYPred,T);
[pruningGradient,netGradients] = dlgradient(loss,pruningActivations,prunableNet.Learnables);


Evaluate Model Accuracy

The modelAccuracy function takes as input the network(dlnetwork), minibatchque object, the classes and the number of observations and returns the accuracy.

function accuracy = modelAccuracy(net, mbq, classes, numObservations)
% This function computes the model accuracy of a net(dlnetwork) on the minibatchque 'mbq'.

totalCorrect = 0;

classes = int32(categorical(classes));


while hasdata(mbq)
    [dlX, Y] = next(mbq);

    dlYPred = extractdata(predict(net, dlX));
    dlYPred = squeeze(dlYPred);

    YPred = onehotdecode(dlYPred,classes,1)';
    YReal = onehotdecode(Y,classes,1)';

    miniBatchCorrect = nnz(YPred == YReal);

    totalCorrect = totalCorrect + miniBatchCorrect;

accuracy = totalCorrect / numObservations * 100;

Evaluate Number of Filters in Convolution Layers

The numConvLayerFilters function returns the number of filters in each convolution layer.

function [nFilters, convNames] = numConvLayerFilters(net)
numLayers = numel(net.Layers);
convNames = [];
nFilters = [];
% Check for convolution layers and extract the number of filters.
for cnt = 1:numLayers
    if isa(net.Layers(cnt),"nnet.cnn.layer.Convolution2DLayer")
        sizeW = size(net.Layers(cnt).Weights);
        nFilters = [nFilters; sizeW(end)]; %#ok<AGROW>
        convNames = [convNames; string(net.Layers(cnt).Name)]; %#ok<AGROW>

Evaluate the network statistics of original network and pruned network

The analyzeNetworkMetrics function takes input as the original network, pruned network, accuracy of original network and the accuracy of the pruned network and returns the different statistics like network learnables, network memory and the accuracy on the test data in form of a table.

function [statistics] = analyzeNetworkMetrics(originalNet,prunedNet,accuracyOriginal,accuracyPruned)

originalNetMetrics = estimateNetworkMetrics(originalNet);
prunedNetMetrics = estimateNetworkMetrics(prunedNet);

% Accuracy of original network and pruned network
perChangeAccu = 100*(accuracyPruned - accuracyOriginal)/accuracyOriginal;
AccuracyForNetworks = [accuracyOriginal;accuracyPruned;perChangeAccu];

% Total learnables in both networks
originalNetLearnables = sum(originalNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables);
prunedNetLearnables = sum(prunedNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables);
percentageChangeLearnables = 100*(prunedNetLearnables - originalNetLearnables)/originalNetLearnables;
LearnablesForNetwork = [originalNetLearnables;prunedNetLearnables;percentageChangeLearnables];

% Approximate parameter memory
approxOriginalMemory = sum(originalNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)"));
approxPrunedMemory = sum(prunedNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)"));
percentageChangeMemory = 100*(approxPrunedMemory - approxOriginalMemory)/approxOriginalMemory;
NetworkMemory = [ approxOriginalMemory; approxPrunedMemory; percentageChangeMemory];

% Create the summary table
statistics = table(LearnablesForNetwork,NetworkMemory,AccuracyForNetworks, ...
    'VariableNames',["Network Learnables","Approx. Network Memory (MB)","Accuracy"], ...
    'RowNames',{'Original Network','Pruned Network','Percentage Change'});


function [statistics] = analyzeQuantizedNetworkMetrics(originalNet,quantizedNet)

originalNetMetrics = estimateNetworkMetrics(originalNet);
quantizedNetMetrics = estimateNetworkMetrics(quantizedNet);

% Total learnables in both networks
originalNetLearnables = sum(originalNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables);
quantizedNetLearnables = sum(quantizedNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables);
percentageChangeLearnables = 100*(quantizedNetLearnables - originalNetLearnables)/originalNetLearnables;
LearnablesForNetwork = [originalNetLearnables;quantizedNetLearnables;percentageChangeLearnables];

% Approximate parameter memory
approxOriginalMemory = sum(originalNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)"));
approxQuantizedMemory = sum(quantizedNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)"));
percentageChangeMemory = 100*(approxQuantizedMemory - approxOriginalMemory)/approxOriginalMemory;
NetworkMemory = [ approxOriginalMemory; approxQuantizedMemory; percentageChangeMemory];

% Create the summary table
statistics = table(LearnablesForNetwork,NetworkMemory, ...
    'VariableNames',["Network Learnables","Approx. Network Memory (MB)"], ...
    'RowNames',{'Original Network', 'Pruned & Quantized Network','Percentage Change'});



[1] Krizhevsky, Alex. "Learning multiple layers of features from tiny images" (2009).

[2] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning for image recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.

[3] Molchanov, Pavlo, Arun Mallya, Stephen Tyree, Iuri Frosio, and Jan Kautz. “Importance Estimation for Neural Network Pruning.” In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 11256–64. Long Beach, CA, USA: IEEE, 2019.

See Also


Related Topics