Main Content

rlMBPOAgent

Model-based policy optimization (MBPO) reinforcement learning agent

Since R2022a

    Description

    A model-based policy optimization (MBPO) agent is a model-based, online, off-policy, reinforcement learning method. An MBPO agent contains an internal model of the environment, which it uses to generate additional experiences without interacting with the environment. The action space can be either discrete or continuous depending on the base agent.

    During training, the MBPO agent generates real experiences by interacting with the environment. These experiences are used to train the internal environment model, which is used to generate additional experiences. The training algorithm then uses both the real and generated experiences to update the agent policy.

    Note

    MBPO agents do not support recurrent networks.

    Creation

    Description

    example

    agent = rlMBPOAgent(baseAgent,envModel) creates a model-based policy optimization agent with default options and sets the BaseAgent and EnvModel properties.

    agent = rlMBPOAgent(___,agentOptions) creates a model-based policy optimization agent using specified options and sets the AgentOptions property.

    Properties

    expand all

    Base reinforcement learning agent, specified as an off-policy agent object.

    For environments with a discrete action space, specify a DQN agent using an rlDQNAgent object.

    For environments with a continuous action space, use one of the following agent objects.

    Environment model, specified as an rlNeuralNetworkEnvironment object. This environment contains transition functions, a reward function, and an is-done function.

    Agent options, specified as an rlMBPOAgentOptions object.

    Current roll-out horizon value, specified as a positive integer. For more information on setting the initial horizon value and the horizon update method, see rlMBPOAgentOptions.

    Model experience buffer, specified as an rlReplayMemory object. During training the agent stores each of its generated experiences (S,A,R,S',D) in a buffer. Here:

    • S is the current observation of the environment.

    • A is the action taken by the agent.

    • R is the reward for taking action A.

    • S' is the next observation after taking action A.

    • D is the is-done signal after taking action A.

    Option to use exploration policy when selecting actions, specified as one of the following logical values.

    • true — Use the base agent exploration policy when selecting actions.

    • false — Use the base agent greedy policy when selecting actions.

    The initial value of UseExplorationPolicy matches the value specified in BaseAgent. If you change the value of UseExplorationPolicy in either the base agent or the MBPO agent, the same value is used for the other agent.

    This property is read-only.

    Observation specifications, specified as an rlFiniteSetSpec or rlNumericSpec object or an array containing a mix of such objects. Each element in the array defines the properties of an environment observation channel, such as its dimensions, data type, and name.

    The value of ObservationInfo matches the corresponding value specified in BaseAgent.

    This property is read-only.

    Action specifications, specified either as an rlFiniteSetSpec (for discrete action spaces) or rlNumericSpec (for continuous action spaces) object. This object defines the properties of the environment action channel, such as its dimensions, data type, and name.

    Note

    Only one action channel is allowed.

    The value of ActionInfo matches the corresponding value specified in BaseAgent.

    Sample time of agent, specified as a positive scalar or as -1. Setting this parameter to -1 allows for event-based simulations.

    Within a Simulink® environment, the RL Agent block in which the agent is specified to execute every SampleTime seconds of simulation time. If SampleTime is -1, the block inherits the sample time from its parent subsystem.

    Within a MATLAB® environment, the agent is executed every time the environment advances. In this case, SampleTime is the time interval between consecutive elements in the output experience returned by sim or train. If SampleTime is -1, the time interval between consecutive elements in the returned output experience reflects the timing of the event that triggers the agent execution.

    This property is shared between the agent and the agent options object within the agent. Therefore, if you change it in the agent options object, it gets changed in the agent, and vice versa.

    Example: SampleTime=-1

    Object Functions

    trainTrain reinforcement learning agents within a specified environment
    simSimulate trained reinforcement learning agents within specified environment

    Examples

    collapse all

    Create an environment interface and extract observation and action specifications.

    env = rlPredefinedEnv("CartPole-Continuous");
    obsInfo = getObservationInfo(env);
    actInfo = getActionInfo(env);

    Create a base off-policy agent. For this example, use a SAC agent.

    agentOpts = rlSACAgentOptions;
    agentOpts.MiniBatchSize = 256;
    initOpts = rlAgentInitializationOptions(NumHiddenUnit=64);
    baseagent = rlSACAgent(obsInfo,actInfo,initOpts,agentOpts);

    Check your agent with a random input observation.

    getAction(baseagent,{rand(obsInfo.Dimension)})
    ans = 1x1 cell array
        {[-7.2875]}
    
    

    The neural network environment uses a function approximator object to approximate the environment transition function. The function approximator object uses one or more neural networks as approximator model. To account for modeling uncertainty, you can specify multiple transition models. For this example, create a single transition model.

    Create a neural network to use as approximation model within the transition function object. Define each network path as an array of layer objects. Specify a name for the input and output layers, so you can later explicitly associate them with the appropriate channel.

    % Observation and action paths
    obsPath = featureInputLayer(obsInfo.Dimension(1),Name="obsInLyr");
    actionPath = featureInputLayer(actInfo.Dimension(1),Name="actInLyr");
    
    % Common path: concatenate along dimension 1
    commonPath = [concatenationLayer(1,2,Name="concat")
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(obsInfo.Dimension(1),Name="nextObsOutLyr")
        ];

    Create dlnetwork object and add layers.

    transNet = dlnetwork;
    transNet = addLayers(transNet,obsPath);
    transNet = addLayers(transNet,actionPath);
    transNet = addLayers(transNet,commonPath);

    Connect layers.

    transNet = connectLayers(transNet,"obsInLyr","concat/in1");
    transNet = connectLayers(transNet,"actInLyr","concat/in2");

    Plot network.

    plot(transNet)

    Initialize network and display the number of weights.

    transNet = initialize(transNet);
    summary(transNet)
       Initialized: true
    
       Number of learnables: 4.8k
    
       Inputs:
          1   'obsInLyr'   4 features
          2   'actInLyr'   1 features
    

    Create the transition function approximator object.

    transitionFcnAppx = rlContinuousDeterministicTransitionFunction( ...
        transNet,obsInfo,actInfo,...
        ObservationInputNames="obsInLyr",...
        ActionInputNames="actInLyr",...
        NextObservationOutputNames="nextObsOutLyr");

    Create a neural network to use as a reward model for the reward function approximator object.

    % Observation and action paths
    actionPath = featureInputLayer(actInfo.Dimension(1),Name="actInLyr");
    nextObsPath = featureInputLayer(obsInfo.Dimension(1),Name="nextObsInLyr");
    
    % Common path: concatenate along dimension 1
    commonPath = [concatenationLayer(1,2,Name="concat")
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(64)
        reluLayer
        fullyConnectedLayer(1)
        ];

    Create dlnetwork object and add layers.

    rewardNet = dlnetwork();
    rewardNet = addLayers(rewardNet,nextObsPath);
    rewardNet = addLayers(rewardNet,actionPath);
    rewardNet = addLayers(rewardNet,commonPath);

    Connect layers.

    rewardNet = connectLayers(rewardNet,"nextObsInLyr","concat/in1");
    rewardNet = connectLayers(rewardNet,"actInLyr","concat/in2");

    Plot network.

    plot(rewardNet)

    Initialize network and display the number of weights.

    rewardNet = initialize(rewardNet);
    summary(rewardNet)
       Initialized: true
    
       Number of learnables: 8.7k
    
       Inputs:
          1   'nextObsInLyr'   4 features
          2   'actInLyr'       1 features
    

    Create the reward function approximator object.

    rewardFcnAppx = rlContinuousDeterministicRewardFunction( ...
        rewardNet,obsInfo,actInfo, ...
        ActionInputNames="actInLyr",...
        NextObservationInputNames="nextObsInLyr");

    Create an is-done model for the reward function approximator object.

    % Define main path
    isdNet = [featureInputLayer(obsInfo.Dimension(1),Name="nextObsInLyr");
    fullyConnectedLayer(64)
    reluLayer
    fullyConnectedLayer(64)
    reluLayer
    fullyConnectedLayer(2)
    softmaxLayer(Name="isdoneOutLyr")
    ];

    Convert to dlnetwork object.

    isdNet = dlnetwork(isdNet);

    Display the number of weights.

    summary(isdNet)
       Initialized: true
    
       Number of learnables: 4.6k
    
       Inputs:
          1   'nextObsInLyr'   4 features
    

    Create the reward function approximator object.

    isdoneFcnAppx = rlIsDoneFunction(isdNet,obsInfo,actInfo, ...
        NextObservationInputNames="nextObsInLyr");

    Create the neural network environment using the observation and action specifications and the three function approximator objects.

    generativeEnv = rlNeuralNetworkEnvironment( ...
        obsInfo,actInfo,...
        transitionFcnAppx,rewardFcnAppx,isdoneFcnAppx);

    Specify options for creating an MBPO agent. Specify the optimizer options for the transition network and use default values for all other options.

    MBPOAgentOpts = rlMBPOAgentOptions;
    MBPOAgentOpts.TransitionOptimizerOptions = rlOptimizerOptions(...
        LearnRate=1e-4,...
        GradientThreshold=1.0);

    Create the MBPO agent.

    agent = rlMBPOAgent(baseagent,generativeEnv,MBPOAgentOpts);

    Check your agent with a random input observation.

    getAction(agent,{rand(obsInfo.Dimension)})
    ans = 1x1 cell array
        {[7.8658]}
    
    

    Version History

    Introduced in R2022a