Clear Filters
Clear Filters

PPO and LSTM agent creation

18 views (last 30 days)
Sourabh
Sourabh on 12 Dec 2023
Commented: Sourabh on 23 Dec 2023
I am trying to implement PPO and LSTM and I am getting the error as
"To train an agent that has states, all actor and critic representations for that agent must have states."
but i am giving the states as input then why am i getting the error.
-----------------------------------------------
obsInfo = rlNumericSpec([3 1],...
'LowerLimit',[ -inf -inf -inf ]',...
'UpperLimit',[ inf inf inf ]');
numObservations = obsInfo.Dimension(1);
actInfo = rlNumericSpec([2 1],...
'LowerLimit',[ -inf -inf ]',...
'UpperLimit',[ inf inf ]');
numActions = actInfo.Dimension(1);
env = rlSimulinkEnv('simmodelppo','simmodelppo/RL Agent',...
obsInfo,actInfo);
env.ResetFcn = @(in)localResetFcn(in);
rng(0)
criticNetwork = [
sequenceInputLayer(prod(obsInfo.Dimension));
fullyConnectedLayer(64);
tanhLayer;
fullyConnectedLayer(64);
tanhLayer;
lstmLayer(5,'Name','lstm1');
fullyConnectedLayer(1)];
criticNetwork = dlnetwork(criticNetwork);
critic = rlValueFunction(criticNetwork,obsInfo);
commonPath = [
featureInputLayer(prod(obsInfo.Dimension),Name="comPathIn")
fullyConnectedLayer(150)
tanhLayer
fullyConnectedLayer(1,Name="comPathOut")
];
% Define mean value path
meanPath = [
fullyConnectedLayer(50,Name="meanPathIn")
tanhLayer
fullyConnectedLayer(50,Name="fc_2")
tanhLayer
lstmLayer(5,'Name','lstm1');
fullyConnectedLayer(prod(actInfo.Dimension))
leakyReluLayer(0.01,Name="meanPathOut")
];
% Define standard deviation path
sdevPath = [
fullyConnectedLayer(50,"Name","stdPathIn")
tanhLayer
lstmLayer(5,'Name','lstm2');
fullyConnectedLayer(prod(actInfo.Dimension));
reluLayer
scalingLayer(Scale=0.9,Name="stdPathOut")
];
% Add layers to layerGraph object
actorNet = layerGraph(commonPath);
actorNet = addLayers(actorNet,meanPath);
actorNet = addLayers(actorNet,sdevPath);
% Connect paths
actorNet = connectLayers(actorNet,"comPathOut","meanPathIn/in");
actorNet = connectLayers(actorNet,"comPathOut","stdPathIn/in");
actorNetwork = dlnetwork(actorNet);
actor = rlContinuousGaussianActor(actorNetwork, obsInfo, actInfo, ...
"ActionMeanOutputNames","meanPathOut",...
"ActionStandardDeviationOutputNames","stdPathOut",...
ObservationInputNames="comPathIn");
%%
agentOpts = rlPPOAgentOptions(...
'SampleTime',800,...
'ClipFactor',0.2,...
'NumEpoch',3,...
'EntropyLossWeight',0.025,...
'AdvantageEstimateMethod','finite-horizon',...
'DiscountFactor',0.99, ...
'MiniBatchSize',64, ...
'ExperienceHorizon',128);
agent = rlPPOAgent(actor,critic,agentOpts);
agent.AgentOptions.CriticOptimizerOptions.LearnRate = 0.0001;
agent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;
agent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;
agent.AgentOptions.ActorOptimizerOptions.LearnRate = 0.0003;
%%
maxepisodes = 6000;
maxsteps = 1;
trainingOpts = rlTrainingOptions(...
'MaxEpisodes',maxepisodes,...
'MaxStepsPerEpisode',1,...
'ScoreAveragingWindowLength',5, ...
'Verbose',false,...
'Plots','training-progress',...
'StopTrainingCriteria','AverageReward',...
'StopTrainingValue',-20);
% TO TRAIN
doTraining = true;
if doTraining
trainingStats = train(agent,env,trainingOpts);
% save('agent_new.mat','agent') %%% to save agent ###
else
% Load pretrained agent for the example.
load('agent_old.mat','agent')
end
%%
function in = localResetFcn(in) %%%%%%% RANDOM INPUT GENERATOR %%%%%%%
% randomize setpoints -- ensure feasible
set_point = 1 + 0.0*rand; % Set-point [0,1]
in = setBlockParameter(in,'simmodelppo/Set','Value',num2str(set_point));
end

Answers (2)

Venu
Venu on 19 Dec 2023
Edited: Venu on 20 Dec 2023
The error is likely occurring because the LSTM layers require explicit handling of their states, which is not just about feeding in the external states (observations) but also about managing the internal states of the LSTM layers. In the code you provided, there is a local function "localResetFcn" that is used as the environment's reset function. However, this function currently only sets the parameter 'Value' for the block 'simmodelppo/Set' in your Simulink model.
In the code snippet you've provided, there is no explicit handling of the LSTM states, so unless the "rlContinuousGaussianActor" and "rlValueFunction" objects handle this internally in a way that is not shown, you would need to add this functionality.
Here's an example of how you might modify your "localResetFcn" to include resetting the LSTM states:
function in = localResetFcn(in) %%%%%%% RANDOM INPUT GENERATOR %%%%%%%
% randomize setpoints -- ensure feasible
set_point = 1 + 0.0*rand; % Set-point [0,1]
in = setBlockParameter(in, 'simmodelppo/Set', 'Value', num2str(set_point));
% Reset the LSTM states of the actor and critic networks
% Note: The following is an example and may need adjustment
% to match the specifics of your MATLAB version and network setup.
% Reset actor LSTM states
actor = getActor(agent);
actor = resetState(actor);
agent = setActor(agent, actor);
% Reset critic LSTM states
critic = getCritic(agent);
critic = resetState(critic);
agent = setCritic(agent, critic);
end
Since the "localResetFcn" currently does not have access to the 'agent' variable, you would need to modify your training setup to ensure that the agent's states can be reset at the start of each episode. This might involve changes to how the 'agent' variable is passed around or stored.
  4 Comments
Venu
Venu on 20 Dec 2023
Edited: Venu on 23 Dec 2023
does the issue still persist?
Sourabh
Sourabh on 23 Dec 2023
its fixed now I had to change feature inputlayer to sequence i/p layer in critic and actor

Sign in to comment.


Emmanouil Tzorakoleftherakis
Hi,
With lstm policies, BOTH the actor and the critic should have lstm layers. That's why you are getting this error.
LSTM policies tend to be harder to architect, so I would siggest using the default agent feature to get an initial architecture. See for example here. Don't forget to indicate you want an rnn policy in the agent initialization options.
Hope this helps

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!