Skip to content

Commit bcb9ebd

Browse files
committed
Allow broadcasting in specialized numba dispatch of AdvancedIncSubtensor
1 parent 89d5366 commit bcb9ebd

File tree

2 files changed

+61
-52
lines changed

2 files changed

+61
-52
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 59 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
130130
if isinstance(idx.type, TensorType)
131131
]
132132

133-
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
134-
# Check that x is not broadcasted to y based on broadcastable info
135-
if len(x_bcast) < len(to_bcast):
136-
return True
137-
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
138-
if x_bcast_dim and not to_bcast_dim:
139-
return True
140-
return False
141-
142133
# Special implementation for consecutive integer vector indices
143134
if (
144135
not basic_idxs
@@ -151,17 +142,6 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
151142
)
152143
# Must be consecutive
153144
and not op.non_consecutive_adv_indexing(node)
154-
# y in set/inc_subtensor cannot be broadcasted
155-
and (
156-
y is None
157-
or not broadcasted_to(
158-
y.type.broadcastable,
159-
(
160-
x.type.broadcastable[: adv_idxs[0]["axis"]]
161-
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
162-
),
163-
)
164-
)
165145
):
166146
return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs)
167147

@@ -191,14 +171,24 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
191171
return numba_funcify_default_subtensor(op, node, **kwargs)
192172

193173

174+
def _broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
175+
# Check that x is not broadcasted to y based on broadcastable info
176+
if len(x_bcast) < len(to_bcast):
177+
return True
178+
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
179+
if x_bcast_dim and not to_bcast_dim:
180+
return True
181+
return False
182+
183+
194184
def numba_funcify_multiple_integer_vector_indexing(
195185
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
196186
):
197187
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
198188
if isinstance(op, AdvancedSubtensor):
199-
y, idxs = None, node.inputs[1:]
189+
x, y, idxs = None, node.inputs
200190
else:
201-
y, *idxs = node.inputs[1:]
191+
x, y, *idxs = node.inputs
202192

203193
first_axis = next(
204194
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
@@ -211,6 +201,10 @@ def numba_funcify_multiple_integer_vector_indexing(
211201
)
212202
except StopIteration:
213203
after_last_axis = len(idxs)
204+
last_axis = after_last_axis - 1
205+
206+
vector_indices = idxs[first_axis:after_last_axis]
207+
assert all(v.type.broadcastable == (False,) for v in vector_indices)
214208

215209
if isinstance(op, AdvancedSubtensor):
216210

@@ -231,43 +225,58 @@ def advanced_subtensor_multiple_vector(x, *idxs):
231225

232226
return advanced_subtensor_multiple_vector
233227

234-
elif op.set_instead_of_inc:
228+
else:
235229
inplace = op.inplace
236230

237-
@numba_njit
238-
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
239-
vec_idxs = idxs[first_axis:after_last_axis]
240-
x_shape = x.shape
231+
# Check if y must be broadcasted
232+
# Includes the last integer vector index,
233+
indexed_bcast_dims = (
234+
*x.type.broadcastable[:first_axis],
235+
*x.type.broadcastable[last_axis:],
236+
)
237+
y_is_broadcasted = _broadcasted_to(y.type.broadcastable, indexed_bcast_dims)
241238

242-
if inplace:
243-
out = x
244-
else:
245-
out = x.copy()
239+
if op.set_instead_of_inc:
246240

247-
for outer in np.ndindex(x_shape[:first_axis]):
248-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
249-
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
250-
return out
241+
@numba_njit
242+
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
243+
vec_idxs = idxs[first_axis:after_last_axis]
244+
x_shape = x.shape
251245

252-
return advanced_set_subtensor_multiple_vector
246+
if inplace:
247+
out = x
248+
else:
249+
out = x.copy()
253250

254-
else:
255-
inplace = op.inplace
251+
if y_is_broadcasted:
252+
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
256253

257-
@numba_njit
258-
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
259-
vec_idxs = idxs[first_axis:after_last_axis]
260-
x_shape = x.shape
254+
for outer in np.ndindex(x_shape[:first_axis]):
255+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
256+
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
257+
return out
258+
259+
return advanced_set_subtensor_multiple_vector
260+
261+
else:
262+
263+
@numba_njit
264+
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
265+
vec_idxs = idxs[first_axis:after_last_axis]
266+
x_shape = x.shape
267+
268+
if inplace:
269+
out = x
270+
else:
271+
out = x.copy()
261272

262-
if inplace:
263-
out = x
264-
else:
265-
out = x.copy()
273+
if y_is_broadcasted:
274+
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
266275

267-
for outer in np.ndindex(x_shape[:first_axis]):
268-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
269-
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
270-
return out
276+
for outer in np.ndindex(x_shape[:first_axis]):
277+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
278+
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
279+
return out
271280

272281
return advanced_inc_subtensor_multiple_vector
273282

tests/link/numba/test_subtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ def test_AdvancedIncSubtensor1(x, y, indices):
392392
np.array(-99), # Broadcasted value
393393
([1, 2], [2, 3]), # 2 vector indices
394394
False,
395-
True,
396-
True,
395+
False,
396+
False,
397397
),
398398
(
399399
np.arange(3 * 4 * 5).reshape((3, 4, 5)),

0 commit comments

Comments
 (0)