how to create channel attention layer in matlab.
11 views (last 30 days)
Show older comments
classdef ChannelAttentionLayer < nnet.layer.Layer
properties
% Reduction ratio used in the channel attention mechanism
ReductionRatio
end
properties (Learnable)
% Layer learnable parameters
Weights1
Bias1
Weights2
Bias2
end
methods
function layer = ChannelAttentionLayer(reduction_ratio, input_channels, name)
% Constructor for ChannelAttentionLayer
layer.Name = name;
layer.ReductionRatio = reduction_ratio;
% Calculate reduced channels based on reduction ratio
reduced_channels = max(1, round(input_channels / reduction_ratio));
% Initialize weights and biases
layer.Weights1 = randn([1, 1, input_channels, reduced_channels], 'single');
layer.Bias1 = zeros([1, 1, reduced_channels], 'single');
layer.Weights2 = randn([1, 1, reduced_channels, input_channels], 'single');
layer.Bias2 = zeros([1, 1, input_channels], 'single');
end
function Z = forward(layer, X)
% Forward pass for training mode
% Ensure X is a dlarray
X = dlarray(X);
% Get input size
[H, W, C] = size(X);
% Global Average Pooling (GAP)
avg_pool = mean(X, [1, 2]); % Mean over height and width
avg_pool = reshape(avg_pool, [1, 1, C]); % Reshape to [1, 1, Channels]
% Global Max Pooling (GMP)
max_pool = max(X, [], [1, 2]); % Max over height and width
max_pool = reshape(max_pool, [1, 1, C]); % Reshape to [1, 1, Channels]
% First fully connected layer applied to both avg and max pooled outputs
avg_out = fullyconnect(avg_pool, layer.Weights1, layer.Bias1, C, layer.ReductionRatio);
max_out = fullyconnect(max_pool, layer.Weights1, layer.Bias1, C, layer.ReductionRatio);
% Apply ReLU
avg_out = relu(avg_out);
max_out = relu(max_out);
% Second fully connected layer
avg_out = fullyconnect(avg_out, layer.Weights2, layer.Bias2, layer.ReductionRatio, C);
max_out = fullyconnect(max_out, layer.Weights2, layer.Bias2, layer.ReductionRatio, C);
% Combine average and max pooled outputs
Z = avg_out + max_out;
% Apply sigmoid to get attention weights
Z = sigmoid(Z);
% Reshape attention map and multiply with input
Z = reshape(Z, [1, 1, C]);
Z = X .* Z;
% Ensure Z is unformatted
Z = dlarray(Z);
end
function Z = predict(layer, X)
% Predict pass for inference mode
Z = forward(layer, X);
end
end
end
% Fully connected operation for 1x1 conv
function out = fullyconnect(input, weights, bias, input_channels, output_channels)
% Ensure the number of input channels matches the weights' channels
[H, W, C_in] = size(input);
[~, ~, C, ~] = size(weights);
if C_in ~= C
error('Number of channels in input and weights do not match.');
end
% Flatten input dimensions
input_reshaped = reshape(input, [], C_in); % Flatten spatial dimensions
% Perform matrix multiplication and add bias
weights_reshaped = reshape(weights, [C_in, output_channels]);
out = input_reshaped * weights_reshaped + reshape(bias, [1, output_channels]);
% Reshape back to original dimensions
out = reshape(out, [1, 1, output_channels]);
end
Answers (1)
See Also
Categories
Find more on Calculus in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!