Sleipnir C++ API
Loading...
Searching...
No Matches
variable_matrix.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <algorithm>
6#include <concepts>
7#include <initializer_list>
8#include <iterator>
9#include <span>
10#include <utility>
11#include <vector>
12
13#include <Eigen/Core>
14#include <Eigen/QR>
15#include <gch/small_vector.hpp>
16
17#include "sleipnir/autodiff/sleipnir_base.hpp"
18#include "sleipnir/autodiff/slice.hpp"
19#include "sleipnir/autodiff/variable.hpp"
20#include "sleipnir/autodiff/variable_block.hpp"
21#include "sleipnir/util/assert.hpp"
22#include "sleipnir/util/concepts.hpp"
23#include "sleipnir/util/empty.hpp"
24#include "sleipnir/util/function_ref.hpp"
25#include "sleipnir/util/symbol_exports.hpp"
26
27namespace slp {
28
34template <typename Scalar_>
36 public:
40 using Scalar = Scalar_;
41
45 VariableMatrix() = default;
46
53 explicit VariableMatrix(int rows) : VariableMatrix{rows, 1} {}
54
61 VariableMatrix(int rows, int cols) : m_rows{rows}, m_cols{cols} {
62 m_storage.reserve(rows * cols);
63 for (int index = 0; index < rows * cols; ++index) {
64 m_storage.emplace_back();
65 }
66 }
67
75 : m_rows{rows}, m_cols{cols} {
76 m_storage.reserve(rows * cols);
77 for (int index = 0; index < rows * cols; ++index) {
78 m_storage.emplace_back(nullptr);
79 }
80 }
81
88 std::initializer_list<std::initializer_list<Variable<Scalar>>> list) {
89 // Get row and column counts for destination matrix
90 m_rows = list.size();
91 m_cols = 0;
92 if (list.size() > 0) {
93 m_cols = list.begin()->size();
94 }
95
96 // Assert all column counts are the same
97 for ([[maybe_unused]]
98 const auto& row : list) {
99 slp_assert(static_cast<int>(row.size()) == m_cols);
100 }
101
102 m_storage.reserve(rows() * cols());
103 for (const auto& row : list) {
104 std::ranges::copy(row, std::back_inserter(m_storage));
105 }
106 }
107
115 // NOLINTNEXTLINE (google-explicit-constructor)
116 VariableMatrix(const std::vector<std::vector<Scalar>>& list) {
117 // Get row and column counts for destination matrix
118 m_rows = list.size();
119 m_cols = 0;
120 if (list.size() > 0) {
121 m_cols = list.begin()->size();
122 }
123
124 // Assert all column counts are the same
125 for ([[maybe_unused]]
126 const auto& row : list) {
127 slp_assert(static_cast<int>(row.size()) == m_cols);
128 }
129
130 m_storage.reserve(rows() * cols());
131 for (const auto& row : list) {
132 std::ranges::copy(row, std::back_inserter(m_storage));
133 }
134 }
135
143 // NOLINTNEXTLINE (google-explicit-constructor)
144 VariableMatrix(const std::vector<std::vector<Variable<Scalar>>>& list) {
145 // Get row and column counts for destination matrix
146 m_rows = list.size();
147 m_cols = 0;
148 if (list.size() > 0) {
149 m_cols = list.begin()->size();
150 }
151
152 // Assert all column counts are the same
153 for ([[maybe_unused]]
154 const auto& row : list) {
155 slp_assert(static_cast<int>(row.size()) == m_cols);
156 }
157
158 m_storage.reserve(rows() * cols());
159 for (const auto& row : list) {
160 std::ranges::copy(row, std::back_inserter(m_storage));
161 }
162 }
163
169 template <typename Derived>
170 // NOLINTNEXTLINE (google-explicit-constructor)
171 VariableMatrix(const Eigen::MatrixBase<Derived>& values)
172 : m_rows{static_cast<int>(values.rows())},
173 m_cols{static_cast<int>(values.cols())} {
174 m_storage.reserve(values.rows() * values.cols());
175 for (int row = 0; row < values.rows(); ++row) {
176 for (int col = 0; col < values.cols(); ++col) {
177 m_storage.emplace_back(values[row, col]);
178 }
179 }
180 }
181
187 template <typename Derived>
188 // NOLINTNEXTLINE (google-explicit-constructor)
189 VariableMatrix(const Eigen::DiagonalBase<Derived>& values)
190 : m_rows{static_cast<int>(values.rows())},
191 m_cols{static_cast<int>(values.cols())} {
192 m_storage.reserve(values.rows() * values.cols());
193 for (int row = 0; row < values.rows(); ++row) {
194 for (int col = 0; col < values.cols(); ++col) {
195 if (row == col) {
196 m_storage.emplace_back(values.diagonal()[row]);
197 } else {
198 m_storage.emplace_back(Scalar(0));
199 }
200 }
201 }
202 }
203
209 // NOLINTNEXTLINE (google-explicit-constructor)
210 VariableMatrix(const Variable<Scalar>& variable) : m_rows{1}, m_cols{1} {
211 m_storage.emplace_back(variable);
212 }
213
219 // NOLINTNEXTLINE (google-explicit-constructor)
220 VariableMatrix(Variable<Scalar>&& variable) : m_rows{1}, m_cols{1} {
221 m_storage.emplace_back(std::move(variable));
222 }
223
229 // NOLINTNEXTLINE (google-explicit-constructor)
231 : m_rows{values.rows()}, m_cols{values.cols()} {
232 m_storage.reserve(rows() * cols());
233 for (int row = 0; row < rows(); ++row) {
234 for (int col = 0; col < cols(); ++col) {
235 m_storage.emplace_back(values[row, col]);
236 }
237 }
238 }
239
245 // NOLINTNEXTLINE (google-explicit-constructor)
247 : m_rows{values.rows()}, m_cols{values.cols()} {
248 m_storage.reserve(rows() * cols());
249 for (int row = 0; row < rows(); ++row) {
250 for (int col = 0; col < cols(); ++col) {
251 m_storage.emplace_back(values[row, col]);
252 }
253 }
254 }
255
261 explicit VariableMatrix(std::span<const Variable<Scalar>> values)
262 : m_rows{static_cast<int>(values.size())}, m_cols{1} {
263 m_storage.reserve(rows() * cols());
264 for (int row = 0; row < rows(); ++row) {
265 for (int col = 0; col < cols(); ++col) {
266 m_storage.emplace_back(values[row * cols() + col]);
267 }
268 }
269 }
270
278 VariableMatrix(std::span<const Variable<Scalar>> values, int rows, int cols)
279 : m_rows{rows}, m_cols{cols} {
280 slp_assert(static_cast<int>(values.size()) == rows * cols);
281 m_storage.reserve(rows * cols);
282 for (int row = 0; row < rows; ++row) {
283 for (int col = 0; col < cols; ++col) {
284 m_storage.emplace_back(values[row * cols + col]);
285 }
286 }
287 }
288
295 template <typename Derived>
296 VariableMatrix& operator=(const Eigen::MatrixBase<Derived>& values) {
297 slp_assert(rows() == values.rows() && cols() == values.cols());
298
299 for (int row = 0; row < values.rows(); ++row) {
300 for (int col = 0; col < values.cols(); ++col) {
301 (*this)[row, col] = values[row, col];
302 }
303 }
304
305 return *this;
306 }
307
317 slp_assert(rows() == 1 && cols() == 1);
318
319 (*this)[0, 0] = value;
320
321 return *this;
322 }
323
329 template <typename Derived>
330 requires std::same_as<typename Derived::Scalar, Scalar>
331 void set_value(const Eigen::MatrixBase<Derived>& values) {
332 slp_assert(rows() == values.rows() && cols() == values.cols());
333
334 for (int row = 0; row < values.rows(); ++row) {
335 for (int col = 0; col < values.cols(); ++col) {
336 (*this)[row, col].set_value(values[row, col]);
337 }
338 }
339 }
340
349 slp_assert(row >= 0 && row < rows());
350 slp_assert(col >= 0 && col < cols());
351 return m_storage[row * cols() + col];
352 }
353
361 const Variable<Scalar>& operator[](int row, int col) const {
362 slp_assert(row >= 0 && row < rows());
363 slp_assert(col >= 0 && col < cols());
364 return m_storage[row * cols() + col];
365 }
366
374 slp_assert(index >= 0 && index < rows() * cols());
375 return m_storage[index];
376 }
377
384 const Variable<Scalar>& operator[](int index) const {
385 slp_assert(index >= 0 && index < rows() * cols());
386 return m_storage[index];
387 }
388
399 int block_rows, int block_cols) {
400 slp_assert(row_offset >= 0 && row_offset <= rows());
401 slp_assert(col_offset >= 0 && col_offset <= cols());
402 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
403 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
405 }
406
417 int col_offset,
418 int block_rows,
419 int block_cols) const {
420 slp_assert(row_offset >= 0 && row_offset <= rows());
421 slp_assert(col_offset >= 0 && col_offset <= cols());
422 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
423 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
425 }
426
440
449 Slice col_slice) const {
450 int row_slice_length = row_slice.adjust(rows());
451 int col_slice_length = col_slice.adjust(cols());
452 return VariableBlock{*this, std::move(row_slice), row_slice_length,
453 std::move(col_slice), col_slice_length};
454 }
455
476
495
504 slp_assert(cols() == 1);
505 slp_assert(offset >= 0 && offset < rows());
506 slp_assert(length >= 0 && length <= rows() - offset);
507 return block(offset, 0, length, 1);
508 }
509
518 int length) const {
519 slp_assert(cols() == 1);
520 slp_assert(offset >= 0 && offset < rows());
521 slp_assert(length >= 0 && length <= rows() - offset);
522 return block(offset, 0, length, 1);
523 }
524
532 slp_assert(row >= 0 && row < rows());
533 return block(row, 0, 1, cols());
534 }
535
543 slp_assert(row >= 0 && row < rows());
544 return block(row, 0, 1, cols());
545 }
546
554 slp_assert(col >= 0 && col < cols());
555 return block(0, col, rows(), 1);
556 }
557
565 slp_assert(col >= 0 && col < cols());
566 return block(0, col, rows(), 1);
567 }
568
575 template <EigenMatrixLike LHS, SleipnirMatrixLike<Scalar> RHS>
577 slp_assert(lhs.cols() == rhs.rows());
578
579 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), rhs.cols());
580
581 for (int i = 0; i < lhs.rows(); ++i) {
582 for (int j = 0; j < rhs.cols(); ++j) {
583 Variable sum{Scalar(0)};
584 for (int k = 0; k < lhs.cols(); ++k) {
585 sum += lhs(i, k) * rhs[k, j];
586 }
587 result[i, j] = sum;
588 }
589 }
590
591 return result;
592 }
593
600 template <SleipnirMatrixLike<Scalar> LHS, EigenMatrixLike RHS>
602 slp_assert(lhs.cols() == rhs.rows());
603
604 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), rhs.cols());
605
606 for (int i = 0; i < lhs.rows(); ++i) {
607 for (int j = 0; j < rhs.cols(); ++j) {
608 Variable sum{Scalar(0)};
609 for (int k = 0; k < lhs.cols(); ++k) {
610 sum += lhs[i, k] * rhs(k, j);
611 }
612 result[i, j] = sum;
613 }
614 }
615
616 return result;
617 }
618
625 template <SleipnirMatrixLike<Scalar> LHS, SleipnirMatrixLike<Scalar> RHS>
627 slp_assert(lhs.cols() == rhs.rows());
628
629 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), rhs.cols());
630
631 for (int i = 0; i < lhs.rows(); ++i) {
632 for (int j = 0; j < rhs.cols(); ++j) {
633 Variable sum{Scalar(0)};
634 for (int k = 0; k < lhs.cols(); ++k) {
635 sum += lhs[i, k] * rhs[k, j];
636 }
637 result[i, j] = sum;
638 }
639 }
640
641 return result;
642 }
643
650 template <EigenMatrixLike LHS>
652 const Variable<Scalar>& rhs) {
653 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
654
655 for (int row = 0; row < result.rows(); ++row) {
656 for (int col = 0; col < result.cols(); ++col) {
657 result[row, col] = lhs[row, col] * rhs;
658 }
659 }
660
661 return result;
662 }
663
670 template <SleipnirMatrixLike<Scalar> LHS, ScalarLike RHS>
672 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
673
674 for (int row = 0; row < result.rows(); ++row) {
675 for (int col = 0; col < result.cols(); ++col) {
676 result[row, col] = lhs[row, col] * rhs;
677 }
678 }
679
680 return result;
681 }
682
689 template <EigenMatrixLike RHS>
691 const RHS& rhs) {
692 VariableMatrix<Scalar> result(detail::empty, rhs.rows(), rhs.cols());
693
694 for (int row = 0; row < result.rows(); ++row) {
695 for (int col = 0; col < result.cols(); ++col) {
696 result[row, col] = rhs[row, col] * lhs;
697 }
698 }
699
700 return result;
701 }
702
709 template <ScalarLike LHS, SleipnirMatrixLike<Scalar> RHS>
711 VariableMatrix<Scalar> result(detail::empty, rhs.rows(), rhs.cols());
712
713 for (int row = 0; row < result.rows(); ++row) {
714 for (int col = 0; col < result.cols(); ++col) {
715 result[row, col] = rhs[row, col] * lhs;
716 }
717 }
718
719 return result;
720 }
721
729 slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
730
731 for (int i = 0; i < rows(); ++i) {
732 for (int j = 0; j < rhs.cols(); ++j) {
733 Variable sum{Scalar(0)};
734 for (int k = 0; k < cols(); ++k) {
735 sum += (*this)[i, k] * rhs[k, j];
736 }
737 (*this)[i, j] = sum;
738 }
739 }
740
741 return *this;
742 }
743
751 for (int row = 0; row < rows(); ++row) {
752 for (int col = 0; col < rhs.cols(); ++col) {
753 (*this)[row, col] *= rhs;
754 }
755 }
756
757 return *this;
758 }
759
767 template <EigenMatrixLike LHS>
769 const Variable<Scalar>& rhs) {
770 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
771
772 for (int row = 0; row < result.rows(); ++row) {
773 for (int col = 0; col < result.cols(); ++col) {
774 result[row, col] = lhs[row, col] / rhs;
775 }
776 }
777
778 return result;
779 }
780
788 template <SleipnirMatrixLike<Scalar> LHS, ScalarLike RHS>
790 friend VariableMatrix<Scalar> operator/(const LHS& lhs, const RHS& rhs) {
791 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
792
793 for (int row = 0; row < result.rows(); ++row) {
794 for (int col = 0; col < result.cols(); ++col) {
795 result[row, col] = lhs[row, col] / rhs;
796 }
797 }
798
799 return result;
800 }
801
809 template <SleipnirMatrixLike<Scalar> LHS>
811 const Variable<Scalar>& rhs) {
812 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
813
814 for (int row = 0; row < result.rows(); ++row) {
815 for (int col = 0; col < result.cols(); ++col) {
816 result[row, col] = lhs[row, col] / rhs;
817 }
818 }
819
820 return result;
821 }
822
830 for (int row = 0; row < rows(); ++row) {
831 for (int col = 0; col < cols(); ++col) {
832 (*this)[row, col] /= rhs;
833 }
834 }
835
836 return *this;
837 }
838
846 template <EigenMatrixLike LHS, SleipnirMatrixLike<Scalar> RHS>
848 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
849
850 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
851
852 for (int row = 0; row < result.rows(); ++row) {
853 for (int col = 0; col < result.cols(); ++col) {
854 result[row, col] = lhs[row, col] + rhs[row, col];
855 }
856 }
857
858 return result;
859 }
860
868 template <SleipnirMatrixLike<Scalar> LHS, EigenMatrixLike RHS>
870 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
871
872 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
873
874 for (int row = 0; row < result.rows(); ++row) {
875 for (int col = 0; col < result.cols(); ++col) {
876 result[row, col] = lhs[row, col] + rhs[row, col];
877 }
878 }
879
880 return result;
881 }
882
890 template <SleipnirMatrixLike<Scalar> LHS, SleipnirMatrixLike<Scalar> RHS>
892 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
893
894 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
895
896 for (int row = 0; row < result.rows(); ++row) {
897 for (int col = 0; col < result.cols(); ++col) {
898 result[row, col] = lhs[row, col] + rhs[row, col];
899 }
900 }
901
902 return result;
903 }
904
912 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
913
914 for (int row = 0; row < rows(); ++row) {
915 for (int col = 0; col < cols(); ++col) {
916 (*this)[row, col] += rhs[row, col];
917 }
918 }
919
920 return *this;
921 }
922
930 slp_assert(rows() == 1 && cols() == 1);
931
932 for (int row = 0; row < rows(); ++row) {
933 for (int col = 0; col < cols(); ++col) {
934 (*this)[row, col] += rhs;
935 }
936 }
937
938 return *this;
939 }
940
948 template <EigenMatrixLike LHS, SleipnirMatrixLike<Scalar> RHS>
950 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
951
952 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
953
954 for (int row = 0; row < result.rows(); ++row) {
955 for (int col = 0; col < result.cols(); ++col) {
956 result[row, col] = lhs[row, col] - rhs[row, col];
957 }
958 }
959
960 return result;
961 }
962
970 template <SleipnirMatrixLike<Scalar> LHS, EigenMatrixLike RHS>
972 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
973
974 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
975
976 for (int row = 0; row < result.rows(); ++row) {
977 for (int col = 0; col < result.cols(); ++col) {
978 result[row, col] = lhs[row, col] - rhs[row, col];
979 }
980 }
981
982 return result;
983 }
984
992 template <SleipnirMatrixLike<Scalar> LHS, SleipnirMatrixLike<Scalar> RHS>
994 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
995
996 VariableMatrix<Scalar> result(detail::empty, lhs.rows(), lhs.cols());
997
998 for (int row = 0; row < result.rows(); ++row) {
999 for (int col = 0; col < result.cols(); ++col) {
1000 result[row, col] = lhs[row, col] - rhs[row, col];
1001 }
1002 }
1003
1004 return result;
1005 }
1006
1014 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
1015
1016 for (int row = 0; row < rows(); ++row) {
1017 for (int col = 0; col < cols(); ++col) {
1018 (*this)[row, col] -= rhs[row, col];
1019 }
1020 }
1021
1022 return *this;
1023 }
1024
1032 slp_assert(rows() == 1 && cols() == 1);
1033
1034 for (int row = 0; row < rows(); ++row) {
1035 for (int col = 0; col < cols(); ++col) {
1036 (*this)[row, col] -= rhs;
1037 }
1038 }
1039
1040 return *this;
1041 }
1042
1049 const SleipnirMatrixLike<Scalar> auto& lhs) {
1050 VariableMatrix<Scalar> result{detail::empty, lhs.rows(), lhs.cols()};
1051
1052 for (int row = 0; row < result.rows(); ++row) {
1053 for (int col = 0; col < result.cols(); ++col) {
1054 result[row, col] = -lhs[row, col];
1055 }
1056 }
1057
1058 return result;
1059 }
1060
1064 // NOLINTNEXTLINE (google-explicit-constructor)
1065 operator Variable<Scalar>() const {
1066 slp_assert(rows() == 1 && cols() == 1);
1067 return (*this)[0, 0];
1068 }
1069
1076 VariableMatrix<Scalar> result{detail::empty, cols(), rows()};
1077
1078 for (int row = 0; row < rows(); ++row) {
1079 for (int col = 0; col < cols(); ++col) {
1080 result[col, row] = (*this)[row, col];
1081 }
1082 }
1083
1084 return result;
1085 }
1086
1092 int rows() const { return m_rows; }
1093
1099 int cols() const { return m_cols; }
1100
1108 Scalar value(int row, int col) { return (*this)[row, col].value(); }
1109
1116 Scalar value(int index) { return (*this)[index].value(); }
1117
1123 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> value() {
1124 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> result{rows(),
1125 cols()};
1126
1127 for (int row = 0; row < rows(); ++row) {
1128 for (int col = 0; col < cols(); ++col) {
1129 result[row, col] = value(row, col);
1130 }
1131 }
1132
1133 return result;
1134 }
1135
1144 const {
1145 VariableMatrix<Scalar> result{detail::empty, rows(), cols()};
1146
1147 for (int row = 0; row < rows(); ++row) {
1148 for (int col = 0; col < cols(); ++col) {
1149 result[row, col] = unary_op((*this)[row, col]);
1150 }
1151 }
1152
1153 return result;
1154 }
1155
1156#ifndef DOXYGEN_SHOULD_SKIP_THIS
1157
1158 class iterator {
1159 public:
1160 using iterator_category = std::bidirectional_iterator_tag;
1161 using value_type = Variable<Scalar>;
1162 using difference_type = std::ptrdiff_t;
1163 using pointer = Variable<Scalar>*;
1164 using reference = Variable<Scalar>&;
1165
1166 constexpr iterator() noexcept = default;
1167
1170 : m_it{it} {}
1171
1172 constexpr iterator& operator++() noexcept {
1173 ++m_it;
1174 return *this;
1175 }
1176
1177 constexpr iterator operator++(int) noexcept {
1178 iterator retval = *this;
1179 ++(*this);
1180 return retval;
1181 }
1182
1183 constexpr iterator& operator--() noexcept {
1184 --m_it;
1185 return *this;
1186 }
1187
1188 constexpr iterator operator--(int) noexcept {
1189 iterator retval = *this;
1190 --(*this);
1191 return retval;
1192 }
1193
1194 constexpr bool operator==(const iterator&) const noexcept = default;
1195
1196 constexpr reference operator*() const noexcept { return *m_it; }
1197
1198 private:
1199 gch::small_vector<Variable<Scalar>>::iterator m_it;
1200 };
1201
1202 class const_iterator {
1203 public:
1204 using iterator_category = std::bidirectional_iterator_tag;
1205 using value_type = Variable<Scalar>;
1206 using difference_type = std::ptrdiff_t;
1207 using pointer = Variable<Scalar>*;
1208 using const_reference = const Variable<Scalar>&;
1209
1210 constexpr const_iterator() noexcept = default;
1211
1212 explicit constexpr const_iterator(
1213 gch::small_vector<Variable<Scalar>>::const_iterator it) noexcept
1214 : m_it{it} {}
1215
1216 constexpr const_iterator& operator++() noexcept {
1217 ++m_it;
1218 return *this;
1219 }
1220
1221 constexpr const_iterator operator++(int) noexcept {
1222 const_iterator retval = *this;
1223 ++(*this);
1224 return retval;
1225 }
1226
1227 constexpr const_iterator& operator--() noexcept {
1228 --m_it;
1229 return *this;
1230 }
1231
1232 constexpr const_iterator operator--(int) noexcept {
1233 const_iterator retval = *this;
1234 --(*this);
1235 return retval;
1236 }
1237
1238 constexpr bool operator==(const const_iterator&) const noexcept = default;
1239
1240 constexpr const_reference operator*() const noexcept { return *m_it; }
1241
1242 private:
1243 gch::small_vector<Variable<Scalar>>::const_iterator m_it;
1244 };
1245
1246 using reverse_iterator = std::reverse_iterator<iterator>;
1247 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
1248
1249#endif // DOXYGEN_SHOULD_SKIP_THIS
1250
1256 iterator begin() { return iterator{m_storage.begin()}; }
1257
1263 iterator end() { return iterator{m_storage.end()}; }
1264
1270 const_iterator begin() const { return const_iterator{m_storage.begin()}; }
1271
1277 const_iterator end() const { return const_iterator{m_storage.end()}; }
1278
1284 const_iterator cbegin() const { return const_iterator{m_storage.cbegin()}; }
1285
1291 const_iterator cend() const { return const_iterator{m_storage.cend()}; }
1292
1299
1306
1315
1324
1333
1342
1348 size_t size() const { return m_storage.size(); }
1349
1358 VariableMatrix<Scalar> result{detail::empty, rows, cols};
1359
1360 for (auto& elem : result) {
1361 elem = Scalar(0);
1362 }
1363
1364 return result;
1365 }
1366
1375 VariableMatrix<Scalar> result{detail::empty, rows, cols};
1376
1377 for (auto& elem : result) {
1378 elem = Scalar(1);
1379 }
1380
1381 return result;
1382 }
1383
1384 private:
1385 gch::small_vector<Variable<Scalar>> m_storage;
1386 int m_rows = 0;
1387 int m_cols = 0;
1388};
1389
1390template <typename Derived>
1391VariableMatrix(const Eigen::MatrixBase<Derived>&)
1392 -> VariableMatrix<typename Derived::Scalar>;
1393
1394template <typename Derived>
1395VariableMatrix(const Eigen::DiagonalBase<Derived>&)
1396 -> VariableMatrix<typename Derived::Scalar>;
1397
1406template <typename Scalar>
1407VariableMatrix<Scalar> cwise_reduce(
1408 const VariableMatrix<Scalar>& lhs, const VariableMatrix<Scalar>& rhs,
1409 function_ref<Variable<Scalar>(const Variable<Scalar>& x,
1410 const Variable<Scalar>& y)>
1411 binary_op) {
1412 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
1413
1414 VariableMatrix<Scalar> result{detail::empty, lhs.rows(), lhs.cols()};
1415
1416 for (int row = 0; row < lhs.rows(); ++row) {
1417 for (int col = 0; col < lhs.cols(); ++col) {
1418 result[row, col] = binary_op(lhs[row, col], rhs[row, col]);
1419 }
1420 }
1421
1422 return result;
1423}
1424
1436template <typename Scalar>
1437VariableMatrix<Scalar> block(
1438 std::initializer_list<std::initializer_list<VariableMatrix<Scalar>>> list) {
1439 // Get row and column counts for destination matrix
1440 int rows = 0;
1441 int cols = -1;
1442 for (const auto& row : list) {
1443 if (row.size() > 0) {
1444 rows += row.begin()->rows();
1445 }
1446
1447 // Get number of columns in this row
1448 int latest_cols = 0;
1449 for (const auto& elem : row) {
1450 // Assert the first and latest row have the same height
1451 slp_assert(row.begin()->rows() == elem.rows());
1452
1453 latest_cols += elem.cols();
1454 }
1455
1456 // If this is the first row, record the column count. Otherwise, assert the
1457 // first and latest column counts are the same.
1458 if (cols == -1) {
1459 cols = latest_cols;
1460 } else {
1461 slp_assert(cols == latest_cols);
1462 }
1463 }
1464
1465 VariableMatrix<Scalar> result{detail::empty, rows, cols};
1466
1467 int row_offset = 0;
1468 for (const auto& row : list) {
1469 int col_offset = 0;
1470 for (const auto& elem : row) {
1471 result.block(row_offset, col_offset, elem.rows(), elem.cols()) = elem;
1472 col_offset += elem.cols();
1473 }
1474 if (row.size() > 0) {
1475 row_offset += row.begin()->rows();
1476 }
1477 }
1478
1479 return result;
1480}
1481
1495template <typename Scalar>
1496VariableMatrix<Scalar> block(
1497 const std::vector<std::vector<VariableMatrix<Scalar>>>& list) {
1498 // Get row and column counts for destination matrix
1499 int rows = 0;
1500 int cols = -1;
1501 for (const auto& row : list) {
1502 if (row.size() > 0) {
1503 rows += row.begin()->rows();
1504 }
1505
1506 // Get number of columns in this row
1507 int latest_cols = 0;
1508 for (const auto& elem : row) {
1509 // Assert the first and latest row have the same height
1510 slp_assert(row.begin()->rows() == elem.rows());
1511
1512 latest_cols += elem.cols();
1513 }
1514
1515 // If this is the first row, record the column count. Otherwise, assert the
1516 // first and latest column counts are the same.
1517 if (cols == -1) {
1518 cols = latest_cols;
1519 } else {
1520 slp_assert(cols == latest_cols);
1521 }
1522 }
1523
1524 VariableMatrix<Scalar> result{detail::empty, rows, cols};
1525
1526 int row_offset = 0;
1527 for (const auto& row : list) {
1528 int col_offset = 0;
1529 for (const auto& elem : row) {
1530 result.block(row_offset, col_offset, elem.rows(), elem.cols()) = elem;
1531 col_offset += elem.cols();
1532 }
1533 if (row.size() > 0) {
1534 row_offset += row.begin()->rows();
1535 }
1536 }
1537
1538 return result;
1539}
1540
1549template <typename Scalar>
1550VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
1551 const VariableMatrix<Scalar>& B) {
1552 // m x n * n x p = m x p
1553 slp_assert(A.rows() == B.rows());
1554
1555 if (A.rows() == 1 && A.cols() == 1) {
1556 // Compute optimal inverse instead of using Eigen's general solver
1557 return B[0, 0] / A[0, 0];
1558 } else if (A.rows() == 2 && A.cols() == 2) {
1559 // Compute optimal inverse instead of using Eigen's general solver
1560 //
1561 // [a b]⁻¹ ___1___ [ d −b]
1562 // [c d] = ad − bc [−c a]
1563
1564 const auto& a = A[0, 0];
1565 const auto& b = A[0, 1];
1566 const auto& c = A[1, 0];
1567 const auto& d = A[1, 1];
1568
1569 VariableMatrix adj_A{{d, -b}, {-c, a}};
1570 auto det_A = a * d - b * c;
1571 return adj_A / det_A * B;
1572 } else if (A.rows() == 3 && A.cols() == 3) {
1573 // Compute optimal inverse instead of using Eigen's general solver
1574 //
1575 // [a b c]⁻¹
1576 // [d e f]
1577 // [g h i]
1578 // 1 [ei − fh ch − bi bf − ce]
1579 // = ------------------------------------ [fg − di ai − cg cd − af]
1580 // a(ei − fh) + b(fg − di) + c(dh − eg) [dh − eg bg − ah ae − bd]
1581 //
1582 // https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%7D%2C+%7Bd%2C+e%2C+f%7D%2C+%7Bg%2C+h%2C+i%7D%7D
1583
1584 const auto& a = A[0, 0];
1585 const auto& b = A[0, 1];
1586 const auto& c = A[0, 2];
1587 const auto& d = A[1, 0];
1588 const auto& e = A[1, 1];
1589 const auto& f = A[1, 2];
1590 const auto& g = A[2, 0];
1591 const auto& h = A[2, 1];
1592 const auto& i = A[2, 2];
1593
1594 auto ae = a * e;
1595 auto af = a * f;
1596 auto ah = a * h;
1597 auto ai = a * i;
1598 auto bd = b * d;
1599 auto bf = b * f;
1600 auto bg = b * g;
1601 auto bi = b * i;
1602 auto cd = c * d;
1603 auto ce = c * e;
1604 auto cg = c * g;
1605 auto ch = c * h;
1606 auto dh = d * h;
1607 auto di = d * i;
1608 auto eg = e * g;
1609 auto ei = e * i;
1610 auto fg = f * g;
1611 auto fh = f * h;
1612
1613 auto adj_A00 = ei - fh;
1614 auto adj_A10 = fg - di;
1615 auto adj_A20 = dh - eg;
1616
1617 VariableMatrix adj_A{{adj_A00, ch - bi, bf - ce},
1618 {adj_A10, ai - cg, cd - af},
1619 {adj_A20, bg - ah, ae - bd}};
1620 auto det_A = a * adj_A00 + b * adj_A10 + c * adj_A20;
1621 return adj_A / det_A * B;
1622 } else if (A.rows() == 4 && A.cols() == 4) {
1623 // Compute optimal inverse instead of using Eigen's general solver
1624 //
1625 // [a b c d]⁻¹
1626 // [e f g h]
1627 // [i j k l]
1628 // [m n o p]
1629 //
1630 // https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%2C+d%7D%2C+%7Be%2C+f%2C+g%2C+h%7D%2C+%7Bi%2C+j%2C+k%2C+l%7D%2C+%7Bm%2C+n%2C+o%2C+p%7D%7D
1631
1632 const auto& a = A[0, 0];
1633 const auto& b = A[0, 1];
1634 const auto& c = A[0, 2];
1635 const auto& d = A[0, 3];
1636 const auto& e = A[1, 0];
1637 const auto& f = A[1, 1];
1638 const auto& g = A[1, 2];
1639 const auto& h = A[1, 3];
1640 const auto& i = A[2, 0];
1641 const auto& j = A[2, 1];
1642 const auto& k = A[2, 2];
1643 const auto& l = A[2, 3];
1644 const auto& m = A[3, 0];
1645 const auto& n = A[3, 1];
1646 const auto& o = A[3, 2];
1647 const auto& p = A[3, 3];
1648
1649 auto afk = a * f * k;
1650 auto afl = a * f * l;
1651 auto afo = a * f * o;
1652 auto afp = a * f * p;
1653 auto agj = a * g * j;
1654 auto agl = a * g * l;
1655 auto agn = a * g * n;
1656 auto agp = a * g * p;
1657 auto ahj = a * h * j;
1658 auto ahk = a * h * k;
1659 auto ahn = a * h * n;
1660 auto aho = a * h * o;
1661 auto ajo = a * j * o;
1662 auto ajp = a * j * p;
1663 auto akn = a * k * n;
1664 auto akp = a * k * p;
1665 auto aln = a * l * n;
1666 auto alo = a * l * o;
1667 auto bek = b * e * k;
1668 auto bel = b * e * l;
1669 auto beo = b * e * o;
1670 auto bep = b * e * p;
1671 auto bgi = b * g * i;
1672 auto bgl = b * g * l;
1673 auto bgm = b * g * m;
1674 auto bgp = b * g * p;
1675 auto bhi = b * h * i;
1676 auto bhk = b * h * k;
1677 auto bhm = b * h * m;
1678 auto bho = b * h * o;
1679 auto bio = b * i * o;
1680 auto bip = b * i * p;
1681 auto bjp = b * j * p;
1682 auto bkm = b * k * m;
1683 auto bkp = b * k * p;
1684 auto blm = b * l * m;
1685 auto blo = b * l * o;
1686 auto cej = c * e * j;
1687 auto cel = c * e * l;
1688 auto cen = c * e * n;
1689 auto cep = c * e * p;
1690 auto cfi = c * f * i;
1691 auto cfl = c * f * l;
1692 auto cfm = c * f * m;
1693 auto cfp = c * f * p;
1694 auto chi = c * h * i;
1695 auto chj = c * h * j;
1696 auto chm = c * h * m;
1697 auto chn = c * h * n;
1698 auto cin = c * i * n;
1699 auto cip = c * i * p;
1700 auto cjm = c * j * m;
1701 auto cjp = c * j * p;
1702 auto clm = c * l * m;
1703 auto cln = c * l * n;
1704 auto dej = d * e * j;
1705 auto dek = d * e * k;
1706 auto den = d * e * n;
1707 auto deo = d * e * o;
1708 auto dfi = d * f * i;
1709 auto dfk = d * f * k;
1710 auto dfm = d * f * m;
1711 auto dfo = d * f * o;
1712 auto dgi = d * g * i;
1713 auto dgj = d * g * j;
1714 auto dgm = d * g * m;
1715 auto dgn = d * g * n;
1716 auto din = d * i * n;
1717 auto dio = d * i * o;
1718 auto djm = d * j * m;
1719 auto djo = d * j * o;
1720 auto dkm = d * k * m;
1721 auto dkn = d * k * n;
1722 auto ejo = e * j * o;
1723 auto ejp = e * j * p;
1724 auto ekn = e * k * n;
1725 auto ekp = e * k * p;
1726 auto eln = e * l * n;
1727 auto elo = e * l * o;
1728 auto fio = f * i * o;
1729 auto fip = f * i * p;
1730 auto fkm = f * k * m;
1731 auto fkp = f * k * p;
1732 auto flm = f * l * m;
1733 auto flo = f * l * o;
1734 auto gin = g * i * n;
1735 auto gip = g * i * p;
1736 auto gjm = g * j * m;
1737 auto gjp = g * j * p;
1738 auto glm = g * l * m;
1739 auto gln = g * l * n;
1740 auto hin = h * i * n;
1741 auto hio = h * i * o;
1742 auto hjm = h * j * m;
1743 auto hjo = h * j * o;
1744 auto hkm = h * k * m;
1745 auto hkn = h * k * n;
1746
1747 auto adj_A00 = fkp - flo - gjp + gln + hjo - hkn;
1748 auto adj_A01 = -bkp + blo + cjp - cln - djo + dkn;
1749 auto adj_A02 = bgp - bho - cfp + chn + dfo - dgn;
1750 auto adj_A03 = -bgl + bhk + cfl - chj - dfk + dgj;
1751 auto adj_A10 = -ekp + elo + gip - glm - hio + hkm;
1752 auto adj_A11 = akp - alo - cip + clm + dio - dkm;
1753 auto adj_A12 = -agp + aho + cep - chm - deo + dgm;
1754 auto adj_A13 = agl - ahk - cel + chi + dek - dgi;
1755 auto adj_A20 = ejp - eln - fip + flm + hin - hjm;
1756 auto adj_A21 = -ajp + aln + bip - blm - din + djm;
1757 auto adj_A22 = afp - ahn - bep + bhm + den - dfm;
1758 auto adj_A23 = -afl + ahj + bel - bhi - dej + dfi;
1759 auto adj_A30 = -ejo + ekn + fio - fkm - gin + gjm;
1760 // NOLINTNEXTLINE(build/include_what_you_use)
1761 auto adj_A31 = ajo - akn - bio + bkm + cin - cjm;
1762 auto adj_A32 = -afo + agn + beo - bgm - cen + cfm;
1763 auto adj_A33 = afk - agj - bek + bgi + cej - cfi;
1764
1765 VariableMatrix adj_A{{adj_A00, adj_A01, adj_A02, adj_A03},
1766 {adj_A10, adj_A11, adj_A12, adj_A13},
1767 {adj_A20, adj_A21, adj_A22, adj_A23},
1768 {adj_A30, adj_A31, adj_A32, adj_A33}};
1769 auto det_A = a * adj_A00 + b * adj_A10 + c * adj_A20 + d * adj_A30;
1770 return adj_A / det_A * B;
1771 } else {
1772 using MatrixXv =
1773 Eigen::Matrix<Variable<Scalar>, Eigen::Dynamic, Eigen::Dynamic>;
1774
1775 MatrixXv eigen_A{A.rows(), A.cols()};
1776 for (int row = 0; row < A.rows(); ++row) {
1777 for (int col = 0; col < A.cols(); ++col) {
1778 eigen_A[row, col] = A[row, col];
1779 }
1780 }
1781
1782 MatrixXv eigen_B{B.rows(), B.cols()};
1783 for (int row = 0; row < B.rows(); ++row) {
1784 for (int col = 0; col < B.cols(); ++col) {
1785 eigen_B[row, col] = B[row, col];
1786 }
1787 }
1788
1789 MatrixXv eigen_X = eigen_A.householderQr().solve(eigen_B);
1790
1791 VariableMatrix<Scalar> X{detail::empty, A.cols(), B.cols()};
1792 for (int row = 0; row < X.rows(); ++row) {
1793 for (int col = 0; col < X.cols(); ++col) {
1794 X[row, col] = eigen_X[row, col];
1795 }
1796 }
1797
1798 return X;
1799 }
1800}
1801
1802extern template SLEIPNIR_DLLEXPORT VariableMatrix<double> solve(
1803 const VariableMatrix<double>& A, const VariableMatrix<double>& B);
1804
1805} // namespace slp
Definition intrusive_shared_ptr.hpp:29
Definition sleipnir_base.hpp:11
Definition slice.hpp:31
Definition variable_block.hpp:26
Definition variable_matrix.hpp:35
const_reverse_iterator crbegin() const
Definition variable_matrix.hpp:1330
VariableMatrix(std::initializer_list< std::initializer_list< Variable< Scalar > > > list)
Definition variable_matrix.hpp:87
const_reverse_iterator crend() const
Definition variable_matrix.hpp:1339
Scalar_ Scalar
Definition variable_matrix.hpp:40
iterator end()
Definition variable_matrix.hpp:1263
const Variable< Scalar > & operator[](int row, int col) const
Definition variable_matrix.hpp:361
VariableMatrix(Variable< Scalar > &&variable)
Definition variable_matrix.hpp:220
const VariableBlock< const VariableMatrix > block(int row_offset, int col_offset, int block_rows, int block_cols) const
Definition variable_matrix.hpp:416
size_t size() const
Definition variable_matrix.hpp:1348
static VariableMatrix< Scalar > zero(int rows, int cols)
Definition variable_matrix.hpp:1357
VariableBlock< VariableMatrix > block(int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_matrix.hpp:398
VariableMatrix(const std::vector< std::vector< Scalar > > &list)
Definition variable_matrix.hpp:116
VariableMatrix & operator=(ScalarLike auto value)
Definition variable_matrix.hpp:316
VariableMatrix & operator-=(const MatrixLike auto &rhs)
Definition variable_matrix.hpp:1013
Variable< Scalar > & operator[](int index)
Definition variable_matrix.hpp:373
VariableMatrix & operator=(const Eigen::MatrixBase< Derived > &values)
Definition variable_matrix.hpp:296
friend VariableMatrix< Scalar > operator-(const LHS &lhs, const RHS &rhs)
Definition variable_matrix.hpp:949
VariableMatrix & operator-=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:1031
friend VariableMatrix< Scalar > operator*(const Variable< Scalar > &lhs, const RHS &rhs)
Definition variable_matrix.hpp:690
VariableMatrix & operator*=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:750
Scalar value(int index)
Definition variable_matrix.hpp:1116
VariableMatrix(const VariableBlock< VariableMatrix > &values)
Definition variable_matrix.hpp:230
VariableMatrix(int rows, int cols)
Definition variable_matrix.hpp:61
const_iterator cbegin() const
Definition variable_matrix.hpp:1284
const_iterator begin() const
Definition variable_matrix.hpp:1270
const_iterator cend() const
Definition variable_matrix.hpp:1291
Eigen::Matrix< Scalar, Eigen::Dynamic, Eigen::Dynamic > value()
Definition variable_matrix.hpp:1123
const VariableBlock< const VariableMatrix > operator[](Slice row_slice, Slice col_slice) const
Definition variable_matrix.hpp:448
friend VariableMatrix< Scalar > operator/(const LHS &lhs, const Variable< Scalar > &rhs)
Definition variable_matrix.hpp:768
VariableMatrix(int rows)
Definition variable_matrix.hpp:53
VariableMatrix(std::span< const Variable< Scalar > > values, int rows, int cols)
Definition variable_matrix.hpp:278
VariableBlock< VariableMatrix > col(int col)
Definition variable_matrix.hpp:553
VariableMatrix(const Eigen::MatrixBase< Derived > &values)
Definition variable_matrix.hpp:171
const VariableBlock< const VariableMatrix > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const
Definition variable_matrix.hpp:489
const VariableBlock< const VariableMatrix > col(int col) const
Definition variable_matrix.hpp:564
VariableMatrix & operator+=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:929
friend VariableMatrix< Scalar > operator*(const LHS &lhs, const Variable< Scalar > &rhs)
Definition variable_matrix.hpp:651
VariableBlock< VariableMatrix > segment(int offset, int length)
Definition variable_matrix.hpp:503
Scalar value(int row, int col)
Definition variable_matrix.hpp:1108
const VariableBlock< const VariableMatrix > row(int row) const
Definition variable_matrix.hpp:542
VariableMatrix()=default
static VariableMatrix< Scalar > ones(int rows, int cols)
Definition variable_matrix.hpp:1374
VariableMatrix & operator+=(const MatrixLike auto &rhs)
Definition variable_matrix.hpp:911
VariableMatrix(std::span< const Variable< Scalar > > values)
Definition variable_matrix.hpp:261
friend VariableMatrix< Scalar > operator-(const SleipnirMatrixLike< Scalar > auto &lhs)
Definition variable_matrix.hpp:1048
VariableMatrix< Scalar > T() const
Definition variable_matrix.hpp:1075
VariableMatrix< Scalar > cwise_transform(function_ref< Variable< Scalar >(const Variable< Scalar > &x)> unary_op) const
Definition variable_matrix.hpp:1142
friend VariableMatrix< Scalar > operator*(const LHS &lhs, const RHS &rhs)
Definition variable_matrix.hpp:576
const_reverse_iterator rbegin() const
Definition variable_matrix.hpp:1312
VariableMatrix(const Variable< Scalar > &variable)
Definition variable_matrix.hpp:210
reverse_iterator rend()
Definition variable_matrix.hpp:1305
VariableMatrix(const VariableBlock< const VariableMatrix > &values)
Definition variable_matrix.hpp:246
int rows() const
Definition variable_matrix.hpp:1092
VariableMatrix(const Eigen::DiagonalBase< Derived > &values)
Definition variable_matrix.hpp:189
VariableBlock< VariableMatrix > operator[](Slice row_slice, Slice col_slice)
Definition variable_matrix.hpp:434
const_reverse_iterator rend() const
Definition variable_matrix.hpp:1321
VariableMatrix & operator*=(const MatrixLike auto &rhs)
Definition variable_matrix.hpp:728
VariableMatrix(detail::empty_t, int rows, int cols)
Definition variable_matrix.hpp:74
const_iterator end() const
Definition variable_matrix.hpp:1277
void set_value(const Eigen::MatrixBase< Derived > &values)
Definition variable_matrix.hpp:331
Variable< Scalar > & operator[](int row, int col)
Definition variable_matrix.hpp:348
iterator begin()
Definition variable_matrix.hpp:1256
const VariableBlock< const VariableMatrix > segment(int offset, int length) const
Definition variable_matrix.hpp:517
reverse_iterator rbegin()
Definition variable_matrix.hpp:1298
VariableBlock< VariableMatrix > row(int row)
Definition variable_matrix.hpp:531
friend VariableMatrix< Scalar > operator+(const LHS &lhs, const RHS &rhs)
Definition variable_matrix.hpp:847
VariableMatrix(const std::vector< std::vector< Variable< Scalar > > > &list)
Definition variable_matrix.hpp:144
VariableBlock< VariableMatrix > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_matrix.hpp:469
VariableMatrix & operator/=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:829
const Variable< Scalar > & operator[](int index) const
Definition variable_matrix.hpp:384
int cols() const
Definition variable_matrix.hpp:1099
Definition variable.hpp:49
Definition function_ref.hpp:13
Definition concepts.hpp:18
Definition concepts.hpp:24
Definition concepts.hpp:33
Definition concepts.hpp:38
Definition empty.hpp:10