- Notifications
You must be signed in to change notification settings - Fork 149
Open
Labels
OpFromGraphbeginner friendlyenhancementNew feature or requestNew feature or requestlinalgLinear algebraLinear algebra
Description
Description
Currently we have an Op that calls np.tri, but we can very easily build lower triangular mask matrices with _iota:
from pytensor.tensor.einsum import _iota def tri(M, N, k): return ((_iota(M) + k) > _iota(N)).astype(int)This is what jax does. The benefit of doing things this way is that we'll automatically have a dispatchable Op for Numba (numba supports np.tri, but only under specific circumstances -- I tried a naive dispatch and it didn't work ) and Pytorch (#821 asks for Tri, so this would check off that box)
I suggest we wrap this in a dummy OpFromGraph like we do for Kron and AllocDiag so that the dprints are nicer. We can also overload the L_op if we want? The current tri has grad_undefined, so we could keep that if it's correct. Or just keep the autodiff solution -- the proposed _iota function should be differentiable.
Metadata
Metadata
Assignees
Labels
OpFromGraphbeginner friendlyenhancementNew feature or requestNew feature or requestlinalgLinear algebraLinear algebra