Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 23, 2025

Description

This PR adds a BandedDot Op that uses gbmv to do matrix-vector multiplication for the case that A is a banded matrix.

In my testing, I found that this case sped up computation significantly. Benchmarking against Pytensor's dot, however, the current implementation is significantly slower:

------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------ Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_dot_perf[10] 1.7500 (1.0) 17.3330 (1.0) 1.9054 (1.0) 0.1292 (1.0) 1.9160 (1.0) 0.0420 (1.0) 585;1740 524,831.2234 (1.0) 38401 1 test_banded_dot_perf[10] 19.9580 (11.40) 13,765.1250 (794.16) 32.5111 (17.06) 282.5468 (>1000.0) 20.5830 (10.74) 0.3750 (8.93) 6;349 30,758.7051 (0.06) 3275 1 test_dot_perf[100] 2.4580 (1.40) 42.5420 (2.45) 2.7856 (1.46) 0.3265 (2.53) 2.7500 (1.44) 0.0420 (1.0) 343;7436 358,988.7425 (0.68) 71429 1 test_banded_dot_perf[100] 19.8330 (11.33) 15,203.3750 (877.13) 30.9185 (16.23) 193.8617 (>1000.0) 20.9580 (10.94) 0.4160 (9.90) 51;3057 32,343.1413 (0.06) 20566 1 test_dot_perf[1000] 15.0000 (8.57) 61.5000 (3.55) 16.6383 (8.73) 1.4182 (10.98) 17.2920 (9.03) 2.2080 (52.57) 905;126 60,102.3508 (0.11) 18377 1 test_banded_dot_perf[1000] 27.0420 (15.45) 423.8750 (24.45) 32.9042 (17.27) 5.2005 (40.25) 32.6250 (17.03) 0.6250 (14.88) 129;1334 30,391.2634 (0.06) 12501 1 test_dot_perf[10_000] 3,369.4580 (>1000.0) 5,011.3330 (289.12) 3,412.7784 (>1000.0) 119.9981 (928.81) 3,394.5625 (>1000.0) 17.2910 (411.69) 4;25 293.0164 (0.00) 198 1 test_banded_dot_perf[10_000] 109.9170 (62.81) 611.5830 (35.28) 139.2751 (73.10) 52.3002 (404.81) 116.5000 (60.80) 14.0000 (333.33) 472;678 7,180.0341 (0.01) 3386 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 

I guess there's some major overhead from doing the diagonal extractions and looking up the blas function in python? This could and should probably be a C Op, but I'm not sure I have time to realistically dig into all that anytime soon. Help wanted, at any rate.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1416.org.readthedocs.build/en/1416/

@jessegrabowski jessegrabowski added enhancement New feature or request help wanted Extra attention is needed Op implementation linalg Linear algebra labels May 23, 2025
@jessegrabowski
Copy link
Member Author

I added trust_input and I also load the BLAS functions once on import and save them. So that should reduce some of the most obvious sources of python overhead. New benchmarks (note that they're in ns now, not us):

------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------------- Name (time in ns) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_banded_dot_perf[10-dot] 541.9988 (1.0) 4,292.0001 (1.0) 638.1136 (1.0) 51.0902 (1.0) 625.0011 (1.0) 41.0000 (40.91) 1506;209 1,567,119.1257 (1.0) 15636 1 test_banded_dot_perf[10-banded_dot] 17,500.0005 (32.29) 418,167.0010 (97.43) 18,191.1183 (28.51) 3,829.7598 (74.96) 18,083.0011 (28.93) 167.0014 (166.62) 70;630 54,971.8815 (0.04) 11353 1 test_banded_dot_perf[100-dot] 1,209.0004 (2.23) 23,959.0008 (5.58) 1,340.3628 (2.10) 103.1441 (2.02) 1,333.0009 (2.13) 1.0023 (1.0) 1217;34675 746,066.6804 (0.48) 88889 1 test_banded_dot_perf[100-banded_dot] 17,542.0009 (32.37) 77,083.9997 (17.96) 18,240.8191 (28.59) 1,230.1810 (24.08) 18,000.0006 (28.80) 250.0001 (249.44) 654;2431 54,822.0996 (0.03) 19018 1 test_banded_dot_perf[1000-dot] 13,291.9995 (24.52) 49,874.9996 (11.62) 15,195.7498 (23.81) 1,137.7872 (22.27) 15,833.0004 (25.33) 1,832.9993 (>1000.0) 2954;119 65,807.8747 (0.04) 22347 1 test_banded_dot_perf[1000-banded_dot] 24,624.9983 (45.43) 74,874.9990 (17.45) 30,233.2753 (47.38) 1,347.0049 (26.37) 30,125.0002 (48.20) 375.0010 (374.15) 874;1333 33,076.1385 (0.02) 15595 1 test_banded_dot_perf[10_000-dot] 3,394,874.9988 (>1000.0) 5,084,541.9992 (>1000.0) 3,585,834.0104 (>1000.0) 191,227.5142 (>1000.0) 3,558,604.5005 (>1000.0) 199,729.5003 (>1000.0) 16;3 278.8752 (0.00) 192 1 test_banded_dot_perf[10_000-banded_dot] 105,208.0006 (194.11) 389,250.0008 (90.69) 124,879.6041 (195.70) 35,967.3472 (704.00) 110,375.0001 (176.60) 8,343.4998 (>1000.0) 320;440 8,007.7128 (0.01) 2665 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 
Comment on lines 1690 to 1699
A = np.asarray(A)
m, n = A.shape
ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C")

for i, k in enumerate(range(ku, -kl - 1, -1)):
padding = (k, 0) if k >= 0 else (0, -k)
diag = np.pad(np.diag(A, k=k), padding)
ab[i, :] = diag

return ab
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine this explains most of the python overhead for small cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one way or another we have to do that though as part of the cost of the Op. Unless we demand users have inputs ready in that form.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's fine, I was just thinking out loud.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rearrangement could be done symbolically in a wrapper Op that calls the blas Op (which expects things to be ready in the correct form)

It might also be better to do smart column indexing on ab instead of using pad

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's similar to the Solve, in that you can also do it once and reuse many times possibly, but I think that's too much micro-optimization for now. We also don't want to autodiff through it

Comment on lines 1702 to 1703
_dgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause import time overhead to PyTensor.

I'm okay paying the extra 3us at runtime instead since virtually nobody will ever use this (or use it in a case where they need those extra us)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this as well. It won't stay in the final verison.

Copy link
Member

@ricardoV94 ricardoV94 May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can exploit prepare_node and add the function to node.tag, which the perform method can then retrieve from. That's two attribute accesses instead of a string check / scipy caching...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you can sidestep perform and use make_thunk instead

@ricardoV94
Copy link
Member

I think the Op is fine, specially if we are not trying to introduce it automatically via rewrites. If we are we may consider the backend (once we have it in numba I suspect it will win for smaller matrices) and/or static shapes if we think the worse-case penalty is still too big

@jessegrabowski
Copy link
Member Author

Benchmark after tuning up the _to_banded_form function:

------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------------ Name (time in ns) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_banded_dot_perf[10-dot] 499.9965 (1.0) 55,500.0006 (1.41) 665.4888 (1.0) 390.9718 (1.0) 666.0011 (1.0) 42.0005 (1.00) 31;2639 1,502,654.9287 (1.0) 32129 1 test_banded_dot_perf[10-banded_dot] 2,832.9996 (5.67) 71,957.9984 (1.82) 3,356.9474 (5.04) 782.8860 (2.00) 3,332.9998 (5.00) 332.9988 (7.93) 1874;2239 297,889.6806 (0.20) 32833 1 test_banded_dot_perf[100-dot] 1,000.0003 (2.00) 58,208.9997 (1.47) 1,191.9862 (1.79) 396.5918 (1.01) 1,166.9981 (1.75) 41.9968 (1.0) 305;3163 838,935.8643 (0.56) 91258 1 test_banded_dot_perf[100-banded_dot] 3,332.9998 (6.67) 39,499.9988 (1.0) 3,874.8349 (5.82) 471.5917 (1.21) 3,875.0004 (5.82) 84.0009 (2.00) 1020;11972 258,075.5142 (0.17) 71008 1 test_banded_dot_perf[1000-dot] 13,584.0019 (27.17) 118,374.9991 (3.00) 16,143.5130 (24.26) 1,984.1144 (5.07) 16,291.0001 (24.46) 2,042.0011 (48.62) 1390;171 61,944.3861 (0.04) 14202 1 test_banded_dot_perf[1000-banded_dot] 8,167.0005 (16.33) 68,749.9996 (1.74) 10,694.7895 (16.07) 1,131.4230 (2.89) 11,000.0001 (16.52) 416.9997 (9.93) 6811;7582 93,503.4764 (0.06) 32521 1 test_banded_dot_perf[10_000-dot] 3,379,415.9972 (>1000.0) 3,680,959.0019 (93.19) 3,463,207.0645 (>1000.0) 79,485.8545 (203.30) 3,434,124.9993 (>1000.0) 114,541.9992 (>1000.0) 6;0 288.7497 (0.00) 31 1 test_banded_dot_perf[10_000-banded_dot] 93,582.9994 (187.17) 294,458.0010 (7.45) 100,154.2338 (150.50) 22,660.4163 (57.96) 95,479.0012 (143.36) 2,083.4996 (49.61) 10;27 9,984.6004 (0.01) 248 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 
@ricardoV94
Copy link
Member

That looks much better!

@jessegrabowski
Copy link
Member Author

I agree numba will probably be better across the board. I'd really like this Op to win on the 100x100 case, that's already a pretty big matrix. 1000x1000 and 10,000x10,000 doesn't really show up in nature too often

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2025

100x100 is 1us, you are at the edge of python overhead there. Calling an identity PyTensor function and no trust_input is 300-500ns. Calling np.zeros is like 100-200ns. That means you would basically need to have no python overhead whatsoever

Edit: those are on my machine, don't know about yours

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2025

This is the best I think we can get out of this in python?

 def make_thunk(self, node, storage_map, compute_map, no_recycling, impl): kl = self.lower_diags ku = self.upper_diags if node.outputs[0].dtype == "float64": gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64") else: gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32") ab_size = kl + ku + 1 a_storage = storage_map[node.inputs[0]] b_storage = storage_map[node.inputs[1]] out_storage = storage_map[node.outputs[0]] out_computed = compute_map[node.outputs[0]] if compute_map is not None else [False] def thunk( a_storage=a_storage, b_storage=b_storage, out_storage=out_storage, out_computed=out_computed, kl=kl, ku=ku, ab_size=ab_size, gbmv=gbmv, ): A = a_storage[0] b = b_storage[0] m, n = A.shape ab = np.zeros((ab_size, n), dtype=A.dtype, order="C") for i, k in enumerate(range(ku, -kl - 1, -1)): if k > 0: ab[i, k:] = diag(A, k=k) else: ab[i, :n + k] = diag(A, k=k) out_storage[0] = gbmv(m, n, kl, ku, 1, ab, b) out_computed[0] = True return thunk
A = as_tensor_variable(A)
B = as_tensor_variable(b)

out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is wrong for integer types

@ricardoV94
Copy link
Member

I'm not saying we should do that, but it gives you a lower bound on what to expect from your micro-optimizations

@jessegrabowski
Copy link
Member Author

jessegrabowski commented May 23, 2025

Here's what the thunk version benchmarks as for me:

------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------------ Name (time in ns) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_banded_dot_perf[10-dot] 582.9970 (1.0) 7,208.0002 (1.0) 648.7823 (1.0) 105.4763 (1.0) 625.0011 (1.0) 41.9968 (1.0) 184;252 1,541,349.0560 (1.0) 18434 1 test_banded_dot_perf[10-banded_dot] 2,749.9991 (4.72) 28,665.9997 (3.98) 2,954.8453 (4.55) 350.8606 (3.33) 2,917.0005 (4.67) 42.9973 (1.02) 555;5229 338,427.1940 (0.22) 39868 1 test_banded_dot_perf[100-dot] 1,042.0008 (1.79) 15,624.9989 (2.17) 1,178.4495 (1.82) 197.8076 (1.88) 1,166.9981 (1.87) 42.0005 (1.00) 512;1917 848,572.6277 (0.55) 100848 1 test_banded_dot_perf[100-banded_dot] 3,166.9988 (5.43) 33,166.9980 (4.60) 3,418.6797 (5.27) 364.1081 (3.45) 3,415.9966 (5.47) 83.0005 (1.98) 826;2615 292,510.5862 (0.19) 65574 1 test_banded_dot_perf[1000-dot] 13,334.0000 (22.87) 45,625.0018 (6.33) 15,480.3238 (23.86) 1,366.7475 (12.96) 15,957.9977 (25.53) 1,958.0002 (46.62) 1490;223 64,598.1318 (0.04) 20426 1 test_banded_dot_perf[1000-banded_dot] 8,541.9997 (14.65) 50,667.0003 (7.03) 10,089.9543 (15.55) 777.8152 (7.37) 10,416.9994 (16.67) 1,290.9986 (30.74) 11635;128 99,108.4762 (0.06) 38096 1 test_banded_dot_perf[10_000-dot] 3,365,791.9994 (>1000.0) 5,034,374.9972 (698.44) 3,495,052.0250 (>1000.0) 345,179.3641 (>1000.0) 3,410,270.5013 (>1000.0) 47,562.5002 (>1000.0) 2;3 286.1188 (0.00) 40 1 test_banded_dot_perf[10_000-banded_dot] 80,417.0013 (137.94) 454,208.9991 (63.01) 119,363.4743 (183.98) 65,435.1952 (620.38) 91,417.0014 (146.27) 38,540.9949 (917.71) 33;33 8,377.7722 (0.01) 350 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 

I'm curious if it's possible to destroy A and make it into A_banded in-place. If it's possible, it doesn't seem trivial. BLAS doesn't have an overwrite_x option, so b can't be destroyed either.

Frankly my time would be better served thinking about how to do this in C at this point.

@jessegrabowski
Copy link
Member Author

Also we should probably be benchmarking this against sparse_dot -- this all might be a waste of time?

@ricardoV94
Copy link
Member

ricardoV94 commented May 23, 2025

Also we should probably be benchmarking this against sparse_dot -- this all might be a waste of time?

Well SparseDot doesn't work with batch inputs, but I'm curious. Also I don't think the code is too complex or performing too bad. I don't agree with your sentiment, should be thinking of a C impl. A numba one is more interesting...

@jessegrabowski
Copy link
Member Author

jessegrabowski commented May 23, 2025

Por que não os dois?

Seriously though my feeling is that if we're putting this stuff into a PyMC model the code has to be ultra-performant. It's going to be called umptillion times, the inner-loop of a PDE solver times the MCMC loop.

I'll work on the numba dispatch next at any rate

@ricardoV94
Copy link
Member

By that argument you can't really add any specialized Op that doesn't have a C implementation (unless it's replacing an Op that also doesn't have C implementation).

Ignoring the general user, you can have code to decide whether to use this Op or not based on the size (or a rewrite). Also how are you sampling / getting A, can you avoid the boxing/unboxing of the diagonals?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented May 23, 2025

well the point is the specialization isn't adding anything over good ol' pt.dot (yet!), except for really huge matrices.

@jessegrabowski
Copy link
Member Author

I just pushed a major refactor to this PR, which:

  1. Renamed BandedDot to BandedGEMV (which is what it actually is, though the actual routine is called GBMV, i thought banded GEMV was more clear)
  2. Add support for all GBMV arguments (A, x, y, alpha, beta) in BandedGEMV.
  3. Adjusts the numba overload accordingly
  4. Adds a numba overload for GEMV itself. Note that this will never be used, because we don't include BlasOpt in the numba rewrites.

Regarding point (4), here are the benchmarks using the numba GEMV overload vs what we current get with mode="NUMBA":

------------------------------------------------------------------------------------------------ benchmark: 6 tests ------------------------------------------------------------------------------------------------ Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_numba_gemv_benchmark[numba-10] 6.7500 (1.0) 109.3330 (1.32) 7.9952 (1.0) 1.6793 (1.09) 7.8750 (1.0) 0.2910 (1.40) 878;1881 125.0747 (1.0) 29888 1 test_numba_gemv_benchmark[numba+blas-10] 7.5000 (1.11) 82.7080 (1.0) 8.0753 (1.01) 1.9040 (1.24) 7.8750 (1.00) 0.2080 (1.0) 252;986 123.8337 (0.99) 14185 1 test_numba_gemv_benchmark[numba-100] 8.0410 (1.19) 129.5830 (1.57) 10.3897 (1.30) 1.7014 (1.11) 10.5830 (1.34) 1.2920 (6.21) 157;122 96.2488 (0.77) 29963 1 test_numba_gemv_benchmark[numba+blas-100] 7.9170 (1.17) 89.1250 (1.08) 10.1977 (1.28) 1.5343 (1.0) 10.4170 (1.32) 1.2080 (5.81) 218;165 98.0610 (0.78) 32129 1 test_numba_gemv_benchmark[numba-1000] 22.2920 (3.30) 708.7500 (8.57) 25.6014 (3.20) 6.1947 (4.04) 24.9170 (3.16) 1.6670 (8.01) 229;1198 39.0604 (0.31) 18824 1 test_numba_gemv_benchmark[numba+blas-1000] 21.3330 (3.16) 186.1250 (2.25) 24.8268 (3.11) 4.2993 (2.80) 24.0830 (3.06) 1.8750 (9.01) 448;739 40.2790 (0.32) 17992 1 -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 

It's about the same or maybe slightly better, but at the cost that we can't cache the compiled function anymore due to the function pointer.

Also note that the test is very sensitive to the detection of the alpha parameter. I had to write:

alpha * (A @ x) + beta * y 

In order for the GEMV rewrite to correctly find alpha. If it fails to find alpha, mode="NUMBA" significantly out-performs the GEMV call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request help wanted Extra attention is needed linalg Linear algebra Op implementation

2 participants