Importing pre-trained recurrent network to reinforcement learning agent
3 views (last 30 days)
Show older comments
Javier Maruenda
on 28 May 2020
Commented: Javier Maruenda
on 1 Jun 2020
Hello,
Are pre-trained recurrent networks re-initialized when used in agents for reinforment learning? If so, how can it be avoided?
I am importing a LSTM network trained using supervised training as the actor for a PPO agent. When simulating without training the reward is fine, however If the agent is trained the reward falls as if no pre-trained network was used. I would expect the reward to be similar or higher after training so presumably the network is being re-initialized, is there a way around it?
Thanks
% Load actor
load(netDir);
actorNetwork = net.Layers;
actorOpts = rlRepresentationOptions('LearnRate',learnRate);
actor = rlStochasticActorRepresentation(actorNetwork,obsInfo,actInfo,'Observation',{'input'},actorOpts);
% Create critic
criticNetwork = [sequenceInputLayer(numObs,"Name","input")
lstmLayer(numObs)
softplusLayer()
fullyConnectedLayer(1)];
criticOpts = rlRepresentationOptions('LearnRate',learnRate);
critic = rlValueRepresentation(criticNetwork,obsInfo,'Observation',{'input'},criticOpts);
% Create agent
agentOpts = rlPPOAgentOptions('ExperienceHorizon',expHorizon, 'MiniBatchSize',miniBatchSz, 'NumEpoch',nEpoch, 'ClipFactor', 0.1);
agent = rlPPOAgent(actor,critic,agentOpts);
% Train agent
trainOpts = rlTrainingOptions('MaxEpisodes',episodes, 'MaxStepsPerEpisode',episodeSteps, ...
'Verbose',false, 'Plots','training-progress', ...
'StopTrainingCriteria', 'AverageReward', ...
'StopTrainingValue',10);
% Run training
trainingStats = train(agent,env,trainOpts);
% Simulate
simOptions = rlSimulationOptions('MaxSteps',2000);
experience = sim(env,agent,simOptions);
0 Comments
Accepted Answer
Ryan Comeau
on 29 May 2020
Hello,
So, transfer learning does not work the same in RL as it does in DL. In DL, there are no environment physics that need to be understood. Recall that neural networks are really just non-linear curve fitting tools. In DL the way transfer learning works, is you take a pre-trained feature extraction network. This learns which shapes are useful(lines, circles and so on). You then add some of your own images to the mix and obtain some curve fitting results.
In MATLAB's current RL framework, we are not extracting information from images using a CNN, we are supplying observations as a vector. This means a transfer learning will not bring any usefulness to you. As well, the transfer learning cannot know the physics of the enviroment that you've made. It will not understand what to do if you halfed gravity for example(because gravity is not observable to the actor). So it has no way of being useful for you.
Hope this helps,
RC
More Answers (0)
See Also
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!