Main Content

encoderDecoderNetwork

Create encoder-decoder network

Description

example

net = encoderDecoderNetwork(inputSize,encoder,decoder) connects an encoder network and a decoder network to create an encoder-decoder network, net.

This function requires Deep Learning Toolbox™.

net = encoderDecoderNetwork(inputSize,encoder,decoder,Name,Value) modifies aspects of the encoder-decoder network using name-value arguments.

Examples

collapse all

Create the encoder module consisting of four encoder blocks.

encoderBlock = @(block) [
    convolution2dLayer(3,2^(5+block),"Padding",'same')
    reluLayer
    convolution2dLayer(3,2^(5+block),"Padding",'same')
    reluLayer
    maxPooling2dLayer(2,"Stride",2)];
encoder = blockedNetwork(encoderBlock,4,"NamePrefix","encoder_");

Create the decoder module consisting of four decoder blocks.

decoderBlock = @(block) [
    transposedConv2dLayer(2,2^(10-block),'Stride',2)
    convolution2dLayer(3,2^(10-block),"Padding",'same')
    reluLayer
    convolution2dLayer(3,2^(10-block),"Padding",'same')
    reluLayer];
decoder = blockedNetwork(decoderBlock,4,"NamePrefix","decoder_");

Create the bridge layers.

bridge = [
    convolution2dLayer(3,1024,"Padding",'same')
    reluLayer
    convolution2dLayer(3,1024,"Padding",'same')
    reluLayer
    dropoutLayer(0.5)];            

Specify the network input size.

inputSize = [224 224 3];

Create the U-Net network by connecting the encoder module, bridge, and decoder module and adding skip connections.

unet = encoderDecoderNetwork(inputSize,encoder,decoder, ...
    "OutputChannels",3, ...
    "SkipConnections","concatenate", ...
    "LatentNetwork",bridge)
unet = 
  dlnetwork with properties:

         Layers: [55x1 nnet.cnn.layer.Layer]
    Connections: [62x2 table]
     Learnables: [46x3 table]
          State: [0x3 table]
     InputNames: {'encoderImageInputLayer'}
    OutputNames: {'encoderDecoderFinalConvLayer'}
    Initialized: 1

Display the network.

analyzeNetwork(unet)

Create a GAN encoder network with four downsampling operations from a pretrained GoogLeNet network.

depth = 4;
[encoder,outputNames] = pretrainedEncoderNetwork('googlenet',depth);

Determine the input size of the encoder network.

inputSize = encoder.Layers(1).InputSize;

Determine the output size of the activation layers in the encoder network by creating a sample data input and then calling forward, which returns the activations.

exampleInput = dlarray(zeros(inputSize),'SSC');
exampleOutput = cell(1,length(outputNames));
[exampleOutput{:}] = forward(encoder,exampleInput,'Outputs',outputNames);

Determine the number of channels in the decoder blocks as the length of the third channel in each activation.

numChannels = cellfun(@(x) size(extractdata(x),3),exampleOutput);
numChannels = fliplr(numChannels(1:end-1));

Define a function that creates an array of layers for one decoder block.

decoderBlock = @(block) [
    transposedConv2dLayer(2,numChannels(block),'Stride',2)
    convolution2dLayer(3,numChannels(block),'Padding','same')
    reluLayer
    convolution2dLayer(3,numChannels(block),'Padding','same')
    reluLayer];

Create the decoder module with the same number of upsampling blocks as there are downsampling blocks in the encoder module.

decoder = blockedNetwork(decoderBlock,depth);

Create the U-Net network by connecting the encoder module and decoder module and adding skip connections.

net = encoderDecoderNetwork([224 224 3],encoder,decoder, ...
   'OutputChannels',3,'SkipConnections','concatenate')
net = 
  dlnetwork with properties:

         Layers: [139x1 nnet.cnn.layer.Layer]
    Connections: [167x2 table]
     Learnables: [116x3 table]
          State: [0x3 table]
     InputNames: {'data'}
    OutputNames: {'encoderDecoderFinalConvLayer'}
    Initialized: 1

Display the network.

analyzeNetwork(net)

Input Arguments

collapse all

Network input size, specified as a 3-element vector of positive integers. inputSize has the form [H W C], where H is the height, W is the width, and C is the number of channels.

Example: [28 28 3] specifies an input size of 28-by-28 pixels for a 3-channel image.

Encoder network, specified as a dlnetwork (Deep Learning Toolbox) object.

Decoder network, specified as a dlnetwork (Deep Learning Toolbox) object. The network must have a single input and a single output.

Name-Value Arguments

Specify optional comma-separated pairs of Name,Value arguments. Name is the argument name and Value is the corresponding value. Name must appear inside quotes. You can specify several name and value pair arguments in any order as Name1,Value1,...,NameN,ValueN.

Example: 'SkipConnections',"concatenate" specifies the type of skip connection between the encoder and decoder networks as concatenation.

Network connecting the encoder and decoder, specified as a layer or array of layers.

Network connected to the output of the decoder, specified as a layer or array of layers. If you specify the 'OutputChannels' argument, then the final network is connected after the final 1-by-1 convolution layer of the decoder.

Number of output channels of the decoder network, specified as a positive integer. If you specify this argument, then the final layer of the decoder performs a 1-by-1 convolution operation with the specified number of channels.

Names of pairs of encoder/decoder layers whose activations are merged by skip connections, specified as one of these values.

  • "auto" — The encoderDecoderNetwork function determines the names of pairs of encoder/decoder layers automatically.

  • M-by-2 string array — The first column is the name of the encoder layer and the second column is the name of the respective decoder layer.

When you specify the 'SkipConnections' argument as "none", the encoderDecoderNetwork function ignores the value of 'SkipConnectionNames'.

Data Types: char | string

Type of skip connection between the encoder and decoder networks, specified as "none", "auto", or "concatenate".

Data Types: char | string

Output Arguments

collapse all

Encoder/decoder network, returned as a dlnetwork (Deep Learning Toolbox) object.

Introduced in R2021a