- Notifications
You must be signed in to change notification settings - Fork 149
Improve performance of CAReduce in Numba backend #1109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This Op does not really fit the CAReduce API, as it requires an extra bit of information (number of elements in the axis) during the loop. A better solution will be a fused Elemwise+CAReduce
bfa16dd to 2bc894a Compare Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@ ## main #1109 +/- ## ========================================== - Coverage 82.12% 82.10% -0.03% ========================================== Files 183 183 Lines 48111 48030 -81 Branches 8667 8658 -9 ========================================== - Hits 39510 39433 -77 + Misses 6435 6434 -1 + Partials 2166 2163 -3
|
2bc894a to 79e8109 Compare | Here is a direct comparison of C and numba backends for the non C-contiguous case: import numpy as np import pytensor c_contiguous = False for transpose_in_graph in (True, False): rng = np.random.default_rng(123) N = 256 x_test = rng.uniform(size=(N, N, N)) transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1) if not transpose_in_graph: x_test = x_test.transpose(transpose_axis) x = pytensor.shared(x_test, name="x", shape=x_test.shape, borrow=True) if transpose_in_graph: x = x.transpose(transpose_axis) out = x.sum(axis=0) c_fn = pytensor.function([], out, mode="FAST_COMPILE") numba_fn = pytensor.function([], out, mode="NUMBA").vm.jit_fn np.testing.assert_allclose(c_fn(), numba_fn()[0]) print(f"{transpose_in_graph=}") %timeit c_fn() %timeit numba_fn() # transpose_in_graph=True # 33.7 ms ± 2.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) # 188 ms ± 4.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) # transpose_in_graph=False # 33 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) # 103 ms ± 1.96 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)Airect numba implementation shows the same bad performance. import numpy as np import numba c_contiguous = False rng = np.random.default_rng(123) N = 256 x_test = rng.uniform(size=(N, N, N)) transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1) x_test = x_test.transpose(transpose_axis) out_dtype = np.float64 @numba.njit(fastmath=True, boundscheck=False) def careduce_add(x): x_shape = x.shape res_shape = (x_shape[1], x_shape[2]) res = np.full((x_shape[1], x_shape[2]), np.asarray(0.0).item(), dtype=out_dtype) for i0 in range(x_shape[0]): for i1 in range(x_shape[1]): for i2 in range(x_shape[2]): res[i1, i2] += x[i0, i1, i2] return res np.testing.assert_allclose(careduce_add(x_test), np.sum(x_test, 0)) %timeit careduce_add(x_test) # 136 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) |
79e8109 to 6268d99 Compare
AlexAndorra left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the walk through in the comparison @ricardoV94 , definitely interesting
| Numba doing badly on the non-contiguous case is all due to loop ordering. LLVM doesn't reorder based on strides :( Anyway this PR improves overall, better old speeds where just due to chance when the reduced loop was the one with smallest strides |
Closes #935
Closes #931
The implementation for multiple axes no longer operates one axis at a time. Here are the benchmarks for the Sum test before and after this PR:
Note that we have a special dispatch for
Sum(axes=None)introduced in #92, so the changes are not reflected in that benchmark. I temporarily disabled the special dispatch, to confirm that case is still improved:Because it is still a bit slower, and this is the most common reduction, I decided to leave the special case.
Numba doesn't seem to optimize non-contiguous arrays very well. The C backend implementation with explicit loop reordering written in #971 does not show such a penalty.
Finally we also see an improvement in the slowest case of the pre-existing numba-logsumexp benchmark: