Applying vectorization techniques to speedup the performance of dividing a 3D matrix by a 2D matrix

4 views (last 30 days)
I'm working on removing a for loop in my Matlab code to improve performance. My original code has one for loop (from j=1:Nx) that is harmful to performance (in my production code, this for loop is processed over 20 million times if I test large simulations). I am curious if I can remove this for loop through vectorization, repmat, or a similar technique. My original Matlab implementation is given below.
clc; clear all;
% Test Data
% I'm trying to remove the for loop for j in the code below
N = 10;
M = 10;
Nx = 32; % Ny=Nx=Nz
Nz = 32;
Ny = 32;
Fnmhat = rand(Nx,Nz+1);
Jnmhat = rand(Nx,1);
xi_n_m_hat = rand(Nx,N+1,M+1);
Uhat = zeros(Nx,Nz+1);
Uhat_2 = zeros(Nx,Nz+1);
identy = eye(Ny+1,Ny+1);
p = rand(Nx,1);
gammap = rand(Nx,1);
D = rand(Nx+1,Ny+1);
D2 = rand(Nx+1,Ny+1);
D_start = D(1,:);
D_end = D(end,:);
gamma = 1.5;
alpha = 0; % this could be non-zero
ntests = 100;
% Original Code (Partially vectorized)
tic
for n=0:N
for m=0:M
b = Fnmhat.';
alphaalpha = 1.0;
betabeta = 0.0; % this could be non-zero
gammagamma = gamma*gamma - p.^2 - 2*alpha.*p; % size (Nx,1)
d_min = 1.0;
n_min = 0.0; % this could be non-zero
r_min = xi_n_m_hat(:,n+1,m+1);
d_max = -1i.*gammap;
n_max = 1.0;
r_max = Jnmhat;
A = alphaalpha*D2 + betabeta*D + permute(gammagamma,[3,2,1]).*identy;
A(end,:,:) = repmat(n_min*D_end,[1,1,Nx]);
b(end,:) = r_min;
A(end,end,:) = A(end,end,:) + d_min;
A(1,:,:) = repmat(n_max*D_start,[1,1,Nx]);
A(1,1,:) = A(1,1,:) + permute(d_max,[2,3,1]);
b(1,:) = r_max;
% Non-vectorized code - can this part be vectorized?
for j=1:Nx
utilde = linsolve(A(:,:,j),b(:,j)); % A\b
Uhat(j,:) = utilde.';
end
end
end
toc
Here is my attempt at vectorizing the code (and removing the for loop for j).
% Same test data as original code
% New Code (completely vectorized but incorrect)
tic
for n=0:N
for m=0:M
b = Fnmhat.';
alphaalpha = 1.0;
betabeta = 0.0; % this could be non-zero
gammagamma = gamma*gamma - p.^2 - 2*alpha.*p; % size (Nx,1)
d_min = 1.0;
n_min = 0.0; % this could be non-zero
r_min = xi_n_m_hat(:,n+1,m+1);
d_max = -1i.*gammap;
n_max = 1.0;
r_max = Jnmhat;
A2 = alphaalpha*D2 + betabeta*D + permute(gammagamma,[3,2,1]).*identy;
A2(end,:,:) = repmat(n_min*D_end,[1,1,Nx]);
b(end,:) = r_min;
A2(end,end,:) = A2(end,end,:) + d_min;
A2(1,:,:) = repmat(n_max*D_start,[1,1,Nx]);
A2(1,1,:) = A2(1,1,:) + permute(d_max,[2,3,1]);
b(1,:) = r_max;
% Non-vectorized code - can this part be vectorized?
%for j=1:Nx
% utilde_2 = linsolve(A2(:,:,j),b(:,j)); % A2\b
% Uhat_2(j,:) = utilde_2.';
%end
% My attempt - this doesn't work since I don't loop through the index j
% in repmat
utilde_2 = squeeze(repmat(linsolve(A2(:,:,Nx),b(:,Nx)),[1,1,Nx]));
utilde_2 = utilde_2(:,1);
Uhat_2 = squeeze(repmat(utilde_2',[1,1,Nx]));
Uhat_2 = Uhat_2';
end
end
toc
diff = norm(Uhat - Uhat_2,inf); % is 0 if correct
I'm curious if repmat (or a different builtin Matlab function) can speed up this part of the code:
for j=1:Nx
utilde = linsolve(A(:,:,j),b(:,j)); % A\b
Uhat(j,:) = utilde.';
end
Is the for loop for j absolutely necessary or can it be removed?

Accepted Answer

Bruno Luong
Bruno Luong on 29 Jul 2021
If you have C compilers the fatest methods are perhaps mmx and MultipleQR avaikable on FEX
  9 Comments
Bruno Luong
Bruno Luong on 30 Jul 2021
MMX and MultipleQRSolve make parallel loop on page, MATLAB for-loop make parallize on the algorithm of linsolve.
That's why the matrix size matters, and possibly the number of the physical processor cores.
MultipleQRSolve is less efficient for larg matrix because my implementation of QR is less efficient than the Lapack function called by MMX and MATLAB native for-loop.
Matthew Kehoe
Matthew Kehoe on 31 Jul 2021
I think that MMX can be optimized if both A and B are not complex doubles. In my data, B is not a complex double so it may be possible to speed up the MMX calculation. Here is how I would implement the three different methods in my real Matlab code.
% These parameters mimic the real data in my code
m=33;
n=33;
p=1;
q=32;
ntests = 10000;
% My code calculates Ac and Br before going into the for loop
Ac = rand(m,n,q)+1i*rand(m,n,q); % A is a complex double of size (33,33,32)
Br = rand(m,q); % B is a (real) double of size (33,32)
% Before I decide to use a for loop/mmx/MultipleQRSolve my code
% "understands" that A is a complex double of size (33,33,32) and B is a
% (real) double of size (33,32). I don't need to calculate what A or B are inside
% the for loop. I only reshape B inside MMX and MultipleQRSolve because I
% have to for the divides operation.
% Here is how I would write the three functions below in my "real" code.
% for-loop
tic
for ii=1:ntests
z1 = zeros(q,m);
for j=1:q
% This is how my code currently computes A\b
utilde = linsolve(Ac(:,:,j),Br(:,j)); % A\b
z1(j,:) = utilde.';
end
end
toc % Elapsed time is 14.231135 seconds.
% mmx
tic
for ii=1:ntests
Bnew = reshape(Br,m,1,q); % Make Br size(33,1,32) to apply MMX
Ar = real(Ac);
Ai = imag(Ac);
Br = real(Bnew);
Bi = imag(Bnew); % is zero as b is a real double
% z_1 = Ar+Ai*i
% z_2 = Br+Bi*i
% z_1/z_2 = [(Ar*Br + Ai*Bi) + 1i*(Ai*Br - Ar*Bi)]/(Br^2 + Bi^2);
% Since Bi == 0, this is simplified to
% z_1/z_2 = [(Ar*Br) + 1i*(Ai*Br)]/(Br^2);
% I think that this makes the code below
%AA = [Ar,-Ai;Ai,Ar];
%BB = [Br;Bi];
%zz = mmx('backslash', AA, BB);
%z2=zz(1:n,:,:)+1i*zz(n+1:end,:,:);
% Into the faster version
Num = mmx('mult', Ar, Br);
Num = Num + 1i*mmx('mult', Ai, Br);
Den = Br.^2;
z2 = mmx('backslash',Num,Den);
z2 = permute(z2,[3 1 2]);
end
toc % Elapsed time is 2.441799 seconds.
% MultipleQRSolve
tic
for ii=1:ntests
Bnew_2 = reshape(Br,m,1,q); % Make Br size(33,1,32) to apply MultipleQRSolve
z3 = MultipleQRSolve(Ac,Bnew_2);
z3 = permute(z3,[3 1 2]);
end
toc % Elapsed time is 25.991396 seconds.
diff = norm(z1-z2,inf); % Not zero since my code for z_1/z_2 isn't correct.
diff2 = norm(z1-z3,inf);
If the code for
AA = [Ar,-Ai;Ai,Ar];
BB = [Br;Bi];
zz = mmx('backslash', AA, BB);
z2=zz(1:n,:,:)+1i*zz(n+1:end,:,:);
isn't needed (as B is not a complex double) then MMX would "beat" the for loop. Thanks for all of your help with this question (and for writing the MultipleQRSolve).

Sign in to comment.

More Answers (2)

Matt J
Matt J on 29 Jul 2021
Edited: Matt J on 29 Jul 2021
Another idea.
clc; clear all;
% Test Data
% I'm trying to remove the for loop for j in the code below
N = 10;
M = 10;
Nx = 32; % Ny=Nx=Nz
Nz = 32;
Ny = 32;
AA=kron(speye(Nx),ones(Nx+1));
map=logical(AA);
% Original Code (Partially vectorized)
tic
for n=0:N
for m=0:M
....
%Vectorized code
AA(map)=A(:);
Uhat=reshape(AA\b(:),Nx+1,Nx).';
end
end
toc
  5 Comments
Matt J
Matt J on 29 Jul 2021
Yeah, I didn't see that A was complex-valued. So,
AA(map)=rehape(A,[],1);
Uhat=reshape(AA\b(:),Nx+1,Nx).';

Sign in to comment.


Matt J
Matt J on 29 Jul 2021
Edited: Matt J on 29 Jul 2021
On the GPU (i.e. if A and b are gpuArrays), the for-loop can be removed:
Uhat = permute( pagefun(@mldivide,A,reshape(b,[],1,Nx)) ,[2,1,3]);
  1 Comment
Matthew Kehoe
Matthew Kehoe on 29 Jul 2021
This approach requires the Parallel Computing Toolbox. I will investigate getting this toolbox. Is there another approach that doesn't require a separate toolbox?

Sign in to comment.

Products


Release

R2020a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!