@@ -130,15 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
130130 if isinstance (idx .type , TensorType )
131131 ]
132132
133- def broadcasted_to (x_bcast : tuple [bool , ...], to_bcast : tuple [bool , ...]):
134- # Check that x is not broadcasted to y based on broadcastable info
135- if len (x_bcast ) < len (to_bcast ):
136- return True
137- for x_bcast_dim , to_bcast_dim in zip (x_bcast , to_bcast , strict = True ):
138- if x_bcast_dim and not to_bcast_dim :
139- return True
140- return False
141-
142133 # Special implementation for consecutive integer vector indices
143134 if (
144135 not basic_idxs
@@ -151,17 +142,6 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
151142 )
152143 # Must be consecutive
153144 and not op .non_consecutive_adv_indexing (node )
154- # y in set/inc_subtensor cannot be broadcasted
155- and (
156- y is None
157- or not broadcasted_to (
158- y .type .broadcastable ,
159- (
160- x .type .broadcastable [: adv_idxs [0 ]["axis" ]]
161- + x .type .broadcastable [adv_idxs [- 1 ]["axis" ] :]
162- ),
163- )
164- )
165145 ):
166146 return numba_funcify_multiple_integer_vector_indexing (op , node , ** kwargs )
167147
@@ -191,14 +171,24 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
191171 return numba_funcify_default_subtensor (op , node , ** kwargs )
192172
193173
174+ def _broadcasted_to (x_bcast : tuple [bool , ...], to_bcast : tuple [bool , ...]):
175+ # Check that x is not broadcasted to y based on broadcastable info
176+ if len (x_bcast ) < len (to_bcast ):
177+ return True
178+ for x_bcast_dim , to_bcast_dim in zip (x_bcast , to_bcast , strict = True ):
179+ if x_bcast_dim and not to_bcast_dim :
180+ return True
181+ return False
182+
183+
194184def numba_funcify_multiple_integer_vector_indexing (
195185 op : AdvancedSubtensor | AdvancedIncSubtensor , node , ** kwargs
196186):
197187 # Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
198188 if isinstance (op , AdvancedSubtensor ):
199- y , idxs = None , node .inputs [1 :]
189+ idxs = node .inputs [1 :]
200190 else :
201- y , * idxs = node .inputs [1 :]
191+ idxs = node .inputs [2 :]
202192
203193 first_axis = next (
204194 i for i , idx in enumerate (idxs ) if isinstance (idx .type , TensorType )
@@ -211,6 +201,10 @@ def numba_funcify_multiple_integer_vector_indexing(
211201 )
212202 except StopIteration :
213203 after_last_axis = len (idxs )
204+ last_axis = after_last_axis - 1
205+
206+ vector_indices = idxs [first_axis :after_last_axis ]
207+ assert all (v .type .broadcastable == (False ,) for v in vector_indices )
214208
215209 if isinstance (op , AdvancedSubtensor ):
216210
@@ -231,43 +225,59 @@ def advanced_subtensor_multiple_vector(x, *idxs):
231225
232226 return advanced_subtensor_multiple_vector
233227
234- elif op . set_instead_of_inc :
228+ else :
235229 inplace = op .inplace
236230
237- @numba_njit
238- def advanced_set_subtensor_multiple_vector (x , y , * idxs ):
239- vec_idxs = idxs [first_axis :after_last_axis ]
240- x_shape = x .shape
231+ # Check if y must be broadcasted
232+ # Includes the last integer vector index,
233+ x , y = node .inputs [:2 ]
234+ indexed_bcast_dims = (
235+ * x .type .broadcastable [:first_axis ],
236+ * x .type .broadcastable [last_axis :],
237+ )
238+ y_is_broadcasted = _broadcasted_to (y .type .broadcastable , indexed_bcast_dims )
241239
242- if inplace :
243- out = x
244- else :
245- out = x .copy ()
240+ if op .set_instead_of_inc :
246241
247- for outer in np . ndindex ( x_shape [: first_axis ]):
248- for i , scalar_idxs in enumerate ( zip ( * vec_idxs )): # noqa: B905
249- out [( * outer , * scalar_idxs )] = y [( * outer , i ) ]
250- return out
242+ @ numba_njit
243+ def advanced_set_subtensor_multiple_vector ( x , y , * idxs ):
244+ vec_idxs = idxs [ first_axis : after_last_axis ]
245+ x_shape = x . shape
251246
252- return advanced_set_subtensor_multiple_vector
247+ if inplace :
248+ out = x
249+ else :
250+ out = x .copy ()
253251
254- else :
255- inplace = op . inplace
252+ if y_is_broadcasted :
253+ y = np . broadcast_to ( y , x_shape [: first_axis ] + x_shape [ last_axis :])
256254
257- @numba_njit
258- def advanced_inc_subtensor_multiple_vector (x , y , * idxs ):
259- vec_idxs = idxs [first_axis :after_last_axis ]
260- x_shape = x .shape
255+ for outer in np .ndindex (x_shape [:first_axis ]):
256+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
257+ out [(* outer , * scalar_idxs )] = y [(* outer , i )]
258+ return out
259+
260+ return advanced_set_subtensor_multiple_vector
261+
262+ else :
263+
264+ @numba_njit
265+ def advanced_inc_subtensor_multiple_vector (x , y , * idxs ):
266+ vec_idxs = idxs [first_axis :after_last_axis ]
267+ x_shape = x .shape
268+
269+ if inplace :
270+ out = x
271+ else :
272+ out = x .copy ()
261273
262- if inplace :
263- out = x
264- else :
265- out = x .copy ()
274+ if y_is_broadcasted :
275+ y = np .broadcast_to (y , x_shape [:first_axis ] + x_shape [last_axis :])
266276
267- for outer in np .ndindex (x_shape [:first_axis ]):
268- for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
269- out [(* outer , * scalar_idxs )] += y [(* outer , i )]
270- return out
277+ for outer in np .ndindex (x_shape [:first_axis ]):
278+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
279+ out [(* outer , * scalar_idxs )] += y [(* outer , i )]
280+ return out
271281
272282 return advanced_inc_subtensor_multiple_vector
273283
0 commit comments