# predict

Predict next observation, next reward, or episode termination given observation and action input data

*Since R2022a*

## Syntax

## Description

evaluates the environment transition function approximator object
`predNextObs`

= predict(`tsnFcnAppx`

,`obs`

,`act`

)`tsnFcnAppx`

and returns the predicted next observation
`nextObs`

, given the current observation `obs`

and
the action `act`

.

evaluates the environment reward function approximator object
`predReward`

= predict(`rwdFcnAppx`

,`obs`

,`act`

,`nextObs`

)`rwdFcnAppx`

and returns the predicted reward
`predReward`

, given the current observation `obs`

,
the action `act`

, and the next observation
`nextObs`

.

evaluates the environment is-done function approximator object
`predIsDone`

= predict(`idnFcnAppx`

,`obs`

,`act`

)`idnFcnAppx`

and returns the predicted is-done status
`predIsDone`

, given the current observation `obs`

,
the action `act`

, and the next observation
`nextObs`

.

`___ = predict(___,UseForward=`

allows you to explicitly call a forward pass when computing gradients.`useForward`

)

## Examples

### Predict Next Observation Using Continuous Gaussian Transition Function Approximator

Create observation and action specification objects (or alternatively use `getObservationInfo`

and `getActionInfo`

to extract the specification objects from an environment). For this example, two observation channels carry vectors in a four- and two-dimensional space, respectively. The action is a continuous three-dimensional vector.

obsInfo = [ rlNumericSpec([4 1],UpperLimit=10*ones(4,1)); rlNumericSpec([1 2],UpperLimit=20*ones(1,2)) ]; actInfo = rlNumericSpec([3 1]);

Create a deep neural network to use as approximation model for the transition function approximator. For a continuous Gaussian transition function approximator, the network must have two output layers for each observation (one for the mean values the other for the standard deviation values).

Define each network path as an array of layer objects. Get the dimensions of the observation and action spaces from the environment specification objects, and specify a name for the input layers, so you can later explicitly associate them with the appropriate environment channel.

% Input path layers from first observation channel inPath1 = [ featureInputLayer( ... prod(obsInfo(1).Dimension), ... Name="netObsIn1") fullyConnectedLayer(5,Name="infc1") ]; % Input path layers from second observation channel inPath2 = [ featureInputLayer( ... prod(obsInfo(2).Dimension), ... Name="netObsIn2") fullyConnectedLayer(5,Name="infc2") ]; % Input path layers from action channel inPath3 = [ featureInputLayer( ... prod(actInfo(1).Dimension), ... Name="netActIn") fullyConnectedLayer(5,Name="infc3") ]; % Joint path layers, concatenate 3 inputs along first dimension jointPath = [ concatenationLayer(1,3,Name="concat") tanhLayer(Name="tanhJnt") fullyConnectedLayer(10,Name="jntfc") ]; % Path layers for mean values of first predicted obs % Using scalingLayer to scale range from (-1,1) to (-10,10) % Note that scale vector must be a column vector meanPath1 = [ tanhLayer(Name="tanhMean1"); fullyConnectedLayer(prod(obsInfo(1).Dimension)); scalingLayer(Name="scale1", ... Scale=obsInfo(1).UpperLimit) ]; % Path layers for standard deviations first predicted obs % Using softplus layer to make them non negative sdevPath1 = [ tanhLayer(Name="tanhStdv1"); fullyConnectedLayer(prod(obsInfo(1).Dimension)); softplusLayer(Name="splus1") ]; % Path layers for mean values of second predicted obs % Using scalingLayer to scale range from (-1,1) to (-20,20) % Note that scale vector must be a column vector meanPath2 = [ tanhLayer(Name="tanhMean2"); fullyConnectedLayer(prod(obsInfo(2).Dimension)); scalingLayer(Name="scale2", ... Scale=obsInfo(2).UpperLimit(:)) ]; % Path layers for standard deviations second predicted obs % Using softplus layer to make them non negative sdevPath2 = [ tanhLayer(Name="tanhStdv2") fullyConnectedLayer(prod(obsInfo(2).Dimension)); softplusLayer(Name="splus2") ]; % Assemble dlnetwork object. net = dlnetwork; net = addLayers(net,inPath1); net = addLayers(net,inPath2); net = addLayers(net,inPath3); net = addLayers(net,jointPath); net = addLayers(net,meanPath1); net = addLayers(net,sdevPath1); net = addLayers(net,meanPath2); net = addLayers(net,sdevPath2); % Connect layers. net = connectLayers(net,"infc1","concat/in1"); net = connectLayers(net,"infc2","concat/in2"); net = connectLayers(net,"infc3","concat/in3"); net = connectLayers(net,"jntfc","tanhMean1/in"); net = connectLayers(net,"jntfc","tanhStdv1/in"); net = connectLayers(net,"jntfc","tanhMean2/in"); net = connectLayers(net,"jntfc","tanhStdv2/in"); % Plot network. plot(net)

% Initialize network. net = initialize(net); % Display the number of weights. summary(net)

Initialized: true Number of learnables: 352 Inputs: 1 'netObsIn1' 4 features 2 'netObsIn2' 2 features 3 'netActIn' 3 features

Create a continuous Gaussian transition function approximator object, specifying the names of all the input and output layers.

tsnFcnAppx = rlContinuousGaussianTransitionFunction(... net,obsInfo,actInfo,... ObservationInputNames=["netObsIn1","netObsIn2"], ... ActionInputNames="netActIn", ... NextObservationMeanOutputNames=["scale1","scale2"], ... NextObservationStandardDeviationOutputNames=["splus1","splus2"] );

Predict the next observation for a random observation and action.

predObs = predict(tsnFcnAppx, ... {rand(obsInfo(1).Dimension),rand(obsInfo(2).Dimension)}, ... {rand(actInfo(1).Dimension)})

`predObs=`*1×2 cell array*
{4x1 single} {[-25.8987 -0.1389]}

Each element of the resulting cell array represents the prediction for the corresponding observation channel.

To display the mean values and standard deviations of the Gaussian probability distribution for the predicted observations, use `evaluate`

.

predDst = evaluate(tsnFcnAppx, ... {rand(obsInfo(1).Dimension),rand(obsInfo(2).Dimension), ... rand(actInfo(1).Dimension)})

`predDst=`*1×4 cell array*
{4x1 single} {[-18.4500 4.3918]} {4x1 single} {[0.7007 1.2851]}

The result is a cell array in which the first and second element represent the mean values for the predicted observations in the first and second channel, respectively. The third and fourth element represent the standard deviations for the predicted observations in the first and second channel, respectively.

### Create Deterministic Reward Function and Predict Reward

Create an environment interface and extract observation and action specifications. Alternatively, you can create specifications using `rlNumericSpec`

and `rlFiniteSetSpec`

.

```
env = rlPredefinedEnv("CartPole-Continuous");
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
```

To approximate the reward function, create a deep neural network. For this example, the network has two input layers, one for the current action and one for the next observations. The single output layer contains a scalar, which represents the value of the predicted reward.

Define each network path as an array of layer objects. Get the dimensions of the observation and action spaces from the environment specifications, and specify a name for the input layers, so you can later explicitly associate them with the appropriate environment channel.

actionPath = featureInputLayer( ... actInfo.Dimension(1), ... Name="action"); nextStatePath = featureInputLayer( ... obsInfo.Dimension(1), ... Name="nextState"); commonPath = [concatenationLayer(1,2,Name="concat") fullyConnectedLayer(64) reluLayer fullyConnectedLayer(64) reluLayer fullyConnectedLayer(64) reluLayer fullyConnectedLayer(1)];

Assemble `dlnetwork`

object.

net = dlnetwork(); net = addLayers(net,nextStatePath); net = addLayers(net,actionPath); net = addLayers(net,commonPath);

Connect layers.

net = connectLayers(net,"nextState","concat/in1"); net = connectLayers(net,"action","concat/in2");

Plot network.

plot(net)

Initialize network and display the number of weights.

net = initialize(net); summary(net)

Initialized: true Number of learnables: 8.7k Inputs: 1 'nextState' 4 features 2 'action' 1 features

Create a deterministic transition function object.

rwdFcnAppx = rlContinuousDeterministicRewardFunction(... net,obsInfo,actInfo,... ActionInputNames="action", ... NextObservationInputNames="nextState");

Using this reward function object, you can predict the next reward value based on the current action and next observation. For example, predict the reward for a random action and next observation. Since, for this example, only the action and the next observation influence the reward, use an empty cell array for the current observation.

act = rand(actInfo.Dimension); nxtobs = rand(obsInfo.Dimension); reward = predict(rwdFcnAppx, {}, {act}, {nxtobs})

`reward = `*single*
0.1034

To predict the reward, you can also use `evaluate`

.

reward_ev = evaluate(rwdFcnAppx, {act,nxtobs} )

`reward_ev = `*1x1 cell array*
{[0.1034]}

### Create Is-Done Function and Predict Termination

Create an environment interface and extract observation and action specifications. Alternatively, you can create specifications using `rlNumericSpec`

and `rlFiniteSetSpec`

.

```
env = rlPredefinedEnv("CartPole-Continuous");
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
```

To approximate the is-done function, use a deep neural network. The network has one input channel for the next observations. The single output channel is for the predicted termination signal.

Create the neural network as a vector of layer objects.

net = [ featureInputLayer( ... obsInfo.Dimension(1), ... Name="nextState") fullyConnectedLayer(64) reluLayer fullyConnectedLayer(64) reluLayer fullyConnectedLayer(2) softmaxLayer(Name="isdone") ];

Convert to `dlnetwork`

object.

net = dlnetwork(net);

Plot network.

plot(net)

Initialize network and display the number of weights.

net = initialize(net); summary(net);

Initialized: true Number of learnables: 4.6k Inputs: 1 'nextState' 4 features

Create an is-done function approximator object.

isDoneFcnAppx = rlIsDoneFunction(... net,obsInfo,actInfo,... NextObservationInputNames="nextState");

Using this is-done function approximator object, you can predict the termination signal based on the next observation. For example, predict the termination signal for a random next observation. Since for this example the termination signal only depends on the next observation, use empty cell arrays for the current action and observation inputs.

nxtobs = rand(obsInfo.Dimension); predIsDone = predict(isDoneFcnAppx,{},{},{nxtobs})

predIsDone = 0

You can obtain the termination probability using `evaluate`

.

predIsDoneProb = evaluate(isDoneFcnAppx,{nxtobs})

`predIsDoneProb = `*1x1 cell array*
{2x1 single}

predIsDoneProb{1}

`ans = `*2x1 single column vector*
0.5405
0.4595

The first number is the probability of obtaining a `0`

(no termination predicted), the second one is the probability of obtaining a `1`

(termination predicted).

## Input Arguments

`tsnFcnAppx`

— Environment transition function approximator object

`rlContinuousDeterministicTransitionFunction`

object | `rlContinuousGaussianTransitionFunction`

object

Environment transition function approximator object, specified as one of the following:

`rwdFcnAppx`

— Environment reward function

`rlContinuousDeterministicRewardFunction`

object | `rlContinuousGaussianRewardFunction`

object | function handle

Environment reward function approximator object, specified as one of the following:

Function handle object. For more information about function handle objects, see What Is a Function Handle?.

`idnFcnAppx`

— Environment is-done function approximator object

`rlIsDoneFunction`

object

Environment is-done function approximator object, specified as an `rlIsDoneFunction`

object.

`obs`

— Observations

cell array

Observations, specified as a cell array with as many elements as there are
observation input channels. Each element of `obs`

contains an array
of observations for a single observation input channel.

The dimensions of each element in `obs`

are
*M _{O}*-by-

*L*, where:

_{B}*M*corresponds to the dimensions of the associated observation input channel._{O}*L*is the batch size. To specify a single observation, set_{B}*L*= 1. To specify a batch of observations, specify_{B}*L*> 1. If_{B}`valueRep`

or`qValueRep`

has multiple observation input channels, then*L*must be the same for all elements of_{B}`obs`

.

*L _{B}* must be the same for both

`act`

and `obs`

.For more information on input and output formats for recurrent neural networks, see
the Algorithms section of `lstmLayer`

.

`act`

— Action

single-element cell array

Action, specified as a single-element cell array that contains an array of action values.

The dimensions of this array are
*M _{A}*-by-

*L*, where:

_{B}*M*corresponds to the dimensions of the associated action specification._{A}*L*is the batch size. To specify a single observation, set_{B}*L*= 1. To specify a batch of observations, specify_{B}*L*> 1._{B}

*L _{B}* must be the same for both

`act`

and `obs`

.For more information on input and output formats for recurrent neural networks, see
the Algorithms section of `lstmLayer`

.

`useForward`

— Option to use parallel training

`false`

(default) | `true`

## Output Arguments

`predNextObs`

— Predicted next observation

cell array

Predicted next observation, that is the observation predicted by the transition
function approximator `tsnFcnAppx`

given the current observation
`obs`

and the action `act`

, returned as a cell
array of the same dimension as `obs`

.

`predReward`

— Predicted next observation

`single`

Predicted reward, that is the reward predicted by the reward function approximator
`rwdFcnAppx`

given the current observation
`obs`

, the action `act`

, and the following
observation `nextObs`

, retuned as a `single`

.

`predIsDone`

— Predicted next observation

`double`

Predicted is-done episode status, that is the episode termination status predicted
by the is-done function approximator `rwdFcnAppx`

given the current
observation `obs`

, the action `act`

, and the
following observation `nextObs`

, returned as a
`double`

.

**Note**

If `fcnAppx`

is an `rlContinuousDeterministicRewardFunction`

object, then
`evaluate`

behaves identically to `predict`

except that
it returns results inside a single-cell array. If `fcnAppx`

is an `rlContinuousDeterministicTransitionFunction`

object, then
`evaluate`

behaves identically to `predict`

. If
`fcnAppx`

is an `rlContinuousGaussianTransitionFunction`

object, then
`evaluate`

returns the mean value and standard deviation the
observation probability distribution, while `predict`

returns an
observation sampled from this distribution. Similarly, for an `rlContinuousGaussianRewardFunction`

object, `evaluate`

returns
the mean value and standard deviation the reward probability distribution, while `predict`

returns a
reward sampled from this distribution. Finally, if `fcnAppx`

is an
`rlIsDoneFunction`

object, then `evaluate`

returns the probabilities of the termination
status being false or true, respectively, while `predict`

returns a
predicted termination status sampled with these probabilities.

## Tips

When the elements of the cell array in `inData`

are
`dlarray`

objects, the elements of the cell array returned in
`outData`

are also `dlarray`

objects. This allows
`predict`

to be used with automatic differentiation.

Specifically, you can write a custom loss function that directly uses
`predict`

and `dlgradient`

within
it, and then use `dlfeval`

and
`dlaccelerate`

with
your custom loss function. For an example, see Train Reinforcement Learning Policy Using Custom Training Loop and Custom Training Loop with Simulink Action Noise.

## Version History

**Introduced in R2022a**

## See Also

### Functions

### Objects

## MATLAB Command

You clicked a link that corresponds to this MATLAB command:

Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.

Select a Web Site

Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .

You can also select a web site from the following list:

## How to Get Best Site Performance

Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.

### Americas

- América Latina (Español)
- Canada (English)
- United States (English)

### Europe

- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)

- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)