Main Content

Specify Custom Layer Backward Function

If Deep Learning Toolbox™ does not provide the layer you require for your classification or regression problem, then you can define your own custom layer. For a list of built-in layers, see List of Deep Learning Layers.

The example Define Custom Deep Learning Layer with Learnable Parameters shows how to create a custom PreLU layer and goes through the following steps:

  1. Name the layer — Give the layer a name so that you can use it in MATLAB®.

  2. Declare the layer properties — Specify the properties of the layer, including learnable parameters and state parameters.

  3. Create a constructor function (optional) — Specify how to construct the layer and initialize its properties. If you do not specify a constructor function, then at creation, the software initializes the Name, Description, and Type properties with [] and sets the number of layer inputs and outputs to 1.

  4. Create initialize function (optional) — Specify how to initialize the learnable and state parameters when the software initializes the network. If you do not specify an initialize function, then the software does not initialize parameters when it initializes the network.

  5. Create forward functions — Specify how data passes forward through the layer (forward propagation) at prediction time and at training time.

  6. Create reset state function (optional) — Specify how to reset state parameters.

  7. Create a backward function (optional) — Specify the derivatives of the loss with respect to the input data and the learnable parameters (backward propagation). If you do not specify a backward function, then the forward functions must support dlarray objects.

If the forward function only uses functions that support dlarray objects, then creating a backward function is optional. In this case, the software determines the derivatives automatically using automatic differentiation. For a list of functions that support dlarray objects, see List of Functions with dlarray Support. If you want to use functions that do not support dlarray objects, or want to use a specific algorithm for the backward function, then you can define a custom backward function using this example as a guide.

Create Custom Layer

The example Define Custom Deep Learning Layer with Learnable Parameters shows how to create a PReLU layer. A PReLU layer performs a threshold operation, where for each channel, any input value less than zero is multiplied by a scalar learned at training time.[1] For values less than zero, a PReLU layer applies scaling coefficients αi to each channel of the input. These coefficients form a learnable parameter, which the layer learns during training.

The PReLU operation is given by

f(xi)={xiif xi>0αixiif xi0

where xi is the input of the nonlinear activation f on channel i, and αi is the coefficient controlling the slope of the negative part. The subscript i in αi indicates that the nonlinear activation can vary on different channels.

View the layer created in the example Define Custom Deep Learning Layer with Learnable Parameters. This layer does not have a backward function.

classdef preluLayer < nnet.layer.Layer ...
        & nnet.layer.Acceleratable
    % Example custom PReLU layer.

    properties (Learnable)
        % Layer learnable parameters
            
        % Scaling coefficient
        Alpha
    end

    methods
        function layer = preluLayer(args) 
            % layer = preluLayer creates a PReLU layer.
            %
            % layer = preluLayer(Name=name) also specifies the
            % layer name.
    
            arguments
                args.Name = "";
            end
    
            % Set layer name.
            layer.Name = args.Name;

            % Set layer description.
            layer.Description = "PReLU";
        end

        function layer = initialize(layer,layout)
            % layer = initialize(layer,layout) initializes the layer
            % learnable parameters using the specified input layout.

            % Skip initialization of nonempty parameters.
            if ~isempty(layer.Alpha)
                return
            end

            % Input data size.
            sz = layout.Size;
            ndims = numel(sz);

            % Find number of channels.
            idx = finddim(layout,"C");
            numChannels = sz(idx);

            % Initialize Alpha.
            szAlpha = ones(1,ndims);
            szAlpha(idx) = numChannels;
            layer.Alpha = rand(szAlpha);
        end

        function Z = predict(layer, X)
            % Z = predict(layer, X) forwards the input data X through the
            % layer and outputs the result Z.
            
            Z = max(X,0) + layer.Alpha .* min(0,X);
        end
    end
end

Note

If the layer has a custom backward function, then you can still inherit from nnet.layer.Formattable.

Create Backward Function

Implement the backward function that returns the derivatives of the loss with respect to the input data and the learnable parameters.

The backward function syntax depends on the type of layer.

  • dLdX = backward(layer,X,Z,dLdZ,memory) returns the derivatives dLdX of the loss with respect to the layer input, where layer has a single input and a single output. Z corresponds to the forward function output and dLdZ corresponds to the derivative of the loss with respect to Z. The function input memory corresponds to the memory output of the forward function.

  • [dLdX,dLdW] = backward(layer,X,Z,dLdZ,memory) also returns the derivative dLdW of the loss with respect to the learnable parameter, where layer has a single learnable parameter.

  • [dLdX,dLdSin] = backward(layer,X,Z,dLdZ,dLdSout,memory) also returns the derivative dLdSin of the loss with respect to the state input, where layer has a single state parameter and dLdSout corresponds to the derivative of the loss with respect to the layer state output.

  • [dLdX,dLdW,dLdSin] = backward(layer,X,Z,dLdZ,dLdSout,memory) also returns the derivative dLdW of the loss with respect to the learnable parameter and returns the derivative dLdSin of the loss with respect to the layer state input, where layer has a single state parameter and single learnable parameter.

You can adjust the syntaxes for layers with multiple inputs, multiple outputs, multiple learnable parameters, or multiple state parameters:

  • For layers with multiple inputs, replace X and dLdX with X1,...,XN and dLdX1,...,dLdXN, respectively, where N is the number of inputs.

  • For layers with multiple outputs, replace Z and dLdZ with Z1,...,ZM and dLdZ1,...,dLdZM, respectively, where M is the number of outputs.

  • For layers with multiple learnable parameters, replace dLdW with dLdW1,...,dLdWP, where P is the number of learnable parameters.

  • For layers with multiple state parameters, replace dLdSin and dLdSout with dLdSin1,...,dLdSinK and dLdSout1,...,dLdSoutK, respectively, where K is the number of state parameters.

To reduce memory usage by preventing unused variables being saved between the forward and backward pass, replace the corresponding input arguments with ~.

Tip

If the number of inputs to backward can vary, then use varargin instead of the input arguments after layer. In this case, varargin is a cell array of the inputs, where the first N elements correspond to the N layer inputs, the next M elements correspond to the M layer outputs, the next M elements correspond to the derivatives of the loss with respect to the M layer outputs, the next K elements correspond to the K derivatives of the loss with respect to the K state outputs, and the last element corresponds to memory.

If the number of outputs can vary, then use varargout instead of the output arguments. In this case, varargout is a cell array of the outputs, where the first N elements correspond to the N the derivatives of the loss with respect to the N layer inputs, the next P elements correspond to the derivatives of the loss with respect to the P learnable parameters, and the next K elements correspond to the derivatives of the loss with respect to the K state inputs.

Note

dlnetwork objects do not support custom layers that require a memory value in a custom backward function. To use a custom layer with a custom backward function in a dlnetwork object, the memory input of the backward function definition must be ~.

Because a PReLU layer has only one input, one output, one learnable parameter, and does not require the outputs of the layer forward function or a memory value, the syntax for backward for a PReLU layer is [dLdX,dLdAlpha] = backward(layer,X,~,dLdZ,~). The dimensions of X are the same as in the forward function. The dimensions of dLdZ are the same as the dimensions of the output Z of the forward function. The dimensions and data type of dLdX are the same as the dimensions and data type of X. The dimension and data type of dLdAlpha is the same as the dimension and data type of the learnable parameter Alpha.

During the backward pass, the layer automatically updates the learnable parameters using the corresponding derivatives.

To include a custom layer in a network, the layer forward functions must accept the outputs of the previous layer and forward propagate arrays with the size expected by the next layer. Similarly, when backward is specified, the backward function must accept inputs with the same size as the corresponding output of the forward function and backward propagate derivatives with the same size.

The derivative of the loss with respect to the input data is

Lxi=Lf(xi)f(xi)xi

where L/f(xi) is the gradient propagated from the next layer, and the derivative of the activation is

f(xi)xi={1if xi0αiif xi<0.

The derivative of the loss with respect to the learnable parameters is

Lαi=jLf(xij)f(xij)αi

where i indexes the channels, j indexes the elements over height, width, and observations, and the gradient of the activation is

f(xi)αi={0if xi0xiif xi<0.

Create the backward function that returns these derivatives.

        function [dLdX, dLdAlpha] = backward(layer, X, ~, dLdZ, ~)
            % [dLdX, dLdAlpha] = backward(layer, X, ~, dLdZ, ~)
            % backward propagates the derivative of the loss function
            % through the layer.
            % Inputs:
            %         layer    - Layer to backward propagate through
            %         X        - Input data
            %         dLdZ     - Gradient propagated from the deeper layer
            % Outputs:
            %         dLdX     - Derivative of the loss with respect to the
            %                    input data
            %         dLdAlpha - Derivative of the loss with respect to the
            %                    learnable parameter Alpha
            
            dLdX = layer.Alpha .* dLdZ;
            dLdX(X>0) = dLdZ(X>0);
            dLdAlpha = min(0,X) .* dLdZ;
            dLdAlpha = sum(dLdAlpha,[1 2]);
    
            % Sum over all observations in mini-batch.
            dLdAlpha = sum(dLdAlpha,4);
        end

Complete Layer

View the completed layer class file.

classdef preluLayer < nnet.layer.Layer
    % Example custom PReLU layer.

    properties (Learnable)
        % Layer learnable parameters
            
        % Scaling coefficient
        Alpha
    end
    
    methods
        function layer = preluLayer(args) 
            % layer = preluLayer creates a PReLU layer.
            %
            % layer = preluLayer(Name=name) also specifies the
            % layer name.
    
            arguments
                args.Name = "";
            end
    
            % Set layer name.
            layer.Name = args.Name;

            % Set layer description.
            layer.Description = "PReLU";
        end

        function layer = initialize(layer,layout)
            % layer = initialize(layer,layout) initializes the layer
            % learnable parameters using the specified input layout.

            % Skip initialization of nonempty parameters.
            if ~isempty(layer.Alpha)
                return
            end

            % Input data size.
            sz = layout.Size;
            ndims = numel(sz);

            % Find number of channels.
            idx = finddim(layout,"C");
            numChannels = sz(idx);

            % Initialize Alpha.
            szAlpha = ones(1,ndims);
            szAlpha(idx) = numChannels;
            layer.Alpha = rand(szAlpha);
        end

        function Z = predict(layer, X)
            % Z = predict(layer, X) forwards the input data X through the
            % layer and outputs the result Z.
            
            Z = max(X,0) + layer.Alpha .* min(0,X);
        end
        
        function [dLdX, dLdAlpha] = backward(layer, X, ~, dLdZ, ~)
            % [dLdX, dLdAlpha] = backward(layer, X, ~, dLdZ, ~)
            % backward propagates the derivative of the loss function
            % through the layer.
            % Inputs:
            %         layer    - Layer to backward propagate through
            %         X        - Input data
            %         dLdZ     - Gradient propagated from the deeper layer
            % Outputs:
            %         dLdX     - Derivative of the loss with respect to the
            %                    input data
            %         dLdAlpha - Derivative of the loss with respect to the
            %                    learnable parameter Alpha
            
            dLdX = layer.Alpha .* dLdZ;
            dLdX(X>0) = dLdZ(X>0);
            dLdAlpha = min(0,X) .* dLdZ;
            dLdAlpha = sum(dLdAlpha,[1 2]);
    
            % Sum over all observations in mini-batch.
            dLdAlpha = sum(dLdAlpha,4);
        end
    end
end

GPU Compatibility

If the layer forward functions fully support dlarray objects, then the layer is GPU compatible. Otherwise, to be GPU compatible, the layer functions must support inputs and return outputs of type gpuArray (Parallel Computing Toolbox).

Many MATLAB built-in functions support gpuArray (Parallel Computing Toolbox) and dlarray input arguments. For a list of functions that support dlarray objects, see List of Functions with dlarray Support. For a list of functions that execute on a GPU, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox). To use a GPU for deep learning, you must also have a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). For more information on working with GPUs in MATLAB, see GPU Computing in MATLAB (Parallel Computing Toolbox).

References

[1] "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification." In 2015 IEEE International Conference on Computer Vision (ICCV), 1026–34. Santiago, Chile: IEEE, 2015. https://doi.org/10.1109/ICCV.2015.123.

See Also

| | | | | | | | | |

Related Topics