Error when retrieving policies from agents trained with DDPG for use in the Predict block of the Deep Learning toolbox in Simulink
34 views (last 30 days)
Show older comments
After learning an Agent with the DDPG algorithm, I want to take the policy from the Agent and use the Predict function in Simulink's Deep Learning Toolbox.
The Agent has been successfully learned by DDPG.
However, in Simulink, in the generatePolicyFunction(agent,'MATFileName','QubeIPBalRLPolicy.mat'), policy is not generated and policydata is generated. So, I don't see a variable named policy and an error occurs in policy = cut_unnecessary_layers_for_SAC_policy(policy);.
Also, when using the model generated by generatePolicyFunction(agent,'MATFileName','QubeIPBalRLPolicy.mat') in the Predict function of the Deep Learning Toolbox in Simulink, the following error The following error occurs.
Error in 'q_qube2_bal_rl_hw/Agent/Predict': Could not evaluate mask initialization command.
Cause:The network input data format string cannot be empty
Please note that my Matlab/Simulink data is based on the following package.
statePath = [
imageInputLayer([numObs 1 1],'Normalization','none','Name','observation')
fullyConnectedLayer(128,'Name','CriticStateFC1')
reluLayer('Name', 'CriticRelu1')
fullyConnectedLayer(200,'Name','CriticStateFC2')];
actionPath = [
imageInputLayer([1 1 1],'Normalization','none','Name','action')
fullyConnectedLayer(200,'Name','CriticActionFC1','BiasLearnRateFactor',0)];
commonPath = [
additionLayer(2,'Name','add')
reluLayer('Name','CriticCommonRelu')
fullyConnectedLayer(1,'Name','CriticOutput')];
%
criticNetwork = layerGraph();
criticNetwork = addLayers(criticNetwork,statePath);
criticNetwork = addLayers(criticNetwork,actionPath);
criticNetwork = addLayers(criticNetwork,commonPath);
%
criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
% plot network
plot(criticNetwork)
%
% Specify options for the critic representation .
criticOpts = rlRepresentationOptions('LearnRate',1e-3,'GradientThreshold',1);
%
% Create critic representation using the specified deep neural network
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
critic = rlQValueRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'observation'},'Action',{'action'},criticOpts);
%
% Create the actor, first create a deep neural network with one input,
% the observation, and one output, the action.
% Construct actor similarly to the critic.
actorNetwork = [
imageInputLayer([numObs 1 1],'Normalization','none','Name','observation')
fullyConnectedLayer(128,'Name','ActorFC1')
reluLayer('Name','ActorRelu1')
fullyConnectedLayer(200,'Name','ActorFC2')
reluLayer('Name','ActorRelu2')
fullyConnectedLayer(1,'Name','ActorFC3')
tanhLayer('Name','ActorTanh')
scalingLayer('Name','ActorScaling','Scale',max(actInfo.UpperLimit))];
%
actorOpts = rlRepresentationOptions('LearnRate',5e-04,'GradientThreshold',1);
%
actor = rlDeterministicActorRepresentation(actorNetwork,obsInfo,actInfo,'Observation',{'observation'},'Action',{'ActorScaling'},actorOpts);
% specify agent options
agentOpts = rlDDPGAgentOptions(...
'SampleTime',Ts,...
'TargetSmoothFactor',1e-3,...
'ExperienceBufferLength',1e6,...
'DiscountFactor',0.99,...
'MiniBatchSize',128);
% For continuous action signals, it is important to set the noise variance
% appropriately to encourage exploration. It is common to have
% Variance*sqrt(Ts) be between 1% and 10% of your action range
agentOpts.NoiseOptions.Variance = 0.4;
agentOpts.NoiseOptions.VarianceDecayRate = 1e-5;
%
% create the DDPG agent using the specified actor representation, critic
% representation and agent options.
agent = rlDDPGAgent(actor,critic,agentOpts);
% Load pre-defined policy (false) or generate new policy for RT code gen (true)
% doPolicy = false;
doPolicy = true;
%
if doPolicy
% Generate policy for deployment
generatePolicyFunction(agent,'MATFileName','QubeIPBalRLPolicy.mat');
load("QubeIPBalRLPolicy.mat");
policy = cut_unnecessary_layers_for_SAC_policy(policy);
save("QubeIPBalRLPolicy.mat", 'policy');
else
% Load previously saved policy
load("QubeIPBalRLPolicy.mat");
end
0 Comments
Answers (0)
See Also
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!