Create and Train Custom PG Agent
This example shows how to create and train a custom PG agent. A custom agent allows you to leverage the following built-in functionality from the Reinforcement Learning Toolbox™ software:
In this example, you convert a custom REINFORCE training loop into a custom agent class, and then train an object of this class (your custom agent) using train
. For more information on writing custom agent classes, see Create Custom Reinforcement Learning Agents. For an example that shows how to create and train a custom agent that learns to solve an LQR problem, see Create and Train Custom LQR Agent.
For more information on custom training loops (that is, loops that do not rely on train
or sim
), see instead Train Reinforcement Learning Policy Using Custom Training Loop.
The example code may involve computation of random numbers at various stages such as initialization of the agent, creation of the actor and critic, resetting the environment during simulations, generating observations (for stochastic environments), generating exploration actions, and sampling min-batches of experiences for learning. Fixing the random number stream preserves the sequence of the random numbers every time you run the code and improves reproducibility of results. You will fix the random number stream at various locations in the example.
Fix the random number stream with the seed 0
and random number algorithm Mersenne Twister. For more information on random number generation see rng
.
previousRngState = rng(0,"twister")
previousRngState = struct with fields:
Type: 'twister'
Seed: 0
State: [625x1 uint32]
The output previousRngState
is a structure that contains information about the previous state of the stream. You will restore the state at the end of the example.
Create Environment Object
Create the same training environment used in the Train Reinforcement Learning Policy Using Custom Training Loop example. The environment is a cart-pole balancing environment with a discrete action space. Create the environment using the rlPredefinedEnv
function.
env = rlPredefinedEnv("CartPole-Discrete");
Extract the observation and action specifications from the environment.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
Obtain the dimension of the observation space (numObs
) and the number of possible actions (numAct
).
numObs = obsInfo.Dimension(1); numAct = numel(actInfo.Elements);
For more information on this environment, see Load Predefined Control System Environments.
Define Policy
The reinforcement learning policy in this example is a parametrized discrete action space stochastic policy, which is learned by a discrete categorical actor. This actor takes an observation as input and returns as output a random action sampled (among the finite number of possible actions) from a categorical probability distribution.
To model the parametrized policy within the actor, use a neural network with one input layer (which receives the content of the environment observation channel, as specified by obsInfo
) and one output layer. The output layer must return a vector of probabilities for each possible action, as specified by actInfo
.
Define the network as an array of layer objects, using fullyConnectedLayer
, reluLayer
, and softmaxLayer
layers. The softmaxLayer
ensures that the policy outputs probability values in the range [0 1] and that all probabilities sum to 1.
actorNetwork = [ featureInputLayer(numObs) fullyConnectedLayer(24) reluLayer fullyConnectedLayer(24) reluLayer fullyConnectedLayer(2) softmaxLayer ];
When you initialize a dlnetwork
object, the network weights are initialized with random values. Fix the random number stream so that the network is always initialized with the same weight values.
rng(0,"twister");
Convert to a dlnetwork
object and summarize its properties.
actorNetwork = dlnetwork(actorNetwork)
actorNetwork = dlnetwork with properties: Layers: [7x1 nnet.cnn.layer.Layer] Connections: [6x2 table] Learnables: [6x3 table] State: [0x3 table] InputNames: {'input'} OutputNames: {'softmax'} Initialized: 1 View summary with summary.
summary(actorNetwork)
Initialized: true Number of learnables: 770 Inputs: 1 'input' 4 features
Create the actor using an rlDiscreteCategoricalActor
object.
actor = rlDiscreteCategoricalActor(actorNetwork,obsInfo,actInfo);
Create the optimizer options object using rlOptimizerOptions
.
actorOpts = rlOptimizerOptions(LearnRate=1e-3);
Custom Agent Class
To define your custom agent, first create a class that is a subclass of the rl.agent.CustomAgent
class. The custom agent class for this example is defined in CustomReinforceAgent.m
.
The CustomReinforceAgent
class has the following class definition, which indicates the agent class name and the associated abstract agent.
classdef CustomReinforceAgent < rl.agent.CustomAgent
To define your agent you must specify the following:
Agent properties
Constructor function
Critic approximator, to estimate the value of the policy (if needed)
Actor, to learn the policy (if needed)
Required agent methods
Optional agent methods
Agent Properties
In the properties
section of the class file, specify any parameters necessary for creating and training the agent.
The rl.Agent.CustomAgent
class already includes properties for the agent sample time (SampleTime
) and the action and observation specifications (ActionInfo
and ObservationInfo
, respectively).
The custom REINFORCE agent defines the following additional agent properties.
properties % Policy Policy ActorOptimizer % Agent options Options % Experience buffer ObservationBuffer ActionBuffer RewardBuffer MaskBuffer end properties (Access = private) % Training utilities Counter NumObservation NumAction end properties (Access = private,Transient) % Accelerated gradient function, not saved with the agent AccelGradFcn = [] end
Constructor Function
To create your custom agent, you must define a constructor function. The constructor function performs the following actions.
Defines the action and observation specifications. For more information about creating these specifications, see
rlNumericSpec
andrlFiniteSetSpec
.Sets the agent properties.
Calls the constructor of the base abstract class.
Defines the sample time (required for training in Simulink environments).
For example, the CustomReinforceAgent
constructor defines action and observation spaces based on the input actor.
function obj = CustomReinforceAgent(Actor,Options) %CUSTOMREINFORCEAGENT Construct custom agent % AGENT = CUSTOMREINFORCEAGENT(ACTOR,OPTIONS) creates custom % REINFORCE AGENT from rlDiscreteCategoricalActor ACTOR % and structure OPTIONS. OPTIONS has fields: % - DiscountFactor % - MaxStepsPerEpisode % (required) Call the abstract class constructor. obj = obj@rl.agent.CustomAgent(); obj.ObservationInfo = Actor.ObservationInfo; obj.ActionInfo = Actor.ActionInfo; % (required for Simulink environment) Register sample time. % For MATLAB environment, use -1. obj.SampleTime = -1; % (optional) store the policy and agent options. obj.Policy = rlStochasticActorPolicy(Actor); obj.Options = Options; obj.ActorOptimizer = rlOptimizer(Options.OptimizerOptions); % (optional) Cache the number of observations and actions. obj.NumObservation = prod(obj.ObservationInfo.Dimension); obj.NumAction = prod(obj.ActionInfo.Dimension); % (optional) Initialize buffer and counter. resetImpl(obj); end
Required Functions
To create a custom reinforcement learning agent you must define the following implementation functions.
getActionImpl
— Evaluates agent policy and selects an action during simulation.getActionWithExplorationImpl
— Evaluates policy and selects an action with exploration during training.learnImpl
— Updates learnable parameters, therefore allowing the agent to learn from the current experience.
To call these functions in your own code, use the wrapper methods from the abstract base class. For example, to call getActionImpl
, use getAction
. The wrapper methods have the same input and output arguments as the implementation methods.
getActionImpl
Function
The getActionImpl
function is used to evaluate the policy of your agent and select an action when simulating the agent using the sim
function. This function must have the following signature, where obj
is your custom agent object, Observation
is the current observation, and Action
is the selected action.
function Action = getActionImpl(obj,Observation)
For the custom REINFORCE agent, you select an action by calling the getAction
function for the policy. The rlStochasticActorPolicy
object created from the rlDiscreteCategoricalActor
generates a discrete distribution from an observation. The policy then samples (and returns) the maximum likelihood action from that distribution.
function Action = getActionImpl(obj,Observation) % Compute the maximum likelihood action given an observation. obj.Policy.UseMaxLikelihoodAction = true; Action = getAction(obj.Policy,Observation); end
getActionWithExplorationImpl
Function
The getActionWithExplorationImpl
function selects an action using the exploration model of your agent when training the agent using the train
function. Using this function you can implement exploration techniques such as epsilon-greedy exploration or the addition of Gaussian noise. This function must have the following signature, where obj
is your custom agent object, Observation
is the current observation, and Action
is the selected action.
function Action = getActionWithExplorationImpl(obj,Observation)
For the custom REINFORCE agent, the getActionWithExplorationImpl
function randomly samples actions from the discrete action probability distribution corresponding to the observation.
function Action = getActionWithExplorationImpl(obj,Observation) % Compute an action using the exploration policy given an % observation. % REINFORCE: Stochastic actors always explore by default % (sample from a probability distribution) obj.Policy.UseMaxLikelihoodAction = false; Action = getAction(obj.Policy,Observation); end
learnImpl
Function
The learnImpl
function defines how the agent learns from the current experience. This function implements the custom learning algorithm of your agent by updating the policy parameters and selecting an action with exploration for the next state. This function must have the following signature, where obj
is the agent object, Experience
is the current agent experience, and Action
is the selected action.
function Action = learnImpl(obj,Experience)
The agent experience is the cell array Experience = {observation,action,reward,nextObs,isdone}
. Here:
observation
is the current observation.action
is the current action. This is different from the output argumentAction
, which is an action for the next state.reward
is the current reward.nextObs
is the next observation.isDone
is a logical flag indicating that the training episode is complete.
function Action = learnImpl(obj,Experience) % Define how the agent learns from an Experience, which is a % cell array with the following format. % Experience = ... % {observation,action,reward,nextObservation,isDone} % Extract data from experience. Obs = Experience{1}; Action = Experience{2}; Reward = Experience{3}; NextObs = Experience{4}; IsDone = Experience{5}; % Save data to buffer. obj.ObservationBuffer(:,:,obj.Counter) = Obs{1}; obj.ActionBuffer(:,:,obj.Counter) = Action{1}; obj.RewardBuffer(:,obj.Counter) = Reward; obj.MaskBuffer(:,obj.Counter) = 1; if ~IsDone % Choose an action for the next state. Action = getActionWithExplorationImpl(obj, NextObs); obj.Counter = obj.Counter + 1; else % Learn from episodic data. BatchSize = obj.Options.MaxStepsPerEpisode; % Compute the discounted future reward. DiscountedReturn = dlarray(zeros(1,BatchSize)); gamma = obj.Options.DiscountFactor; for t = 1:BatchSize k = t:BatchSize; G = sum(gamma.^(k-t).*obj.RewardBuffer(k)); DiscountedReturn(t) = G; end % Compute the indices of actions sampled during the % trajectory. Z = repmat(obj.ActionInfo.Elements',1,BatchSize); actionIndicationMatrix = obj.ActionBuffer(:,:) == Z; % Compute the gradient of the loss with respect to the % actor parameters. Use dlaccelerate to improve gradient % computation performance. Note, the mask buffer is used % here to make sure accelerated functions are not % re-generated due to varying episode lengths. if isempty(obj.AccelGradFcn) obj.AccelGradFcn = dlaccelerate(@lossFunction); end ActorGradient = dlfeval(obj.AccelGradFcn,... obj.Policy.Actor,{obj.ObservationBuffer},... actionIndicationMatrix,DiscountedReturn,obj.MaskBuffer); % Update the actor parameters using the computed gradients. [obj.Policy.Actor,obj.ActorOptimizer] = update( ... obj.ActorOptimizer,obj.Policy.Actor,ActorGradient); % Reset the counter and flush the reward and mask buffers % for the next trajectory. obj.Counter = 1; obj.MaskBuffer(:) = 0; obj.RewardBuffer(:) = 0; end end
The custom REINFORCE agent is a similar implementation to the custom training loop found in Train Reinforcement Learning Policy Using Custom Training Loop with the exception that, for the custom agent, the built-in train
function manages the training loop of the agent. For the custom training loop, instead, the training loop is not managed by train
.
The lossFunction
in CustomREINFORCEAgent.m
computes the gradient of the loss function with respect to the parameters. The loss function in the REINFORCE algorithm is the product between the discounted reward and the logarithm of the probability distribution of the action (coming from the policy evaluation for a given observation), summed across all time steps.
The loss function accepts the actor function approximator as an input argument and calls evaluate
to compute the batch discrete action probabilities given batch observations. UseForward
is set to true to handle cases where the actor has layers that modify behavior during the forward pass of the network. The remaining input objects are required data to compute the REINFORCE loss function. The actor gradients are then evaluated from the loss with respect to the actor Learnable
parameters.
To reduce the time to compute actor gradients, dlaccelerate
is used in conjunction with dlfeval
and dlgradient
. To prevent multiple accelerated functions from being generated, all varying input arguments (aside from the actor) must be dlarray
objects, cell arrays of dlarray
objects, or structures of dlarray
objects with fixed size. The mask argument is used to prevent certain batch elements from contributing to the loss function (for example, when the episode terminates early).
function actorGradient = lossFunction(... actor,observations,actionIndicationMatrix,discountedReturn,mask) % Evaluate the action probabilities given batch observations. % Set UseForward=true to handle layers such as batch normalization % which modify their behavior during the forward pass. actionProbs = evaluate(actor,observations,UseForward=true); actionProbs = actionProbs{1}; % Resize the discounted return to the size of actionProbs. % Elements in the batch data corresponding to mask == 0 % do not contribute to the loss function. G = actionIndicationMatrix .* discountedReturn .* mask; G = reshape(G,size(actionProbs)); % Clip action probability values less than eps to eps. actionProbs(actionProbs < eps) = eps; % Compute the loss. loss = -sum(G.*log(actionProbs),"all")/sum(mask); actorGradient = dlgradient(loss,actor.Learnables); end
Optional Functions
Optionally, you can define how your agent is reset at the start of training by specifying a resetImpl
function with the following function signature, where obj
is the agent object.
function resetImpl(obj)
Using this function, you can set the agent into a know or random condition before training.
function resetImpl(obj) % (Optional) Define how the agent is reset before training/ resetBuffer(obj); obj.Counter = 1; end
Also, you can define any other helper functions in your custom agent class as required. For example, the custom REINFORCE agent defines a resetBuffer
function for re-initializing the experience buffer at the beginning of each training episode.
function resetBuffer(obj) % Reinitialize observation buffer. Allocate as dlarray to % support automatic differentiation with dlfeval and % dlgradient. obj.ObservationBuffer = dlarray( ... zeros(obj.NumObservation,1,obj.Options.MaxStepsPerEpisode)); % Reinitialize action buffer with valid actions. obj.ActionBuffer = dlarray(... repmat( ... obj.ActionInfo.Elements(1), ... 1, ... 1, ... obj.Options.MaxStepsPerEpisode) ... ); % Reinitialize reward buffer. obj.RewardBuffer = dlarray(zeros(1,obj.Options.MaxStepsPerEpisode)); % Reinitialize mask buffer. obj.MaskBuffer = dlarray(zeros(1,obj.Options.MaxStepsPerEpisode)); end
Create Custom Agent
Once you have defined your custom agent class, create an instance of it in the MATLAB® workspace. To create the custom REINFORCE agent, first specify the agent options.
options.MaxStepsPerEpisode = 250; options.DiscountFactor = 0.995; options.OptimizerOptions = actorOpts;
Then, using the options and the previously defined actor, call the constructor function of the custom agent.
agent = CustomReinforceAgent(actor,options);
Train Custom Agent
Configure the training to use the following options.
Set up the training to last at most 5000 episodes, with each episode lasting at most 250 steps.
Terminate the training after the maximum number of episodes is reached or when the average reward across 100 episodes reaches a value of 220.
For more information, see rlTrainingOptions
.
numEpisodes = 5000; aveWindowSize = 100; trainingTerminationValue = 220; trainOpts = rlTrainingOptions(... MaxEpisodes=numEpisodes,... MaxStepsPerEpisode=options.MaxStepsPerEpisode,... ScoreAveragingWindowLength=aveWindowSize,... StopTrainingValue=trainingTerminationValue);
Train the agent using the train
function. Training this agent is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining
to false
. To train the agent yourself, set doTraining
to true
.
doTraining = false; if doTraining % Train the agent. trainStats = train(agent,env,trainOpts); else % Load pretrained agent for the example. load("CustomReinforce.mat","agent"); end
Simulate Custom Agent
Enable the environment visualization, which is updated each time the environment step
function is called.
plot(env)
To validate the performance of the trained agent, simulate it within the cart-pole environment. For more information on agent simulation, see rlSimulationOptions
and sim
.
simOpts = rlSimulationOptions(MaxSteps=options.MaxStepsPerEpisode); experience = sim(env,agent,simOpts);
Restore the random number stream using the information stored in previousRngState
.
rng(previousRngState);
See Also
Functions
rlPredefinedEnv
|train
|evaluate
|gradient
|accelerate
|getAction
|sim
Objects
rlDiscreteCategoricalActor
|rlNumericSpec
|rlFiniteSetSpec
|rlTrainingOptions
|rlSimulationOptions