Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,17 @@

B_is_1d = B.ndim == 1

if overwrite_b:
B_copy = B
else:
if B_is_1d:
# _copy_to_fortran_order does nothing with vectors
B_copy = np.copy(B)
else:
B_copy = _copy_to_fortran_order(B)
A_copy = _copy_to_fortran_order(A)

Check warning on line 129 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L129

Added line #L129 was not covered by tests

if B_is_1d:
B_copy = np.expand_dims(B_copy, -1)
# This list is exhaustive, but numba freaks out if we include a final else clause
if not overwrite_b and not B_is_1d:
B_copy = _copy_to_fortran_order(B)

Check warning on line 133 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L133

Added line #L133 was not covered by tests
elif overwrite_b and not B_is_1d:
B_copy = np.asfortranarray(B)

Check warning on line 135 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L135

Added line #L135 was not covered by tests
elif not overwrite_b and B_is_1d:
B_copy = np.copy(np.expand_dims(B, -1))

Check warning on line 137 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L137

Added line #L137 was not covered by tests
elif overwrite_b and B_is_1d:
B_copy = np.expand_dims(B, -1)

Check warning on line 139 in pytensor/link/numba/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/slinalg.py#L139

Added line #L139 was not covered by tests

NRHS = 1 if B_is_1d else int(B_copy.shape[-1])

Expand All @@ -155,7 +155,7 @@
DIAG,
N,
NRHS,
np.asfortranarray(A).T.view(w_type).ctypes,
A_copy.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
Expand Down
43 changes: 43 additions & 0 deletions tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.tensor.slinalg import SolveTriangular
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py

Expand Down Expand Up @@ -130,6 +131,48 @@ def A_func_pt(x):
)


@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"])
def test_solve_triangular_overwrite_b_correct(overwrite_b):
# Regression test for issue #1233

rng = np.random.default_rng(utt.fetch_seed())
a_test_py = np.asfortranarray(rng.normal(size=(3, 3)))
a_test_py = np.tril(a_test_py)
b_test_py = np.asfortranarray(rng.normal(size=(3, 2)))

# .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous)
a_test_nb = a_test_py.T.copy().T
b_test_nb = b_test_py.T.copy().T

op = SolveTriangular(
trans=0,
unit_diagonal=False,
lower=False,
check_finite=True,
b_ndim=2,
overwrite_b=overwrite_b,
)

a_pt = pt.matrix("a", shape=(3, 3))
b_pt = pt.matrix("b", shape=(3, 2))
out = op(a_pt, b_pt)

py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True)
numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA")

x_py = py_fn(a_test_py, b_test_py)
x_nb = numba_fn(a_test_nb, b_test_nb)

np.testing.assert_allclose(
py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb)
)
np.testing.assert_allclose(b_test_py, b_test_nb)

if overwrite_b:
np.testing.assert_allclose(b_test_py, x_py)
np.testing.assert_allclose(b_test_nb, x_nb)


@pytest.mark.parametrize("value", [np.nan, np.inf])
@pytest.mark.filterwarnings(
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
Expand Down