how to create channel attention layer in matlab.

11 views (last 30 days)
vino
vino on 6 Sep 2024
Answered: Udit06 on 1 Oct 2024
classdef ChannelAttentionLayer < nnet.layer.Layer
properties
% Reduction ratio used in the channel attention mechanism
ReductionRatio
end
properties (Learnable)
% Layer learnable parameters
Weights1
Bias1
Weights2
Bias2
end
methods
function layer = ChannelAttentionLayer(reduction_ratio, input_channels, name)
% Constructor for ChannelAttentionLayer
layer.Name = name;
layer.ReductionRatio = reduction_ratio;
% Calculate reduced channels based on reduction ratio
reduced_channels = max(1, round(input_channels / reduction_ratio));
% Initialize weights and biases
layer.Weights1 = randn([1, 1, input_channels, reduced_channels], 'single');
layer.Bias1 = zeros([1, 1, reduced_channels], 'single');
layer.Weights2 = randn([1, 1, reduced_channels, input_channels], 'single');
layer.Bias2 = zeros([1, 1, input_channels], 'single');
end
function Z = forward(layer, X)
% Forward pass for training mode
% Ensure X is a dlarray
X = dlarray(X);
% Get input size
[H, W, C] = size(X);
% Global Average Pooling (GAP)
avg_pool = mean(X, [1, 2]); % Mean over height and width
avg_pool = reshape(avg_pool, [1, 1, C]); % Reshape to [1, 1, Channels]
% Global Max Pooling (GMP)
max_pool = max(X, [], [1, 2]); % Max over height and width
max_pool = reshape(max_pool, [1, 1, C]); % Reshape to [1, 1, Channels]
% First fully connected layer applied to both avg and max pooled outputs
avg_out = fullyconnect(avg_pool, layer.Weights1, layer.Bias1, C, layer.ReductionRatio);
max_out = fullyconnect(max_pool, layer.Weights1, layer.Bias1, C, layer.ReductionRatio);
% Apply ReLU
avg_out = relu(avg_out);
max_out = relu(max_out);
% Second fully connected layer
avg_out = fullyconnect(avg_out, layer.Weights2, layer.Bias2, layer.ReductionRatio, C);
max_out = fullyconnect(max_out, layer.Weights2, layer.Bias2, layer.ReductionRatio, C);
% Combine average and max pooled outputs
Z = avg_out + max_out;
% Apply sigmoid to get attention weights
Z = sigmoid(Z);
% Reshape attention map and multiply with input
Z = reshape(Z, [1, 1, C]);
Z = X .* Z;
% Ensure Z is unformatted
Z = dlarray(Z);
end
function Z = predict(layer, X)
% Predict pass for inference mode
Z = forward(layer, X);
end
end
end
% Fully connected operation for 1x1 conv
function out = fullyconnect(input, weights, bias, input_channels, output_channels)
% Ensure the number of input channels matches the weights' channels
[H, W, C_in] = size(input);
[~, ~, C, ~] = size(weights);
if C_in ~= C
error('Number of channels in input and weights do not match.');
end
% Flatten input dimensions
input_reshaped = reshape(input, [], C_in); % Flatten spatial dimensions
% Perform matrix multiplication and add bias
weights_reshaped = reshape(weights, [C_in, output_channels]);
out = input_reshaped * weights_reshaped + reshape(bias, [1, output_channels]);
% Reshape back to original dimensions
out = reshape(out, [1, 1, output_channels]);
end

Answers (1)

Udit06
Udit06 on 1 Oct 2024

Tags

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!