Sleipnir C++ API
Loading...
Searching...
No Matches
variable_matrix.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <algorithm>
6#include <concepts>
7#include <initializer_list>
8#include <iterator>
9#include <span>
10#include <utility>
11#include <vector>
12
13#include <Eigen/Core>
14#include <gch/small_vector.hpp>
15
16#include "sleipnir/autodiff/slice.hpp"
17#include "sleipnir/autodiff/variable.hpp"
18#include "sleipnir/autodiff/variable_block.hpp"
19#include "sleipnir/util/assert.hpp"
20#include "sleipnir/util/concepts.hpp"
21#include "sleipnir/util/function_ref.hpp"
22#include "sleipnir/util/symbol_exports.hpp"
23
24namespace slp {
25
29class SLEIPNIR_DLLEXPORT VariableMatrix {
30 public:
34 struct empty_t {};
35
39 static constexpr empty_t empty{};
40
44 VariableMatrix() = default;
45
52 explicit VariableMatrix(int rows) : VariableMatrix{rows, 1} {}
53
60 VariableMatrix(int rows, int cols) : m_rows{rows}, m_cols{cols} {
61 m_storage.reserve(rows * cols);
62 for (int index = 0; index < rows * cols; ++index) {
63 m_storage.emplace_back();
64 }
65 }
66
73 VariableMatrix(empty_t, int rows, int cols) : m_rows{rows}, m_cols{cols} {
74 m_storage.reserve(rows * cols);
75 for (int index = 0; index < rows * cols; ++index) {
76 m_storage.emplace_back(nullptr);
77 }
78 }
79
85 VariableMatrix( // NOLINT
86 std::initializer_list<std::initializer_list<Variable>> list) {
87 // Get row and column counts for destination matrix
88 m_rows = list.size();
89 m_cols = 0;
90 if (list.size() > 0) {
91 m_cols = list.begin()->size();
92 }
93
94 // Assert all column counts are the same
95 for ([[maybe_unused]]
96 const auto& row : list) {
97 slp_assert(static_cast<int>(row.size()) == m_cols);
98 }
99
100 m_storage.reserve(rows() * cols());
101 for (const auto& row : list) {
102 std::ranges::copy(row, std::back_inserter(m_storage));
103 }
104 }
105
113 VariableMatrix(const std::vector<std::vector<double>>& list) { // NOLINT
114 // Get row and column counts for destination matrix
115 m_rows = list.size();
116 m_cols = 0;
117 if (list.size() > 0) {
118 m_cols = list.begin()->size();
119 }
120
121 // Assert all column counts are the same
122 for ([[maybe_unused]]
123 const auto& row : list) {
124 slp_assert(static_cast<int>(row.size()) == m_cols);
125 }
126
127 m_storage.reserve(rows() * cols());
128 for (const auto& row : list) {
129 std::ranges::copy(row, std::back_inserter(m_storage));
130 }
131 }
132
140 VariableMatrix(const std::vector<std::vector<Variable>>& list) { // NOLINT
141 // Get row and column counts for destination matrix
142 m_rows = list.size();
143 m_cols = 0;
144 if (list.size() > 0) {
145 m_cols = list.begin()->size();
146 }
147
148 // Assert all column counts are the same
149 for ([[maybe_unused]]
150 const auto& row : list) {
151 slp_assert(static_cast<int>(row.size()) == m_cols);
152 }
153
154 m_storage.reserve(rows() * cols());
155 for (const auto& row : list) {
156 std::ranges::copy(row, std::back_inserter(m_storage));
157 }
158 }
159
165 template <typename Derived>
166 VariableMatrix(const Eigen::MatrixBase<Derived>& values) // NOLINT
167 : m_rows{static_cast<int>(values.rows())},
168 m_cols{static_cast<int>(values.cols())} {
169 m_storage.reserve(values.rows() * values.cols());
170 for (int row = 0; row < values.rows(); ++row) {
171 for (int col = 0; col < values.cols(); ++col) {
172 m_storage.emplace_back(values(row, col));
173 }
174 }
175 }
176
182 template <typename Derived>
183 VariableMatrix(const Eigen::DiagonalBase<Derived>& values) // NOLINT
184 : m_rows{static_cast<int>(values.rows())},
185 m_cols{static_cast<int>(values.cols())} {
186 m_storage.reserve(values.rows() * values.cols());
187 for (int row = 0; row < values.rows(); ++row) {
188 for (int col = 0; col < values.cols(); ++col) {
189 if (row == col) {
190 m_storage.emplace_back(values.diagonal()[row]);
191 } else {
192 m_storage.emplace_back(0.0);
193 }
194 }
195 }
196 }
197
203 VariableMatrix(const Variable& variable) // NOLINT
204 : m_rows{1}, m_cols{1} {
205 m_storage.emplace_back(variable);
206 }
207
213 VariableMatrix(Variable&& variable) : m_rows{1}, m_cols{1} { // NOLINT
214 m_storage.emplace_back(std::move(variable));
215 }
216
223 : m_rows{values.rows()}, m_cols{values.cols()} {
224 m_storage.reserve(rows() * cols());
225 for (int row = 0; row < rows(); ++row) {
226 for (int col = 0; col < cols(); ++col) {
227 m_storage.emplace_back(values[row, col]);
228 }
229 }
230 }
231
238 : m_rows{values.rows()}, m_cols{values.cols()} {
239 m_storage.reserve(rows() * cols());
240 for (int row = 0; row < rows(); ++row) {
241 for (int col = 0; col < cols(); ++col) {
242 m_storage.emplace_back(values[row, col]);
243 }
244 }
245 }
246
252 explicit VariableMatrix(std::span<const Variable> values)
253 : m_rows{static_cast<int>(values.size())}, m_cols{1} {
254 m_storage.reserve(rows() * cols());
255 for (int row = 0; row < rows(); ++row) {
256 for (int col = 0; col < cols(); ++col) {
257 m_storage.emplace_back(values[row * cols() + col]);
258 }
259 }
260 }
261
269 VariableMatrix(std::span<const Variable> values, int rows, int cols)
270 : m_rows{rows}, m_cols{cols} {
271 slp_assert(static_cast<int>(values.size()) == rows * cols);
272 m_storage.reserve(rows * cols);
273 for (int row = 0; row < rows; ++row) {
274 for (int col = 0; col < cols; ++col) {
275 m_storage.emplace_back(values[row * cols + col]);
276 }
277 }
278 }
279
286 template <typename Derived>
287 VariableMatrix& operator=(const Eigen::MatrixBase<Derived>& values) {
288 slp_assert(rows() == values.rows() && cols() == values.cols());
289
290 for (int row = 0; row < values.rows(); ++row) {
291 for (int col = 0; col < values.cols(); ++col) {
292 (*this)[row, col] = values(row, col);
293 }
294 }
295
296 return *this;
297 }
298
308 slp_assert(rows() == 1 && cols() == 1);
309
310 (*this)[0, 0] = value;
311
312 return *this;
313 }
314
320 template <typename Derived>
321 requires std::same_as<typename Derived::Scalar, double>
322 void set_value(const Eigen::MatrixBase<Derived>& values) {
323 slp_assert(rows() == values.rows() && cols() == values.cols());
324
325 for (int row = 0; row < values.rows(); ++row) {
326 for (int col = 0; col < values.cols(); ++col) {
327 (*this)[row, col].set_value(values(row, col));
328 }
329 }
330 }
331
339 Variable& operator[](int row, int col) {
340 slp_assert(row >= 0 && row < rows());
341 slp_assert(col >= 0 && col < cols());
342 return m_storage[row * cols() + col];
343 }
344
352 const Variable& operator[](int row, int col) const {
353 slp_assert(row >= 0 && row < rows());
354 slp_assert(col >= 0 && col < cols());
355 return m_storage[row * cols() + col];
356 }
357
364 Variable& operator[](int index) {
365 slp_assert(index >= 0 && index < rows() * cols());
366 return m_storage[index];
367 }
368
375 const Variable& operator[](int index) const {
376 slp_assert(index >= 0 && index < rows() * cols());
377 return m_storage[index];
378 }
379
389 VariableBlock<VariableMatrix> block(int row_offset, int col_offset,
390 int block_rows, int block_cols) {
391 slp_assert(row_offset >= 0 && row_offset <= rows());
392 slp_assert(col_offset >= 0 && col_offset <= cols());
393 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
394 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
395 return VariableBlock{*this, row_offset, col_offset, block_rows, block_cols};
396 }
397
408 int col_offset,
409 int block_rows,
410 int block_cols) const {
411 slp_assert(row_offset >= 0 && row_offset <= rows());
412 slp_assert(col_offset >= 0 && col_offset <= cols());
413 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
414 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
415 return VariableBlock{*this, row_offset, col_offset, block_rows, block_cols};
416 }
417
426 int row_slice_length = row_slice.adjust(rows());
427 int col_slice_length = col_slice.adjust(cols());
428 return VariableBlock{*this, std::move(row_slice), row_slice_length,
429 std::move(col_slice), col_slice_length};
430 }
431
440 Slice col_slice) const {
441 int row_slice_length = row_slice.adjust(rows());
442 int col_slice_length = col_slice.adjust(cols());
443 return VariableBlock{*this, std::move(row_slice), row_slice_length,
444 std::move(col_slice), col_slice_length};
445 }
446
461 int row_slice_length,
462 Slice col_slice,
463 int col_slice_length) {
464 return VariableBlock{*this, std::move(row_slice), row_slice_length,
465 std::move(col_slice), col_slice_length};
466 }
467
481 Slice row_slice, int row_slice_length, Slice col_slice,
482 int col_slice_length) const {
483 return VariableBlock{*this, std::move(row_slice), row_slice_length,
484 std::move(col_slice), col_slice_length};
485 }
486
494 VariableBlock<VariableMatrix> segment(int offset, int length) {
495 slp_assert(cols() == 1);
496 slp_assert(offset >= 0 && offset < rows());
497 slp_assert(length >= 0 && length <= rows() - offset);
498 return block(offset, 0, length, 1);
499 }
500
509 int length) const {
510 slp_assert(cols() == 1);
511 slp_assert(offset >= 0 && offset < rows());
512 slp_assert(length >= 0 && length <= rows() - offset);
513 return block(offset, 0, length, 1);
514 }
515
523 slp_assert(row >= 0 && row < rows());
524 return block(row, 0, 1, cols());
525 }
526
534 slp_assert(row >= 0 && row < rows());
535 return block(row, 0, 1, cols());
536 }
537
545 slp_assert(col >= 0 && col < cols());
546 return block(0, col, rows(), 1);
547 }
548
556 slp_assert(col >= 0 && col < cols());
557 return block(0, col, rows(), 1);
558 }
559
566 template <MatrixLike LHS, MatrixLike RHS>
568 friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const LHS& lhs,
569 const RHS& rhs) {
570 slp_assert(lhs.cols() == rhs.rows());
571
572 VariableMatrix result(VariableMatrix::empty, lhs.rows(), rhs.cols());
573
574 for (int i = 0; i < lhs.rows(); ++i) {
575 for (int j = 0; j < rhs.cols(); ++j) {
576 Variable sum{0.0};
577 for (int k = 0; k < lhs.cols(); ++k) {
579 sum += lhs[i, k] * rhs[k, j];
580 } else if constexpr (SleipnirMatrixLike<LHS> &&
582 sum += lhs[i, k] * rhs(k, j);
583 } else if constexpr (EigenMatrixLike<LHS> &&
585 sum += lhs(i, k) * rhs[k, j];
586 }
587 }
588 result[i, j] = sum;
589 }
590 }
591
592 return result;
593 }
594
601 friend SLEIPNIR_DLLEXPORT VariableMatrix
602 operator*(const SleipnirMatrixLike auto& lhs, const ScalarLike auto& rhs) {
603 VariableMatrix result{VariableMatrix::empty, lhs.rows(), lhs.cols()};
604
605 for (int row = 0; row < result.rows(); ++row) {
606 for (int col = 0; col < result.cols(); ++col) {
607 result[row, col] = lhs[row, col] * rhs;
608 }
609 }
610
611 return result;
612 }
613
620 friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const MatrixLike auto& lhs,
621 const Variable& rhs) {
622 VariableMatrix result(VariableMatrix::empty, lhs.rows(), lhs.cols());
623
624 for (int row = 0; row < result.rows(); ++row) {
625 for (int col = 0; col < result.cols(); ++col) {
626 if constexpr (SleipnirMatrixLike<decltype(lhs)>) {
627 result[row, col] = lhs[row, col] * rhs;
628 } else {
629 result[row, col] = lhs(row, col) * rhs;
630 }
631 }
632 }
633
634 return result;
635 }
636
643 friend SLEIPNIR_DLLEXPORT VariableMatrix
644 operator*(const ScalarLike auto& lhs, const SleipnirMatrixLike auto& rhs) {
645 VariableMatrix result{VariableMatrix::empty, rhs.rows(), rhs.cols()};
646
647 for (int row = 0; row < result.rows(); ++row) {
648 for (int col = 0; col < result.cols(); ++col) {
649 result[row, col] = rhs[row, col] * lhs;
650 }
651 }
652
653 return result;
654 }
655
662 friend SLEIPNIR_DLLEXPORT VariableMatrix
663 operator*(const Variable& lhs, const MatrixLike auto& rhs) {
664 VariableMatrix result(VariableMatrix::empty, rhs.rows(), rhs.cols());
665
666 for (int row = 0; row < result.rows(); ++row) {
667 for (int col = 0; col < result.cols(); ++col) {
668 if constexpr (SleipnirMatrixLike<decltype(rhs)>) {
669 result[row, col] = rhs[row, col] * lhs;
670 } else {
671 result[row, col] = rhs(row, col) * lhs;
672 }
673 }
674 }
675
676 return result;
677 }
678
686 slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
687
688 for (int i = 0; i < rows(); ++i) {
689 for (int j = 0; j < rhs.cols(); ++j) {
690 Variable sum{0.0};
691 for (int k = 0; k < cols(); ++k) {
692 if constexpr (SleipnirMatrixLike<decltype(rhs)>) {
693 sum += (*this)[i, k] * rhs[k, j];
694 } else {
695 sum += (*this)[i, k] * rhs(k, j);
696 }
697 }
698 (*this)[i, j] = sum;
699 }
700 }
701
702 return *this;
703 }
704
712 for (int row = 0; row < rows(); ++row) {
713 for (int col = 0; col < rhs.cols(); ++col) {
714 (*this)[row, col] *= rhs;
715 }
716 }
717
718 return *this;
719 }
720
728 friend SLEIPNIR_DLLEXPORT VariableMatrix
729 operator/(const MatrixLike auto& lhs, const ScalarLike auto& rhs) {
730 VariableMatrix result(VariableMatrix::empty, lhs.rows(), lhs.cols());
731
732 for (int row = 0; row < result.rows(); ++row) {
733 for (int col = 0; col < result.cols(); ++col) {
734 if constexpr (SleipnirMatrixLike<decltype(lhs)>) {
735 result[row, col] = lhs[row, col] / rhs;
736 } else {
737 result[row, col] = lhs(row, col) / rhs;
738 }
739 }
740 }
741
742 return result;
743 }
744
752 for (int row = 0; row < rows(); ++row) {
753 for (int col = 0; col < cols(); ++col) {
754 (*this)[row, col] /= rhs;
755 }
756 }
757
758 return *this;
759 }
760
768 template <MatrixLike LHS, MatrixLike RHS>
770 friend SLEIPNIR_DLLEXPORT VariableMatrix operator+(const LHS& lhs,
771 const RHS& rhs) {
772 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
773
774 VariableMatrix result(VariableMatrix::empty, lhs.rows(), lhs.cols());
775
776 for (int row = 0; row < result.rows(); ++row) {
777 for (int col = 0; col < result.cols(); ++col) {
779 result[row, col] = lhs[row, col] + rhs[row, col];
780 } else if constexpr (SleipnirMatrixLike<LHS> && EigenMatrixLike<RHS>) {
781 result[row, col] = lhs[row, col] + rhs(row, col);
782 } else if constexpr (EigenMatrixLike<LHS> && SleipnirMatrixLike<RHS>) {
783 result[row, col] = lhs(row, col) + rhs[row, col];
784 }
785 }
786 }
787
788 return result;
789 }
790
798 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
799
800 for (int row = 0; row < rows(); ++row) {
801 for (int col = 0; col < cols(); ++col) {
802 if constexpr (SleipnirMatrixLike<decltype(rhs)>) {
803 (*this)[row, col] += rhs[row, col];
804 } else {
805 (*this)[row, col] += rhs(row, col);
806 }
807 }
808 }
809
810 return *this;
811 }
812
820 slp_assert(rows() == 1 && cols() == 1);
821
822 for (int row = 0; row < rows(); ++row) {
823 for (int col = 0; col < cols(); ++col) {
824 (*this)[row, col] += rhs;
825 }
826 }
827
828 return *this;
829 }
830
838 template <MatrixLike LHS, MatrixLike RHS>
840 friend SLEIPNIR_DLLEXPORT VariableMatrix operator-(const LHS& lhs,
841 const RHS& rhs) {
842 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
843
844 VariableMatrix result(VariableMatrix::empty, lhs.rows(), lhs.cols());
845
846 for (int row = 0; row < result.rows(); ++row) {
847 for (int col = 0; col < result.cols(); ++col) {
849 result[row, col] = lhs[row, col] - rhs[row, col];
850 } else if constexpr (SleipnirMatrixLike<LHS> && EigenMatrixLike<RHS>) {
851 result[row, col] = lhs[row, col] - rhs(row, col);
852 } else if constexpr (EigenMatrixLike<LHS> && SleipnirMatrixLike<RHS>) {
853 result[row, col] = lhs(row, col) - rhs[row, col];
854 }
855 }
856 }
857
858 return result;
859 }
860
868 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
869
870 for (int row = 0; row < rows(); ++row) {
871 for (int col = 0; col < cols(); ++col) {
872 if constexpr (SleipnirMatrixLike<decltype(rhs)>) {
873 (*this)[row, col] -= rhs[row, col];
874 } else {
875 (*this)[row, col] -= rhs(row, col);
876 }
877 }
878 }
879
880 return *this;
881 }
882
890 slp_assert(rows() == 1 && cols() == 1);
891
892 for (int row = 0; row < rows(); ++row) {
893 for (int col = 0; col < cols(); ++col) {
894 (*this)[row, col] -= rhs;
895 }
896 }
897
898 return *this;
899 }
900
906 friend SLEIPNIR_DLLEXPORT VariableMatrix
907 operator-(const SleipnirMatrixLike auto& lhs) {
908 VariableMatrix result{VariableMatrix::empty, lhs.rows(), lhs.cols()};
909
910 for (int row = 0; row < result.rows(); ++row) {
911 for (int col = 0; col < result.cols(); ++col) {
912 result[row, col] = -lhs[row, col];
913 }
914 }
915
916 return result;
917 }
918
922 operator Variable() const { // NOLINT
923 slp_assert(rows() == 1 && cols() == 1);
924 return (*this)[0, 0];
925 }
926
933 VariableMatrix result{VariableMatrix::empty, cols(), rows()};
934
935 for (int row = 0; row < rows(); ++row) {
936 for (int col = 0; col < cols(); ++col) {
937 result[col, row] = (*this)[row, col];
938 }
939 }
940
941 return result;
942 }
943
949 int rows() const { return m_rows; }
950
956 int cols() const { return m_cols; }
957
965 double value(int row, int col) { return (*this)[row, col].value(); }
966
973 double value(int index) { return (*this)[index].value(); }
974
980 Eigen::MatrixXd value() {
981 Eigen::MatrixXd result{rows(), cols()};
982
983 for (int row = 0; row < rows(); ++row) {
984 for (int col = 0; col < cols(); ++col) {
985 result(row, col) = value(row, col);
986 }
987 }
988
989 return result;
990 }
991
999 function_ref<Variable(const Variable& x)> unary_op) const {
1000 VariableMatrix result{VariableMatrix::empty, rows(), cols()};
1001
1002 for (int row = 0; row < rows(); ++row) {
1003 for (int col = 0; col < cols(); ++col) {
1004 result[row, col] = unary_op((*this)[row, col]);
1005 }
1006 }
1007
1008 return result;
1009 }
1010
1011#ifndef DOXYGEN_SHOULD_SKIP_THIS
1012
1013 class iterator {
1014 public:
1015 using iterator_category = std::bidirectional_iterator_tag;
1016 using value_type = Variable;
1017 using difference_type = std::ptrdiff_t;
1018 using pointer = Variable*;
1019 using reference = Variable&;
1020
1021 constexpr iterator() noexcept = default;
1022
1023 explicit constexpr iterator(
1024 gch::small_vector<Variable>::iterator it) noexcept
1025 : m_it{it} {}
1026
1027 constexpr iterator& operator++() noexcept {
1028 ++m_it;
1029 return *this;
1030 }
1031
1032 constexpr iterator operator++(int) noexcept {
1033 iterator retval = *this;
1034 ++(*this);
1035 return retval;
1036 }
1037
1038 constexpr iterator& operator--() noexcept {
1039 --m_it;
1040 return *this;
1041 }
1042
1043 constexpr iterator operator--(int) noexcept {
1044 iterator retval = *this;
1045 --(*this);
1046 return retval;
1047 }
1048
1049 constexpr bool operator==(const iterator&) const noexcept = default;
1050
1051 constexpr reference operator*() const noexcept { return *m_it; }
1052
1053 private:
1054 gch::small_vector<Variable>::iterator m_it;
1055 };
1056
1057 class const_iterator {
1058 public:
1059 using iterator_category = std::bidirectional_iterator_tag;
1060 using value_type = Variable;
1061 using difference_type = std::ptrdiff_t;
1062 using pointer = Variable*;
1063 using const_reference = const Variable&;
1064
1065 constexpr const_iterator() noexcept = default;
1066
1067 explicit constexpr const_iterator(
1068 gch::small_vector<Variable>::const_iterator it) noexcept
1069 : m_it{it} {}
1070
1071 constexpr const_iterator& operator++() noexcept {
1072 ++m_it;
1073 return *this;
1074 }
1075
1076 constexpr const_iterator operator++(int) noexcept {
1077 const_iterator retval = *this;
1078 ++(*this);
1079 return retval;
1080 }
1081
1082 constexpr const_iterator& operator--() noexcept {
1083 --m_it;
1084 return *this;
1085 }
1086
1087 constexpr const_iterator operator--(int) noexcept {
1088 const_iterator retval = *this;
1089 --(*this);
1090 return retval;
1091 }
1092
1093 constexpr bool operator==(const const_iterator&) const noexcept = default;
1094
1095 constexpr const_reference operator*() const noexcept { return *m_it; }
1096
1097 private:
1098 gch::small_vector<Variable>::const_iterator m_it;
1099 };
1100
1101 using reverse_iterator = std::reverse_iterator<iterator>;
1102 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
1103
1104#endif // DOXYGEN_SHOULD_SKIP_THIS
1105
1111 iterator begin() { return iterator{m_storage.begin()}; }
1112
1118 iterator end() { return iterator{m_storage.end()}; }
1119
1125 const_iterator begin() const { return const_iterator{m_storage.begin()}; }
1126
1132 const_iterator end() const { return const_iterator{m_storage.end()}; }
1133
1139 const_iterator cbegin() const { return const_iterator{m_storage.cbegin()}; }
1140
1146 const_iterator cend() const { return const_iterator{m_storage.cend()}; }
1147
1153 reverse_iterator rbegin() { return reverse_iterator{end()}; }
1154
1160 reverse_iterator rend() { return reverse_iterator{begin()}; }
1161
1167 const_reverse_iterator rbegin() const {
1168 return const_reverse_iterator{end()};
1169 }
1170
1176 const_reverse_iterator rend() const {
1177 return const_reverse_iterator{begin()};
1178 }
1179
1185 const_reverse_iterator crbegin() const {
1186 return const_reverse_iterator{cend()};
1187 }
1188
1194 const_reverse_iterator crend() const {
1195 return const_reverse_iterator{cbegin()};
1196 }
1197
1203 size_t size() const { return m_storage.size(); }
1204
1212 static VariableMatrix zero(int rows, int cols) {
1213 VariableMatrix result{VariableMatrix::empty, rows, cols};
1214
1215 for (auto& elem : result) {
1216 elem = 0.0;
1217 }
1218
1219 return result;
1220 }
1221
1229 static VariableMatrix ones(int rows, int cols) {
1230 VariableMatrix result{VariableMatrix::empty, rows, cols};
1231
1232 for (auto& elem : result) {
1233 elem = 1.0;
1234 }
1235
1236 return result;
1237 }
1238
1239 private:
1240 gch::small_vector<Variable> m_storage;
1241 int m_rows = 0;
1242 int m_cols = 0;
1243};
1244
1252SLEIPNIR_DLLEXPORT inline VariableMatrix cwise_reduce(
1253 const VariableMatrix& lhs, const VariableMatrix& rhs,
1254 function_ref<Variable(const Variable& x, const Variable& y)> binary_op) {
1255 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
1256
1257 VariableMatrix result{VariableMatrix::empty, lhs.rows(), lhs.cols()};
1258
1259 for (int row = 0; row < lhs.rows(); ++row) {
1260 for (int col = 0; col < lhs.cols(); ++col) {
1261 result[row, col] = binary_op(lhs[row, col], rhs[row, col]);
1262 }
1263 }
1264
1265 return result;
1266}
1267
1278SLEIPNIR_DLLEXPORT inline VariableMatrix block(
1279 std::initializer_list<std::initializer_list<VariableMatrix>> list) {
1280 // Get row and column counts for destination matrix
1281 int rows = 0;
1282 int cols = -1;
1283 for (const auto& row : list) {
1284 if (row.size() > 0) {
1285 rows += row.begin()->rows();
1286 }
1287
1288 // Get number of columns in this row
1289 int latest_cols = 0;
1290 for (const auto& elem : row) {
1291 // Assert the first and latest row have the same height
1292 slp_assert(row.begin()->rows() == elem.rows());
1293
1294 latest_cols += elem.cols();
1295 }
1296
1297 // If this is the first row, record the column count. Otherwise, assert the
1298 // first and latest column counts are the same.
1299 if (cols == -1) {
1300 cols = latest_cols;
1301 } else {
1302 slp_assert(cols == latest_cols);
1303 }
1304 }
1305
1306 VariableMatrix result{VariableMatrix::empty, rows, cols};
1307
1308 int row_offset = 0;
1309 for (const auto& row : list) {
1310 int col_offset = 0;
1311 for (const auto& elem : row) {
1312 result.block(row_offset, col_offset, elem.rows(), elem.cols()) = elem;
1313 col_offset += elem.cols();
1314 }
1315 if (row.size() > 0) {
1316 row_offset += row.begin()->rows();
1317 }
1318 }
1319
1320 return result;
1321}
1322
1335SLEIPNIR_DLLEXPORT inline VariableMatrix block(
1336 const std::vector<std::vector<VariableMatrix>>& list) {
1337 // Get row and column counts for destination matrix
1338 int rows = 0;
1339 int cols = -1;
1340 for (const auto& row : list) {
1341 if (row.size() > 0) {
1342 rows += row.begin()->rows();
1343 }
1344
1345 // Get number of columns in this row
1346 int latest_cols = 0;
1347 for (const auto& elem : row) {
1348 // Assert the first and latest row have the same height
1349 slp_assert(row.begin()->rows() == elem.rows());
1350
1351 latest_cols += elem.cols();
1352 }
1353
1354 // If this is the first row, record the column count. Otherwise, assert the
1355 // first and latest column counts are the same.
1356 if (cols == -1) {
1357 cols = latest_cols;
1358 } else {
1359 slp_assert(cols == latest_cols);
1360 }
1361 }
1362
1363 VariableMatrix result{VariableMatrix::empty, rows, cols};
1364
1365 int row_offset = 0;
1366 for (const auto& row : list) {
1367 int col_offset = 0;
1368 for (const auto& elem : row) {
1369 result.block(row_offset, col_offset, elem.rows(), elem.cols()) = elem;
1370 col_offset += elem.cols();
1371 }
1372 if (row.size() > 0) {
1373 row_offset += row.begin()->rows();
1374 }
1375 }
1376
1377 return result;
1378}
1379
1387SLEIPNIR_DLLEXPORT VariableMatrix solve(const VariableMatrix& A,
1388 const VariableMatrix& B);
1389
1390} // namespace slp
Definition slice.hpp:31
constexpr int adjust(int length)
Definition slice.hpp:134
Definition variable_block.hpp:24
Definition variable_matrix.hpp:29
VariableMatrix & operator=(ScalarLike auto value)
Definition variable_matrix.hpp:307
Eigen::MatrixXd value()
Definition variable_matrix.hpp:980
const VariableBlock< const VariableMatrix > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const
Definition variable_matrix.hpp:480
const VariableBlock< const VariableMatrix > row(int row) const
Definition variable_matrix.hpp:533
VariableBlock< VariableMatrix > segment(int offset, int length)
Definition variable_matrix.hpp:494
VariableMatrix(const Eigen::DiagonalBase< Derived > &values)
Definition variable_matrix.hpp:183
const_iterator end() const
Definition variable_matrix.hpp:1132
VariableBlock< VariableMatrix > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_matrix.hpp:460
const_reverse_iterator rbegin() const
Definition variable_matrix.hpp:1167
const_iterator cend() const
Definition variable_matrix.hpp:1146
VariableMatrix(const Variable &variable)
Definition variable_matrix.hpp:203
VariableBlock< VariableMatrix > block(int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_matrix.hpp:389
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const ScalarLike auto &lhs, const SleipnirMatrixLike auto &rhs)
Definition variable_matrix.hpp:644
VariableMatrix & operator*=(const MatrixLike auto &rhs)
Definition variable_matrix.hpp:685
size_t size() const
Definition variable_matrix.hpp:1203
VariableMatrix & operator+=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:819
friend SLEIPNIR_DLLEXPORT VariableMatrix operator-(const LHS &lhs, const RHS &rhs)
Definition variable_matrix.hpp:840
const Variable & operator[](int index) const
Definition variable_matrix.hpp:375
VariableMatrix & operator/=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:751
VariableMatrix()=default
const VariableBlock< const VariableMatrix > block(int row_offset, int col_offset, int block_rows, int block_cols) const
Definition variable_matrix.hpp:407
VariableMatrix(const Eigen::MatrixBase< Derived > &values)
Definition variable_matrix.hpp:166
reverse_iterator rbegin()
Definition variable_matrix.hpp:1153
const VariableBlock< const VariableMatrix > operator[](Slice row_slice, Slice col_slice) const
Definition variable_matrix.hpp:439
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const Variable &lhs, const MatrixLike auto &rhs)
Definition variable_matrix.hpp:663
VariableMatrix & operator+=(const MatrixLike auto &rhs)
Definition variable_matrix.hpp:797
VariableMatrix T() const
Definition variable_matrix.hpp:932
VariableMatrix & operator*=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:711
VariableMatrix cwise_transform(function_ref< Variable(const Variable &x)> unary_op) const
Definition variable_matrix.hpp:998
reverse_iterator rend()
Definition variable_matrix.hpp:1160
const Variable & operator[](int row, int col) const
Definition variable_matrix.hpp:352
static VariableMatrix ones(int rows, int cols)
Definition variable_matrix.hpp:1229
const_reverse_iterator rend() const
Definition variable_matrix.hpp:1176
double value(int index)
Definition variable_matrix.hpp:973
VariableMatrix & operator-=(const MatrixLike auto &rhs)
Definition variable_matrix.hpp:867
VariableMatrix(int rows, int cols)
Definition variable_matrix.hpp:60
VariableBlock< VariableMatrix > operator[](Slice row_slice, Slice col_slice)
Definition variable_matrix.hpp:425
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const MatrixLike auto &lhs, const Variable &rhs)
Definition variable_matrix.hpp:620
const_reverse_iterator crend() const
Definition variable_matrix.hpp:1194
VariableMatrix(int rows)
Definition variable_matrix.hpp:52
VariableMatrix(const VariableBlock< const VariableMatrix > &values)
Definition variable_matrix.hpp:237
friend SLEIPNIR_DLLEXPORT VariableMatrix operator/(const MatrixLike auto &lhs, const ScalarLike auto &rhs)
Definition variable_matrix.hpp:729
iterator end()
Definition variable_matrix.hpp:1118
const_reverse_iterator crbegin() const
Definition variable_matrix.hpp:1185
VariableBlock< VariableMatrix > row(int row)
Definition variable_matrix.hpp:522
VariableMatrix & operator-=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:889
const_iterator cbegin() const
Definition variable_matrix.hpp:1139
VariableMatrix(empty_t, int rows, int cols)
Definition variable_matrix.hpp:73
VariableMatrix(std::span< const Variable > values)
Definition variable_matrix.hpp:252
VariableMatrix(std::span< const Variable > values, int rows, int cols)
Definition variable_matrix.hpp:269
const_iterator begin() const
Definition variable_matrix.hpp:1125
Variable & operator[](int row, int col)
Definition variable_matrix.hpp:339
double value(int row, int col)
Definition variable_matrix.hpp:965
iterator begin()
Definition variable_matrix.hpp:1111
static VariableMatrix zero(int rows, int cols)
Definition variable_matrix.hpp:1212
VariableMatrix & operator=(const Eigen::MatrixBase< Derived > &values)
Definition variable_matrix.hpp:287
friend SLEIPNIR_DLLEXPORT VariableMatrix operator-(const SleipnirMatrixLike auto &lhs)
Definition variable_matrix.hpp:907
const VariableBlock< const VariableMatrix > col(int col) const
Definition variable_matrix.hpp:555
int rows() const
Definition variable_matrix.hpp:949
int cols() const
Definition variable_matrix.hpp:956
VariableMatrix(Variable &&variable)
Definition variable_matrix.hpp:213
const VariableBlock< const VariableMatrix > segment(int offset, int length) const
Definition variable_matrix.hpp:508
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const SleipnirMatrixLike auto &lhs, const ScalarLike auto &rhs)
Definition variable_matrix.hpp:602
VariableMatrix(std::initializer_list< std::initializer_list< Variable > > list)
Definition variable_matrix.hpp:85
void set_value(const Eigen::MatrixBase< Derived > &values)
Definition variable_matrix.hpp:322
VariableMatrix(const std::vector< std::vector< double > > &list)
Definition variable_matrix.hpp:113
Variable & operator[](int index)
Definition variable_matrix.hpp:364
VariableMatrix(const std::vector< std::vector< Variable > > &list)
Definition variable_matrix.hpp:140
VariableBlock< VariableMatrix > col(int col)
Definition variable_matrix.hpp:544
static constexpr empty_t empty
Definition variable_matrix.hpp:39
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const LHS &lhs, const RHS &rhs)
Definition variable_matrix.hpp:568
friend SLEIPNIR_DLLEXPORT VariableMatrix operator+(const LHS &lhs, const RHS &rhs)
Definition variable_matrix.hpp:770
VariableMatrix(const VariableBlock< VariableMatrix > &values)
Definition variable_matrix.hpp:222
Definition variable.hpp:40
Definition function_ref.hpp:13
Definition concepts.hpp:26
Definition concepts.hpp:40
Definition concepts.hpp:13
Definition concepts.hpp:30
Definition variable_matrix.hpp:34