Skip to content

Commit 278996e

Browse files
committed
Implement numba tridiagonal solve
1 parent d50518f commit 278996e

File tree

3 files changed

+320
-3
lines changed

3 files changed

+320
-3
lines changed
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
from collections.abc import Callable
2+
3+
import numpy as np
4+
from numba.core.extending import overload
5+
from numba.np.linalg import ensure_lapack
6+
from numpy import ndarray
7+
from scipy import linalg
8+
9+
from pytensor.link.numba.dispatch.basic import numba_njit
10+
from pytensor.link.numba.dispatch.linalg._LAPACK import (
11+
_LAPACK,
12+
_get_underlying_float,
13+
int_ptr_to_val,
14+
val_to_int_ptr,
15+
)
16+
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
17+
from pytensor.link.numba.dispatch.linalg.utils import (
18+
_check_scipy_linalg_matrix,
19+
_copy_to_fortran_order_even_if_1d,
20+
_solve_check,
21+
_trans_char_to_int,
22+
)
23+
24+
25+
def _gttrf(
26+
dl: ndarray, d: ndarray, du: ndarray
27+
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
28+
"""Placeholder for LU factorization of tridiagonal matrix."""
29+
return # type: ignore
30+
31+
32+
@overload(_gttrf)
33+
def gttrf_impl(
34+
dl: ndarray,
35+
d: ndarray,
36+
du: ndarray,
37+
) -> Callable[
38+
[ndarray, ndarray, ndarray], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]
39+
]:
40+
ensure_lapack()
41+
_check_scipy_linalg_matrix(dl, "gttrf")
42+
_check_scipy_linalg_matrix(d, "gttrf")
43+
_check_scipy_linalg_matrix(du, "gttrf")
44+
dtype = d.dtype
45+
w_type = _get_underlying_float(dtype)
46+
numba_gttrf = _LAPACK().numba_xgttrf(dtype)
47+
48+
def impl(
49+
dl: ndarray,
50+
d: ndarray,
51+
du: ndarray,
52+
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
53+
n = np.int32(d.shape[-1])
54+
ipiv = np.empty(n, dtype=np.int32)
55+
du2 = np.empty(n - 2, dtype=dtype)
56+
info = val_to_int_ptr(0)
57+
58+
numba_gttrf(
59+
val_to_int_ptr(n),
60+
dl.view(w_type).ctypes,
61+
d.view(w_type).ctypes,
62+
du.view(w_type).ctypes,
63+
du2.view(w_type).ctypes,
64+
ipiv.ctypes,
65+
info,
66+
)
67+
68+
return dl, d, du, du2, ipiv, int_ptr_to_val(info)
69+
70+
return impl
71+
72+
73+
def _gttrs(
74+
dl: ndarray,
75+
d: ndarray,
76+
du: ndarray,
77+
du2: ndarray,
78+
ipiv: ndarray,
79+
b: ndarray,
80+
overwrite_b: bool,
81+
trans: bool,
82+
) -> tuple[ndarray, int]:
83+
"""Placeholder for solving an LU-decomposed tridiagonal system."""
84+
return # type: ignore
85+
86+
87+
@overload(_gttrs)
88+
def gttrs_impl(
89+
dl: ndarray,
90+
d: ndarray,
91+
du: ndarray,
92+
du2: ndarray,
93+
ipiv: ndarray,
94+
b: ndarray,
95+
overwrite_b: bool,
96+
trans: bool,
97+
) -> Callable[
98+
[ndarray, ndarray, ndarray, ndarray, ndarray, ndarray, bool, bool],
99+
tuple[ndarray, int],
100+
]:
101+
ensure_lapack()
102+
_check_scipy_linalg_matrix(dl, "gttrs")
103+
_check_scipy_linalg_matrix(d, "gttrs")
104+
_check_scipy_linalg_matrix(du, "gttrs")
105+
_check_scipy_linalg_matrix(du2, "gttrs")
106+
_check_scipy_linalg_matrix(b, "gttrs")
107+
dtype = d.dtype
108+
w_type = _get_underlying_float(dtype)
109+
numba_gttrs = _LAPACK().numba_xgttrs(dtype)
110+
111+
def impl(
112+
dl: ndarray,
113+
d: ndarray,
114+
du: ndarray,
115+
du2: ndarray,
116+
ipiv: ndarray,
117+
b: ndarray,
118+
overwrite_b: bool,
119+
trans: bool,
120+
) -> tuple[ndarray, int]:
121+
n = np.int32(d.shape[-1])
122+
nrhs = 1 if b.ndim == 1 else int(b.shape[-1])
123+
info = val_to_int_ptr(0)
124+
125+
if overwrite_b and b.flags.f_contiguous:
126+
b_copy = b
127+
else:
128+
b_copy = _copy_to_fortran_order_even_if_1d(b)
129+
130+
numba_gttrs(
131+
val_to_int_ptr(_trans_char_to_int(trans)),
132+
val_to_int_ptr(n),
133+
val_to_int_ptr(nrhs),
134+
dl.view(w_type).ctypes,
135+
d.view(w_type).ctypes,
136+
du.view(w_type).ctypes,
137+
du2.view(w_type).ctypes,
138+
ipiv.ctypes,
139+
b_copy.view(w_type).ctypes,
140+
val_to_int_ptr(n),
141+
info,
142+
)
143+
144+
return b_copy, int_ptr_to_val(info)
145+
146+
return impl
147+
148+
149+
def _gtcon(
150+
dl: ndarray,
151+
d: ndarray,
152+
du: ndarray,
153+
du2: ndarray,
154+
ipiv: ndarray,
155+
anorm: float,
156+
norm: str,
157+
) -> tuple[ndarray, int]:
158+
"""Placeholder for computing the condition number of a tridiagonal system."""
159+
return # type: ignore
160+
161+
162+
@overload(_gtcon)
163+
def gtcon_impl(
164+
dl: ndarray,
165+
d: ndarray,
166+
du: ndarray,
167+
du2: ndarray,
168+
ipiv: ndarray,
169+
anorm: float,
170+
norm: str,
171+
) -> Callable[
172+
[ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int]
173+
]:
174+
ensure_lapack()
175+
_check_scipy_linalg_matrix(dl, "gtcon")
176+
_check_scipy_linalg_matrix(d, "gtcon")
177+
_check_scipy_linalg_matrix(du, "gtcon")
178+
_check_scipy_linalg_matrix(du2, "gtcon")
179+
dtype = d.dtype
180+
w_type = _get_underlying_float(dtype)
181+
numba_gtcon = _LAPACK().numba_xgtcon(dtype)
182+
183+
def impl(
184+
dl: ndarray,
185+
d: ndarray,
186+
du: ndarray,
187+
du2: ndarray,
188+
ipiv: ndarray,
189+
anorm: float,
190+
norm: str,
191+
) -> tuple[ndarray, int]:
192+
n = np.int32(d.shape[-1])
193+
rcond = np.empty(1, dtype=dtype)
194+
work = np.empty(2 * n, dtype=dtype)
195+
iwork = np.empty(n, dtype=np.int32)
196+
info = val_to_int_ptr(0)
197+
198+
numba_gtcon(
199+
val_to_int_ptr(ord(norm)),
200+
val_to_int_ptr(n),
201+
dl.view(w_type).ctypes,
202+
d.view(w_type).ctypes,
203+
du.view(w_type).ctypes,
204+
du2.view(w_type).ctypes,
205+
ipiv.ctypes,
206+
np.array(anorm, dtype=dtype).view(w_type).ctypes,
207+
rcond.view(w_type).ctypes,
208+
work.view(w_type).ctypes,
209+
iwork.ctypes,
210+
info,
211+
)
212+
213+
return rcond, int_ptr_to_val(info)
214+
215+
return impl
216+
217+
218+
def _solve_tridiagonal(
219+
a: ndarray,
220+
b: ndarray,
221+
lower: bool,
222+
overwrite_a: bool,
223+
overwrite_b: bool,
224+
check_finite: bool,
225+
transposed: bool,
226+
):
227+
"""
228+
Solve a positive-definite linear system using the Cholesky decomposition.
229+
"""
230+
return linalg.solve(
231+
a=a,
232+
b=b,
233+
lower=lower,
234+
overwrite_a=overwrite_a,
235+
overwrite_b=overwrite_b,
236+
check_finite=check_finite,
237+
transposed=transposed,
238+
assume_a="tridiagonal",
239+
)
240+
241+
242+
@numba_njit
243+
def tridiagonal_norm(du, d, dl):
244+
# Adapted from scipy _matrix_norm_tridiagonal:
245+
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
246+
anorm = np.abs(d)
247+
anorm[1:] += np.abs(du)
248+
anorm[:-1] += np.abs(dl)
249+
anorm = anorm.max()
250+
return anorm
251+
252+
253+
@overload(_solve_tridiagonal)
254+
def _tridiagonal_solve_impl(
255+
A: ndarray,
256+
B: ndarray,
257+
lower: bool,
258+
overwrite_a: bool,
259+
overwrite_b: bool,
260+
check_finite: bool,
261+
transposed: bool,
262+
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]:
263+
ensure_lapack()
264+
_check_scipy_linalg_matrix(A, "solve")
265+
_check_scipy_linalg_matrix(B, "solve")
266+
267+
def impl(
268+
A: ndarray,
269+
B: ndarray,
270+
lower: bool,
271+
overwrite_a: bool,
272+
overwrite_b: bool,
273+
check_finite: bool,
274+
transposed: bool,
275+
) -> ndarray:
276+
n = np.int32(A.shape[-1])
277+
_solve_check_input_shapes(A, B)
278+
norm = "1"
279+
280+
if transposed:
281+
A = A.T
282+
dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1)
283+
284+
anorm = tridiagonal_norm(du, d, dl)
285+
286+
dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du)
287+
_solve_check(n, INFO)
288+
289+
X, INFO = _gttrs(
290+
dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b
291+
)
292+
_solve_check(n, INFO)
293+
294+
RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm)
295+
_solve_check(n, INFO, True, RCOND)
296+
297+
return X
298+
299+
return impl

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
1010
from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
1111
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
12+
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
1213
from pytensor.tensor.slinalg import (
1314
BlockDiagonal,
1415
Cholesky,
@@ -114,10 +115,12 @@ def numba_funcify_Solve(op, node, **kwargs):
114115
solve_fn = _solve_symmetric
115116
elif assume_a == "pos":
116117
solve_fn = _solve_psd
118+
elif assume_a == "tridiagonal":
119+
solve_fn = _solve_tridiagonal
117120
else:
118121
warnings.warn(
119122
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
120-
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her', or 'triangular' to improve performance.",
123+
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her', 'triangular' or 'tridiagonal' to improve performance.",
121124
UserWarning,
122125
)
123126
solve_fn = _solve_gen

tests/link/numba/test_slinalg.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class TestSolves:
9797
[(5, 1), (5, 5), (5,)],
9898
ids=["b_col_vec", "b_matrix", "b_vec"],
9999
)
100-
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
100+
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos", "tridiagonal"], ids=str)
101101
def test_solve(
102102
self,
103103
b_shape: tuple[int],
@@ -106,7 +106,7 @@ def test_solve(
106106
overwrite_a: bool,
107107
overwrite_b: bool,
108108
):
109-
if assume_a not in ("sym", "her", "pos") and not lower:
109+
if assume_a not in ("sym", "her", "pos", "tridiagonal") and not lower:
110110
# Avoid redundant tests with lower=True and lower=False for non symmetric matrices
111111
pytest.skip("Skipping redundant test already covered by lower=True")
112112

@@ -120,6 +120,14 @@ def A_func(x):
120120
# We have to set the unused triangle to something other than zero
121121
# to see lapack destroying it.
122122
x[np.triu_indices(n, 1) if lower else np.tril_indices(n, 1)] = np.pi
123+
elif assume_a == "tridiagonal":
124+
_x = x
125+
x = np.zeros_like(x)
126+
n = x.shape[-1]
127+
arange_n = np.arange(n)
128+
x[arange_n[1:], arange_n[:-1]] = np.diag(_x, k=-1)
129+
x[arange_n, arange_n] = np.diag(_x, k=0)
130+
x[arange_n[:-1], arange_n[1:]] = np.diag(_x, k=1)
123131
return x
124132

125133
A = pt.matrix("A", dtype=floatX)
@@ -146,7 +154,14 @@ def A_func(x):
146154

147155
op = f.maker.fgraph.outputs[0].owner.op
148156
assert isinstance(op, Solve)
157+
assert op.assume_a == assume_a
149158
destroy_map = op.destroy_map
159+
160+
if overwrite_a and assume_a == "tridiagonal":
161+
# Tridiagonal solve never destroys the A matrix
162+
# Treat test from here as if overwrite_a is False
163+
overwrite_a = False
164+
150165
if overwrite_a and overwrite_b:
151166
raise NotImplementedError(
152167
"Test not implemented for simultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor"

0 commit comments

Comments
 (0)