Clear Filters
Clear Filters

Trainnetwork to Trainnet conversion

32 views (last 30 days)
Emre Can Ertekin
Emre Can Ertekin on 7 Jul 2024 at 16:08
Answered: Paras Gupta on 15 Jul 2024 at 15:11
Hi there,
I was using Trainnetwork(https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html#mw_408bdd15-2d34-4c0d-ad91-bc83942f7493) function for my study. However, in 2024b trainnet function(https://www.mathworks.com/help/deeplearning/ref/trainnet.html#mw_ffa5eeae-b6e0-444e-a464-91e257cef95b) is slightly faster in computing. I try to convert my Trainnetwork function to trainnet but i can't managed. How can i convert it? My code is written below. Thank you.
%% Train network part
numClasses = numel(categories(trainImgs.Labels));
dropoutProb = 0.2;
layers = [...%my network layers in here.
%% Training Options
options = trainingOptions('adam', ...
'Plots','training-progress',"MiniBatchSize",64, ...
'ValidationData',valImgs,"ExecutionEnvironment","gpu")
%% Training network
trainednet = trainNetwork(trainImgs,layers,options)
% trainednet = trainnet(trainImgs,layers,"crossentropy",options)
  1 Comment
Matt J
Matt J on 7 Jul 2024 at 16:54
Edited: Matt J on 7 Jul 2024 at 17:24
Your post doesn't mention what bad behavior you're seeing with trainnet, nor give us enough detail and input to run the code ourselves.
There's nothing in the way you're calling trainnet that looks "wrong".

Sign in to comment.

Answers (1)

Paras Gupta
Paras Gupta on 15 Jul 2024 at 15:11
Hi Emre,
I understand that you are experiencing issues when transitioning from the 'trainNetwork' function to the 'trainnet' function in MATLAB.
From the code provided in the question, I assume that you are using the same network for both functions. However, the 'trainnet' function requires a slightly modified network architecture that does not include the output layer in the specified layer array. Instead of using an output layer, a loss function is specified using the 'lossFcn' argument.
The following example code illustrates the difference in the network architectures for both the functions:
%% Load and Preprocess Data
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
% Split data into training and validation sets
[trainImgs, valImgs] = splitEachLabel(imds, 0.8, 'randomized');
%% Define Network Architectures
% Network for trainNetwork (includes output layer)
layersTrainNetwork = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
dropoutLayer(0.2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
% Network for trainnet (does not include output layer)
% Instead of the output layer, we specify the loss function in the trainnet syntax
layersTrainnet = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
dropoutLayer(0.2)
fullyConnectedLayer(10)];
%% Training Options
options = trainingOptions('adam', ...
'Plots','training-progress', ...
'MiniBatchSize',64, ...
'ValidationData',valImgs, ...
'ExecutionEnvironment','gpu');
%% Train Network using trainNetwork
trainednet = trainNetwork(trainImgs, layersTrainNetwork, options);
%% Train Network using trainnet
trainednet_ = trainnet(trainImgs, layersTrainnet, "crossentropy", options);
Please refer to the following documentation links for more information on the differences between 'trainNetwork' and 'trainnet' functions:
Hope this helps resolve the issue.

Products


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!