# Physical Informed Neural Network - Identify coefficient of loss function

51 views (last 30 days)
Giulio Vialetto on 10 Sep 2023
Commented: Giulio on 21 Nov 2023
Is it possible in MATLAB to train a PINN to find also unknown parameters of the physical loss function.
In this case it is presented how to manage a PINN for projectile motion and drag coefficient is defined as trainable variable even if it is coefficient of the the loss function.

Ben on 18 Sep 2023
Yes this is possible, you can make the coefficient μ into a dlarray and train it alongside the dlnetwork or other dlarray-s as in https://uk.mathworks.com/help/deeplearning/ug/solve-partial-differential-equations-using-deep-learning.html
There's also some discussion here where I previously gave some details on this: https://uk.mathworks.com/matlabcentral/answers/1783690-physics-informed-nn-for-parameter-identification
Here's a simple example to use an inverse PINN to find μ in from solution data.
% Inverse PINN for d2x/dt2 = mu*x
%
% For mu<0 it's known the solution is
% x(t) = a*cos(sqrt(-mu)*t) + b*sin(sqrt(-mu)*t);
%
% where a, b are determined by initial conditions.
%
% Let's fix a = 1, b = 0 and train to learn mu from the solution data x
% Set some value for mu and specify the true solution function to generate data to train on.
% In practice these values are unknown.
muActual = -rand;
x = @(t) cos(sqrt(-muActual)*t);
% Create training data - in practice you might get this data elsewhere, e.g. from sensors on a physical system or model.
% Evaluate x(t) at uniformly spaced t in [0,maxT].
% Choose maxT such that a full wavelength occurs for x(t).
maxT = 2*pi/sqrt(-muActual);
t = dlarray(linspace(0,maxT,batchSize),"CB");
xactual = dlarray(x(t),"CB");
% Specify a network and initial guess for mu as parameters to train
net = [
featureInputLayer(1)
fullyConnectedLayer(100)
tanhLayer
fullyConnectedLayer(100)
tanhLayer
fullyConnectedLayer(1)];
params.net = dlnetwork(net);
params.mu = dlarray(-0.5);
% Specify training configuration
numEpochs = 5000;
avgG = [];
avgSqG = [];
batchSize = 100;
lossFcn = dlaccelerate(@modelLoss);
clearCache(lossFcn);
lr = 1e-3;
% Train
for i = 1:numEpochs
if mod(i,1000)==0
fprintf("Epoch: %d, Predicted mu: %.3f, Actual mu: %.3f\n",i,extractdata(params.mu),muActual);
end
end
% Implement the PINN loss by predicting x(t) via params.net and computing the derivatives with dlgradient
xpred = forward(params.net,t);
% Here we sum xpred over the batch dimension to get a scalar.
% This works because xpred(i) depends only on t(i)
% - i.e. the network-s forward pass vectorizes over the batch dimension
odeResidual = d2xdt2 - params.mu*xpred;
% Compute the mean square error of the ODE residual.
odeLoss = mean(odeResidual.^2);
% Compute the L2 difference between the predicted xpred and the true x.
dataLoss = l2loss(x,xpred);
% Sum the losses and take gradients.
% Note that by creating the grad as a struct with fields matching params we can
loss = odeLoss + dataLoss;
end
This script should run reasonably quickly and usually approximates μ well.
In general you will need to:
1. Modify the PINN loss in the modelLoss function for your particular ODE or PDE.
2. Modify the params to include both the dlnetwork and any additional coefficients.
3. Tweak the net design and training configuration to achieve a good loss.
##### 3 CommentsShow 1 older commentHide 1 older comment
Ben on 21 Nov 2023
@Giulio I don't get an error when running that script.
The error stack in your comment notes complex values which aren't supported by adamupdate. I'm not sure how complex values would appear in your script though. Is this reproducible every iteration or training? If so - you could remove the dlaccelerate calls, and place a breakpoint in modelLoss to try to identify where the values become complex.
Giulio on 21 Nov 2023
Could you confirm to me that it's not expected to generate a complex gradient, is it? I think it should be related to a wrong setting of the loss function.

### Categories

Find more on Custom Training Loops in Help Center and File Exchange

### Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!