I have a small performance bottleneck in an application that requires removing the non-diagonal elements from a large square matrix. So, the matrix x
17 24 1 8 15 23 5 7 14 16 4 6 13 20 22 10 12 19 21 3 11 18 25 2 9 becomes
17 0 0 0 0 0 5 0 0 0 0 0 13 0 0 0 0 0 21 0 0 0 0 0 9 Question: The bsxfun and diag solution below is the fastest solution so far, and I doubt I can improve it while still keeping the code in Matlab, but is there a faster way?
Solutions
Here is what I thought of so far.
Perform element-wise multiplication by the identity matrix. This is the simplest solution:
y = x .* eye(n); Using bsxfun and diag:
y = bsxfun(@times, diag(x), eye(n)); Lower/upper triangular matrices:
y = x - tril(x, -1) - triu(x, 1); Various solutions using loops:
y = x; for ix=1:n for jx=1:n if ix ~= jx y(ix, jx) = 0; end end end and
y = x; for ix=1:n for jx=1:ix-1 y(ix, jx) = 0; end for jx=ix+1:n y(ix, jx) = 0; end end Timing
The bsxfun solution is actually the fastest. This is my timing code:
function timing() clear all n = 5000; x = rand(n, n); f1 = @() tf1(x, n); f2 = @() tf2(x, n); f3 = @() tf3(x); f4 = @() tf4(x, n); f5 = @() tf5(x, n); t1 = timeit(f1); t2 = timeit(f2); t3 = timeit(f3); t4 = timeit(f4); t5 = timeit(f5); fprintf('t1: %f s\n', t1) fprintf('t2: %f s\n', t2) fprintf('t3: %f s\n', t3) fprintf('t4: %f s\n', t4) fprintf('t5: %f s\n', t5) end function y = tf1(x, n) y = x .* eye(n); end function y = tf2(x, n) y = bsxfun(@times, diag(x), eye(n)); end function y = tf3(x) y = x - tril(x, -1) - triu(x, 1); end function y = tf4(x, n) y = x; for ix=1:n for jx=1:n if ix ~= jx y(ix, jx) = 0; end end end end function y = tf5(x, n) y = x; for ix=1:n for jx=1:ix-1 y(ix, jx) = 0; end for jx=ix+1:n y(ix, jx) = 0; end end end which returns
t1: 0.111117 s t2: 0.078692 s t3: 0.219582 s t4: 1.183389 s t5: 1.198795 s