7#include <initializer_list>
15#include "sleipnir/autodiff/slice.hpp"
16#include "sleipnir/autodiff/variable.hpp"
17#include "sleipnir/autodiff/variable_block.hpp"
18#include "sleipnir/util/assert.hpp"
19#include "sleipnir/util/concepts.hpp"
20#include "sleipnir/util/function_ref.hpp"
21#include "sleipnir/util/small_vector.hpp"
22#include "sleipnir/util/symbol_exports.hpp"
52 m_storage.reserve(rows);
53 for (
int row = 0; row < rows; ++row) {
54 m_storage.emplace_back();
65 m_storage.reserve(rows * cols);
66 for (
int index = 0; index < rows * cols; ++index) {
67 m_storage.emplace_back();
78 m_storage.reserve(rows * cols);
79 for (
int index = 0; index < rows * cols; ++index) {
80 m_storage.emplace_back(
nullptr);
90 std::initializer_list<std::initializer_list<Variable>> list) {
94 if (list.size() > 0) {
95 m_cols = list.begin()->size();
100 const auto& row : list) {
101 slp_assert(list.begin()->size() == row.size());
104 m_storage.reserve(rows() * cols());
105 for (
const auto& row : list) {
106 std::ranges::copy(row, std::back_inserter(m_storage));
119 m_rows = list.size();
121 if (list.size() > 0) {
122 m_cols = list.begin()->size();
126 for ([[maybe_unused]]
127 const auto& row : list) {
128 slp_assert(list.begin()->size() == row.size());
131 m_storage.reserve(rows() * cols());
132 for (
const auto& row : list) {
133 std::ranges::copy(row, std::back_inserter(m_storage));
146 m_rows = list.size();
148 if (list.size() > 0) {
149 m_cols = list.begin()->size();
153 for ([[maybe_unused]]
154 const auto& row : list) {
155 slp_assert(list.begin()->size() == row.size());
158 m_storage.reserve(rows() * cols());
159 for (
const auto& row : list) {
160 std::ranges::copy(row, std::back_inserter(m_storage));
169 template <
typename Derived>
171 : m_rows{static_cast<int>(values.rows())},
172 m_cols{static_cast<int>(values.cols())} {
173 m_storage.reserve(values.rows() * values.cols());
174 for (
int row = 0; row < values.rows(); ++row) {
175 for (
int col = 0; col < values.cols(); ++col) {
176 m_storage.emplace_back(values(row, col));
186 template <
typename Derived>
188 : m_rows{static_cast<int>(values.rows())},
189 m_cols{static_cast<int>(values.cols())} {
190 m_storage.reserve(values.rows() * values.cols());
191 for (
int row = 0; row < values.rows(); ++row) {
192 for (
int col = 0; col < values.cols(); ++col) {
194 m_storage.emplace_back(values.diagonal()[row]);
196 m_storage.emplace_back(0.0);
208 template <
typename Derived>
210 slp_assert(rows() == values.rows() && cols() == values.cols());
212 for (
int row = 0; row < values.rows(); ++row) {
213 for (
int col = 0; col < values.cols(); ++col) {
214 (*this)(row, col) = values(row, col);
230 slp_assert(rows() == 1 && cols() == 1);
232 (*this)(0, 0) = value;
242 template <
typename Derived>
243 requires std::same_as<typename Derived::Scalar, double>
244 void set_value(
const Eigen::MatrixBase<Derived>& values) {
245 slp_assert(rows() == values.rows() && cols() == values.cols());
247 for (
int row = 0; row < values.rows(); ++row) {
248 for (
int col = 0; col < values.cols(); ++col) {
249 (*this)(row, col).set_value(values(row, col));
260 : m_rows{1}, m_cols{1} {
261 m_storage.emplace_back(variable);
270 m_storage.emplace_back(std::move(variable));
279 : m_rows{values.rows()}, m_cols{values.cols()} {
280 m_storage.reserve(rows() * cols());
281 for (
int row = 0; row < rows(); ++row) {
282 for (
int col = 0; col < cols(); ++col) {
283 m_storage.emplace_back(values(row, col));
294 : m_rows{values.rows()}, m_cols{values.cols()} {
295 m_storage.reserve(rows() * cols());
296 for (
int row = 0; row < rows(); ++row) {
297 for (
int col = 0; col < cols(); ++col) {
298 m_storage.emplace_back(values(row, col));
309 : m_rows{static_cast<int>(values.size())}, m_cols{1} {
310 m_storage.reserve(rows() * cols());
311 for (
int row = 0; row < rows(); ++row) {
312 for (
int col = 0; col < cols(); ++col) {
313 m_storage.emplace_back(values[row * cols() + col]);
326 : m_rows{rows}, m_cols{cols} {
327 slp_assert(
static_cast<int>(values.size()) == rows * cols);
328 m_storage.reserve(rows * cols);
329 for (
int row = 0; row < rows; ++row) {
330 for (
int col = 0; col < cols; ++col) {
331 m_storage.emplace_back(values[row * cols + col]);
344 slp_assert(row >= 0 && row < rows());
345 slp_assert(col >= 0 && col < cols());
346 return m_storage[row * cols() + col];
357 slp_assert(row >= 0 && row < rows());
358 slp_assert(col >= 0 && col < cols());
359 return m_storage[row * cols() + col];
369 slp_assert(row >= 0 && row < rows() * cols());
370 return m_storage[row];
380 slp_assert(row >= 0 && row < rows() * cols());
381 return m_storage[row];
394 int block_rows,
int block_cols) {
395 slp_assert(row_offset >= 0 && row_offset <= rows());
396 slp_assert(col_offset >= 0 && col_offset <= cols());
397 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
398 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
399 return VariableBlock{*
this, row_offset, col_offset, block_rows, block_cols};
414 int block_cols)
const {
415 slp_assert(row_offset >= 0 && row_offset <= rows());
416 slp_assert(col_offset >= 0 && col_offset <= cols());
417 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
418 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
419 return VariableBlock{*
this, row_offset, col_offset, block_rows, block_cols};
430 int row_slice_length = row_slice.
adjust(rows());
431 int col_slice_length = col_slice.
adjust(cols());
432 return VariableBlock{*
this, std::move(row_slice), row_slice_length,
433 std::move(col_slice), col_slice_length};
444 Slice col_slice)
const {
445 int row_slice_length = row_slice.
adjust(rows());
446 int col_slice_length = col_slice.
adjust(cols());
447 return VariableBlock{*
this, std::move(row_slice), row_slice_length,
448 std::move(col_slice), col_slice_length};
465 int row_slice_length,
467 int col_slice_length) {
468 return VariableBlock{*
this, std::move(row_slice), row_slice_length,
469 std::move(col_slice), col_slice_length};
485 Slice row_slice,
int row_slice_length,
Slice col_slice,
486 int col_slice_length)
const {
487 return VariableBlock{*
this, std::move(row_slice), row_slice_length,
488 std::move(col_slice), col_slice_length};
499 slp_assert(offset >= 0 && offset < rows() * cols());
500 slp_assert(length >= 0 && length <= rows() * cols() - offset);
501 return block(offset, 0, length, 1);
513 slp_assert(offset >= 0 && offset < rows() * cols());
514 slp_assert(length >= 0 && length <= rows() * cols() - offset);
515 return block(offset, 0, length, 1);
525 slp_assert(row >= 0 && row < rows());
526 return block(row, 0, 1, cols());
536 slp_assert(row >= 0 && row < rows());
537 return block(row, 0, 1, cols());
547 slp_assert(col >= 0 && col < cols());
548 return block(0, col, rows(), 1);
558 slp_assert(col >= 0 && col < cols());
559 return block(0, col, rows(), 1);
568 template <MatrixLike LHS, MatrixLike RHS>
572 slp_assert(lhs.cols() == rhs.rows());
574 VariableMatrix result(VariableMatrix::empty, lhs.rows(), rhs.cols());
576 for (
int i = 0; i < lhs.rows(); ++i) {
577 for (
int j = 0; j < rhs.cols(); ++j) {
579 for (
int k = 0; k < lhs.cols(); ++k) {
580 sum += lhs(i, k) * rhs(k, j);
599 for (
int row = 0; row < result.rows(); ++row) {
600 for (
int col = 0; col < result.cols(); ++col) {
601 result(row, col) = lhs(row, col) * rhs;
616 VariableMatrix result(VariableMatrix::empty, lhs.rows(), lhs.cols());
618 for (
int row = 0; row < result.
rows(); ++row) {
619 for (
int col = 0; col < result.
cols(); ++col) {
620 result(row, col) = lhs(row, col) * rhs;
637 for (
int row = 0; row < result.rows(); ++row) {
638 for (
int col = 0; col < result.cols(); ++col) {
639 result(row, col) = rhs(row, col) * lhs;
654 VariableMatrix result(VariableMatrix::empty, rhs.rows(), rhs.cols());
656 for (
int row = 0; row < result.
rows(); ++row) {
657 for (
int col = 0; col < result.
cols(); ++col) {
658 result(row, col) = rhs(row, col) * lhs;
672 slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
674 for (
int i = 0; i < rows(); ++i) {
675 for (
int j = 0; j < rhs.cols(); ++j) {
677 for (
int k = 0; k < cols(); ++k) {
678 sum += (*this)(i, k) * rhs(k, j);
694 for (
int row = 0; row < rows(); ++row) {
695 for (
int col = 0; col < rhs.cols(); ++col) {
696 (*this)(row, col) *= rhs;
712 VariableMatrix result(VariableMatrix::empty, lhs.rows(), lhs.cols());
714 for (
int row = 0; row < result.
rows(); ++row) {
715 for (
int col = 0; col < result.
cols(); ++col) {
716 result(row, col) = lhs(row, col) / rhs;
730 for (
int row = 0; row < rows(); ++row) {
731 for (
int col = 0; col < cols(); ++col) {
732 (*this)(row, col) /= rhs;
746 template <MatrixLike LHS, MatrixLike RHS>
750 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
752 VariableMatrix result(VariableMatrix::empty, lhs.rows(), lhs.cols());
754 for (
int row = 0; row < result.
rows(); ++row) {
755 for (
int col = 0; col < result.
cols(); ++col) {
756 result(row, col) = lhs(row, col) + rhs(row, col);
770 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
772 for (
int row = 0; row < rows(); ++row) {
773 for (
int col = 0; col < cols(); ++col) {
774 (*this)(row, col) += rhs(row, col);
788 slp_assert(rows() == 1 && cols() == 1);
790 for (
int row = 0; row < rows(); ++row) {
791 for (
int col = 0; col < cols(); ++col) {
792 (*this)(row, col) += rhs;
806 template <MatrixLike LHS, MatrixLike RHS>
810 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
812 VariableMatrix result(VariableMatrix::empty, lhs.rows(), lhs.cols());
814 for (
int row = 0; row < result.
rows(); ++row) {
815 for (
int col = 0; col < result.
cols(); ++col) {
816 result(row, col) = lhs(row, col) - rhs(row, col);
830 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
832 for (
int row = 0; row < rows(); ++row) {
833 for (
int col = 0; col < cols(); ++col) {
834 (*this)(row, col) -= rhs(row, col);
848 slp_assert(rows() == 1 && cols() == 1);
850 for (
int row = 0; row < rows(); ++row) {
851 for (
int col = 0; col < cols(); ++col) {
852 (*this)(row, col) -= rhs;
868 for (
int row = 0; row < result.rows(); ++row) {
869 for (
int col = 0; col < result.cols(); ++col) {
870 result(row, col) = -lhs(row, col);
881 slp_assert(rows() == 1 && cols() == 1);
882 return (*
this)(0, 0);
893 for (
int row = 0; row < rows(); ++row) {
894 for (
int col = 0; col < cols(); ++col) {
895 result(col, row) = (*this)(row, col);
907 int rows()
const {
return m_rows; }
914 int cols()
const {
return m_cols; }
924 slp_assert(row >= 0 && row < rows());
925 slp_assert(col >= 0 && col < cols());
926 return m_storage[row * cols() + col].value();
936 slp_assert(index >= 0 && index < rows() * cols());
937 return m_storage[index].value();
946 Eigen::MatrixXd result{rows(), cols()};
948 for (
int row = 0; row < rows(); ++row) {
949 for (
int col = 0; col < cols(); ++col) {
950 result(row, col) = value(row, col);
967 for (
int row = 0; row < rows(); ++row) {
968 for (
int col = 0; col < cols(); ++col) {
969 result(row, col) = unary_op((*
this)(row, col));
976#ifndef DOXYGEN_SHOULD_SKIP_THIS
980 using iterator_category = std::forward_iterator_tag;
982 using difference_type = std::ptrdiff_t;
986 constexpr iterator() noexcept = default;
988 explicit constexpr iterator(small_vector<
Variable>::iterator it) noexcept
991 constexpr iterator& operator++() noexcept {
996 constexpr iterator operator++(
int)
noexcept {
997 iterator retval = *
this;
1002 constexpr bool operator==(
const iterator&)
const noexcept =
default;
1004 constexpr reference operator*() const noexcept {
return *m_it; }
1007 small_vector<Variable>::iterator m_it;
1010 class const_iterator {
1012 using iterator_category = std::forward_iterator_tag;
1013 using value_type = Variable;
1014 using difference_type = std::ptrdiff_t;
1015 using pointer = Variable*;
1016 using const_reference =
const Variable&;
1018 constexpr const_iterator() noexcept = default;
1020 explicit constexpr const_iterator(
1021 small_vector<Variable>::const_iterator it) noexcept
1024 constexpr const_iterator& operator++() noexcept {
1029 constexpr const_iterator operator++(
int)
noexcept {
1030 const_iterator retval = *
this;
1035 constexpr bool operator==(
const const_iterator&)
const noexcept =
default;
1037 constexpr const_reference operator*() const noexcept {
return *m_it; }
1040 small_vector<Variable>::const_iterator m_it;
1050 iterator
begin() {
return iterator{m_storage.begin()}; }
1057 iterator
end() {
return iterator{m_storage.end()}; }
1064 const_iterator
begin()
const {
return const_iterator{m_storage.begin()}; }
1071 const_iterator
end()
const {
return const_iterator{m_storage.end()}; }
1078 const_iterator
cbegin()
const {
return const_iterator{m_storage.cbegin()}; }
1085 const_iterator
cend()
const {
return const_iterator{m_storage.cend()}; }
1092 size_t size()
const {
return m_storage.size(); }
1104 for (
auto& elem : result) {
1121 for (
auto& elem : result) {
1129 small_vector<Variable> m_storage;
1141SLEIPNIR_DLLEXPORT
inline VariableMatrix cwise_reduce(
1142 const VariableMatrix& lhs,
const VariableMatrix& rhs,
1143 function_ref<Variable(
const Variable& x,
const Variable& y)> binary_op) {
1144 slp_assert(lhs.rows() == rhs.rows() && lhs.rows() == rhs.rows());
1148 for (
int row = 0; row < lhs.rows(); ++row) {
1149 for (
int col = 0; col < lhs.cols(); ++col) {
1150 result(row, col) = binary_op(lhs(row, col), rhs(row, col));
1167SLEIPNIR_DLLEXPORT
inline VariableMatrix block(
1168 std::initializer_list<std::initializer_list<VariableMatrix>> list) {
1172 for (
const auto& row : list) {
1173 if (row.size() > 0) {
1174 rows += row.
begin()->rows();
1178 int latest_cols = 0;
1179 for (
const auto& elem : row) {
1181 slp_assert(row.begin()->rows() == elem.rows());
1183 latest_cols += elem.cols();
1191 slp_assert(cols == latest_cols);
1198 for (
const auto& row : list) {
1200 for (
const auto& elem : row) {
1201 result.block(row_offset, col_offset, elem.rows(), elem.cols()) = elem;
1202 col_offset += elem.cols();
1204 row_offset += row.begin()->rows();
1222SLEIPNIR_DLLEXPORT
inline VariableMatrix block(
1223 const std::vector<std::vector<VariableMatrix>>& list) {
1227 for (
const auto& row : list) {
1228 if (row.size() > 0) {
1229 rows += row.
begin()->rows();
1233 int latest_cols = 0;
1234 for (
const auto& elem : row) {
1236 slp_assert(row.begin()->rows() == elem.rows());
1238 latest_cols += elem.cols();
1246 slp_assert(cols == latest_cols);
1253 for (
const auto& row : list) {
1255 for (
const auto& elem : row) {
1256 result.block(row_offset, col_offset, elem.rows(), elem.cols()) = elem;
1257 col_offset += elem.cols();
1259 row_offset += row.begin()->rows();
1272SLEIPNIR_DLLEXPORT VariableMatrix solve(
const VariableMatrix& A,
1273 const VariableMatrix& B);
constexpr int adjust(int length)
Definition slice.hpp:134
Definition variable_block.hpp:24
Definition variable_matrix.hpp:29
VariableMatrix & operator=(ScalarLike auto value)
Definition variable_matrix.hpp:229
Eigen::MatrixXd value()
Definition variable_matrix.hpp:945
const VariableBlock< const VariableMatrix > row(int row) const
Definition variable_matrix.hpp:535
VariableBlock< VariableMatrix > segment(int offset, int length)
Definition variable_matrix.hpp:498
VariableMatrix(const Eigen::DiagonalBase< Derived > &values)
Definition variable_matrix.hpp:187
const_iterator end() const
Definition variable_matrix.hpp:1071
const_iterator cend() const
Definition variable_matrix.hpp:1085
VariableMatrix(const Variable &variable)
Definition variable_matrix.hpp:259
VariableBlock< VariableMatrix > block(int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_matrix.hpp:393
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const ScalarLike auto &lhs, const SleipnirMatrixLike auto &rhs)
Definition variable_matrix.hpp:634
VariableMatrix & operator*=(const MatrixLike auto &rhs)
Definition variable_matrix.hpp:671
size_t size() const
Definition variable_matrix.hpp:1092
VariableMatrix & operator+=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:787
friend SLEIPNIR_DLLEXPORT VariableMatrix operator-(const LHS &lhs, const RHS &rhs)
Definition variable_matrix.hpp:808
const Variable & operator()(int row, int col) const
Definition variable_matrix.hpp:356
VariableMatrix & operator/=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:729
const VariableBlock< const VariableMatrix > block(int row_offset, int col_offset, int block_rows, int block_cols) const
Definition variable_matrix.hpp:411
Variable & operator()(int row, int col)
Definition variable_matrix.hpp:343
const Variable & operator[](int row) const
Definition variable_matrix.hpp:379
VariableMatrix(const Eigen::MatrixBase< Derived > &values)
Definition variable_matrix.hpp:170
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const Variable &lhs, const MatrixLike auto &rhs)
Definition variable_matrix.hpp:653
VariableMatrix & operator+=(const MatrixLike auto &rhs)
Definition variable_matrix.hpp:769
VariableMatrix T() const
Definition variable_matrix.hpp:890
VariableMatrix & operator*=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:693
VariableMatrix cwise_transform(function_ref< Variable(const Variable &x)> unary_op) const
Definition variable_matrix.hpp:963
const VariableBlock< const VariableMatrix > operator()(Slice row_slice, Slice col_slice) const
Definition variable_matrix.hpp:443
static VariableMatrix ones(int rows, int cols)
Definition variable_matrix.hpp:1118
double value(int index)
Definition variable_matrix.hpp:935
VariableMatrix & operator-=(const MatrixLike auto &rhs)
Definition variable_matrix.hpp:829
VariableMatrix(int rows, int cols)
Definition variable_matrix.hpp:64
VariableBlock< VariableMatrix > operator()(Slice row_slice, Slice col_slice)
Definition variable_matrix.hpp:429
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const MatrixLike auto &lhs, const Variable &rhs)
Definition variable_matrix.hpp:614
VariableMatrix(int rows)
Definition variable_matrix.hpp:51
const VariableBlock< const VariableMatrix > operator()(Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const
Definition variable_matrix.hpp:484
VariableMatrix(const VariableBlock< const VariableMatrix > &values)
Definition variable_matrix.hpp:293
friend SLEIPNIR_DLLEXPORT VariableMatrix operator/(const MatrixLike auto &lhs, const ScalarLike auto &rhs)
Definition variable_matrix.hpp:711
iterator end()
Definition variable_matrix.hpp:1057
VariableBlock< VariableMatrix > row(int row)
Definition variable_matrix.hpp:524
VariableMatrix & operator-=(const ScalarLike auto &rhs)
Definition variable_matrix.hpp:847
const_iterator cbegin() const
Definition variable_matrix.hpp:1078
VariableMatrix(empty_t, int rows, int cols)
Definition variable_matrix.hpp:77
Variable & operator[](int row)
Definition variable_matrix.hpp:368
VariableMatrix(std::span< const Variable > values)
Definition variable_matrix.hpp:308
VariableMatrix(std::span< const Variable > values, int rows, int cols)
Definition variable_matrix.hpp:325
const_iterator begin() const
Definition variable_matrix.hpp:1064
double value(int row, int col)
Definition variable_matrix.hpp:923
iterator begin()
Definition variable_matrix.hpp:1050
static VariableMatrix zero(int rows, int cols)
Definition variable_matrix.hpp:1101
VariableMatrix & operator=(const Eigen::MatrixBase< Derived > &values)
Definition variable_matrix.hpp:209
friend SLEIPNIR_DLLEXPORT VariableMatrix operator-(const SleipnirMatrixLike auto &lhs)
Definition variable_matrix.hpp:865
const VariableBlock< const VariableMatrix > col(int col) const
Definition variable_matrix.hpp:557
int rows() const
Definition variable_matrix.hpp:907
int cols() const
Definition variable_matrix.hpp:914
VariableMatrix(Variable &&variable)
Definition variable_matrix.hpp:269
const VariableBlock< const VariableMatrix > segment(int offset, int length) const
Definition variable_matrix.hpp:511
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const SleipnirMatrixLike auto &lhs, const ScalarLike auto &rhs)
Definition variable_matrix.hpp:596
VariableMatrix(std::initializer_list< std::initializer_list< Variable > > list)
Definition variable_matrix.hpp:89
void set_value(const Eigen::MatrixBase< Derived > &values)
Definition variable_matrix.hpp:244
VariableMatrix(const std::vector< std::vector< double > > &list)
Definition variable_matrix.hpp:117
VariableMatrix(const std::vector< std::vector< Variable > > &list)
Definition variable_matrix.hpp:144
VariableBlock< VariableMatrix > col(int col)
Definition variable_matrix.hpp:546
static constexpr empty_t empty
Definition variable_matrix.hpp:39
friend SLEIPNIR_DLLEXPORT VariableMatrix operator*(const LHS &lhs, const RHS &rhs)
Definition variable_matrix.hpp:570
VariableBlock< VariableMatrix > operator()(Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_matrix.hpp:464
friend SLEIPNIR_DLLEXPORT VariableMatrix operator+(const LHS &lhs, const RHS &rhs)
Definition variable_matrix.hpp:748
VariableMatrix(const VariableBlock< VariableMatrix > &values)
Definition variable_matrix.hpp:278
Definition variable.hpp:41
Definition function_ref.hpp:13
Definition concepts.hpp:29
Definition concepts.hpp:13
Definition concepts.hpp:23
Definition variable_matrix.hpp:34