Clear Filters
Clear Filters

How to plot confusion matrix?

11 views (last 30 days)
Adrian Kleffler
Adrian Kleffler on 22 May 2023
Edited: Venkat Siddarth on 29 May 2023
Hello guys, I want to plot confusion matrix after training an object detector ... Here is my code ... How to plot confusion matrix ?
data = load("letisko_labels_new.mat");
LabelData = data.gTruth.LabelData;
LabelData.imageFilename = fullfile(LabelData.imageFilename);
rng("default");
shuffledIndices = randperm(height(LabelData));
idx = floor(0.6 * length(shuffledIndices) );
trainingIdx = 1:idx;
trainingDataTbl = LabelData(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = LabelData(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = LabelData(shuffledIndices(testIdx),:);
imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:6));
imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"});
bldsValidation = boxLabelDatastore(validationDataTbl(:,2:6));
imdsTest = imageDatastore(testDataTbl{:,"imageFilename"});
bldsTest = boxLabelDatastore(testDataTbl(:,2:6));
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
validateInputData(trainingData);
validateInputData(validationData);
validateInputData(testData);
inputSize = [256 256 3];
className = ["kamera","lietadlo","satelit","stlp","veza"];
rng("default")
trainingDataForEstimation = transform(trainingData,@(data)preprocessData(data,inputSize));
numAnchors = 9;
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);
area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");
anchors = anchors(idx,:);
anchorBoxes = {anchors(1:3,:)
anchors(4:6,:)
anchors(7:9,:)
};
detector = yolov4ObjectDetector("csp-darknet53-coco",className,anchorBoxes,InputSize=inputSize);
augmentedTrainingData = transform(trainingData,@augmentData);
options = trainingOptions("adam",...
GradientDecayFactor=0.9,...
SquaredGradientDecayFactor=0.999,...
InitialLearnRate=0.001,...
LearnRateSchedule="none",...
MiniBatchSize=4,...
L2Regularization=0.0005,...
MaxEpochs=50,...
BatchNormalizationStatistics="moving",...
DispatchInBackground=true,...
ResetInputNormalization=false,...
Shuffle="every-epoch",...
VerboseFrequency=20,...
ValidationFrequency=1000,...
Plots="training-progress",...
CheckpointPath='C:\BAKALARKA\checkpointYOLO',...
ValidationData=validationData);
doTraining = true;
if doTraining
% Train the YOLO v4 detector.
[detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options);
else
% Load pretrained detector for the example.
detector = downloadPretrainedYOLOv4Detector();
end
detectionResults = detect(detector,testData,'MiniBatchSize',4);
[ap,recall,precision] = evaluateDetectionPrecision(detectionResults,testData);
recallv = cell2mat(recall);
precisionv = cell2mat(precision);
[r,index] = sort(recallv);
p = precisionv(index);
figure
plot(r,p)
xlabel("Recall")
ylabel("Precision")
grid on
title(sprintf("Average Precision = %.2f",mean(ap)))

Answers (1)

Venkat Siddarth
Venkat Siddarth on 29 May 2023
Edited: Venkat Siddarth on 29 May 2023
I understand that you are looking to plot confusion matrix for the model. Here I am assuming that you want to plot the confusion matrix for the labels column in detectionResults,which can be achieved by using a function called confusionmat. This function takes two vectors as inputs, the true labels and the predicted labels and produces a confusion matrix.
y_true=[1 0 1 1 1 1 0 0];
y_pred=[1 1 1 1 0 0 1 1];
C=confusionmat(y_true,y_pred)
C = 2×2
0 3 2 3
After generating the confusion matrix you can plot the confusion matrix using the function confusionchart
confusionchart(C)
To know more about these functions, check out the following documentation
I hope this resolves the issue,
Regards
Venkat Siddarth V.

Categories

Find more on Recognition, Object Detection, and Semantic Segmentation in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!