Deep Learning Semantic Segmentation Example
2 views (last 30 days)
Show older comments
In order to familiarize myself with semantic segmentation and convolutional neural networks I am going through this tutorial by MathWorks:
I did not use the pretrained version of Segnet since I wanted to test on my custom data set. All code is the same, however I have different classes, and **fewer labels**. Below image shows the label name and amount of pixels associated with each.

To make up for the low pixel data for class 2, median frequency balancing was performed.
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount
classWeights = median(imageFreq) ./ imageFreq
I proceed to train the network using the code provided in the example with the `options` and `lgraph` unchanged. The SegNet network is created with weights initialized from the VGG-16 network
Unlike the example, I get a much lower global accuracy:

To gain further insight I plotted the Mini-batch accuracy and Mini-batch loss against each iteration.

It is clearly seen that the accuracy fluctuates wildly and ends up worse than it started, so the network learned absolutely nothing! However the loss decreased gradually.
A possible solution I propose would be to use inverse frequency balancing. However, in the example above, median frequency balancing was already performed, so I doubt how much this would help.
Is the terrible performance related to simply not having enough training data? Can anything be be done to improve performance with existing data?
Any suggestions are greatly appreciated.
0 Comments
Answers (1)
Sourabh
on 11 Jun 2025
Hey @Ryan Rizzo
The graph tells that your network is optimizing (loss is decreasing) but not generalizing or learning meaningful class boundaries. You are already trying to fix this using median frequency balancing, which is good, but in very low data scenarios, it can overcompensate, making the model oscillate or diverge.
Instead, you can try smoothing weights:
epsilon = 1e-6;
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ (imageFreq + epsilon);
or try log-scaled weights:
totalPixels = sum(tbl.PixelCount);
classWeights = log(1 + totalPixels ./ tbl.PixelCount);
Segmentation tasks are data hungry. A small dataset can mean poor generalization, incomplete coverage of class variations and noisy or unstable learning. To solve this, you can perform Data Augmentation using MATLAB “imageDataAugmenter”. Apply it as:
imageAugmenter = imageDataAugmenter( ...
'RandRotation', [-20,20], ...
'RandXTranslation', [-10 10], ...
'RandYTranslation', [-10 10], ...
'RandXReflection', true);
Sometimes, unstable training might be caused by learning rate being too high or batch size being too small. You can try reducing “InitialLearnRate” to 1e-4 or lower and use a larger mini-batch size if possible.
For more information and examples on “imageDataAugmenter”, kindly refer the following MATLAB documentation:
0 Comments
See Also
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!