Oversampling for deep learning: classification example
Image classification using oversampling imagedatastore
This example shows how to classify images with imbalanced training dataset where the number of images per class is different over classes. Two of the most popular solutions are down-sampling and over-sampling. In down-sampling, the number of images per class is reduced to the minimal number of images among all classes. On the other hand, the number of images per class is increased when performing over-sampling. Both strategies are effective for imbalanced datasets. Implementation of down-sampling is easy: just use splitEachLabel function and specify the minimal number over the classes, however, over-sampling requires more complicated processes.
This code does not use transform datastore that cannot be associated with augmented datastore. It means, it becomes difficult to perform data augmentation with transform datastore as of 2020a.
Load data
Please download Food image dataset provided from MathWorks. The Example Food Images data set contains 978 photographs of food in nine classes (ceaser_salad, caprese_salad, french_fries, greek_salad, hamburger, hot_dog, pizza, sashimi, and sushi).
Other dataset is available at https://jp.mathworks.com/help/deeplearning/ug/data-sets-for-deep-learning.html.
clear;clc;close all
url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip";
downloadFolder = pwd;
filename = fullfile(downloadFolder,'ExampleFoodImageDataset.zip');
dataFolder = fullfile(downloadFolder, "ExampleFoodImageDataset");
if ~exist('ExampleFoodImageDataset.zip')
fprintf("Downloading Example Food Image data set (77 MB)... ")
imds=imageDatastore('myimages', ...
[imdsTrain, imdsValid,imdsTest]=splitEachLabel(imds,0.8,0.1);
Confirm the imbalance
Note that the number of images among the classes is imbalanced.
labelCount = countEachLabel(imdsTrain)
Label | Count | |
1 | caesar_salad | 21 |
2 | caprese_salad | 12 |
3 | french_fries | 145 |
4 | greek_salad | 19 |
5 | hamburger | 190 |
6 | hot_dog | 25 |
7 | pizza | 239 |
8 | sashimi | 32 |
9 | sushi | 99 |
I refferred to https://jp.mathworks.com/help/vision/examples/point-cloud-classification-using-pointnet-deep-learning.html for this section.
[G,classes] = findgroups(labels);
numObservations = splitapply(@numel,labels,G);
is the maximum number of sample among all classes.
desiredNumObservationsPerClass = max(numObservations);
is a supporting function just shuffling the files. The number of images to select is difined by desiredNumObservationsPerClass
. Then, the files are randomly extracted from imdsTrain.Files
files = splitapply(@(x){randReplicateFiles(x,desiredNumObservationsPerClass)},imdsTrain.Files,G);
files = vertcat(files{:});
for i=1:numel(files)
imdsTrain.Files = files;
labelCount_oversampled = countEachLabel(imdsTrain)
Label | Count | |
1 | caesar_salad | 239 |
2 | caprese_salad | 239 |
3 | french_fries | 239 |
4 | greek_salad | 239 |
5 | hamburger | 239 |
6 | hot_dog | 239 |
7 | pizza | 239 |
8 | sashimi | 239 |
9 | sushi | 239 |
Load the pre-trained model, ResNet-18
net = resnet18;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
Modify the network for the current task
numClasses = numel(categories(imds.Labels));
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
lgraph = replaceLayer(lgraph,learnableLayer,newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer,newClassLayer);
Define image augmenter
pixelRange = [-30 30];
RotationRange = [-30 30];
scaleRange = [0.8 1.2];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange, ...
'RandXScale',scaleRange, ...
'RandYScale',scaleRange, ...
'RandRotation',RotationRange ...
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
augimdsValid = augmentedImageDatastore(inputSize(1:2),imdsValid);
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);
Specify the training options
miniBatchSize = 64;
valFrequency = max(floor(numel(augimdsTest.Files)/miniBatchSize)*10,1);
options = trainingOptions('sgdm', ...
'MiniBatchSize',miniBatchSize, ...
'MaxEpochs',30, ...
'InitialLearnRate',3e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValid, ...
'ValidationFrequency',valFrequency, ...
'Verbose',false, ...
Train the network
net = trainNetwork(augimdsTrain,lgraph,options);
Classification assessment
[YPred,probs] = classify(net,augimdsTest);
accuracy = mean(YPred == imdsTest.Labels)
accuracy = 0.9072
YValidation = imdsTest.Labels;
When I run this code, the main mis-classification was made between sashimi and sushi, which look similar. Please try this code for over-sampling and hope it helps your work.
Supporting function
This sub-function randomly replicate the image directry for each class. When the target class is A, the image directry of image A was found from imds.Files and the image directry was copied in order to balance the number of images over the classes.
function files = randReplicateFiles(files,numDesired)
n = numel(files);
ind = randi(n,numDesired,1);
files = files(ind);
Cite As
Kenta (2025). Oversampling for deep learning: classification example (https://github.com/KentaItakura/Image-classification-using-oversampling-imagedatastore/releases/tag/2.0), GitHub. Retrieved .
MATLAB Release Compatibility
Platform Compatibility
Windows macOS LinuxTags
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!Discover Live Editor
Create scripts with code, output, and formatted text in a single executable document.
Version | Published | Release Notes | |
2.0 | See release notes for this release on GitHub: https://github.com/KentaItakura/Image-classification-using-oversampling-imagedatastore/releases/tag/2.0 |
1.0.0 |