Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 50 additions & 22 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self,
*,
lower: bool = True,
check_finite: bool = True,
check_finite: bool = False,
on_error: Literal["raise", "nan"] = "raise",
overwrite_a: bool = False,
):
Expand Down Expand Up @@ -67,29 +67,55 @@ def make_node(self, x):
def perform(self, node, inputs, outputs):
[x] = inputs
[out] = outputs
try:
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
out[0] = scipy_linalg.cholesky(
x.T,
lower=not self.lower,
check_finite=self.check_finite,
overwrite_a=True,
).T
else:
out[0] = scipy_linalg.cholesky(
x,
lower=self.lower,
check_finite=self.check_finite,
overwrite_a=self.overwrite_a,
)

except scipy_linalg.LinAlgError:
if self.on_error == "raise":
raise
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,))

# Quick return for square empty array
if x.size == 0:
out[0] = np.empty_like(x, dtype=potrf.dtype)
return

if self.check_finite and not np.isfinite(x).all():
if self.on_error == "nan":
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
out[0] = np.full(x.shape, np.nan, dtype=potrf.dtype)
return
else:
raise ValueError("array must not contain infs or NaNs")

# Squareness check
if x.shape[0] != x.shape[1]:
raise ValueError(
"Input array is expected to be square but has " f"the shape: {x.shape}."
)

# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"]
if c_contiguous_input:
x = x.T
lower = not self.lower
overwrite_a = True
else:
lower = self.lower
overwrite_a = self.overwrite_a

c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True)

if info != 0:
if self.on_error == "nan":
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
elif info > 0:
raise scipy_linalg.LinAlgError(
f"{info}-th leading minor of the array is not positive definite"
)
elif info < 0:
raise ValueError(
f"LAPACK reported an illegal value in {-info}-th argument "
f'on entry to "POTRF".'
)
else:
# Transpose result if input was transposed
out[0] = c.T if c_contiguous_input else c

def L_op(self, inputs, outputs, gradients):
"""
Expand Down Expand Up @@ -201,7 +227,9 @@ def cholesky(

"""

return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)


class SolveBase(Op):
Expand Down
20 changes: 20 additions & 0 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ def test_cholesky():
check_upper_triangular(pd, ch_f)


def test_cholesky_performance(benchmark):
rng = np.random.default_rng(utt.fetch_seed())
r = rng.standard_normal((10, 10)).astype(config.floatX)
pd = np.dot(r, r.T)
x = matrix()
chol = cholesky(x)
ch_f = function([x], chol)
benchmark(ch_f, pd)


def test_cholesky_empty():
empty = np.empty([0, 0], dtype=config.floatX)
x = matrix()
chol = cholesky(x)
ch_f = function([x], chol)
ch = ch_f(empty)
assert ch.size == 0
assert ch.dtype == config.floatX


def test_cholesky_indef():
x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
Expand Down
Loading