You are now following this question
- You will see updates in your followed content feed.
- You may receive emails, depending on your communication preferences.
matlabのディープラーニングでは、なぜテストデータを使わずにバリデーションデータを使うのか
2 views (last 30 days)
Show older comments
プログラミング初心者です。
下記リンクにつきまして、
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);
という一文がありますが、なぜ、テストデータを使わずにバリデーションデータを使うのでしょうか。
imdsValidationではなく、imdsTestだと納得できるのですが不思議です。
もしバリデーションデータを使うのであれば、テストデータは使わなくてもいいかご教示頂けますと幸いです。
Accepted Answer
Kenta
on 12 Mar 2019
-
-
Direct link to this answer
⋮
-
-
Direct link to this answer
単に、ここではバリデーションデータをテストデータと読み替えて問題ないと思います。また、以下のように、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2);
などとして、画像を訓練、バリデーション、テストデータに分けると良いかもしれません。
リンクの学習曲線のところでは、バリデーションデータを使います。
そして、最後のところで
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
とすると、テストデータで正答率を計算できます。ここで、optionsのところに
'ValidationPatience', 3
を追加すれば学習の早期終了ができます。「'ValidationPatience' の値は、ネットワークの学習が停止するまでに、検証セットでの損失が前の最小損失以上になることが許容される回数です。」
とあります。学習がある程度のところで限界が来たらそこで学習がストップするので学習時間を短縮できたり、過学習が抑えられる可能性があります。
11 Comments
ssk
on 12 Mar 2019
itakuraさま、いつもご回答いただきまして誠にありがとうございます。本例につきまして、バリデーションデータとテストデータを置き換えることができる旨、ご教示いただきありがとうございます。imdstestも使った例でも試してみます。また、'ValidationPatience', も利用してみたいと思います。
あわせて、下記リンクにつきましてitakura様宛に追加でご質問させていただきましたのでご覧頂けますと幸いです。
https://jp.mathworks.com/matlabcentral/answers/447586-cross-validation
ssk
on 13 Mar 2019
トレーニング、テスト、バリデーションの3つに分けたコードを試しに作成してみたのですが、以下のコードでご趣旨を反映できておりますでしょうか。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsValidation' ,'=', stname2,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
[imds11,imds12,imds13,imds14,imds15]...
= splitEachLabel(imds,0.2,0.2,0.2,0.2,'randomize');
imdsTest11 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds14.Files));
imdsTest11.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds14.Labels);
imdsTest12 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds15.Files));
imdsTest12.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds15.Labels);
imdsTest13 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds14.Files,imds15.Files));
imdsTest13.Labels = cat(1,imds11.Labels,imds12.Labels,imds14.Labels,imds15.Labels);
imdsTest14 = imageDatastore(cat(1,imds11.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest14.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
imdsTest15 = imageDatastore(cat(1,imds12.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest15.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
%% training for test data
accuracy=zeros(11,15);
for i3=11:15
stname3=sprintf('imdsTest%d',i3);
eval(['imdsTest' ,'=', stname3,';'])
%imdsTest.ReadFcn = @(filename)resize(filename);
i4=15-i+1;
stname4=sprintf('imds0%d',i4);
eval(['imdsValidation' ,'=', stname4,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
%%train network(中略)
[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
Kenta
on 14 Mar 2019
i番目のループのなかで、トレーニングデータ(仮)をトレーニングデータとバリデーションデータに分けたらいいと思います。そして、バリデーションデータをテストデータ(ただ名前を変えるだけ)としてテストしたらいいです。
ある程度までロスが下がり切ったりしたら計算時間が冗長になるし、訓練データに過適合するのを防げます。ただ、たくさんの枚数をこなしたときに必ずしももこの操作が必要かどうかは不明です。1クラス100枚くらいで交差検証なしでやってみてはどうでしょうか。CPUで計算してもそこまで計算時間はかからないと思います。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsTest' ,'=', stname2,';'])
imdsTest.ReadFcn = @(filename)resize(filename);
%% training for test data
%imdstrainで訓練
%imdsvalidationをoptionsのなかのvalidationに指定
%imdstestでテスト
ssk
on 14 Mar 2019
Edited: ssk
on 14 Mar 2019
ありがとうございます!コードを試したところ無事に動きました。本コードにおけるクロスバリデーションのニュアンスの確認をしたいのですが、はじめに全ての画像をtrainingとして均等に10分割し、さらに10分割した画像をそれぞれtraining:validation = 8:2で分ける。このとき、testはvalidationと同視できるので、training:test = 8:2である。(つまり、本データの8割をtraining、2割をtest(validation)として使う。その後、組み合わせをかえてそれぞれの画像のaccuracyを調べて平均を取る。上記の認識でよろしいでしょうか?
以前あった例ですと、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2,0.1); で合計が100%ですが、今回の場合は、[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.8,0.2,0.2);で合計120%のような気もするのですが、例えば[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.6,0.2,0.2);のような形で修正する必要はないのでしょうか?
また、なぜテストデータとバリデーションデータを同視できるか理由をご存知でしたらご教示いただけますと幸いです。
Kenta
on 15 Mar 2019
今回の場合は、1000枚あったとすると、900, 100に分けて、その900をさらに800と100に分けたイメージです。
もちろんおっしゃるように、1000枚を一気に8:1:1に分けても等価です。そのサイクルを10回して、それを平均すればいいです。
同視できるというのは、バリデーションデータと名付けられたデータでテストをしているので、その場合に限り、バリデーションデータをテストデータと読み替えて問題ないのでは?ということです。
本来、バリデーションデータとテストデータは異なったニュアンスを持っているものと思います。
ssk
on 15 Mar 2019
ニュアンスをご教示いただきありがとうございました!
覚えが悪く大変申し訳なのでもう一度確認しますが、(1,000枚の画像がある場合、)まず900(training), 100(test)に分けて、その後で900を更に800(training)、100(varidation)に分けるということですね。以上から、800(training)、100(test)、100(validation)になるということでしょうか。
上記コードですと、for loop内では以下のようになっておりますので
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
imstrain1の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds10
これを順に10回続けていって・・・
imstrain10の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds01
以上の平均を求めるという認識でよろしいでしょうか?
More Answers (0)
See Also
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!An Error Occurred
Unable to complete the action because of changes made to the page. Reload the page to see its updated state.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom(English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)