@@ -130,15 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
130130 if isinstance (idx .type , TensorType )
131131 ]
132132
133- def broadcasted_to (x_bcast : tuple [bool , ...], to_bcast : tuple [bool , ...]):
134- # Check that x is not broadcasted to y based on broadcastable info
135- if len (x_bcast ) < len (to_bcast ):
136- return True
137- for x_bcast_dim , to_bcast_dim in zip (x_bcast , to_bcast , strict = True ):
138- if x_bcast_dim and not to_bcast_dim :
139- return True
140- return False
141-
142133 # Special implementation for consecutive integer vector indices
143134 if (
144135 not basic_idxs
@@ -151,17 +142,6 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
151142 )
152143 # Must be consecutive
153144 and not op .non_consecutive_adv_indexing (node )
154- # y in set/inc_subtensor cannot be broadcasted
155- and (
156- y is None
157- or not broadcasted_to (
158- y .type .broadcastable ,
159- (
160- x .type .broadcastable [: adv_idxs [0 ]["axis" ]]
161- + x .type .broadcastable [adv_idxs [- 1 ]["axis" ] :]
162- ),
163- )
164- )
165145 ):
166146 return numba_funcify_multiple_integer_vector_indexing (op , node , ** kwargs )
167147
@@ -191,14 +171,24 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
191171 return numba_funcify_default_subtensor (op , node , ** kwargs )
192172
193173
174+ def _broadcasted_to (x_bcast : tuple [bool , ...], to_bcast : tuple [bool , ...]):
175+ # Check that x is not broadcasted to y based on broadcastable info
176+ if len (x_bcast ) < len (to_bcast ):
177+ return True
178+ for x_bcast_dim , to_bcast_dim in zip (x_bcast , to_bcast , strict = True ):
179+ if x_bcast_dim and not to_bcast_dim :
180+ return True
181+ return False
182+
183+
194184def numba_funcify_multiple_integer_vector_indexing (
195185 op : AdvancedSubtensor | AdvancedIncSubtensor , node , ** kwargs
196186):
197187 # Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
198188 if isinstance (op , AdvancedSubtensor ):
199- y , idxs = None , node .inputs [ 1 :]
189+ x , y , idxs = None , node .inputs
200190 else :
201- y , * idxs = node .inputs [ 1 :]
191+ x , y , * idxs = node .inputs
202192
203193 first_axis = next (
204194 i for i , idx in enumerate (idxs ) if isinstance (idx .type , TensorType )
@@ -211,6 +201,10 @@ def numba_funcify_multiple_integer_vector_indexing(
211201 )
212202 except StopIteration :
213203 after_last_axis = len (idxs )
204+ last_axis = after_last_axis - 1
205+
206+ vector_indices = idxs [first_axis :after_last_axis ]
207+ assert all (v .type .broadcastable == (False ,) for v in vector_indices )
214208
215209 if isinstance (op , AdvancedSubtensor ):
216210
@@ -231,43 +225,58 @@ def advanced_subtensor_multiple_vector(x, *idxs):
231225
232226 return advanced_subtensor_multiple_vector
233227
234- elif op . set_instead_of_inc :
228+ else :
235229 inplace = op .inplace
236230
237- @numba_njit
238- def advanced_set_subtensor_multiple_vector (x , y , * idxs ):
239- vec_idxs = idxs [first_axis :after_last_axis ]
240- x_shape = x .shape
231+ # Check if y must be broadcasted
232+ # Includes the last integer vector index,
233+ indexed_bcast_dims = (
234+ * x .type .broadcastable [:first_axis ],
235+ * x .type .broadcastable [last_axis :],
236+ )
237+ y_is_broadcasted = _broadcasted_to (y .type .broadcastable , indexed_bcast_dims )
241238
242- if inplace :
243- out = x
244- else :
245- out = x .copy ()
239+ if op .set_instead_of_inc :
246240
247- for outer in np . ndindex ( x_shape [: first_axis ]):
248- for i , scalar_idxs in enumerate ( zip ( * vec_idxs )): # noqa: B905
249- out [( * outer , * scalar_idxs )] = y [( * outer , i ) ]
250- return out
241+ @ numba_njit
242+ def advanced_set_subtensor_multiple_vector ( x , y , * idxs ):
243+ vec_idxs = idxs [ first_axis : after_last_axis ]
244+ x_shape = x . shape
251245
252- return advanced_set_subtensor_multiple_vector
246+ if inplace :
247+ out = x
248+ else :
249+ out = x .copy ()
253250
254- else :
255- inplace = op . inplace
251+ if y_is_broadcasted :
252+ y = np . broadcast_to ( y , x_shape [: first_axis ] + x_shape [ last_axis :])
256253
257- @numba_njit
258- def advanced_inc_subtensor_multiple_vector (x , y , * idxs ):
259- vec_idxs = idxs [first_axis :after_last_axis ]
260- x_shape = x .shape
254+ for outer in np .ndindex (x_shape [:first_axis ]):
255+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
256+ out [(* outer , * scalar_idxs )] = y [(* outer , i )]
257+ return out
258+
259+ return advanced_set_subtensor_multiple_vector
260+
261+ else :
262+
263+ @numba_njit
264+ def advanced_inc_subtensor_multiple_vector (x , y , * idxs ):
265+ vec_idxs = idxs [first_axis :after_last_axis ]
266+ x_shape = x .shape
267+
268+ if inplace :
269+ out = x
270+ else :
271+ out = x .copy ()
261272
262- if inplace :
263- out = x
264- else :
265- out = x .copy ()
273+ if y_is_broadcasted :
274+ y = np .broadcast_to (y , x_shape [:first_axis ] + x_shape [last_axis :])
266275
267- for outer in np .ndindex (x_shape [:first_axis ]):
268- for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
269- out [(* outer , * scalar_idxs )] += y [(* outer , i )]
270- return out
276+ for outer in np .ndindex (x_shape [:first_axis ]):
277+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
278+ out [(* outer , * scalar_idxs )] += y [(* outer , i )]
279+ return out
271280
272281 return advanced_inc_subtensor_multiple_vector
273282
0 commit comments