Skip to content

Commit 773339f

Browse files
committed
array-unchecked: add runtime dimension support and array-compatible methods
The extends the previous unchecked support with the ability to determine the dimensions at runtime. This incurs a small performance hit when used (versus the compile-time fixed alternative), but is still considerably faster than the full checks on every call that happen with `.at()`/`.mutable_at()`.
1 parent 423a49b commit 773339f

File tree

4 files changed

+163
-30
lines changed

4 files changed

+163
-30
lines changed

docs/advanced/pycpp/numpy.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,39 @@ To obtain the proxy from an ``array`` object, you must specify both the data
340340
type and number of dimensions as template arguments, such as ``auto r =
341341
myarray.mutable_unchecked<float, 2>()``.
342342

343+
If the number of dimensions is not known at compile time, you can omit the
344+
dimensions template parameter (i.e. calling ``arr_t.unchecked()`` or
345+
``arr.unchecked<T>()``. This will give you a proxy object that works in the
346+
same way, but results in less optimizable code and thus a small efficiency
347+
loss in tight loops.
348+
343349
Note that the returned proxy object directly references the array's data, and
344350
only reads its shape, strides, and writeable flag when constructed. You must
345351
take care to ensure that the referenced array is not destroyed or reshaped for
346352
the duration of the returned object, typically by limiting the scope of the
347353
returned instance.
348354

355+
The returned proxy object supports some of the same methods as ``py::array`` so
356+
that it can be used as a drop-in replacement for some existing, index-checked
357+
uses of ``py::array``:
358+
359+
- ``r.ndim()`` returns the number of dimensions
360+
361+
- ``r.data(1, 2, ...)`` and ``r.mutable_data(1, 2, ...)``` returns a pointer to
362+
the ``const T`` or ``T`` data, respectively, at the given indices. The
363+
latter is only available to proxies obtained via ``a.mutable_unchecked()``.
364+
365+
- ``itemsize()`` returns the size of an item in bytes, i.e. ``sizeof(T)``.
366+
367+
- ``ndim()`` returns the number of dimensions.
368+
369+
- ``shape(n)`` returns the size of dimension ``n``
370+
371+
- ``size()`` returns the total number of elements (i.e. the product of the shapes).
372+
373+
- ``nbytes()`` returns the number of bytes used by the referenced elements
374+
(i.e. ``itemsize()`` times ``size()``).
375+
349376
.. seealso::
350377

351378
The file :file:`tests/test_numpy_array.cpp` contains additional examples

include/pybind11/numpy.h

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -242,67 +242,107 @@ size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
242242
}
243243

244244
/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through
245-
* the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.
245+
* the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
246+
* will be -1 for dimensions determined at runtime.
246247
*/
247-
template <typename T, size_t Dims>
248+
template <typename T, ssize_t Dims>
248249
class unchecked_reference {
249250
protected:
251+
static constexpr bool Dynamic = Dims < 0;
250252
const unsigned char *data_;
251253
// Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
252-
// make large performance gains on big, nested loops.
253-
std::array<size_t, Dims> shape_, strides_;
254+
// make large performance gains on big, nested loops, but requires compile-time dimensions
255+
conditional_t<Dynamic, const size_t *, std::array<size_t, (size_t) Dims>>
256+
shape_, strides_;
257+
const size_t dims_;
254258

255259
friend class pybind11::array;
256-
unchecked_reference(const void *data, const size_t *shape, const size_t *strides)
257-
: data_{reinterpret_cast<const unsigned char *>(data)} {
258-
for (size_t i = 0; i < Dims; i++) {
260+
// Constructor for compile-time dimensions:
261+
template <bool Dyn = Dynamic>
262+
unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<!Dyn, size_t>)
263+
: data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
264+
for (size_t i = 0; i < dims_; i++) {
259265
shape_[i] = shape[i];
260266
strides_[i] = strides[i];
261267
}
262268
}
269+
// Constructor for runtime dimensions:
270+
template <bool Dyn = Dynamic>
271+
unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<Dyn, size_t> dims)
272+
: data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
263273

264274
public:
265-
/** Unchecked const reference access to data at the given indices. Omiting trailing indices
266-
* is equivalent to specifying them as 0.
275+
/** Unchecked const reference access to data at the given indices. For a compile-time known
276+
* number of dimensions, this requires the correct number of arguments; for run-time
277+
* dimensionality, this is not checked (and so is up to the caller to use safely).
267278
*/
268-
template <typename... Ix> const T& operator()(Ix... index) const {
269-
static_assert(sizeof...(Ix) <= Dims, "Invalid number of indices for unchecked array reference");
270-
return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t{index}...));
279+
template <typename... Ix> const T &operator()(Ix... index) const {
280+
static_assert(sizeof...(Ix) == Dims || Dynamic,
281+
"Invalid number of indices for unchecked array reference");
282+
return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t(index)...));
271283
}
272284
/** Unchecked const reference access to data; this operator only participates if the reference
273285
* is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
274286
*/
275-
template <size_t D = Dims, typename = enable_if_t<D == 1>>
287+
template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
276288
const T &operator[](size_t index) const { return operator()(index); }
277289

290+
/// Pointer access to the data at the given indices.
291+
template <typename... Ix> const T *data(Ix... ix) const { return &operator()(size_t(ix)...); }
292+
293+
/// Returns the item size, i.e. sizeof(T)
294+
constexpr static size_t itemsize() { return sizeof(T); }
295+
278296
/// Returns the shape (i.e. size) of dimension `dim`
279297
size_t shape(size_t dim) const { return shape_[dim]; }
280298

281299
/// Returns the number of dimensions of the array
282-
constexpr static size_t ndim() { return Dims; }
300+
size_t ndim() const { return dims_; }
301+
302+
/// Returns the total number of elements in the referenced array, i.e. the product of the shapes
303+
template <bool Dyn = Dynamic>
304+
enable_if_t<!Dyn, size_t> size() const {
305+
return std::accumulate(shape_.begin(), shape_.end(), (size_t) 1, std::multiplies<size_t>());
306+
}
307+
template <bool Dyn = Dynamic>
308+
enable_if_t<Dyn, size_t> size() const {
309+
return std::accumulate(shape_, shape_ + ndim(), (size_t) 1, std::multiplies<size_t>());
310+
}
311+
312+
/// Returns the total number of bytes used by the referenced data. Note that the actual span in
313+
/// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
314+
size_t nbytes() const {
315+
return size() * itemsize();
316+
}
283317
};
284318

285-
template <typename T, size_t Dims>
319+
template <typename T, ssize_t Dims>
286320
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
287321
friend class pybind11::array;
288322
using ConstBase = unchecked_reference<T, Dims>;
289323
using ConstBase::ConstBase;
324+
using ConstBase::Dynamic;
290325
public:
291326
/// Mutable, unchecked access to data at the given indices.
292327
template <typename... Ix> T& operator()(Ix... index) {
293-
static_assert(sizeof...(Ix) == Dims, "Invalid number of indices for unchecked array reference");
328+
static_assert(sizeof...(Ix) == Dims || Dynamic,
329+
"Invalid number of indices for unchecked array reference");
294330
return const_cast<T &>(ConstBase::operator()(index...));
295331
}
296332
/** Mutable, unchecked access data at the given index; this operator only participates if the
297-
* reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
333+
* reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
334+
* exactly equivalent to `obj(index)`.
298335
*/
299-
template <size_t D = Dims, typename = enable_if_t<D == 1>>
336+
template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
300337
T &operator[](size_t index) { return operator()(index); }
338+
339+
/// Mutable pointer access to the data at the given indices.
340+
template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(size_t(ix)...); }
301341
};
302342

303343
template <typename T, size_t Dim>
304344
struct type_caster<unchecked_reference<T, Dim>> {
305-
static_assert(Dim == (size_t) -1 /* always fail */, "unchecked array proxy object is not castable");
345+
static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
306346
};
307347
template <typename T, size_t Dim>
308348
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
@@ -580,11 +620,11 @@ class array : public buffer {
580620
* care: the array must not be destroyed or reshaped for the duration of the returned object,
581621
* and the caller must take care not to access invalid dimensions or dimension indices.
582622
*/
583-
template <typename T, size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
584-
if (ndim() != Dims)
623+
template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
624+
if (Dims >= 0 && ndim() != (size_t) Dims)
585625
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
586626
"; expected " + std::to_string(Dims));
587-
return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides());
627+
return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
588628
}
589629

590630
/** Returns a proxy object that provides const access to the array's data without bounds or
@@ -593,11 +633,11 @@ class array : public buffer {
593633
* reshaped for the duration of the returned object, and the caller must take care not to access
594634
* invalid dimensions or dimension indices.
595635
*/
596-
template <typename T, size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
597-
if (ndim() != Dims)
636+
template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
637+
if (Dims >= 0 && ndim() != (size_t) Dims)
598638
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
599639
"; expected " + std::to_string(Dims));
600-
return detail::unchecked_reference<T, Dims>(data(), shape(), strides());
640+
return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
601641
}
602642

603643
/// Return a new view with all of the dimensions of length 1 removed
@@ -625,7 +665,7 @@ class array : public buffer {
625665

626666
template<typename... Ix> size_t byte_offset(Ix... index) const {
627667
check_dimensions(index...);
628-
return detail::byte_offset_unsafe(strides(), size_t{index}...);
668+
return detail::byte_offset_unsafe(strides(), size_t(index)...);
629669
}
630670

631671
void check_writeable() const {
@@ -736,7 +776,7 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
736776
* care: the array must not be destroyed or reshaped for the duration of the returned object,
737777
* and the caller must take care not to access invalid dimensions or dimension indices.
738778
*/
739-
template <size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
779+
template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
740780
return array::mutable_unchecked<T, Dims>();
741781
}
742782

@@ -746,7 +786,7 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
746786
* for the duration of the returned object, and the caller must take care not to access invalid
747787
* dimensions or dimension indices.
748788
*/
749-
template <size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
789+
template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
750790
return array::unchecked<T, Dims>();
751791
}
752792

tests/test_numpy_array.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ template<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(
6868
sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \
6969
sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
7070

71+
template <typename T, typename T2> py::handle auxiliaries(T &&r, T2 &&r2) {
72+
if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
73+
py::list l;
74+
l.append(*r.data(0, 0));
75+
l.append(*r2.mutable_data(0, 0));
76+
l.append(r.data(0, 1) == r2.mutable_data(0, 1));
77+
l.append(r.ndim());
78+
l.append(r.itemsize());
79+
l.append(r.shape(0));
80+
l.append(r.shape(1));
81+
l.append(r.size());
82+
l.append(r.nbytes());
83+
return l.release();
84+
}
85+
7186
test_initializer numpy_array([](py::module &m) {
7287
auto sm = m.def_submodule("array");
7388

@@ -191,6 +206,7 @@ test_initializer numpy_array([](py::module &m) {
191206
for (size_t j = 0; j < r.shape(1); j++)
192207
r(i, j) += v;
193208
}, py::arg().noconvert(), py::arg());
209+
194210
sm.def("proxy_init3", [](double start) {
195211
py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
196212
auto r = a.mutable_unchecked<3>();
@@ -216,4 +232,36 @@ test_initializer numpy_array([](py::module &m) {
216232
sumsq += r[i] * r(i); // Either notation works for a 1D array
217233
return sumsq;
218234
});
235+
236+
sm.def("proxy_auxiliaries2", [](py::array_t<double> a) {
237+
auto r = a.unchecked<2>();
238+
auto r2 = a.mutable_unchecked<2>();
239+
return auxiliaries(r, r2);
240+
});
241+
242+
// Same as the above, but without a compile-time dimensions specification:
243+
sm.def("proxy_add2_dyn", [](py::array_t<double> a, double v) {
244+
auto r = a.mutable_unchecked();
245+
if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
246+
for (size_t i = 0; i < r.shape(0); i++)
247+
for (size_t j = 0; j < r.shape(1); j++)
248+
r(i, j) += v;
249+
}, py::arg().noconvert(), py::arg());
250+
sm.def("proxy_init3_dyn", [](double start) {
251+
py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
252+
auto r = a.mutable_unchecked();
253+
if (r.ndim() != 3) throw std::domain_error("error: ndim != 3");
254+
for (size_t i = 0; i < r.shape(0); i++)
255+
for (size_t j = 0; j < r.shape(1); j++)
256+
for (size_t k = 0; k < r.shape(2); k++)
257+
r(i, j, k) = start++;
258+
return a;
259+
});
260+
sm.def("proxy_auxiliaries2_dyn", [](py::array_t<double> a) {
261+
return auxiliaries(a.unchecked(), a.mutable_unchecked());
262+
});
263+
264+
sm.def("array_auxiliaries2", [](py::array_t<double> a) {
265+
return auxiliaries(a, a);
266+
});
219267
});

tests/test_numpy_array.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,9 @@ def test_greedy_string_overload(): # issue 685
341341
assert issue685(123) == "other"
342342

343343

344-
def test_array_unchecked(msg):
345-
from pybind11_tests.array import proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm
344+
def test_array_unchecked_fixed_dims(msg):
345+
from pybind11_tests.array import (proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm,
346+
proxy_auxiliaries2, array_auxiliaries2)
346347

347348
z1 = np.array([[1, 2], [3, 4]], dtype='float64')
348349
proxy_add2(z1, 10)
@@ -359,3 +360,20 @@ def test_array_unchecked(msg):
359360

360361
assert proxy_squared_L2_norm(np.array(range(6))) == 55
361362
assert proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55
363+
364+
assert proxy_auxiliaries2(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
365+
assert proxy_auxiliaries2(z1) == array_auxiliaries2(z1)
366+
367+
368+
def test_array_unchecked_dyn_dims(msg):
369+
from pybind11_tests.array import (proxy_add2_dyn, proxy_init3_dyn, proxy_auxiliaries2_dyn,
370+
array_auxiliaries2)
371+
z1 = np.array([[1, 2], [3, 4]], dtype='float64')
372+
proxy_add2_dyn(z1, 10)
373+
assert np.all(z1 == [[11, 12], [13, 14]])
374+
375+
expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int')
376+
assert np.all(proxy_init3_dyn(3.0) == expect_c)
377+
378+
assert proxy_auxiliaries2_dyn(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
379+
assert proxy_auxiliaries2_dyn(z1) == array_auxiliaries2(z1)

0 commit comments

Comments
 (0)