Can I make a custom RL agent that has 2 distinct critics and 1 actor?

2 views (last 30 days)
So I'm trying to create a custom DDPG agent that has 2 critics. It's not like 2 critics like a TD3 where we choose the minimum Q value. Each critic agent is going to learn a different portion of the environment and the reward is going to be combined. To give context, I'm trying to train a manipulator to grab a target while avoiding obstacles. So 1 critic will be for tracking and closing the distance to the target, the other critic will be for avoiding obstacles. And the Qtotal is the sum of each critic coresponinding Qvalue.
I also need to mention that the enviroment is implemented in a Simulink model. So I'm hoping to be able to use the reinforcement learning train() function and pass in something like train(customagent, Simulinkenv, Trainoptions)
Is this possible to implement with a custom class RL agent while using the built in train() function? I heard the matlab supports training 2 agents, I'm willing to use that as a last resort but I'm trying to ideally have 1 agent that has 1 actor and 2 critics.
Thanks

Answers (2)

Maneet Kaur Bagga
Maneet Kaur Bagga on 3 Apr 2025
Hi,
As per my understanding, you can implement a custom RL agent with one actor and two critics while still using MATLAB's built-in "train()" function. To do the same you can create a custom agent class that adheres to MATLAB's "rl.agent.CustomAgent" framework.
Please refer to the following MathWorks documentation to create custom agents by subclassing "rl.agent.CustomAgent", which is essential for implementing agents with unique architectures, such as having multiple critics.
Train the Reinforcement Learning Agents using "train" function, the following example demonstrates how to create and train a custom PG agent, showcasing the process of defining custom agents and training them using the built-in "train" function.
Hope this helps!

Aravind
Aravind on 3 Apr 2025
From your question, it seems you want to implement a custom Reinforcement Learning (RL) agent, specifically a Deep Deterministic Policy Gradient (DDPG) agent with two critic networks that learn different things and a single actor network, while still using the "train" function to train the RL agent in the environment.
To achieve this, you need to create a custom agent class that inherits from the "rl.agent.CustomAgent" class. This is an internal class compatible with the "train" function. By implementing the required methods in a custom agent class that inherits from "rl.agent.CustomAgent", you can use the "train" function just like with predefined RL agents. More information on implementing a custom agent class can be found at: https://www.mathworks.com/help/releases/R2024a/reinforcement-learning/ug/create-custom-pg-agent.html.
In your custom agent class, you need to implement a constructor that performs the following tasks:
  • Defines the action and observation specifications.
  • Sets the agent properties.
  • Calls the constructor of the base abstract class.
  • Defines the sample time (necessary for training in Simulink environments).
Additionally, you need to implement three essential functions:
  • getActionImpl — Evaluates the agent policy and selects an action during simulation.
  • getActionWithExplorationImpl — Evaluates the policy and selects an action with exploration during training.
  • learnImpl — Updates learnable parameters, allowing the agent to learn from the current experience.
For your specific case, initialize the two critic networks and the actor network in the constructor. In the learnImpl function, implement the DDPG algorithm that uses the total Q value from the two Q networks to update the critic and actor networks' learnable parameters. You can also add a function to calculate the total Q-value from both networks.
Here's a basic skeleton of the class you might need:
classdef CustomDDPGAgent < rl.agent.CustomAgent
properties
Actor
Critic1
Critic2
end
methods
function obj = CustomDDPGAgent(actor, critic1, critic2)
% Initialize actor and critics
obj.Actor = actor;
obj.Critic1 = critic1;
obj.Critic2 = critic2;
% Initialize other parameters …
end
end
methods (Access = protected)
function action = getActionImpl(obj, observation)
% Implement action selection logic using the two Q networks
end
function getActionWithExplorationImpl(obj, experience)
% Implement the action selection with exploration
end
function learnImpl(obj)
% Implement the learning algorithm (DDPG) using the Critic networks
% Use the computeQTotal function to calculate the total Q value
end
function qTotal = computeQTotal(obj, observation, action)
q1 = obj.Critic1.evaluate(observation, action);
q2 = obj.Critic2.evaluate(observation, action);
qTotal = q1 + q2;
end
% Implement other functions like reset, and other helper functions
end
end
You can use the following code to train the network:
% Define your Simulink environment
env = rlSimulinkEnv('modelName', 'blockName', observationInfo, actionInfo);
% Create actor and critics
actor = rlContinuousDeterministicActor(observationInfo, actionInfo, actorNetwork);
critic1 = rlQValueFunction(observationInfo, actionInfo, criticNetwork1);
critic2 = rlQValueFunction(observationInfo, actionInfo, criticNetwork2);
% Create custom agent
agent = CustomDDPGAgent(actor, critic1, critic2);
% Define training options
trainOpts = rlTrainingOptions('MaxEpisodes', 1000, 'MaxStepsPerEpisode', 200);
% Train the agent
trainResults = train(agent, env, trainOpts);
You can also refer to the following example that implements a custom LQR Agent to get more information on how to use “rl.agent.CustomAgent” to implement custom RL Agents: https://www.mathworks.com/help/releases/R2024a/reinforcement-learning/ug/create-custom-agents.html.
I hope this helps resolve your query!
  1 Comment
Vincent
Vincent on 6 Apr 2025
Edited: Vincent on 6 Apr 2025
Hello, thanks for the response and I've gotten the basic skeleton of my custom agent so far. Currently I'm running into issues trying to get the gradient of the Qtotal with respect to the parameters of my actor network using dlfeval,dlgradient, and dlarrays. I have my code below. It's mainly the actorupdate and learnimp functions. I have disp() functions to double check that the various parts inside the actorupdate function are all dlarray types. But when I feed Qtotal into dlgradient, I get the following error.
Error using dlarray/dlgradient (line 115)
'dlgradient' inputs must be traced dlarray objects or cell arrays, structures or tables containing traced dlarray objects. To enable tracing, use 'dlfeval'.
I have my code below, appreciate any feedbacks.
classdef CustomDDPGAgent < rl.agent.CustomAgent
properties
%actor NN
actor
%critic for tracking target
critic_track
%critic for obstacle avoidance
critic_obstacle
%dimensions
statesize
end
methods
%constructor function
function obj = CustomDDPGAgent(ActorNN,Critic_Track,Critic_Obst,statesize,actionsize)
%(required) call abstract class constructor
obj = obj@rl.agent.CustomAgent();
%define observation + action space
obj.ObservationInfo = rlNumericSpec([statesize 1]);
obj.ActionInfo = rlNumericSpec([actionsize 1],LowerLimit = -1,UpperLimit = 1);
obj.SampleTime = 0.01;
%define the actor and 2 critics
obj.actor = ActorNN;
obj.critic_track = Critic_Track;
obj.critic_obstacle = Critic_Obst;
%record observation dimensions
obj.statesize = statesize;
end
end
methods (Access = protected)
%Actor update based on Q value
function actorgradient = actorupdate(obj,Observation)
Obs_Obstacle = {dlarray([])};
for index = 1:20
Obs_Obstacle{1}(index) = Observation{1}(index);
end
disp(Observation);
disp(Obs_Obstacle);
action = evaluate(obj.actor,Observation,UseForward=true);
disp(action);
%Obtained combined Q values
Qtrack = getValue(obj.critic_track,Observation,action);
Qobstacle = getValue(obj.critic_obstacle,Obs_Obstacle,action);
Qtotal = Qtrack + Qobstacle;
Qtotal = sum(Qtotal);
disp(Qtotal);
%obtain gradient of Q value wrt parameters of actor network
actorgradient = dlgradient(Qtotal,obj.actor.Learnables); %ERROR
end
%Action method
function action = getActionImpl(obj,Observation)
% Given the current state of the system, return an action
action = getAction(obj.actor,Observation);
end
%Action with noise method
function action = getActionWithExplorationImpl(obj,Observation)
% Given the current observation, select an action
action = getAction(obj.actor,Observation);
% Add random noise to action
end
%Learn method
function action = learnImpl(obj,Experience)
%parse experience
Obs = Experience{1};
%reformat in dlarrays
Obs_reformat = {dlarray(Obs{1})};
action = getAction(obj.actor,Obs_reformat);
%update actor network
ActorGradient = dlfeval(@actorupdate,obj,Obs_reformat);
end
end
end

Sign in to comment.

Products


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!