@@ -242,67 +242,107 @@ size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
242242}
243243
244244/* * Proxy class providing unsafe, unchecked const access to array data. This is constructed through
245- * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.
245+ * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
246+ * will be -1 for dimensions determined at runtime.
246247 */
247- template <typename T, size_t Dims>
248+ template <typename T, ssize_t Dims>
248249class unchecked_reference {
249250protected:
251+ static constexpr bool Dynamic = Dims < 0 ;
250252 const unsigned char *data_;
251253 // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
252- // make large performance gains on big, nested loops.
253- std::array<size_t , Dims> shape_, strides_;
254+ // make large performance gains on big, nested loops, but requires compile-time dimensions
255+ conditional_t <Dynamic, const size_t *, std::array<size_t , (size_t ) Dims>>
256+ shape_, strides_;
257+ const size_t dims_;
254258
255259 friend class pybind11 ::array;
256- unchecked_reference (const void *data, const size_t *shape, const size_t *strides)
257- : data_{reinterpret_cast <const unsigned char *>(data)} {
258- for (size_t i = 0 ; i < Dims; i++) {
260+ // Constructor for compile-time dimensions:
261+ template <bool Dyn = Dynamic>
262+ unchecked_reference (const void *data, const size_t *shape, const size_t *strides, enable_if_t <!Dyn, size_t >)
263+ : data_{reinterpret_cast <const unsigned char *>(data)}, dims_{Dims} {
264+ for (size_t i = 0 ; i < dims_; i++) {
259265 shape_[i] = shape[i];
260266 strides_[i] = strides[i];
261267 }
262268 }
269+ // Constructor for runtime dimensions:
270+ template <bool Dyn = Dynamic>
271+ unchecked_reference (const void *data, const size_t *shape, const size_t *strides, enable_if_t <Dyn, size_t > dims)
272+ : data_{reinterpret_cast <const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
263273
264274public:
265- /* * Unchecked const reference access to data at the given indices. Omiting trailing indices
266- * is equivalent to specifying them as 0.
275+ /* * Unchecked const reference access to data at the given indices. For a compile-time known
276+ * number of dimensions, this requires the correct number of arguments; for run-time
277+ * dimensionality, this is not checked (and so is up to the caller to use safely).
267278 */
268- template <typename ... Ix> const T& operator ()(Ix... index) const {
269- static_assert (sizeof ...(Ix) <= Dims, " Invalid number of indices for unchecked array reference" );
270- return *reinterpret_cast <const T *>(data_ + byte_offset_unsafe (strides_, size_t {index}...));
279+ template <typename ... Ix> const T &operator ()(Ix... index) const {
280+ static_assert (sizeof ...(Ix) == Dims || Dynamic,
281+ " Invalid number of indices for unchecked array reference" );
282+ return *reinterpret_cast <const T *>(data_ + byte_offset_unsafe (strides_, size_t (index)...));
271283 }
272284 /* * Unchecked const reference access to data; this operator only participates if the reference
273285 * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
274286 */
275- template <size_t D = Dims, typename = enable_if_t <D == 1 >>
287+ template <size_t D = Dims, typename = enable_if_t <D == 1 || Dynamic >>
276288 const T &operator [](size_t index) const { return operator ()(index); }
277289
290+ // / Pointer access to the data at the given indices.
291+ template <typename ... Ix> const T *data (Ix... ix) const { return &operator ()(size_t (ix)...); }
292+
293+ // / Returns the item size, i.e. sizeof(T)
294+ constexpr static size_t itemsize () { return sizeof (T); }
295+
278296 // / Returns the shape (i.e. size) of dimension `dim`
279297 size_t shape (size_t dim) const { return shape_[dim]; }
280298
281299 // / Returns the number of dimensions of the array
282- constexpr static size_t ndim () { return Dims; }
300+ size_t ndim () const { return dims_; }
301+
302+ // / Returns the total number of elements in the referenced array, i.e. the product of the shapes
303+ template <bool Dyn = Dynamic>
304+ enable_if_t <!Dyn, size_t > size () const {
305+ return std::accumulate (shape_.begin (), shape_.end (), (size_t ) 1 , std::multiplies<size_t >());
306+ }
307+ template <bool Dyn = Dynamic>
308+ enable_if_t <Dyn, size_t > size () const {
309+ return std::accumulate (shape_, shape_ + ndim (), (size_t ) 1 , std::multiplies<size_t >());
310+ }
311+
312+ // / Returns the total number of bytes used by the referenced data. Note that the actual span in
313+ // / memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
314+ size_t nbytes () const {
315+ return size () * itemsize ();
316+ }
283317};
284318
285- template <typename T, size_t Dims>
319+ template <typename T, ssize_t Dims>
286320class unchecked_mutable_reference : public unchecked_reference <T, Dims> {
287321 friend class pybind11 ::array;
288322 using ConstBase = unchecked_reference<T, Dims>;
289323 using ConstBase::ConstBase;
324+ using ConstBase::Dynamic;
290325public:
291326 // / Mutable, unchecked access to data at the given indices.
292327 template <typename ... Ix> T& operator ()(Ix... index) {
293- static_assert (sizeof ...(Ix) == Dims, " Invalid number of indices for unchecked array reference" );
328+ static_assert (sizeof ...(Ix) == Dims || Dynamic,
329+ " Invalid number of indices for unchecked array reference" );
294330 return const_cast <T &>(ConstBase::operator ()(index...));
295331 }
296332 /* * Mutable, unchecked access data at the given index; this operator only participates if the
297- * reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
333+ * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
334+ * exactly equivalent to `obj(index)`.
298335 */
299- template <size_t D = Dims, typename = enable_if_t <D == 1 >>
336+ template <size_t D = Dims, typename = enable_if_t <D == 1 || Dynamic >>
300337 T &operator [](size_t index) { return operator ()(index); }
338+
339+ // / Mutable pointer access to the data at the given indices.
340+ template <typename ... Ix> T *mutable_data (Ix... ix) { return &operator ()(size_t (ix)...); }
301341};
302342
303343template <typename T, size_t Dim>
304344struct type_caster <unchecked_reference<T, Dim>> {
305- static_assert (Dim == ( size_t ) - 1 /* always fail */ , " unchecked array proxy object is not castable" );
345+ static_assert (Dim == 0 && Dim > 0 /* always fail */ , " unchecked array proxy object is not castable" );
306346};
307347template <typename T, size_t Dim>
308348struct type_caster <unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
@@ -580,11 +620,11 @@ class array : public buffer {
580620 * care: the array must not be destroyed or reshaped for the duration of the returned object,
581621 * and the caller must take care not to access invalid dimensions or dimension indices.
582622 */
583- template <typename T, size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked () {
584- if (ndim () != Dims)
623+ template <typename T, ssize_t Dims = - 1 > detail::unchecked_mutable_reference<T, Dims> mutable_unchecked () {
624+ if (Dims >= 0 && ndim () != ( size_t ) Dims)
585625 throw std::domain_error (" array has incorrect number of dimensions: " + std::to_string (ndim ()) +
586626 " ; expected " + std::to_string (Dims));
587- return detail::unchecked_mutable_reference<T, Dims>(mutable_data (), shape (), strides ());
627+ return detail::unchecked_mutable_reference<T, Dims>(mutable_data (), shape (), strides (), ndim () );
588628 }
589629
590630 /* * Returns a proxy object that provides const access to the array's data without bounds or
@@ -593,11 +633,11 @@ class array : public buffer {
593633 * reshaped for the duration of the returned object, and the caller must take care not to access
594634 * invalid dimensions or dimension indices.
595635 */
596- template <typename T, size_t Dims> detail::unchecked_reference<T, Dims> unchecked () const {
597- if (ndim () != Dims)
636+ template <typename T, ssize_t Dims = - 1 > detail::unchecked_reference<T, Dims> unchecked () const {
637+ if (Dims >= 0 && ndim () != ( size_t ) Dims)
598638 throw std::domain_error (" array has incorrect number of dimensions: " + std::to_string (ndim ()) +
599639 " ; expected " + std::to_string (Dims));
600- return detail::unchecked_reference<T, Dims>(data (), shape (), strides ());
640+ return detail::unchecked_reference<T, Dims>(data (), shape (), strides (), ndim () );
601641 }
602642
603643 // / Return a new view with all of the dimensions of length 1 removed
@@ -625,7 +665,7 @@ class array : public buffer {
625665
626666 template <typename ... Ix> size_t byte_offset (Ix... index) const {
627667 check_dimensions (index...);
628- return detail::byte_offset_unsafe (strides (), size_t { index} ...);
668+ return detail::byte_offset_unsafe (strides (), size_t ( index) ...);
629669 }
630670
631671 void check_writeable () const {
@@ -736,7 +776,7 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
736776 * care: the array must not be destroyed or reshaped for the duration of the returned object,
737777 * and the caller must take care not to access invalid dimensions or dimension indices.
738778 */
739- template <size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked () {
779+ template <ssize_t Dims = - 1 > detail::unchecked_mutable_reference<T, Dims> mutable_unchecked () {
740780 return array::mutable_unchecked<T, Dims>();
741781 }
742782
@@ -746,7 +786,7 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
746786 * for the duration of the returned object, and the caller must take care not to access invalid
747787 * dimensions or dimension indices.
748788 */
749- template <size_t Dims> detail::unchecked_reference<T, Dims> unchecked () const {
789+ template <ssize_t Dims = - 1 > detail::unchecked_reference<T, Dims> unchecked () const {
750790 return array::unchecked<T, Dims>();
751791 }
752792
0 commit comments