Clear Filters
Clear Filters

Design of a neural network with custom loss

10 views (last 30 days)
Ramon Suarez
Ramon Suarez on 19 Feb 2024
Answered: Ben on 9 Apr 2024
I would like to design a simple feedforward neural network with 1 input and 2 outputs.
The input is a parameter λ between a predefined range (for instance between -5 and 5) and the output is a vector of two components .
The loss function I would like to implement is given by this expression
As it can be seen from the loss definition, this network does not need any target outputs. The objective is to devise a network that predicts the vector that minimizes the loss.
Any help would be greatly appreciated!
I have read the following threads that talk about customizing a loss function
I also read the response from that indicates a new way of doing this with Custom Training Loops, but I have not been successful to implement this for my problem at hand.

Answers (1)

Ben on 9 Apr 2024
The term is minimized if , which is a linear problem as you've stated, so you can actually use classic methods to solve this for x.
A = @(lambda) [lambda^2 + 1, lambda; lambda, 1];
f = @(lambda) [lambda; 1-lambda];
x = @(lambda) A(lambda)\f(lambda);
lambda = 0.123; % random choice
xlambda = x(lambda);
A(lambda)*xlambda - f(lambda) % returns [0;0], i.e. exact solution.
If you still want to model for a neural net N, you will have to use a custom training loop, since your loss is unsupervised, and trainNetwork / trainnet work for supervised training. You can write a custom training loop as follows, however note that I was unable to get this to train well, and certainly not as fast as computing the solution as above.
net = [featureInputLayer(1)
net = dlnetwork(net);
lambda = dlarray(linspace(-5,5,10000),"CB");
maxIters = 10000;
vel = [];
lr = 1e-4;
lossFcn = dlaccelerate(@modelLoss);
for iter = 1:maxIters
[loss,grad] = dlfeval(lossFcn,net,lambda);
fprintf("Iter: %d, Loss: %.4f\n",iter,extractdata(loss));
[net,vel] = sgdmupdate(net,grad,vel,lr);
function [loss,grad] = modelLoss(net,lambda)
% Permute lambda to 1x1xBatchSize
x = forward(net,lambda);
x = stripdims(x);
x = permute(x,[1,3,2]);
lambda = stripdims(lambda);
lambda = permute(lambda,[1,3,2]);
A = [lambda.^2 + 1, lambda; lambda, ones(1,1,size(lambda,3),like=lambda)];
Ax = pagemtimes(A,x);
f = [lambda;1-lambda];
loss = l2loss(Ax,f,DataFormat="CUB");
grad = dlgradient(loss,net.Learnables);
If you can use a linear solve method as above, but need it to be autodiff compatible, you can use pinv which is supported by dlarray in R2024a:
A = @(lambda) [lambda^2 + 1, lambda; lambda, 1];
f = @(lambda) [lambda; 1-lambda];
x = @(lambda) pinv(A(lambda))*f(lambda);
This supports auto-diff with dlarray, so you can compute things like .
% x supports auto-diff, e.g. we can compute dx/dlambda
function dxidlambda(lambda,i)
A = @(lambda) [lambda^2 + 1, lambda; lambda, 1];
f = @(lambda) [lambda; 1-lambda];
x = @(lambda) pinv(A(lambda))*f(lambda);
xlambda = x(lambda);
xlambdai = xlambda(i);
dxidlambda = dlgradient(xlambdai,lambda);
lambda0 = dlarray(0.123);
dx1dlambda = dlfeval(@dxidlambda, lambda0, 1)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!