@@ -100,40 +100,43 @@ def find_solve_clients(var, assume_a):
100100 elif isinstance (cl .op , DimShuffle ) and cl .op .is_left_expand_dims :
101101 # If it's a left expand_dims, recurse on the output
102102 clients .extend (find_solve_clients (cl .outputs [0 ], assume_a ))
103+
103104 return clients
104105
105106 assume_a = node .op .core_op .assume_a
106107
107108 if assume_a not in allowed_assume_a :
108109 return None
109110
110- A , _ = get_root_A (node .inputs [0 ])
111+ root_A , root_A_transposed = get_root_A (node .inputs [0 ])
111112
112113 # Find Solve using A (or left expand_dims of A)
113114 # TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
114115 # that to the A_decomp outputs
115- A_solve_clients_and_transpose = [
116- (client , False ) for client in find_solve_clients (A , assume_a )
116+ root_A_solve_clients_and_transpose = [
117+ (client , False ) for client in find_solve_clients (root_A , assume_a )
117118 ]
118119
119120 # Find Solves using A.T
120- for cl , _ in fgraph .clients [A ]:
121+ for cl , _ in fgraph .clients [root_A ]:
121122 if isinstance (cl .op , DimShuffle ) and is_matrix_transpose (cl .out ):
122123 A_T = cl .out
123- A_solve_clients_and_transpose .extend (
124+ root_A_solve_clients_and_transpose .extend (
124125 (client , True ) for client in find_solve_clients (A_T , assume_a )
125126 )
126127
127- if not eager and len (A_solve_clients_and_transpose ) == 1 :
128+ if not eager and len (root_A_solve_clients_and_transpose ) == 1 :
128129 # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
129130 # That's a "reuse" inside the inner vectorized loop
130131 batch_ndim = node .op .batch_ndim (node )
131- (client , _ ) = A_solve_clients_and_transpose [0 ]
132- original_A , b = client .inputs
132+ (client , _ ) = root_A_solve_clients_and_transpose [0 ]
133+
134+ A , b = client .inputs
135+
133136 if not any (
134137 a_bcast and not b_bcast
135138 for a_bcast , b_bcast in zip (
136- original_A .type .broadcastable [:batch_ndim ],
139+ A .type .broadcastable [:batch_ndim ],
137140 b .type .broadcastable [:batch_ndim ],
138141 strict = True ,
139142 )
@@ -142,19 +145,27 @@ def find_solve_clients(var, assume_a):
142145
143146 # If any Op had check_finite=True, we also do it for the LU decomposition
144147 check_finite_decomp = False
145- for client , _ in A_solve_clients_and_transpose :
148+ for client , _ in root_A_solve_clients_and_transpose :
146149 if client .op .core_op .check_finite :
147150 check_finite_decomp = True
148151 break
149152
150- lower = node .op .core_op .lower
153+ (first_solve , transposed ) = root_A_solve_clients_and_transpose [0 ]
154+ lower = first_solve .op .core_op .lower
155+ if transposed :
156+ lower = not lower
157+
151158 A_decomp = decompose_A (
152- A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
159+ root_A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
153160 )
154161
155162 replacements = {}
156- for client , transposed in A_solve_clients_and_transpose :
163+ for client , transposed in root_A_solve_clients_and_transpose :
157164 _ , b = client .inputs
165+ lower = client .op .core_op .lower
166+ if transposed :
167+ lower = not lower
168+
158169 new_x = solve_decomposed_system (
159170 A_decomp ,
160171 b ,
0 commit comments