Skip to content

Commit 41c2131

Browse files
committed
Fix numba symmetrical solve reciprocal of condition number
1 parent 15fb803 commit 41c2131

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def impl(
654654

655655
def _sysv(
656656
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
657-
) -> tuple[np.ndarray, np.ndarray, int]:
657+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
658658
"""
659659
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
660660
"""
@@ -665,7 +665,8 @@ def _sysv(
665665
def sysv_impl(
666666
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
667667
) -> Callable[
668-
[np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int]
668+
[np.ndarray, np.ndarray, bool, bool, bool],
669+
tuple[np.ndarray, np.ndarray, np.ndarray, int],
669670
]:
670671
ensure_lapack()
671672
_check_scipy_linalg_matrix(A, "sysv")
@@ -741,8 +742,8 @@ def impl(
741742
)
742743

743744
if B_is_1d:
744-
return B_copy[..., 0], IPIV, int_ptr_to_val(INFO)
745-
return B_copy, IPIV, int_ptr_to_val(INFO)
745+
B_copy = B_copy[..., 0]
746+
return A_copy, B_copy, IPIV, int_ptr_to_val(INFO)
746747

747748
return impl
748749

@@ -771,7 +772,7 @@ def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int
771772

772773
N = val_to_int_ptr(_N)
773774
LDA = val_to_int_ptr(_N)
774-
UPLO = val_to_int_ptr(ord("L"))
775+
UPLO = val_to_int_ptr(ord("U"))
775776
ANORM = np.array(anorm, dtype=dtype)
776777
RCOND = np.empty(1, dtype=dtype)
777778
WORK = np.empty(2 * _N, dtype=dtype)
@@ -844,10 +845,10 @@ def impl(
844845
) -> np.ndarray:
845846
_solve_check_input_shapes(A, B)
846847

847-
x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
848+
lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
848849
_solve_check(A.shape[-1], info)
849850

850-
rcond, info = _sycon(A, ipiv, _xlange(A, order="I"))
851+
rcond, info = _sycon(lu, ipiv, _xlange(A, order="I"))
851852
_solve_check(A.shape[-1], info, True, rcond)
852853

853854
return x

0 commit comments

Comments
 (0)