2
\$\begingroup\$

I've improved my N-dimensional C++20 matrix project (C++20 : N-dimensional minimal Matrix class).

Implemented general matrix addition/subtraction, elementwise multiplication/division, dot product, matrix product, reshape, transpose.

There's a lot of code:

ObjectBase.h

#ifndef FROZENCA_OBJECTBASE_H #define FROZENCA_OBJECTBASE_H #include <functional> #include <utility> #include "MatrixUtils.h" namespace frozenca { template <typename Derived> class ObjectBase { private: Derived& self() { return static_cast<Derived&>(*this); } const Derived& self() const { return static_cast<const Derived&>(*this); } protected: ObjectBase() = default; ~ObjectBase() noexcept = default; public: auto begin() { return self().begin(); } auto begin() const { return self().begin(); } auto cbegin() const { return self().cbegin(); } auto end() { return self().end(); } auto end() const { return self().end(); } auto cend() const { return self().cend(); } auto rbegin() { return self().rbegin(); } auto rbegin() const { return self().rbegin(); } auto crbegin() const { return self().crbegin(); } auto rend() { return self().rend(); } auto rend() const { return self().rend(); } auto crend() const { return self().crend(); } template <typename F> requires std::invocable<F, typename Derived::reference> ObjectBase& applyFunction(F&& f); template <typename F, typename... Args> requires std::invocable<F, typename Derived::reference, Args...> ObjectBase& applyFunction(F&& f, Args&&... args); template <typename DerivedOther, typename F> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference> ObjectBase& applyFunction(const ObjectBase<DerivedOther>& other, F&& f); template <typename DerivedOther, typename F, typename... Args> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference, Args...> ObjectBase& applyFunction(const ObjectBase<DerivedOther>& other, F&& f, Args&&... args); template <isNotMatrix U> requires Addable<typename Derived::value_type, U> ObjectBase& operator=(const U& val) { return applyFunction([&val](auto& v) {v = val;}); } template <isNotMatrix U> requires Addable<typename Derived::value_type, U> ObjectBase& operator+=(const U& val) { return applyFunction([&val](auto& v) {v += val;}); } template <isNotMatrix U> requires Subtractable<typename Derived::value_type, U> ObjectBase& operator-=(const U& val) { return applyFunction([&val](auto& v) {v -= val;}); } template <isNotMatrix U> requires Multipliable<typename Derived::value_type, U> ObjectBase& operator*=(const U& val) { return applyFunction([&val](auto& v) {v *= val;}); } template <isNotMatrix U> requires Dividable<typename Derived::value_type, U> ObjectBase& operator/=(const U& val) { return applyFunction([&val](auto& v) {v /= val;}); } template <isNotMatrix U> requires Remaindable<typename Derived::value_type, U> ObjectBase& operator%=(const U& val) { return applyFunction([&val](auto& v) {v %= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator&=(const U& val) { return applyFunction([&val](auto& v) {v &= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator|=(const U& val) { return applyFunction([&val](auto& v) {v |= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator^=(const U& val) { return applyFunction([&val](auto& v) {v ^= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator<<=(const U& val) { return applyFunction([&val](auto& v) {v <<= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator>>=(const U& val) { return applyFunction([&val](auto& v) {v >>= val;}); } }; template <typename Derived> template <typename F> requires std::invocable<F, typename Derived::reference> ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(F&& f) { for (auto it = begin(); it != end(); ++it) { f(*it); } return *this; } template <typename Derived> template <typename F, typename... Args> requires std::invocable<F, typename Derived::reference, Args...> ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(F&& f, Args&&... args) { for (auto it = begin(); it != end(); ++it) { f(*it, std::forward<Args...>(args...)); } return *this; } template <typename Derived> template <typename DerivedOther, typename F> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference> ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(const ObjectBase<DerivedOther>& other, F&& f) { for (auto it = begin(), it2 = other.begin(); it != end(); ++it, ++it2) { f(*it, *it2); } return *this; } template <typename Derived> template <typename DerivedOther, typename F, typename... Args> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference, Args...> ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(const ObjectBase<DerivedOther>& other, F&& f, Args&&... args) { for (auto it = begin(), it2 = other.begin(); it != end(); ++it, ++it2) { f(*it, *it2, std::forward<Args...>(args...)); } return *this; } template <typename Derived, isNotMatrix U> requires Addable<typename Derived::value_type, U> ObjectBase<Derived> operator+(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res += val; return res; } template <typename Derived, isNotMatrix U> requires Subtractable<typename Derived::value_type, U> ObjectBase<Derived> operator-(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res -= val; return res; } template <typename Derived, isNotMatrix U> requires Multipliable<typename Derived::value_type, U> ObjectBase<Derived> operator*(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res *= val; return res; } template <typename Derived, isNotMatrix U> requires Dividable<typename Derived::value_type, U> ObjectBase<Derived> operator/(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res /= val; return res; } template <typename Derived, isNotMatrix U> requires Remaindable<typename Derived::value_type, U> ObjectBase<Derived> operator%(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res %= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator&(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res &= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator^(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res ^= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator|(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res |= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator<<(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res <<= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator>>(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res >>= val; return res; } } // namespace frozenca #endif //FROZENCA_OBJECTBASE_H 

MatrixBase.h

#ifndef FROZENCA_MATRIXBASE_H #define FROZENCA_MATRIXBASE_H #include <numeric> #include "ObjectBase.h" #include "MatrixInitializer.h" namespace frozenca { template <std::semiregular T, std::size_t N> class MatrixView; template <typename Derived, std::semiregular T, std::size_t N> class MatrixBase : public ObjectBase<MatrixBase<Derived, T, N>> { static_assert(N > 1); public: static constexpr std::size_t ndim = N; private: std::array<std::size_t, N> dims_; std::size_t size_; std::array<std::size_t, N> strides_; public: MatrixBase() = delete; using Base = ObjectBase<MatrixBase<Derived, T, N>>; using Base::applyFunction; using Base::operator=; using Base::operator+=; using Base::operator-=; using Base::operator*=; using Base::operator/=; using Base::operator%=; Derived& self() { return static_cast<Derived&>(*this); } const Derived& self() const { return static_cast<const Derived&>(*this); } protected: ~MatrixBase() noexcept = default; MatrixBase(const std::array<std::size_t, N>& dims); template <std::size_t M> requires (M < N) MatrixBase(const std::array<std::size_t, M>& dims); template <IndexType... Dims> explicit MatrixBase(Dims... dims); template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T> MatrixBase(const MatrixBase<DerivedOther, U, N>&); MatrixBase(typename MatrixInitializer<T, N>::type init); public: template <typename U> MatrixBase(std::initializer_list<U>) = delete; template <typename U> MatrixBase& operator=(std::initializer_list<U>) = delete; using value_type = T; using reference = T&; using const_reference = const T&; using pointer = T*; public: friend void swap(MatrixBase& a, MatrixBase& b) noexcept { std::swap(a.size_, b.size_); std::swap(a.dims_, b.dims_); std::swap(a.strides_, b.strides_); } auto begin() { return self().begin(); } auto begin() const { return self().begin(); } auto cbegin() const { return self().cbegin(); } auto end() { return self().end(); } auto end() const { return self().end(); } auto cend() const { return self().cend(); } auto rbegin() { return self().rbegin(); } auto rbegin() const { return self().rbegin(); } auto crbegin() const { return self().crbegin(); } auto rend() { return self().rend(); } auto rend() const { return self().rend(); } auto crend() const { return self().crend(); } template <IndexType... Args> reference operator()(Args... args); template <IndexType... Args> const_reference operator()(Args... args) const; reference operator[](const std::array<std::size_t, N>& pos); const_reference operator[](const std::array<std::size_t, N>& pos) const; [[nodiscard]] std::size_t size() const { return size_;} [[nodiscard]] const std::array<std::size_t, N>& dims() const { return dims_; } [[nodiscard]] std::size_t dims(std::size_t n) const { if (n >= N) { throw std::out_of_range("Out of range in dims"); } return dims_[n]; } [[nodiscard]] const std::array<std::size_t, N>& strides() const { return strides_; } [[nodiscard]] std::size_t strides(std::size_t n) const { if (n >= N) { throw std::out_of_range("Out of range in strides"); } return strides_[n]; } auto dataView() const { return self().dataView(); } auto origStrides() const { return self().origStrides(); } MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin); MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end); MatrixView<T, N - 1> row(std::size_t n); MatrixView<T, N - 1> col(std::size_t n); MatrixView<T, N - 1> operator[](std::size_t n) { return row(n); } MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin) const; MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end) const; MatrixView<T, N - 1> row(std::size_t n) const; MatrixView<T, N - 1> col(std::size_t n) const; MatrixView<T, N - 1> operator[](std::size_t n) const { return row(n); } friend std::ostream& operator<<(std::ostream& os, const MatrixBase& m) { os << '{'; for (std::size_t i = 0; i != m.dims(0); ++i) { os << m[i]; if (i + 1 != m.dims(0)) { os << ", "; } } return os << '}'; } template <typename DerivedOther1, typename DerivedOther2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::invocable<MatrixView<T, N - 1>&, const MatrixView<U, std::min(N1, N - 1)>&, const MatrixView<V, std::min(N2, N - 1)>&> F> requires (std::max(N1, N2) == N) MatrixBase& applyFunctionWithBroadcast(const MatrixBase<DerivedOther1, U, N1>& m1, const MatrixBase<DerivedOther2, V, N2>& m2, F&& f); }; template <typename Derived, std::semiregular T, std::size_t N> MatrixBase<Derived, T, N>::MatrixBase(const std::array<std::size_t, N>& dims) : dims_ {dims} { if (std::ranges::find(dims_, 0lu) != std::end(dims_)) { throw std::invalid_argument("Zero dimension not allowed"); } size_ = std::accumulate(std::begin(dims_), std::end(dims_), 1lu, std::multiplies<>{}); strides_ = computeStrides(dims_); } template <typename Derived, std::semiregular T, std::size_t N> template <std::size_t M> requires (M < N) MatrixBase<Derived, T, N>::MatrixBase(const std::array<std::size_t, M>& dims) : MatrixBase (prepend<N, M>(dims)) {} template <typename Derived, std::semiregular T, std::size_t N> template <IndexType... Dims> MatrixBase<Derived, T, N>::MatrixBase(Dims... dims) : dims_{static_cast<std::size_t>(dims)...} { static_assert(sizeof...(Dims) == N); static_assert((std::is_integral_v<Dims> && ...)); if (std::ranges::find(dims_, 0lu) != std::end(dims_)) { throw std::invalid_argument("Zero dimension not allowed"); } size_ = std::accumulate(std::begin(dims_), std::end(dims_), 1lu, std::multiplies<>{}); strides_ = computeStrides(dims_); } template <typename Derived, std::semiregular T, std::size_t N> template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T> MatrixBase<Derived, T, N>::MatrixBase(const MatrixBase<DerivedOther, U, N>& other) : MatrixBase(other.dims()) {} template <typename Derived, std::semiregular T, std::size_t N> MatrixBase<Derived, T, N>::MatrixBase(typename MatrixInitializer<T, N>::type init) : MatrixBase(deriveDims<N>(init)) {} template <typename Derived, std::semiregular T, std::size_t N> template <IndexType... Args> typename MatrixBase<Derived, T, N>::reference MatrixBase<Derived, T, N>::operator()(Args... args) { return const_cast<typename MatrixBase<Derived, T, N>::reference>(std::as_const(*this).operator()(args...)); } template <typename Derived, std::semiregular T, std::size_t N> template <IndexType... Args> typename MatrixBase<Derived, T, N>::const_reference MatrixBase<Derived, T, N>::operator()(Args... args) const { static_assert(sizeof...(args) == N); std::array<std::size_t, N> pos {std::size_t(args)...}; return operator[](pos); } template <typename Derived, std::semiregular T, std::size_t N> typename MatrixBase<Derived, T, N>::reference MatrixBase<Derived, T, N>::operator[](const std::array<std::size_t, N>& pos) { return const_cast<typename MatrixBase<Derived, T, N>::reference>(std::as_const(*this).operator[](pos)); } template <typename Derived, std::semiregular T, std::size_t N> typename MatrixBase<Derived, T, N>::const_reference MatrixBase<Derived, T, N>::operator[](const std::array<std::size_t, N>& pos) const { if (!std::equal(std::cbegin(pos), std::cend(pos), std::cbegin(dims_), std::less<>{})) { throw std::out_of_range("Out of range in element access"); } return *(cbegin() + std::inner_product(std::cbegin(pos), std::cend(pos), std::cbegin(strides_), 0lu)); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin) { return submatrix(pos_begin, dims_); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin) const { return submatrix(pos_begin, dims_); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end) { return std::as_const(*this).submatrix(pos_begin, pos_end); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end) const { if (!std::equal(std::cbegin(pos_begin), std::cend(pos_begin), std::cbegin(pos_end), std::less<>{})) { throw std::out_of_range("submatrix begin/end position error"); } std::array<std::size_t, N> view_dims; std::transform(std::cbegin(pos_end), std::cend(pos_end), std::cbegin(pos_begin), std::begin(view_dims), std::minus<>{}); MatrixView<T, N> view(view_dims, const_cast<T*>(&operator[](pos_begin)), strides()); return view; } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N - 1> MatrixBase<Derived, T, N>::row(std::size_t n) { return std::as_const(*this).row(n); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N - 1> MatrixBase<Derived, T, N>::row(std::size_t n) const { const auto& orig_dims = dims(); if (n >= orig_dims[0]) { throw std::out_of_range("row index error"); } std::array<std::size_t, N - 1> row_dims; std::copy(std::cbegin(orig_dims) + 1, std::cend(orig_dims), std::begin(row_dims)); std::array<std::size_t, N> pos_begin = {n, }; std::array<std::size_t, N - 1> row_strides; std::array<std::size_t, N> orig_strides; if constexpr (std::is_same_v<Derived, MatrixView<T, N>>) { orig_strides = origStrides(); } else { orig_strides = strides(); } std::copy(std::cbegin(orig_strides) + 1, std::cend(orig_strides), std::begin(row_strides)); MatrixView<T, N - 1> nth_row(row_dims, const_cast<T*>(&operator[](pos_begin)), row_strides); return nth_row; } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N - 1> MatrixBase<Derived, T, N>::col(std::size_t n) { return std::as_const(*this).col(n); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N - 1> MatrixBase<Derived, T, N>::col(std::size_t n) const { const auto& orig_dims = dims(); if (n >= orig_dims[N - 1]) { throw std::out_of_range("row index error"); } std::array<std::size_t, N - 1> col_dims; std::copy(std::cbegin(orig_dims), std::cend(orig_dims) - 1, std::begin(col_dims)); std::array<std::size_t, N> pos_begin = {0}; pos_begin[N - 1] = n; std::array<std::size_t, N - 1> col_strides; std::array<std::size_t, N> orig_strides; if constexpr (std::is_same_v<Derived, MatrixView<T, N>>) { orig_strides = origStrides(); } else { orig_strides = strides(); } std::copy(std::cbegin(orig_strides), std::cend(orig_strides) - 1, std::begin(col_strides)); MatrixView<T, N - 1> nth_col(col_dims, const_cast<T*>(&operator[](pos_begin)), col_strides); return nth_col; } template <typename Derived, std::semiregular T, std::size_t N> template <typename DerivedOther1, typename DerivedOther2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::invocable<MatrixView<T, N - 1>&, const MatrixView<U, std::min(N1, N - 1)>&, const MatrixView<V, std::min(N2, N - 1)>&> F> requires (std::max(N1, N2) == N) MatrixBase<Derived, T, N>& MatrixBase<Derived, T, N>::applyFunctionWithBroadcast(const MatrixBase<DerivedOther1, U, N1>& m1, const MatrixBase<DerivedOther2, V, N2>& m2, F&& f) { if constexpr (N1 == N) { if constexpr (N2 == N) { auto r = dims(0); auto r1 = m1.dims(0); auto r2 = m2.dims(0); if (r1 == r) { if (r2 == r) { for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, m1.row(i), m2.row(i)); } } else { // r2 < r == r1 auto row2 = m2.row(0); for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, m1.row(i), row2); } } } else if (r2 == r) { // r1 < r == r2 auto row1 = m1.row(0); for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, row1, m2.row(i)); } } else { assert(0); // cannot happen } } else { // N2 < N == N1 auto r = dims(0); assert(r == m1.dims(0)); MatrixView<V, N2> view2 (m2); for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, m1.row(i), view2); } } } else if constexpr (N2 == N) { // N1 < N == N2 auto r = dims(0); assert(r == m2.dims(0)); MatrixView<U, N1> view1 (m1); for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, view1, m2.row(i)); } } else { assert(0); // cannot happen } return *this; } template <typename Derived, std::semiregular T> class MatrixBase<Derived, T, 1> : public ObjectBase<MatrixBase<Derived, T, 1>> { public: static constexpr std::size_t ndim = 1; private: std::size_t dims_; std::size_t strides_; Derived& self() { return static_cast<Derived&>(*this); } const Derived& self() const { return static_cast<const Derived&>(*this); } public: MatrixBase() = delete; using Base = ObjectBase<MatrixBase<Derived, T, 1>>; using Base::applyFunction; using Base::operator=; using Base::operator+=; using Base::operator-=; using Base::operator*=; using Base::operator/=; using Base::operator%=; protected: ~MatrixBase() noexcept = default; template <typename Dim> requires std::is_integral_v<Dim> explicit MatrixBase(Dim dim) : dims_(dim), strides_(1) {}; template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T> MatrixBase(const MatrixBase<DerivedOther, U, 1>&); MatrixBase(typename MatrixInitializer<T, 1>::type init); public: using value_type = T; using reference = T&; using const_reference = const T&; using pointer = T*; public: friend void swap(MatrixBase& a, MatrixBase& b) noexcept { std::swap(a.size_, b.size_); std::swap(a.dims_, b.dims_); std::swap(a.strides_, b.strides_); } auto begin() { return self().begin(); } auto begin() const { return self().begin(); } auto cbegin() const { return self().cbegin(); } auto end() { return self().end(); } auto end() const { return self().end(); } auto cend() const { return self().cend(); } auto rbegin() { return self().rbegin(); } auto rbegin() const { return self().rbegin(); } auto crbegin() const { return self().crbegin(); } auto rend() { return self().rend(); } auto rend() const { return self().rend(); } auto crend() const { return self().crend(); } template <typename Dim> requires std::is_integral_v<Dim> reference operator()(Dim dim) { return operator[](dim); } template <typename Dim> requires std::is_integral_v<Dim> const_reference operator()(Dim dim) const { return operator[](dim); } [[nodiscard]] std::array<std::size_t, 1> dims() const { return {dims_}; } [[nodiscard]] std::size_t dims(std::size_t n) const { if (n >= 1) { throw std::out_of_range("Out of range in dims"); } return dims_; } [[nodiscard]] std::size_t strides() const { return strides_; } auto dataView() const { return self().dataView(); } auto origStrides() const { return self().origStrides(); } MatrixView<T, 1> submatrix(std::size_t pos_begin); MatrixView<T, 1> submatrix(std::size_t pos_begin, std::size_t pos_end); T& row(std::size_t n); T& col(std::size_t n); T& operator[](std::size_t n) { return *(begin() + n); } MatrixView<T, 1> submatrix(std::size_t pos_begin) const; MatrixView<T, 1> submatrix(std::size_t pos_begin, std::size_t pos_end) const; const T& row(std::size_t n) const; const T& col(std::size_t n) const; const T& operator[](std::size_t n) const { return *(cbegin() + n); } friend std::ostream& operator<<(std::ostream& os, const MatrixBase& m) { os << '{'; for (std::size_t i = 0; i != m.dims_; ++i) { os << m[i]; if (i + 1 != m.dims_) { os << ", "; } } return os << '}'; } template <typename DerivedOther1, typename DerivedOther2, std::semiregular U, std::semiregular V, std::invocable<T&, const U&, const V&> F> MatrixBase& applyFunctionWithBroadcast(const frozenca::MatrixBase<DerivedOther1, U, 1>& m1, const frozenca::MatrixBase<DerivedOther2, V, 1>& m2, F&& f); }; template <typename Derived, std::semiregular T> MatrixBase<Derived, T, 1>::MatrixBase(typename MatrixInitializer<T, 1>::type init) : MatrixBase(deriveDims<1>(init)[0]) { } template <typename Derived, std::semiregular T> MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin) { return submatrix(pos_begin, dims_); } template <typename Derived, std::semiregular T> MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin) const { return submatrix(pos_begin, dims_); } template <typename Derived, std::semiregular T> MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin, std::size_t pos_end) { return std::as_const(*this).submatrix(pos_begin, pos_end); } template <typename Derived, std::semiregular T> MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin, std::size_t pos_end) const { if (pos_begin >= pos_end) { throw std::out_of_range("submatrix begin/end position error"); } MatrixView<T, 1> view ({pos_end - pos_begin}, const_cast<T*>(&operator[](pos_begin)), {strides_}); return view; } template <typename Derived, std::semiregular T> T& MatrixBase<Derived, T, 1>::row(std::size_t n) { return const_cast<T&>(std::as_const(*this).row(n)); } template <typename Derived, std::semiregular T> const T& MatrixBase<Derived, T, 1>::row(std::size_t n) const { if (n >= dims_) { throw std::out_of_range("row index error"); } const T& val = operator[](n); return val; } template <typename Derived, std::semiregular T> T& MatrixBase<Derived, T, 1>::col(std::size_t n) { return row(n); } template <typename Derived, std::semiregular T> const T& MatrixBase<Derived, T, 1>::col(std::size_t n) const { return row(n); } template <typename Derived, std::semiregular T> template <typename DerivedOther1, typename DerivedOther2, std::semiregular U, std::semiregular V, std::invocable<T&, const U&, const V&> F> MatrixBase<Derived, T, 1>& MatrixBase<Derived, T, 1>::applyFunctionWithBroadcast( const frozenca::MatrixBase<DerivedOther1, U, 1>& m1, const frozenca::MatrixBase<DerivedOther2, V, 1>& m2, F&& f) { // real update is done here by passing lvalue reference T& auto r = dims(0); auto r1 = m1.dims(0); auto r2 = m2.dims(0); if (r1 == r) { if (r2 == r) { for (std::size_t i = 0; i < r; ++i) { f(this->row(i), m1.row(i), m2.row(i)); } } else { // r2 < r == r1 auto row2 = m2.row(0); for (std::size_t i = 0; i < r; ++i) { f(this->row(i), m1.row(i), row2); } } } else if (r2 == r) { // r1 < r == r2 auto row1 = m1.row(0); for (std::size_t i = 0; i < r; ++i) { f(this->row(i), row1, m2.row(i)); } } return *this; } } // namespace frozenca #endif //FROZENCA_MATRIXBASE_H 

(Stackexchange says OP is too long so I replace two files as links)

MatrixImpl.h (https://github.com/frozenca/Ndim-Matrix/blob/main/MatrixImpl.h)

MatrixView.h (https://github.com/frozenca/Ndim-Matrix/blob/main/MatrixView.h)

MatrixUtils.h

#ifndef FROZENCA_MATRIXUTILS_H #define FROZENCA_MATRIXUTILS_H #include <algorithm> #include <array> #include <cassert> #include <concepts> #include <cstddef> #include <initializer_list> #include <iostream> #include <iterator> #include <memory> #include <stdexcept> #include <type_traits> namespace frozenca { template <std::semiregular T, std::size_t N> class Matrix; template <std::semiregular T, std::size_t N> class MatrixView; template <typename Derived, std::semiregular T, std::size_t N> class MatrixBase; template <typename Derived> class ObjectBase; template <typename T> constexpr bool NotMatrix = true; template <std::semiregular T, std::size_t N> constexpr bool NotMatrix<Matrix<T, N>> = false; template <std::semiregular T, std::size_t N> constexpr bool NotMatrix<MatrixView<T, N>> = false; template <typename Derived, std::semiregular T, std::size_t N> constexpr bool NotMatrix<MatrixBase<Derived, T, N>> = false; template <typename Derived> constexpr bool NotMatrix<ObjectBase<Derived>> = false; template <typename T> concept isNotMatrix = NotMatrix<T> && std::semiregular<T>; template <typename T> concept isMatrix = !NotMatrix<T>; template <typename T> concept OneExists = requires () { { T{0} } -> std::convertible_to<T>; { T{1} } -> std::convertible_to<T>; }; template <typename A, typename B> concept WeakAddable = requires (A a, B b) { a + b; }; template <typename A, typename B> concept WeakSubtractable = requires (A a, B b) { a - b; }; template <typename A, typename B> concept WeakMultipliable = requires (A a, B b) { a * b; }; template <typename A, typename B> concept WeakDividable = requires (A a, B b) { a / b; }; template <typename A, typename B> concept WeakRemaindable = requires (A a, B b) { a / b; a % b; }; template <typename A, typename B, typename C> concept AddableTo = requires (A a, B b) { { a + b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept SubtractableTo = requires (A a, B b) { { a - b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept MultipliableTo = requires (A a, B b) { { a * b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept DividableTo = requires (A a, B b) { { a / b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept RemaindableTo = requires (A a, B b) { { a / b } -> std::convertible_to<C>; { a % b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept BitMaskableTo = requires (A a, B b) { { a & b } -> std::convertible_to<C>; { a | b } -> std::convertible_to<C>; { a ^ b } -> std::convertible_to<C>; { a << b } -> std::convertible_to<C>; { a >> b } -> std::convertible_to<C>; }; template <typename A, typename B> concept Addable = AddableTo<A, B, A>; template <typename A, typename B> concept Subtractable = SubtractableTo<A, B, A>; template <typename A, typename B> concept Multipliable = MultipliableTo<A, B, A>; template <typename A, typename B> concept Dividable = DividableTo<A, B, A>; template <typename A, typename B> concept Remaindable = RemaindableTo<A, B, A>; template <typename A, typename B> concept BitMaskable = BitMaskableTo<A, B, A>; template <typename A, typename B> requires WeakAddable<A, B> inline decltype(auto) Plus(A a, B b) { return a + b; } template <typename A, typename B> requires WeakSubtractable<A, B> inline decltype(auto) Minus(A a, B b) { return a - b; } template <typename A, typename B> requires WeakMultipliable<A, B> inline decltype(auto) Multiplies(A a, B b) { return a * b; } template <typename A, typename B> requires WeakDividable<A, B> inline decltype(auto) Divides(A a, B b) { return a / b; } template <typename A, typename B> requires WeakRemaindable<A, B> inline decltype(auto) Modulus(A a, B b) { return a % b; } template <typename A, typename B> using AddType = std::invoke_result_t<decltype(Plus<A, B>), A, B>; template <typename A, typename B> using SubType = std::invoke_result_t<decltype(Minus<A, B>), A, B>; template <typename A, typename B> using MulType = std::invoke_result_t<decltype(Multiplies<A, B>), A, B>; template <typename A, typename B> using DivType = std::invoke_result_t<decltype(Divides<A, B>), A, B>; template <typename A, typename B> using ModType = std::invoke_result_t<decltype(Modulus<A, B>), A, B>; template <typename A, typename B> concept DotProductable = Addable<MulType<A, B>, MulType<A, B>>; template <typename A, typename B, typename C> concept DotProductableTo = DotProductable<A, B> && MultipliableTo<A, B, C> && Addable<C, C>; template <typename A, typename B, typename C> requires AddableTo<A, B, C> inline void PlusTo(C& c, const A& a, const B& b) { c = a + b; } template <typename A, typename B, typename C> requires SubtractableTo<A, B, C> inline void MinusTo(C& c, const A& a, const B& b) { c = a - b; } template <typename A, typename B, typename C> requires MultipliableTo<A, B, C> inline void MultipliesTo(C& c, const A& a, const B& b) { c = a * b; } template <typename A, typename B, typename C> requires DividableTo<A, B, C> inline void DividesTo(C& c, const A& a, const B& b) { c = a / b; } template <typename A, typename B, typename C> requires RemaindableTo<A, B, C> inline void ModulusTo(C& c, const A& a, const B& b) { c = a % b; } template <typename... Args> inline constexpr bool All(Args... args) { return (... && args); }; template <typename... Args> inline constexpr bool Some(Args... args) { return (... || args); }; template <std::size_t M, std::size_t N> requires (N < M) std::array<std::size_t, M> prependDims(const std::array<std::size_t, N>& arr) { std::array<std::size_t, M> dims; std::ranges::fill(dims, 1u); std::ranges::copy(arr, std::begin(dims) + (M - N)); return dims; } template <std::size_t M, std::size_t N> bool bidirBroadcastable(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, N>& sz2) { if constexpr (M == N) { return (std::ranges::equal(sz1, sz2, [](const auto& d1, const auto& d2) { return (d1 == d2) || (d1 == 1) || (d2 == 1);})); } else if constexpr (M < N) { return bidirBroadcastable(prependDims<N, M>(sz1), sz2); } else { static_assert(M > N); return bidirBroadcastable(sz1, prependDims<M, N>(sz2)); } } template <std::size_t M, std::size_t N> std::array<std::size_t, std::max(M, N)> bidirBroadcastedDims(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, N>& sz2) { if constexpr (M == N) { if (!bidirBroadcastable(sz1, sz2)) { throw std::invalid_argument("Cannot broadcast"); } std::array<std::size_t, M> sz; std::ranges::transform(sz1, sz2, std::begin(sz), [](const auto& d1, const auto& d2) { return std::max(d1, d2); }); return sz; } else if constexpr (M < N) { return bidirBroadcastedDims(prependDims<N, M>(sz1), sz2); } else { static_assert(M > N); return bidirBroadcastedDims(sz1, prependDims<M, N>(sz2)); } } template <std::size_t M> requires (M > 1) std::array<std::size_t, M - 1> dotDims(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, 1>& sz2) { if (sz1[M - 1] != sz2[0]) { throw std::invalid_argument("Cannot do dot product, shape is not aligned"); } std::array<std::size_t, M - 1> sz; std::copy(std::begin(sz1), std::begin(sz1) + (M - 1), std::begin(sz)); return sz; } template <std::size_t M, std::size_t N> requires (N > 1) std::array<std::size_t, M + N - 2> dotDims(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, N>& sz2) { if (sz1[M - 1] != sz2[N - 2]) { throw std::invalid_argument("Cannot do dot product, shape is not aligned"); } std::array<std::size_t, M + N - 2> sz; std::copy(std::begin(sz1), std::begin(sz1) + (M - 1), std::begin(sz)); std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(sz) + (M - 1)); std::copy(std::begin(sz2) + (N - 1), std::end(sz2), std::begin(sz) + (M + N - 3)); return sz; } template <std::size_t M, std::size_t N> std::array<std::size_t, std::max(M, N)> matmulDims(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, N>& sz2) { if constexpr (M == 1) { std::array<std::size_t, 2> sz1_ = {1, sz1[0]}; return matmulDims(sz1_, sz2); } else if constexpr (N == 1) { std::array<std::size_t, 2> sz2_ = {sz2[0], 1}; return matmulDims(sz1, sz2_); } assert(M >= 2 && N >= 2); if (sz1[M - 1] != sz2[N - 2]) { throw std::invalid_argument("Cannot do dot product, shape is not aligned"); } std::array<std::size_t, 2> last_sz = {sz1[M - 2], sz2[N - 1]}; if constexpr (M == 2) { if constexpr (N == 2) { return last_sz; } else { // M = 2, N > 2 std::array<std::size_t, N> res_sz; std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(res_sz)); std::copy(std::begin(last_sz), std::end(last_sz), std::begin(res_sz) + (N - 2)); return res_sz; } } else if constexpr (N == 2) { // M > 2, N = 2 std::array<std::size_t, M> res_sz; std::copy(std::begin(sz1), std::begin(sz2) + (M - 2), std::begin(res_sz)); std::copy(std::begin(last_sz), std::end(last_sz), std::begin(res_sz) + (M - 2)); return res_sz; } else { // M > 2, N > 2 std::array<std::size_t, std::max(M, N)> res_sz; std::array<std::size_t, M - 2> sz1_front; std::array<std::size_t, N - 2> sz2_front; std::copy(std::begin(sz1), std::begin(sz1) + (M - 2), std::begin(sz1_front)); std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(sz2_front)); auto common_sz = bidirBroadcastedDims(sz1_front, sz2_front); std::copy(std::begin(common_sz), std::end(common_sz), std::begin(res_sz)); std::copy(std::begin(last_sz), std::end(last_sz), std::end(res_sz) - 2); return res_sz; } } template <typename... Args> concept IndexType = All(std::is_integral_v<Args>...); template <std::size_t N> std::array<std::size_t, N> computeStrides(const std::array<std::size_t, N>& dims) { std::array<std::size_t, N> strides; std::size_t str = 1; for (std::size_t i = N - 1; i < N; --i) { strides[i] = str; str *= dims[i]; } return strides; } template <std::size_t N, typename Initializer> bool checkNonJagged(const Initializer& init) { auto i = std::cbegin(init); for (auto j = std::next(i); j != std::cend(init); ++j) { if (i->size() != j->size()) { return false; } } return true; } template <std::size_t N, typename Iter, typename Initializer> void addDims(Iter& first, const Initializer& init) { if constexpr (N > 1) { if (!checkNonJagged<N>(init)) { throw std::invalid_argument("Jagged matrix initializer"); } } *first = std::size(init); ++first; if constexpr (N > 1) { addDims<N - 1>(first, *std::begin(init)); } } template <std::size_t N, typename Initializer> std::array<std::size_t, N> deriveDims(const Initializer& init) { std::array<std::size_t, N> dims; auto f = std::begin(dims); addDims<N>(f, init); return dims; } template <std::semiregular T> void addList(std::unique_ptr<T[]>& data, const T* first, const T* last, std::size_t& index) { for (; first != last; ++first) { data[index] = *first; ++index; } } template <std::semiregular T, typename I> void addList(std::unique_ptr<T[]>& data, const std::initializer_list<I>* first, const std::initializer_list<I>* last, std::size_t& index) { for (; first != last; ++first) { addList(data, first->begin(), first->end(), index); } } template <std::semiregular T, typename I> void insertFlat(std::unique_ptr<T[]>& data, std::initializer_list<I> list) { std::size_t index = 0; addList(data, std::begin(list), std::end(list), index); } inline long quot(long a, long b) { return (a / b) - (a % b < 0); } inline long mod(long a, long b) { return (a % b + b) % b; } } // namespace frozenca #endif //FROZENCA_MATRIXUTILS_H 

MatrixInitializer.h

#ifndef FROZENCA_MATRIXINITIALIZER_H #define FROZENCA_MATRIXINITIALIZER_H #include <cstddef> #include <concepts> #include <initializer_list> namespace frozenca { template <std::semiregular T, std::size_t N> struct MatrixInitializer { using type = std::initializer_list<typename MatrixInitializer<T, N - 1>::type>; }; template <std::semiregular T> struct MatrixInitializer<T, 1> { using type = std::initializer_list<T>; }; template <std::semiregular T> struct MatrixInitializer<T, 0>; } // namespace frozenca #endif //FROZENCA_MATRIXINITIALIZER_H 

MatrixOps.h

#ifndef FROZENCA_MATRIXOPS_H #define FROZENCA_MATRIXOPS_H #include "MatrixImpl.h" namespace frozenca { // Matrix constructs template <std::semiregular T, std::size_t N> Matrix<T, N> empty(const std::array<std::size_t, N>& arr) { Matrix<T, N> mat (arr); return mat; } template <typename Derived, std::semiregular T, std::size_t N> Matrix<T, N> empty_like(const MatrixBase<Derived, T, N>& base) { Matrix<T, N> mat (base.dims()); return mat; } template <OneExists T> Matrix<T, 2> eye(std::size_t n, std::size_t m) { Matrix<T, 2> mat (n, m); for (std::size_t i = 0; i < std::min(n, m); ++i) { mat(i, i) = T{1}; } return mat; } template <OneExists T> Matrix<T, 2> eye(std::size_t n) { return eye<T>(n, n); } template <OneExists T> Matrix<T, 2> identity(std::size_t n) { return eye<T>(n, n); } template <OneExists T, std::size_t N> Matrix<T, N> ones(const std::array<std::size_t, N>& arr) { Matrix<T, N> mat (arr); std::ranges::fill(mat, T{1}); return mat; } template <typename Derived, OneExists T, std::size_t N> Matrix<T, N> ones_like(const MatrixBase<Derived, T, N>& base) { Matrix<T, N> mat (base.dims()); std::ranges::fill(mat, T{1}); return mat; } template <std::semiregular T, std::size_t N> Matrix<T, N> zeros(const std::array<std::size_t, N>& arr) { Matrix<T, N> mat (arr); std::ranges::fill(mat, T{0}); return mat; } template <typename Derived, std::semiregular T, std::size_t N> Matrix<T, N> zeros_like(const MatrixBase<Derived, T, N>& base) { Matrix<T, N> mat (base.dims()); std::ranges::fill(mat, T{0}); return mat; } template <std::semiregular T, std::size_t N> Matrix<T, N> full(const std::array<std::size_t, N>& arr, const T& fill_value) { Matrix<T, N> mat (arr); std::ranges::fill(mat, fill_value); return mat; } template <typename Derived, std::semiregular T, std::size_t N> Matrix<T, N> full_like(const MatrixBase<Derived, T, N>& base, const T& fill_value) { Matrix<T, N> mat (base.dims()); std::ranges::fill(mat, fill_value); return mat; } // binary matrix operators namespace { template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires AddableTo<U, V, T> && (std::max(N1, N2) == N) void AddTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, PlusTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, AddTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires SubtractableTo<U, V, T> && (std::max(N1, N2) == N) void SubtractTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, MinusTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, SubtractTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires MultipliableTo<U, V, T> && (std::max(N1, N2) == N) void MultiplyTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, MultipliesTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, MultiplyTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires DividableTo<U, V, T> && (std::max(N1, N2) == N) void DivideTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, DividesTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, DivideTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires RemaindableTo<U, V, T> && (std::max(N1, N2) == N) void ModuloTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, ModulusTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, ModuloTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } } // anonymous namespace template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = AddType<U, V>> requires AddableTo<U, V, T> decltype(auto) operator+ (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, AddTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = SubType<U, V>> requires SubtractableTo<U, V, T> decltype(auto) operator- (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, SubtractTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = MulType<U, V>> requires MultipliableTo<U, V, T> decltype(auto) operator* (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, MultiplyTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = DivType<U, V>> requires DividableTo<U, V, T> decltype(auto) operator/ (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, DivideTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = ModType<U, V>> requires RemaindableTo<U, V, T> decltype(auto) operator% (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, ModuloTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } } // namespace frozenca #endif //FROZENCA_MATRIXOPS_H 

LinalgOps.h (LINEAR ALGEBRA OPERATIONS)

#ifndef FROZENCA_LINALGOPS_H #define FROZENCA_LINALGOPS_H #include "Matrix.h" namespace frozenca { namespace { template <std::semiregular U, std::semiregular V, std::semiregular T> requires DotProductableTo<U, V, T> void DotTo(T& m, const MatrixView<U, 1>& m1, const MatrixView<V, 1>& m2) { m += std::inner_product(std::begin(m1), std::end(m1), std::begin(m2), T{0}); } template <std::semiregular U, std::semiregular V, std::semiregular T> requires DotProductableTo<U, V, T> void DotTo(MatrixView<T, 1>& m, const MatrixView<U, 1>& m1, const MatrixView<V, 2>& m2) { assert(m.dims(0) == m2.dims(1)); std::size_t c = m2.dims(1); for (std::size_t j = 0; j < c; ++j) { auto col2 = m2.col(j); m[j] += std::inner_product(std::begin(m1), std::end(m1), std::begin(col2), T{0}); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N2> requires DotProductableTo<U, V, T> && (N2 > 2) void DotTo(MatrixView<T, N2 - 1>& m, const MatrixView<U, 1>& m1, const MatrixView<V, N2>& m2) { assert(m.dims(0) == m2.dims(0)); std::size_t r = m.dims(0); for (std::size_t i = 0; i < r; ++i) { auto row0 = m.row(i); auto row2 = m2.row(i); DotTo(row0, m1, row2); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2> requires DotProductableTo<U, V, T> && (N1 > 1) void DotTo(MatrixView<T, N1 - 1>& m, const MatrixView<U, N1>& m1, const MatrixView<V, 1>& m2) { assert(m.dims(0) == m1.dims(0)); std::size_t r = m.dims(0); for (std::size_t i = 0; i < r; ++i) { auto row0 = m.row(i); auto row1 = m1.row(i); DotTo(row0, row1, m2); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2> requires DotProductableTo<U, V, T> && (N1 > 1) && (N2 > 1) void DotTo(MatrixView<T, N1 + N2 - 2>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { assert(m.dims(0) == m1.dims(0)); std::size_t r = m.dims(0); for (std::size_t i = 0; i < r; ++i) { auto row0 = m.row(i); auto row1 = m1.row(i); DotTo(row0, row1, m2); } } template <typename Derived0, typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2> requires DotProductableTo<U, V, T> void DotTo(MatrixBase<Derived0, T, (N1 + N2 - 2)>& m, const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { MatrixView<T, (N1 + N2 - 2)> m_view (m); MatrixView<U, N1> m1_view (m1); MatrixView<V, N2> m2_view (m2); DotTo(m_view, m1_view, m2_view); } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires DotProductableTo<U, V, T> && (std::max(N1, N2) == N) void MatmulTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 2) { DotTo(m, m1, m2); } else { m.applyFunctionWithBroadcast(m1, m2, MatmulTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } } // anonymous namespace template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t M, std::size_t N, std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T> decltype(auto) dot(const MatrixBase<Derived1, U, M>& m1, const MatrixBase<Derived2, V, N>& m2) { auto dims = dotDims(m1.dims(), m2.dims()); Matrix<T, (M + N - 2)> res = zeros<T, (M + N - 2)>(dims); DotTo(res, m1, m2); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T> decltype(auto) dot(const MatrixBase<Derived1, U, 1>& m1, const MatrixBase<Derived2, V, 1>& m2) { auto dims = dotDims(m1.dims(), m2.dims()); T res {0}; DotTo(res, m1, m2); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T> decltype(auto) matmul(const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = matmulDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, MatmulTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } } // namespace frozenca #endif //FROZENCA_LINALGOPS_H 

Matrix.h

#ifndef FROZENCA_MATRIX_H #define FROZENCA_MATRIX_H #include "MatrixImpl.h" #include "MatrixOps.h" #include "LinalgOps.h" #endif //FROZENCA_MATRIX_H 

Feel free to comment anything!

What I don't like:

  • Too much expose of APIs to users. I hate C++ includes, I'm really waiting C++20 modules! (Current module implementation in MSVC 19.28 is broken, it fails to work in my code)
  • .reshape(). It should move the buffer so takes rvalue reference, but I wonder something better.
  • The helper functions are becoming too template-heavy. Precisely, I don't like functions like applyFunctionWithBroadcasting are generated differently with every different N > 1 even if what they do is the same. But I need specialization for N = 1 (because submatrix, row, col should return T&), so I don't know better way.

To see what is going on with "Broadcasting", dot product and matrix multiplication, see below:

https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md

https://numpy.org/doc/stable/reference/generated/numpy.dot.html

https://numpy.org/doc/stable/reference/generated/numpy.matmul.html

\$\endgroup\$

1 Answer 1

1
\$\begingroup\$

ObjectBase.h can use a couple of simplifications.

template <typename F> requires std::invocable<F, typename Derived::reference> 

We can use abbreviated syntax here:

template <std::invocable<typename Derived::reference> F> 

Similarly for other constrained templates (though we might need to change the template-parameter order for Addable, etc.).

The loop in applyFunction doesn't need iterators:

for (auto&& e: *this) { f(e); } 

Or even

std::for_each(begin(), end(), std::forward<F>(f)); 

With the first approach, I think we don't need an overload for zero Args..., as the varargs template can bind to that. I could be wrong there - that's a tricky area to get right!

The operator implementations accept a const ref, but then make a copy:

ObjectBase<Derived> operator+(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res += val; return res; } 

If we accept m by value, then we get a copy of our own (and avoid a copy in some cases):

ObjectBase<Derived> operator+(ObjectBase<Derived> m, const U& val) { return m += val; } 

The MatrixBase constructors can initialise size_ and stride_, rather than assigning them. That may mean that they can become const members (actually, no they can't, because we have assignment and swap).


MatrixUtils.h - std::array has a member begin() which std::begin() will call. It's shorter and just as clear to use the member function. Same applies to std::end() of course.


\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.