I've improved my N-dimensional C++20 matrix project (C++20 : N-dimensional minimal Matrix class).
Implemented general matrix addition/subtraction, elementwise multiplication/division, dot product, matrix product, reshape, transpose.
There's a lot of code:
ObjectBase.h
#ifndef FROZENCA_OBJECTBASE_H #define FROZENCA_OBJECTBASE_H #include <functional> #include <utility> #include "MatrixUtils.h" namespace frozenca { template <typename Derived> class ObjectBase { private: Derived& self() { return static_cast<Derived&>(*this); } const Derived& self() const { return static_cast<const Derived&>(*this); } protected: ObjectBase() = default; ~ObjectBase() noexcept = default; public: auto begin() { return self().begin(); } auto begin() const { return self().begin(); } auto cbegin() const { return self().cbegin(); } auto end() { return self().end(); } auto end() const { return self().end(); } auto cend() const { return self().cend(); } auto rbegin() { return self().rbegin(); } auto rbegin() const { return self().rbegin(); } auto crbegin() const { return self().crbegin(); } auto rend() { return self().rend(); } auto rend() const { return self().rend(); } auto crend() const { return self().crend(); } template <typename F> requires std::invocable<F, typename Derived::reference> ObjectBase& applyFunction(F&& f); template <typename F, typename... Args> requires std::invocable<F, typename Derived::reference, Args...> ObjectBase& applyFunction(F&& f, Args&&... args); template <typename DerivedOther, typename F> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference> ObjectBase& applyFunction(const ObjectBase<DerivedOther>& other, F&& f); template <typename DerivedOther, typename F, typename... Args> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference, Args...> ObjectBase& applyFunction(const ObjectBase<DerivedOther>& other, F&& f, Args&&... args); template <isNotMatrix U> requires Addable<typename Derived::value_type, U> ObjectBase& operator=(const U& val) { return applyFunction([&val](auto& v) {v = val;}); } template <isNotMatrix U> requires Addable<typename Derived::value_type, U> ObjectBase& operator+=(const U& val) { return applyFunction([&val](auto& v) {v += val;}); } template <isNotMatrix U> requires Subtractable<typename Derived::value_type, U> ObjectBase& operator-=(const U& val) { return applyFunction([&val](auto& v) {v -= val;}); } template <isNotMatrix U> requires Multipliable<typename Derived::value_type, U> ObjectBase& operator*=(const U& val) { return applyFunction([&val](auto& v) {v *= val;}); } template <isNotMatrix U> requires Dividable<typename Derived::value_type, U> ObjectBase& operator/=(const U& val) { return applyFunction([&val](auto& v) {v /= val;}); } template <isNotMatrix U> requires Remaindable<typename Derived::value_type, U> ObjectBase& operator%=(const U& val) { return applyFunction([&val](auto& v) {v %= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator&=(const U& val) { return applyFunction([&val](auto& v) {v &= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator|=(const U& val) { return applyFunction([&val](auto& v) {v |= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator^=(const U& val) { return applyFunction([&val](auto& v) {v ^= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator<<=(const U& val) { return applyFunction([&val](auto& v) {v <<= val;}); } template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase& operator>>=(const U& val) { return applyFunction([&val](auto& v) {v >>= val;}); } }; template <typename Derived> template <typename F> requires std::invocable<F, typename Derived::reference> ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(F&& f) { for (auto it = begin(); it != end(); ++it) { f(*it); } return *this; } template <typename Derived> template <typename F, typename... Args> requires std::invocable<F, typename Derived::reference, Args...> ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(F&& f, Args&&... args) { for (auto it = begin(); it != end(); ++it) { f(*it, std::forward<Args...>(args...)); } return *this; } template <typename Derived> template <typename DerivedOther, typename F> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference> ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(const ObjectBase<DerivedOther>& other, F&& f) { for (auto it = begin(), it2 = other.begin(); it != end(); ++it, ++it2) { f(*it, *it2); } return *this; } template <typename Derived> template <typename DerivedOther, typename F, typename... Args> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference, Args...> ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(const ObjectBase<DerivedOther>& other, F&& f, Args&&... args) { for (auto it = begin(), it2 = other.begin(); it != end(); ++it, ++it2) { f(*it, *it2, std::forward<Args...>(args...)); } return *this; } template <typename Derived, isNotMatrix U> requires Addable<typename Derived::value_type, U> ObjectBase<Derived> operator+(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res += val; return res; } template <typename Derived, isNotMatrix U> requires Subtractable<typename Derived::value_type, U> ObjectBase<Derived> operator-(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res -= val; return res; } template <typename Derived, isNotMatrix U> requires Multipliable<typename Derived::value_type, U> ObjectBase<Derived> operator*(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res *= val; return res; } template <typename Derived, isNotMatrix U> requires Dividable<typename Derived::value_type, U> ObjectBase<Derived> operator/(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res /= val; return res; } template <typename Derived, isNotMatrix U> requires Remaindable<typename Derived::value_type, U> ObjectBase<Derived> operator%(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res %= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator&(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res &= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator^(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res ^= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator|(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res |= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator<<(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res <<= val; return res; } template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U> ObjectBase<Derived> operator>>(const ObjectBase<Derived>& m, const U& val) { ObjectBase<Derived> res = m; res >>= val; return res; } } // namespace frozenca #endif //FROZENCA_OBJECTBASE_H MatrixBase.h
#ifndef FROZENCA_MATRIXBASE_H #define FROZENCA_MATRIXBASE_H #include <numeric> #include "ObjectBase.h" #include "MatrixInitializer.h" namespace frozenca { template <std::semiregular T, std::size_t N> class MatrixView; template <typename Derived, std::semiregular T, std::size_t N> class MatrixBase : public ObjectBase<MatrixBase<Derived, T, N>> { static_assert(N > 1); public: static constexpr std::size_t ndim = N; private: std::array<std::size_t, N> dims_; std::size_t size_; std::array<std::size_t, N> strides_; public: MatrixBase() = delete; using Base = ObjectBase<MatrixBase<Derived, T, N>>; using Base::applyFunction; using Base::operator=; using Base::operator+=; using Base::operator-=; using Base::operator*=; using Base::operator/=; using Base::operator%=; Derived& self() { return static_cast<Derived&>(*this); } const Derived& self() const { return static_cast<const Derived&>(*this); } protected: ~MatrixBase() noexcept = default; MatrixBase(const std::array<std::size_t, N>& dims); template <std::size_t M> requires (M < N) MatrixBase(const std::array<std::size_t, M>& dims); template <IndexType... Dims> explicit MatrixBase(Dims... dims); template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T> MatrixBase(const MatrixBase<DerivedOther, U, N>&); MatrixBase(typename MatrixInitializer<T, N>::type init); public: template <typename U> MatrixBase(std::initializer_list<U>) = delete; template <typename U> MatrixBase& operator=(std::initializer_list<U>) = delete; using value_type = T; using reference = T&; using const_reference = const T&; using pointer = T*; public: friend void swap(MatrixBase& a, MatrixBase& b) noexcept { std::swap(a.size_, b.size_); std::swap(a.dims_, b.dims_); std::swap(a.strides_, b.strides_); } auto begin() { return self().begin(); } auto begin() const { return self().begin(); } auto cbegin() const { return self().cbegin(); } auto end() { return self().end(); } auto end() const { return self().end(); } auto cend() const { return self().cend(); } auto rbegin() { return self().rbegin(); } auto rbegin() const { return self().rbegin(); } auto crbegin() const { return self().crbegin(); } auto rend() { return self().rend(); } auto rend() const { return self().rend(); } auto crend() const { return self().crend(); } template <IndexType... Args> reference operator()(Args... args); template <IndexType... Args> const_reference operator()(Args... args) const; reference operator[](const std::array<std::size_t, N>& pos); const_reference operator[](const std::array<std::size_t, N>& pos) const; [[nodiscard]] std::size_t size() const { return size_;} [[nodiscard]] const std::array<std::size_t, N>& dims() const { return dims_; } [[nodiscard]] std::size_t dims(std::size_t n) const { if (n >= N) { throw std::out_of_range("Out of range in dims"); } return dims_[n]; } [[nodiscard]] const std::array<std::size_t, N>& strides() const { return strides_; } [[nodiscard]] std::size_t strides(std::size_t n) const { if (n >= N) { throw std::out_of_range("Out of range in strides"); } return strides_[n]; } auto dataView() const { return self().dataView(); } auto origStrides() const { return self().origStrides(); } MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin); MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end); MatrixView<T, N - 1> row(std::size_t n); MatrixView<T, N - 1> col(std::size_t n); MatrixView<T, N - 1> operator[](std::size_t n) { return row(n); } MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin) const; MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end) const; MatrixView<T, N - 1> row(std::size_t n) const; MatrixView<T, N - 1> col(std::size_t n) const; MatrixView<T, N - 1> operator[](std::size_t n) const { return row(n); } friend std::ostream& operator<<(std::ostream& os, const MatrixBase& m) { os << '{'; for (std::size_t i = 0; i != m.dims(0); ++i) { os << m[i]; if (i + 1 != m.dims(0)) { os << ", "; } } return os << '}'; } template <typename DerivedOther1, typename DerivedOther2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::invocable<MatrixView<T, N - 1>&, const MatrixView<U, std::min(N1, N - 1)>&, const MatrixView<V, std::min(N2, N - 1)>&> F> requires (std::max(N1, N2) == N) MatrixBase& applyFunctionWithBroadcast(const MatrixBase<DerivedOther1, U, N1>& m1, const MatrixBase<DerivedOther2, V, N2>& m2, F&& f); }; template <typename Derived, std::semiregular T, std::size_t N> MatrixBase<Derived, T, N>::MatrixBase(const std::array<std::size_t, N>& dims) : dims_ {dims} { if (std::ranges::find(dims_, 0lu) != std::end(dims_)) { throw std::invalid_argument("Zero dimension not allowed"); } size_ = std::accumulate(std::begin(dims_), std::end(dims_), 1lu, std::multiplies<>{}); strides_ = computeStrides(dims_); } template <typename Derived, std::semiregular T, std::size_t N> template <std::size_t M> requires (M < N) MatrixBase<Derived, T, N>::MatrixBase(const std::array<std::size_t, M>& dims) : MatrixBase (prepend<N, M>(dims)) {} template <typename Derived, std::semiregular T, std::size_t N> template <IndexType... Dims> MatrixBase<Derived, T, N>::MatrixBase(Dims... dims) : dims_{static_cast<std::size_t>(dims)...} { static_assert(sizeof...(Dims) == N); static_assert((std::is_integral_v<Dims> && ...)); if (std::ranges::find(dims_, 0lu) != std::end(dims_)) { throw std::invalid_argument("Zero dimension not allowed"); } size_ = std::accumulate(std::begin(dims_), std::end(dims_), 1lu, std::multiplies<>{}); strides_ = computeStrides(dims_); } template <typename Derived, std::semiregular T, std::size_t N> template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T> MatrixBase<Derived, T, N>::MatrixBase(const MatrixBase<DerivedOther, U, N>& other) : MatrixBase(other.dims()) {} template <typename Derived, std::semiregular T, std::size_t N> MatrixBase<Derived, T, N>::MatrixBase(typename MatrixInitializer<T, N>::type init) : MatrixBase(deriveDims<N>(init)) {} template <typename Derived, std::semiregular T, std::size_t N> template <IndexType... Args> typename MatrixBase<Derived, T, N>::reference MatrixBase<Derived, T, N>::operator()(Args... args) { return const_cast<typename MatrixBase<Derived, T, N>::reference>(std::as_const(*this).operator()(args...)); } template <typename Derived, std::semiregular T, std::size_t N> template <IndexType... Args> typename MatrixBase<Derived, T, N>::const_reference MatrixBase<Derived, T, N>::operator()(Args... args) const { static_assert(sizeof...(args) == N); std::array<std::size_t, N> pos {std::size_t(args)...}; return operator[](pos); } template <typename Derived, std::semiregular T, std::size_t N> typename MatrixBase<Derived, T, N>::reference MatrixBase<Derived, T, N>::operator[](const std::array<std::size_t, N>& pos) { return const_cast<typename MatrixBase<Derived, T, N>::reference>(std::as_const(*this).operator[](pos)); } template <typename Derived, std::semiregular T, std::size_t N> typename MatrixBase<Derived, T, N>::const_reference MatrixBase<Derived, T, N>::operator[](const std::array<std::size_t, N>& pos) const { if (!std::equal(std::cbegin(pos), std::cend(pos), std::cbegin(dims_), std::less<>{})) { throw std::out_of_range("Out of range in element access"); } return *(cbegin() + std::inner_product(std::cbegin(pos), std::cend(pos), std::cbegin(strides_), 0lu)); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin) { return submatrix(pos_begin, dims_); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin) const { return submatrix(pos_begin, dims_); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end) { return std::as_const(*this).submatrix(pos_begin, pos_end); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end) const { if (!std::equal(std::cbegin(pos_begin), std::cend(pos_begin), std::cbegin(pos_end), std::less<>{})) { throw std::out_of_range("submatrix begin/end position error"); } std::array<std::size_t, N> view_dims; std::transform(std::cbegin(pos_end), std::cend(pos_end), std::cbegin(pos_begin), std::begin(view_dims), std::minus<>{}); MatrixView<T, N> view(view_dims, const_cast<T*>(&operator[](pos_begin)), strides()); return view; } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N - 1> MatrixBase<Derived, T, N>::row(std::size_t n) { return std::as_const(*this).row(n); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N - 1> MatrixBase<Derived, T, N>::row(std::size_t n) const { const auto& orig_dims = dims(); if (n >= orig_dims[0]) { throw std::out_of_range("row index error"); } std::array<std::size_t, N - 1> row_dims; std::copy(std::cbegin(orig_dims) + 1, std::cend(orig_dims), std::begin(row_dims)); std::array<std::size_t, N> pos_begin = {n, }; std::array<std::size_t, N - 1> row_strides; std::array<std::size_t, N> orig_strides; if constexpr (std::is_same_v<Derived, MatrixView<T, N>>) { orig_strides = origStrides(); } else { orig_strides = strides(); } std::copy(std::cbegin(orig_strides) + 1, std::cend(orig_strides), std::begin(row_strides)); MatrixView<T, N - 1> nth_row(row_dims, const_cast<T*>(&operator[](pos_begin)), row_strides); return nth_row; } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N - 1> MatrixBase<Derived, T, N>::col(std::size_t n) { return std::as_const(*this).col(n); } template <typename Derived, std::semiregular T, std::size_t N> MatrixView<T, N - 1> MatrixBase<Derived, T, N>::col(std::size_t n) const { const auto& orig_dims = dims(); if (n >= orig_dims[N - 1]) { throw std::out_of_range("row index error"); } std::array<std::size_t, N - 1> col_dims; std::copy(std::cbegin(orig_dims), std::cend(orig_dims) - 1, std::begin(col_dims)); std::array<std::size_t, N> pos_begin = {0}; pos_begin[N - 1] = n; std::array<std::size_t, N - 1> col_strides; std::array<std::size_t, N> orig_strides; if constexpr (std::is_same_v<Derived, MatrixView<T, N>>) { orig_strides = origStrides(); } else { orig_strides = strides(); } std::copy(std::cbegin(orig_strides), std::cend(orig_strides) - 1, std::begin(col_strides)); MatrixView<T, N - 1> nth_col(col_dims, const_cast<T*>(&operator[](pos_begin)), col_strides); return nth_col; } template <typename Derived, std::semiregular T, std::size_t N> template <typename DerivedOther1, typename DerivedOther2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::invocable<MatrixView<T, N - 1>&, const MatrixView<U, std::min(N1, N - 1)>&, const MatrixView<V, std::min(N2, N - 1)>&> F> requires (std::max(N1, N2) == N) MatrixBase<Derived, T, N>& MatrixBase<Derived, T, N>::applyFunctionWithBroadcast(const MatrixBase<DerivedOther1, U, N1>& m1, const MatrixBase<DerivedOther2, V, N2>& m2, F&& f) { if constexpr (N1 == N) { if constexpr (N2 == N) { auto r = dims(0); auto r1 = m1.dims(0); auto r2 = m2.dims(0); if (r1 == r) { if (r2 == r) { for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, m1.row(i), m2.row(i)); } } else { // r2 < r == r1 auto row2 = m2.row(0); for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, m1.row(i), row2); } } } else if (r2 == r) { // r1 < r == r2 auto row1 = m1.row(0); for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, row1, m2.row(i)); } } else { assert(0); // cannot happen } } else { // N2 < N == N1 auto r = dims(0); assert(r == m1.dims(0)); MatrixView<V, N2> view2 (m2); for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, m1.row(i), view2); } } } else if constexpr (N2 == N) { // N1 < N == N2 auto r = dims(0); assert(r == m2.dims(0)); MatrixView<U, N1> view1 (m1); for (std::size_t i = 0; i < r; ++i) { auto row = this->row(i); f(row, view1, m2.row(i)); } } else { assert(0); // cannot happen } return *this; } template <typename Derived, std::semiregular T> class MatrixBase<Derived, T, 1> : public ObjectBase<MatrixBase<Derived, T, 1>> { public: static constexpr std::size_t ndim = 1; private: std::size_t dims_; std::size_t strides_; Derived& self() { return static_cast<Derived&>(*this); } const Derived& self() const { return static_cast<const Derived&>(*this); } public: MatrixBase() = delete; using Base = ObjectBase<MatrixBase<Derived, T, 1>>; using Base::applyFunction; using Base::operator=; using Base::operator+=; using Base::operator-=; using Base::operator*=; using Base::operator/=; using Base::operator%=; protected: ~MatrixBase() noexcept = default; template <typename Dim> requires std::is_integral_v<Dim> explicit MatrixBase(Dim dim) : dims_(dim), strides_(1) {}; template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T> MatrixBase(const MatrixBase<DerivedOther, U, 1>&); MatrixBase(typename MatrixInitializer<T, 1>::type init); public: using value_type = T; using reference = T&; using const_reference = const T&; using pointer = T*; public: friend void swap(MatrixBase& a, MatrixBase& b) noexcept { std::swap(a.size_, b.size_); std::swap(a.dims_, b.dims_); std::swap(a.strides_, b.strides_); } auto begin() { return self().begin(); } auto begin() const { return self().begin(); } auto cbegin() const { return self().cbegin(); } auto end() { return self().end(); } auto end() const { return self().end(); } auto cend() const { return self().cend(); } auto rbegin() { return self().rbegin(); } auto rbegin() const { return self().rbegin(); } auto crbegin() const { return self().crbegin(); } auto rend() { return self().rend(); } auto rend() const { return self().rend(); } auto crend() const { return self().crend(); } template <typename Dim> requires std::is_integral_v<Dim> reference operator()(Dim dim) { return operator[](dim); } template <typename Dim> requires std::is_integral_v<Dim> const_reference operator()(Dim dim) const { return operator[](dim); } [[nodiscard]] std::array<std::size_t, 1> dims() const { return {dims_}; } [[nodiscard]] std::size_t dims(std::size_t n) const { if (n >= 1) { throw std::out_of_range("Out of range in dims"); } return dims_; } [[nodiscard]] std::size_t strides() const { return strides_; } auto dataView() const { return self().dataView(); } auto origStrides() const { return self().origStrides(); } MatrixView<T, 1> submatrix(std::size_t pos_begin); MatrixView<T, 1> submatrix(std::size_t pos_begin, std::size_t pos_end); T& row(std::size_t n); T& col(std::size_t n); T& operator[](std::size_t n) { return *(begin() + n); } MatrixView<T, 1> submatrix(std::size_t pos_begin) const; MatrixView<T, 1> submatrix(std::size_t pos_begin, std::size_t pos_end) const; const T& row(std::size_t n) const; const T& col(std::size_t n) const; const T& operator[](std::size_t n) const { return *(cbegin() + n); } friend std::ostream& operator<<(std::ostream& os, const MatrixBase& m) { os << '{'; for (std::size_t i = 0; i != m.dims_; ++i) { os << m[i]; if (i + 1 != m.dims_) { os << ", "; } } return os << '}'; } template <typename DerivedOther1, typename DerivedOther2, std::semiregular U, std::semiregular V, std::invocable<T&, const U&, const V&> F> MatrixBase& applyFunctionWithBroadcast(const frozenca::MatrixBase<DerivedOther1, U, 1>& m1, const frozenca::MatrixBase<DerivedOther2, V, 1>& m2, F&& f); }; template <typename Derived, std::semiregular T> MatrixBase<Derived, T, 1>::MatrixBase(typename MatrixInitializer<T, 1>::type init) : MatrixBase(deriveDims<1>(init)[0]) { } template <typename Derived, std::semiregular T> MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin) { return submatrix(pos_begin, dims_); } template <typename Derived, std::semiregular T> MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin) const { return submatrix(pos_begin, dims_); } template <typename Derived, std::semiregular T> MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin, std::size_t pos_end) { return std::as_const(*this).submatrix(pos_begin, pos_end); } template <typename Derived, std::semiregular T> MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin, std::size_t pos_end) const { if (pos_begin >= pos_end) { throw std::out_of_range("submatrix begin/end position error"); } MatrixView<T, 1> view ({pos_end - pos_begin}, const_cast<T*>(&operator[](pos_begin)), {strides_}); return view; } template <typename Derived, std::semiregular T> T& MatrixBase<Derived, T, 1>::row(std::size_t n) { return const_cast<T&>(std::as_const(*this).row(n)); } template <typename Derived, std::semiregular T> const T& MatrixBase<Derived, T, 1>::row(std::size_t n) const { if (n >= dims_) { throw std::out_of_range("row index error"); } const T& val = operator[](n); return val; } template <typename Derived, std::semiregular T> T& MatrixBase<Derived, T, 1>::col(std::size_t n) { return row(n); } template <typename Derived, std::semiregular T> const T& MatrixBase<Derived, T, 1>::col(std::size_t n) const { return row(n); } template <typename Derived, std::semiregular T> template <typename DerivedOther1, typename DerivedOther2, std::semiregular U, std::semiregular V, std::invocable<T&, const U&, const V&> F> MatrixBase<Derived, T, 1>& MatrixBase<Derived, T, 1>::applyFunctionWithBroadcast( const frozenca::MatrixBase<DerivedOther1, U, 1>& m1, const frozenca::MatrixBase<DerivedOther2, V, 1>& m2, F&& f) { // real update is done here by passing lvalue reference T& auto r = dims(0); auto r1 = m1.dims(0); auto r2 = m2.dims(0); if (r1 == r) { if (r2 == r) { for (std::size_t i = 0; i < r; ++i) { f(this->row(i), m1.row(i), m2.row(i)); } } else { // r2 < r == r1 auto row2 = m2.row(0); for (std::size_t i = 0; i < r; ++i) { f(this->row(i), m1.row(i), row2); } } } else if (r2 == r) { // r1 < r == r2 auto row1 = m1.row(0); for (std::size_t i = 0; i < r; ++i) { f(this->row(i), row1, m2.row(i)); } } return *this; } } // namespace frozenca #endif //FROZENCA_MATRIXBASE_H (Stackexchange says OP is too long so I replace two files as links)
MatrixImpl.h (https://github.com/frozenca/Ndim-Matrix/blob/main/MatrixImpl.h)
MatrixView.h (https://github.com/frozenca/Ndim-Matrix/blob/main/MatrixView.h)
MatrixUtils.h
#ifndef FROZENCA_MATRIXUTILS_H #define FROZENCA_MATRIXUTILS_H #include <algorithm> #include <array> #include <cassert> #include <concepts> #include <cstddef> #include <initializer_list> #include <iostream> #include <iterator> #include <memory> #include <stdexcept> #include <type_traits> namespace frozenca { template <std::semiregular T, std::size_t N> class Matrix; template <std::semiregular T, std::size_t N> class MatrixView; template <typename Derived, std::semiregular T, std::size_t N> class MatrixBase; template <typename Derived> class ObjectBase; template <typename T> constexpr bool NotMatrix = true; template <std::semiregular T, std::size_t N> constexpr bool NotMatrix<Matrix<T, N>> = false; template <std::semiregular T, std::size_t N> constexpr bool NotMatrix<MatrixView<T, N>> = false; template <typename Derived, std::semiregular T, std::size_t N> constexpr bool NotMatrix<MatrixBase<Derived, T, N>> = false; template <typename Derived> constexpr bool NotMatrix<ObjectBase<Derived>> = false; template <typename T> concept isNotMatrix = NotMatrix<T> && std::semiregular<T>; template <typename T> concept isMatrix = !NotMatrix<T>; template <typename T> concept OneExists = requires () { { T{0} } -> std::convertible_to<T>; { T{1} } -> std::convertible_to<T>; }; template <typename A, typename B> concept WeakAddable = requires (A a, B b) { a + b; }; template <typename A, typename B> concept WeakSubtractable = requires (A a, B b) { a - b; }; template <typename A, typename B> concept WeakMultipliable = requires (A a, B b) { a * b; }; template <typename A, typename B> concept WeakDividable = requires (A a, B b) { a / b; }; template <typename A, typename B> concept WeakRemaindable = requires (A a, B b) { a / b; a % b; }; template <typename A, typename B, typename C> concept AddableTo = requires (A a, B b) { { a + b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept SubtractableTo = requires (A a, B b) { { a - b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept MultipliableTo = requires (A a, B b) { { a * b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept DividableTo = requires (A a, B b) { { a / b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept RemaindableTo = requires (A a, B b) { { a / b } -> std::convertible_to<C>; { a % b } -> std::convertible_to<C>; }; template <typename A, typename B, typename C> concept BitMaskableTo = requires (A a, B b) { { a & b } -> std::convertible_to<C>; { a | b } -> std::convertible_to<C>; { a ^ b } -> std::convertible_to<C>; { a << b } -> std::convertible_to<C>; { a >> b } -> std::convertible_to<C>; }; template <typename A, typename B> concept Addable = AddableTo<A, B, A>; template <typename A, typename B> concept Subtractable = SubtractableTo<A, B, A>; template <typename A, typename B> concept Multipliable = MultipliableTo<A, B, A>; template <typename A, typename B> concept Dividable = DividableTo<A, B, A>; template <typename A, typename B> concept Remaindable = RemaindableTo<A, B, A>; template <typename A, typename B> concept BitMaskable = BitMaskableTo<A, B, A>; template <typename A, typename B> requires WeakAddable<A, B> inline decltype(auto) Plus(A a, B b) { return a + b; } template <typename A, typename B> requires WeakSubtractable<A, B> inline decltype(auto) Minus(A a, B b) { return a - b; } template <typename A, typename B> requires WeakMultipliable<A, B> inline decltype(auto) Multiplies(A a, B b) { return a * b; } template <typename A, typename B> requires WeakDividable<A, B> inline decltype(auto) Divides(A a, B b) { return a / b; } template <typename A, typename B> requires WeakRemaindable<A, B> inline decltype(auto) Modulus(A a, B b) { return a % b; } template <typename A, typename B> using AddType = std::invoke_result_t<decltype(Plus<A, B>), A, B>; template <typename A, typename B> using SubType = std::invoke_result_t<decltype(Minus<A, B>), A, B>; template <typename A, typename B> using MulType = std::invoke_result_t<decltype(Multiplies<A, B>), A, B>; template <typename A, typename B> using DivType = std::invoke_result_t<decltype(Divides<A, B>), A, B>; template <typename A, typename B> using ModType = std::invoke_result_t<decltype(Modulus<A, B>), A, B>; template <typename A, typename B> concept DotProductable = Addable<MulType<A, B>, MulType<A, B>>; template <typename A, typename B, typename C> concept DotProductableTo = DotProductable<A, B> && MultipliableTo<A, B, C> && Addable<C, C>; template <typename A, typename B, typename C> requires AddableTo<A, B, C> inline void PlusTo(C& c, const A& a, const B& b) { c = a + b; } template <typename A, typename B, typename C> requires SubtractableTo<A, B, C> inline void MinusTo(C& c, const A& a, const B& b) { c = a - b; } template <typename A, typename B, typename C> requires MultipliableTo<A, B, C> inline void MultipliesTo(C& c, const A& a, const B& b) { c = a * b; } template <typename A, typename B, typename C> requires DividableTo<A, B, C> inline void DividesTo(C& c, const A& a, const B& b) { c = a / b; } template <typename A, typename B, typename C> requires RemaindableTo<A, B, C> inline void ModulusTo(C& c, const A& a, const B& b) { c = a % b; } template <typename... Args> inline constexpr bool All(Args... args) { return (... && args); }; template <typename... Args> inline constexpr bool Some(Args... args) { return (... || args); }; template <std::size_t M, std::size_t N> requires (N < M) std::array<std::size_t, M> prependDims(const std::array<std::size_t, N>& arr) { std::array<std::size_t, M> dims; std::ranges::fill(dims, 1u); std::ranges::copy(arr, std::begin(dims) + (M - N)); return dims; } template <std::size_t M, std::size_t N> bool bidirBroadcastable(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, N>& sz2) { if constexpr (M == N) { return (std::ranges::equal(sz1, sz2, [](const auto& d1, const auto& d2) { return (d1 == d2) || (d1 == 1) || (d2 == 1);})); } else if constexpr (M < N) { return bidirBroadcastable(prependDims<N, M>(sz1), sz2); } else { static_assert(M > N); return bidirBroadcastable(sz1, prependDims<M, N>(sz2)); } } template <std::size_t M, std::size_t N> std::array<std::size_t, std::max(M, N)> bidirBroadcastedDims(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, N>& sz2) { if constexpr (M == N) { if (!bidirBroadcastable(sz1, sz2)) { throw std::invalid_argument("Cannot broadcast"); } std::array<std::size_t, M> sz; std::ranges::transform(sz1, sz2, std::begin(sz), [](const auto& d1, const auto& d2) { return std::max(d1, d2); }); return sz; } else if constexpr (M < N) { return bidirBroadcastedDims(prependDims<N, M>(sz1), sz2); } else { static_assert(M > N); return bidirBroadcastedDims(sz1, prependDims<M, N>(sz2)); } } template <std::size_t M> requires (M > 1) std::array<std::size_t, M - 1> dotDims(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, 1>& sz2) { if (sz1[M - 1] != sz2[0]) { throw std::invalid_argument("Cannot do dot product, shape is not aligned"); } std::array<std::size_t, M - 1> sz; std::copy(std::begin(sz1), std::begin(sz1) + (M - 1), std::begin(sz)); return sz; } template <std::size_t M, std::size_t N> requires (N > 1) std::array<std::size_t, M + N - 2> dotDims(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, N>& sz2) { if (sz1[M - 1] != sz2[N - 2]) { throw std::invalid_argument("Cannot do dot product, shape is not aligned"); } std::array<std::size_t, M + N - 2> sz; std::copy(std::begin(sz1), std::begin(sz1) + (M - 1), std::begin(sz)); std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(sz) + (M - 1)); std::copy(std::begin(sz2) + (N - 1), std::end(sz2), std::begin(sz) + (M + N - 3)); return sz; } template <std::size_t M, std::size_t N> std::array<std::size_t, std::max(M, N)> matmulDims(const std::array<std::size_t, M>& sz1, const std::array<std::size_t, N>& sz2) { if constexpr (M == 1) { std::array<std::size_t, 2> sz1_ = {1, sz1[0]}; return matmulDims(sz1_, sz2); } else if constexpr (N == 1) { std::array<std::size_t, 2> sz2_ = {sz2[0], 1}; return matmulDims(sz1, sz2_); } assert(M >= 2 && N >= 2); if (sz1[M - 1] != sz2[N - 2]) { throw std::invalid_argument("Cannot do dot product, shape is not aligned"); } std::array<std::size_t, 2> last_sz = {sz1[M - 2], sz2[N - 1]}; if constexpr (M == 2) { if constexpr (N == 2) { return last_sz; } else { // M = 2, N > 2 std::array<std::size_t, N> res_sz; std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(res_sz)); std::copy(std::begin(last_sz), std::end(last_sz), std::begin(res_sz) + (N - 2)); return res_sz; } } else if constexpr (N == 2) { // M > 2, N = 2 std::array<std::size_t, M> res_sz; std::copy(std::begin(sz1), std::begin(sz2) + (M - 2), std::begin(res_sz)); std::copy(std::begin(last_sz), std::end(last_sz), std::begin(res_sz) + (M - 2)); return res_sz; } else { // M > 2, N > 2 std::array<std::size_t, std::max(M, N)> res_sz; std::array<std::size_t, M - 2> sz1_front; std::array<std::size_t, N - 2> sz2_front; std::copy(std::begin(sz1), std::begin(sz1) + (M - 2), std::begin(sz1_front)); std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(sz2_front)); auto common_sz = bidirBroadcastedDims(sz1_front, sz2_front); std::copy(std::begin(common_sz), std::end(common_sz), std::begin(res_sz)); std::copy(std::begin(last_sz), std::end(last_sz), std::end(res_sz) - 2); return res_sz; } } template <typename... Args> concept IndexType = All(std::is_integral_v<Args>...); template <std::size_t N> std::array<std::size_t, N> computeStrides(const std::array<std::size_t, N>& dims) { std::array<std::size_t, N> strides; std::size_t str = 1; for (std::size_t i = N - 1; i < N; --i) { strides[i] = str; str *= dims[i]; } return strides; } template <std::size_t N, typename Initializer> bool checkNonJagged(const Initializer& init) { auto i = std::cbegin(init); for (auto j = std::next(i); j != std::cend(init); ++j) { if (i->size() != j->size()) { return false; } } return true; } template <std::size_t N, typename Iter, typename Initializer> void addDims(Iter& first, const Initializer& init) { if constexpr (N > 1) { if (!checkNonJagged<N>(init)) { throw std::invalid_argument("Jagged matrix initializer"); } } *first = std::size(init); ++first; if constexpr (N > 1) { addDims<N - 1>(first, *std::begin(init)); } } template <std::size_t N, typename Initializer> std::array<std::size_t, N> deriveDims(const Initializer& init) { std::array<std::size_t, N> dims; auto f = std::begin(dims); addDims<N>(f, init); return dims; } template <std::semiregular T> void addList(std::unique_ptr<T[]>& data, const T* first, const T* last, std::size_t& index) { for (; first != last; ++first) { data[index] = *first; ++index; } } template <std::semiregular T, typename I> void addList(std::unique_ptr<T[]>& data, const std::initializer_list<I>* first, const std::initializer_list<I>* last, std::size_t& index) { for (; first != last; ++first) { addList(data, first->begin(), first->end(), index); } } template <std::semiregular T, typename I> void insertFlat(std::unique_ptr<T[]>& data, std::initializer_list<I> list) { std::size_t index = 0; addList(data, std::begin(list), std::end(list), index); } inline long quot(long a, long b) { return (a / b) - (a % b < 0); } inline long mod(long a, long b) { return (a % b + b) % b; } } // namespace frozenca #endif //FROZENCA_MATRIXUTILS_H MatrixInitializer.h
#ifndef FROZENCA_MATRIXINITIALIZER_H #define FROZENCA_MATRIXINITIALIZER_H #include <cstddef> #include <concepts> #include <initializer_list> namespace frozenca { template <std::semiregular T, std::size_t N> struct MatrixInitializer { using type = std::initializer_list<typename MatrixInitializer<T, N - 1>::type>; }; template <std::semiregular T> struct MatrixInitializer<T, 1> { using type = std::initializer_list<T>; }; template <std::semiregular T> struct MatrixInitializer<T, 0>; } // namespace frozenca #endif //FROZENCA_MATRIXINITIALIZER_H MatrixOps.h
#ifndef FROZENCA_MATRIXOPS_H #define FROZENCA_MATRIXOPS_H #include "MatrixImpl.h" namespace frozenca { // Matrix constructs template <std::semiregular T, std::size_t N> Matrix<T, N> empty(const std::array<std::size_t, N>& arr) { Matrix<T, N> mat (arr); return mat; } template <typename Derived, std::semiregular T, std::size_t N> Matrix<T, N> empty_like(const MatrixBase<Derived, T, N>& base) { Matrix<T, N> mat (base.dims()); return mat; } template <OneExists T> Matrix<T, 2> eye(std::size_t n, std::size_t m) { Matrix<T, 2> mat (n, m); for (std::size_t i = 0; i < std::min(n, m); ++i) { mat(i, i) = T{1}; } return mat; } template <OneExists T> Matrix<T, 2> eye(std::size_t n) { return eye<T>(n, n); } template <OneExists T> Matrix<T, 2> identity(std::size_t n) { return eye<T>(n, n); } template <OneExists T, std::size_t N> Matrix<T, N> ones(const std::array<std::size_t, N>& arr) { Matrix<T, N> mat (arr); std::ranges::fill(mat, T{1}); return mat; } template <typename Derived, OneExists T, std::size_t N> Matrix<T, N> ones_like(const MatrixBase<Derived, T, N>& base) { Matrix<T, N> mat (base.dims()); std::ranges::fill(mat, T{1}); return mat; } template <std::semiregular T, std::size_t N> Matrix<T, N> zeros(const std::array<std::size_t, N>& arr) { Matrix<T, N> mat (arr); std::ranges::fill(mat, T{0}); return mat; } template <typename Derived, std::semiregular T, std::size_t N> Matrix<T, N> zeros_like(const MatrixBase<Derived, T, N>& base) { Matrix<T, N> mat (base.dims()); std::ranges::fill(mat, T{0}); return mat; } template <std::semiregular T, std::size_t N> Matrix<T, N> full(const std::array<std::size_t, N>& arr, const T& fill_value) { Matrix<T, N> mat (arr); std::ranges::fill(mat, fill_value); return mat; } template <typename Derived, std::semiregular T, std::size_t N> Matrix<T, N> full_like(const MatrixBase<Derived, T, N>& base, const T& fill_value) { Matrix<T, N> mat (base.dims()); std::ranges::fill(mat, fill_value); return mat; } // binary matrix operators namespace { template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires AddableTo<U, V, T> && (std::max(N1, N2) == N) void AddTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, PlusTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, AddTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires SubtractableTo<U, V, T> && (std::max(N1, N2) == N) void SubtractTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, MinusTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, SubtractTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires MultipliableTo<U, V, T> && (std::max(N1, N2) == N) void MultiplyTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, MultipliesTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, MultiplyTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires DividableTo<U, V, T> && (std::max(N1, N2) == N) void DivideTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, DividesTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, DivideTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires RemaindableTo<U, V, T> && (std::max(N1, N2) == N) void ModuloTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 1) { m.applyFunctionWithBroadcast(m1, m2, ModulusTo<U, V, T>); } else { m.applyFunctionWithBroadcast(m1, m2, ModuloTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } } // anonymous namespace template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = AddType<U, V>> requires AddableTo<U, V, T> decltype(auto) operator+ (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, AddTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = SubType<U, V>> requires SubtractableTo<U, V, T> decltype(auto) operator- (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, SubtractTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = MulType<U, V>> requires MultipliableTo<U, V, T> decltype(auto) operator* (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, MultiplyTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = DivType<U, V>> requires DividableTo<U, V, T> decltype(auto) operator/ (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, DivideTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = ModType<U, V>> requires RemaindableTo<U, V, T> decltype(auto) operator% (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = bidirBroadcastedDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, ModuloTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } } // namespace frozenca #endif //FROZENCA_MATRIXOPS_H LinalgOps.h (LINEAR ALGEBRA OPERATIONS)
#ifndef FROZENCA_LINALGOPS_H #define FROZENCA_LINALGOPS_H #include "Matrix.h" namespace frozenca { namespace { template <std::semiregular U, std::semiregular V, std::semiregular T> requires DotProductableTo<U, V, T> void DotTo(T& m, const MatrixView<U, 1>& m1, const MatrixView<V, 1>& m2) { m += std::inner_product(std::begin(m1), std::end(m1), std::begin(m2), T{0}); } template <std::semiregular U, std::semiregular V, std::semiregular T> requires DotProductableTo<U, V, T> void DotTo(MatrixView<T, 1>& m, const MatrixView<U, 1>& m1, const MatrixView<V, 2>& m2) { assert(m.dims(0) == m2.dims(1)); std::size_t c = m2.dims(1); for (std::size_t j = 0; j < c; ++j) { auto col2 = m2.col(j); m[j] += std::inner_product(std::begin(m1), std::end(m1), std::begin(col2), T{0}); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N2> requires DotProductableTo<U, V, T> && (N2 > 2) void DotTo(MatrixView<T, N2 - 1>& m, const MatrixView<U, 1>& m1, const MatrixView<V, N2>& m2) { assert(m.dims(0) == m2.dims(0)); std::size_t r = m.dims(0); for (std::size_t i = 0; i < r; ++i) { auto row0 = m.row(i); auto row2 = m2.row(i); DotTo(row0, m1, row2); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2> requires DotProductableTo<U, V, T> && (N1 > 1) void DotTo(MatrixView<T, N1 - 1>& m, const MatrixView<U, N1>& m1, const MatrixView<V, 1>& m2) { assert(m.dims(0) == m1.dims(0)); std::size_t r = m.dims(0); for (std::size_t i = 0; i < r; ++i) { auto row0 = m.row(i); auto row1 = m1.row(i); DotTo(row0, row1, m2); } } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2> requires DotProductableTo<U, V, T> && (N1 > 1) && (N2 > 1) void DotTo(MatrixView<T, N1 + N2 - 2>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { assert(m.dims(0) == m1.dims(0)); std::size_t r = m.dims(0); for (std::size_t i = 0; i < r; ++i) { auto row0 = m.row(i); auto row1 = m1.row(i); DotTo(row0, row1, m2); } } template <typename Derived0, typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2> requires DotProductableTo<U, V, T> void DotTo(MatrixBase<Derived0, T, (N1 + N2 - 2)>& m, const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { MatrixView<T, (N1 + N2 - 2)> m_view (m); MatrixView<U, N1> m1_view (m1); MatrixView<V, N2> m2_view (m2); DotTo(m_view, m1_view, m2_view); } template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N1, std::size_t N2, std::size_t N> requires DotProductableTo<U, V, T> && (std::max(N1, N2) == N) void MatmulTo(MatrixView<T, N>& m, const MatrixView<U, N1>& m1, const MatrixView<V, N2>& m2) { if constexpr (N == 2) { DotTo(m, m1, m2); } else { m.applyFunctionWithBroadcast(m1, m2, MatmulTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); } } } // anonymous namespace template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t M, std::size_t N, std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T> decltype(auto) dot(const MatrixBase<Derived1, U, M>& m1, const MatrixBase<Derived2, V, N>& m2) { auto dims = dotDims(m1.dims(), m2.dims()); Matrix<T, (M + N - 2)> res = zeros<T, (M + N - 2)>(dims); DotTo(res, m1, m2); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T> decltype(auto) dot(const MatrixBase<Derived1, U, 1>& m1, const MatrixBase<Derived2, V, 1>& m2) { auto dims = dotDims(m1.dims(), m2.dims()); T res {0}; DotTo(res, m1, m2); return res; } template <typename Derived1, typename Derived2, std::semiregular U, std::semiregular V, std::size_t N1, std::size_t N2, std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T> decltype(auto) matmul(const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) { constexpr std::size_t N = std::max(N1, N2); auto dims = matmulDims(m1.dims(), m2.dims()); Matrix<T, N> res = zeros<T, N>(dims); res.applyFunctionWithBroadcast(m1, m2, MatmulTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>); return res; } } // namespace frozenca #endif //FROZENCA_LINALGOPS_H Matrix.h
#ifndef FROZENCA_MATRIX_H #define FROZENCA_MATRIX_H #include "MatrixImpl.h" #include "MatrixOps.h" #include "LinalgOps.h" #endif //FROZENCA_MATRIX_H Feel free to comment anything!
What I don't like:
- Too much expose of APIs to users. I hate C++
includes, I'm really waiting C++20 modules! (Current module implementation in MSVC 19.28 is broken, it fails to work in my code) - .reshape(). It should move the buffer so takes rvalue reference, but I wonder something better.
- The helper functions are becoming too template-heavy. Precisely, I don't like functions like
applyFunctionWithBroadcastingare generated differently with every different N > 1 even if what they do is the same. But I need specialization for N = 1 (because submatrix, row, col should returnT&), so I don't know better way.
To see what is going on with "Broadcasting", dot product and matrix multiplication, see below:
https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
https://numpy.org/doc/stable/reference/generated/numpy.dot.html
https://numpy.org/doc/stable/reference/generated/numpy.matmul.html