RLToolboxのGridWorldについて
1 view (last 30 days)
Show older comments
shoki kobayashi
on 14 Jun 2020
Commented: shoki kobayashi
on 28 Jul 2020
GridWorldをQ学習で解くのに困っています。
GridWorldを解くプログラミングを作ったのですが、Agentが上手に学習してくれないです
どのように改善すればよろしいでしょうか
%迷路の作成
GW = createGridWorld(8,8);
GW.CurrentState = '[2,1]';
GW.TerminalStates = '[8,8]';
GW.ObstacleStates = ["[3,3]";"[3,4]";"[3,5]";"[3,6]";"[3,7]";"[4,3]";"[7,3]";"[6,3]";"[5,3]"];
updateStateTranstionForObstacles(GW)
GW.T(state2idx(GW,"[2,4]"),:,:) = 0;
GW.T(state2idx(GW,"[2,4]"),state2idx(GW,"[4,4]"),:) = 1;
nS = numel(GW.States);
nA = numel(GW.Actions);
GW.R = -1*ones(nS,nS,nA);
GW.R(state2idx(GW,"[4,2]"),state2idx(GW,"[5,2]"),:) = 5;
GW.R(state2idx(GW,"[8,3]"),state2idx(GW,"[8,4]"),:) = 5;
GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;
%環境の読み込み
env = rlMDPEnv(GW)
env.ResetFcn = @() 2;
rng(0)
%Q学習
qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qRepresentation = rlQValueRepresentation(qTable,getObservationInfo(env),getActionInfo(env));
qRepresentation.Options.LearnRate = 1;
agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = .04;
qAgent = rlQAgent(qRepresentation,agentOpts);
trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 200;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 101;
trainOpts.ScoreAveragingWindowLength = 30;
doTraining = false;
if doTraining
% Train the agent.
trainingStats = train(qAgent,env,trainOpts);
end
%結果の描画
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(qAgent,env)
0 Comments
Accepted Answer
Kazuaki Yamada
on 28 Jul 2020
次の通り変更すると学習しました.
12-13行目をコメントアウト
32行目のfalseをtrueに変更
%迷路の作成
GW = createGridWorld(8,8);
GW.CurrentState = '[2,1]';
GW.TerminalStates = '[8,8]';
GW.ObstacleStates = ["[3,3]";"[3,4]";"[3,5]";"[3,6]";"[3,7]";"[4,3]";"[7,3]";"[6,3]";"[5,3]"];
updateStateTranstionForObstacles(GW)
GW.T(state2idx(GW,"[2,4]"),:,:) = 0;
GW.T(state2idx(GW,"[2,4]"),state2idx(GW,"[4,4]"),:) = 1;
nS = numel(GW.States);
nA = numel(GW.Actions);
GW.R = -1*ones(nS,nS,nA);
%GW.R(state2idx(GW,"[4,2]"),state2idx(GW,"[5,2]"),:) = 5; %--- ?
%GW.R(state2idx(GW,"[8,3]"),state2idx(GW,"[8,4]"),:) = 5; %--- ?
GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;
%環境の読み込み
env = rlMDPEnv(GW)
env.ResetFcn = @() 2;
rng(0)
%Q学習
qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qRepresentation = rlQValueRepresentation(qTable,getObservationInfo(env),getActionInfo(env));
qRepresentation.Options.LearnRate = 1;
agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = .04;
qAgent = rlQAgent(qRepresentation,agentOpts);
trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 200;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 101;
trainOpts.ScoreAveragingWindowLength = 30;
doTraining = true; %--- trueにしないと以下のif文に入らない
if doTraining
% Train the agent.
trainingStats = train(qAgent,env,trainOpts);
end
%結果の描画
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(qAgent,env)
More Answers (0)
See Also
Categories
Find more on 行列計算 in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!