11#include "sleipnir/autodiff/slice.hpp"
12#include "sleipnir/autodiff/variable.hpp"
13#include "sleipnir/util/assert.hpp"
14#include "sleipnir/util/function_ref.hpp"
23template <
typename Mat>
38 if (
this == &values) {
42 if (m_mat ==
nullptr) {
44 m_row_slice = values.m_row_slice;
45 m_row_slice_length = values.m_row_slice_length;
46 m_col_slice = values.m_col_slice;
47 m_col_slice_length = values.m_col_slice_length;
73 if (
this == &values) {
77 if (m_mat ==
nullptr) {
79 m_row_slice = values.m_row_slice;
80 m_row_slice_length = values.m_row_slice_length;
81 m_col_slice = values.m_col_slice;
82 m_col_slice_length = values.m_col_slice_length;
84 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
116 m_row_slice{row_offset, row_offset + block_rows, 1},
117 m_row_slice_length{m_row_slice.adjust(mat.
rows())},
118 m_col_slice{col_offset, col_offset + block_cols, 1},
119 m_col_slice_length{m_col_slice.adjust(mat.
cols())} {}
133 Slice col_slice,
int col_slice_length)
135 m_row_slice{std::move(row_slice)},
136 m_row_slice_length{row_slice_length},
137 m_col_slice{std::move(col_slice)},
138 m_col_slice_length{col_slice_length} {}
149 slp_assert(
rows() == 1 &&
cols() == 1);
151 (*this)[0, 0] =
value;
164 slp_assert(
rows() == 1 &&
cols() == 1);
166 (*this)[0, 0].set_value(
value);
175 template <
typename Derived>
177 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
193 template <
typename Derived>
194 requires std::same_as<typename Derived::Scalar, double>
195 void set_value(
const Eigen::MatrixBase<Derived>& values) {
196 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
212 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
229 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
247 requires(!std::is_const_v<Mat>)
251 return (*m_mat)[m_row_slice.
start +
row * m_row_slice.
step,
265 return (*m_mat)[m_row_slice.
start +
row * m_row_slice.
step,
276 requires(!std::is_const_v<Mat>)
278 slp_assert(index >= 0 && index <
rows() *
cols());
279 return (*
this)[index /
cols(), index %
cols()];
289 slp_assert(index >= 0 && index <
rows() *
cols());
290 return (*
this)[index /
cols(), index %
cols()];
304 slp_assert(row_offset >= 0 && row_offset <=
rows());
305 slp_assert(col_offset >= 0 && col_offset <=
cols());
306 slp_assert(block_rows >= 0 && block_rows <=
rows() - row_offset);
307 slp_assert(block_cols >= 0 && block_cols <=
cols() - col_offset);
308 return (*
this)[
Slice{row_offset, row_offset + block_rows, 1},
309 Slice{col_offset, col_offset + block_cols, 1}];
322 int block_rows,
int block_cols)
const {
323 slp_assert(row_offset >= 0 && row_offset <=
rows());
324 slp_assert(col_offset >= 0 && col_offset <=
cols());
325 slp_assert(block_rows >= 0 && block_rows <=
rows() - row_offset);
326 slp_assert(block_cols >= 0 && block_cols <=
cols() - col_offset);
327 return (*
this)[
Slice{row_offset, row_offset + block_rows, 1},
328 Slice{col_offset, col_offset + block_cols, 1}];
339 int row_slice_length = row_slice.
adjust(m_row_slice_length);
340 int col_slice_length = col_slice.
adjust(m_col_slice_length);
341 return (*
this)[row_slice, row_slice_length, col_slice, col_slice_length];
352 Slice col_slice)
const {
353 int row_slice_length = row_slice.
adjust(m_row_slice_length);
354 int col_slice_length = col_slice.
adjust(m_col_slice_length);
355 return (*
this)[row_slice, row_slice_length, col_slice, col_slice_length];
371 Slice col_slice,
int col_slice_length) {
397 int row_slice_length,
399 int col_slice_length)
const {
420 slp_assert(
cols() == 1);
421 slp_assert(offset >= 0 && offset <
rows());
422 slp_assert(length >= 0 && length <=
rows() - offset);
423 return block(offset, 0, length, 1);
434 slp_assert(
cols() == 1);
435 slp_assert(offset >= 0 && offset <
rows());
436 slp_assert(length >= 0 && length <=
rows() - offset);
437 return block(offset, 0, length, 1);
491 slp_assert(
cols() == rhs.rows() &&
cols() == rhs.cols());
493 for (
int i = 0; i <
rows(); ++i) {
494 for (
int j = 0; j < rhs.cols(); ++j) {
496 for (
int k = 0; k <
cols(); ++k) {
497 sum += (*this)(i, k) * rhs(k, j);
529 slp_assert(rhs.rows() == 1 && rhs.cols() == 1);
533 (*this)[
row,
col] /= rhs[0, 0];
563 slp_assert(
rows() == rhs.rows() &&
cols() == rhs.cols());
581 slp_assert(
rows() == 1 &&
cols() == 1);
599 slp_assert(
rows() == rhs.rows() &&
cols() == rhs.cols());
617 slp_assert(
rows() == 1 &&
cols() == 1);
632 slp_assert(
rows() == 1 &&
cols() == 1);
633 return (*
this)(0, 0);
641 std::remove_cv_t<Mat>
T()
const {
642 std::remove_cv_t<Mat> result{Mat::empty,
cols(),
rows()};
658 int rows()
const {
return m_row_slice_length; }
665 int cols()
const {
return m_col_slice_length; }
683 slp_assert(index >= 0 && index <
rows() *
cols());
693 Eigen::MatrixXd result{
rows(),
cols()};
712 std::remove_cv_t<Mat> result{Mat::empty,
rows(),
cols()};
723#ifndef DOXYGEN_SHOULD_SKIP_THIS
727 using iterator_category = std::bidirectional_iterator_tag;
729 using difference_type = std::ptrdiff_t;
733 constexpr iterator() noexcept = default;
735 constexpr iterator(
VariableBlock<Mat>* mat,
int index) noexcept
736 : m_mat{mat}, m_index{index} {}
738 constexpr iterator& operator++() noexcept {
743 constexpr iterator operator++(
int)
noexcept {
744 iterator retval = *
this;
749 constexpr iterator& operator--() noexcept {
754 constexpr iterator operator--(
int)
noexcept {
755 iterator retval = *
this;
760 constexpr bool operator==(
const iterator&)
const noexcept =
default;
762 constexpr reference operator*() const noexcept {
return (*m_mat)[m_index]; }
765 VariableBlock<Mat>* m_mat =
nullptr;
769 class const_iterator {
771 using iterator_category = std::bidirectional_iterator_tag;
772 using value_type = Variable;
773 using difference_type = std::ptrdiff_t;
774 using pointer = Variable*;
775 using const_reference =
const Variable&;
777 constexpr const_iterator() noexcept = default;
779 constexpr const_iterator(const VariableBlock<Mat>* mat,
int index) noexcept
780 : m_mat{mat}, m_index{index} {}
782 constexpr const_iterator& operator++() noexcept {
787 constexpr const_iterator operator++(
int)
noexcept {
788 const_iterator retval = *
this;
793 constexpr const_iterator& operator--() noexcept {
798 constexpr const_iterator operator--(
int)
noexcept {
799 iterator retval = *
this;
804 constexpr bool operator==(
const const_iterator&)
const noexcept =
default;
806 constexpr const_reference operator*() const noexcept {
807 return (*m_mat)[m_index];
811 const VariableBlock<Mat>* m_mat =
nullptr;
815 using reverse_iterator = std::reverse_iterator<iterator>;
816 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
825 iterator
begin() {
return iterator(
this, 0); }
839 const_iterator
begin()
const {
return const_iterator(
this, 0); }
846 const_iterator
end()
const {
return const_iterator(
this,
rows() *
cols()); }
853 const_iterator
cbegin()
const {
return const_iterator(
this, 0); }
860 const_iterator
cend()
const {
return const_iterator(
this,
rows() *
cols()); }
867 reverse_iterator
rbegin() {
return reverse_iterator{
end()}; }
874 reverse_iterator
rend() {
return reverse_iterator{
begin()}; }
882 return const_reverse_iterator{
end()};
890 const_reverse_iterator
rend()
const {
891 return const_reverse_iterator{
begin()};
900 return const_reverse_iterator{
cend()};
908 const_reverse_iterator
crend()
const {
909 return const_reverse_iterator{
cbegin()};
920 Mat* m_mat =
nullptr;
923 int m_row_slice_length = 0;
926 int m_col_slice_length = 0;
int step
Step.
Definition slice.hpp:40
int stop
Stop index (exclusive).
Definition slice.hpp:37
constexpr int adjust(int length)
Definition slice.hpp:134
int start
Start index (inclusive).
Definition slice.hpp:34
Definition variable_block.hpp:24
VariableBlock< Mat > & operator=(const Mat &values)
Definition variable_block.hpp:211
VariableBlock< Mat > & operator*=(const ScalarLike auto &rhs)
Definition variable_block.hpp:512
const Variable & operator[](int row, int col) const
Definition variable_block.hpp:262
const_iterator begin() const
Definition variable_block.hpp:839
const_iterator end() const
Definition variable_block.hpp:846
Eigen::MatrixXd value()
Definition variable_block.hpp:692
const_iterator cbegin() const
Definition variable_block.hpp:853
VariableBlock< Mat > row(int row)
Definition variable_block.hpp:446
VariableBlock< Mat > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:370
VariableBlock< Mat > & operator=(ScalarLike auto value)
Definition variable_block.hpp:148
VariableBlock< Mat > & operator+=(const MatrixLike auto &rhs)
Definition variable_block.hpp:562
std::remove_cv_t< Mat > cwise_transform(function_ref< Variable(const Variable &x)> unary_op) const
Definition variable_block.hpp:710
VariableBlock< Mat > & operator*=(const MatrixLike auto &rhs)
Definition variable_block.hpp:490
iterator begin()
Definition variable_block.hpp:825
VariableBlock(VariableBlock< Mat > &&)=default
VariableBlock< Mat > & operator=(const VariableBlock< Mat > &values)
Definition variable_block.hpp:37
VariableBlock< const Mat > col(int col) const
Definition variable_block.hpp:479
int rows() const
Definition variable_block.hpp:658
VariableBlock< Mat > operator[](Slice row_slice, Slice col_slice)
Definition variable_block.hpp:338
double value(int row, int col)
Definition variable_block.hpp:674
VariableBlock< Mat > col(int col)
Definition variable_block.hpp:468
const VariableBlock< const Mat > operator[](Slice row_slice, Slice col_slice) const
Definition variable_block.hpp:351
VariableBlock< Mat > & operator=(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:176
reverse_iterator rbegin()
Definition variable_block.hpp:867
VariableBlock< Mat > & operator=(Mat &&values)
Definition variable_block.hpp:228
const_iterator cend() const
Definition variable_block.hpp:860
void set_value(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:195
Variable & operator[](int row, int col)
Definition variable_block.hpp:246
VariableBlock< Mat > & operator/=(const MatrixLike auto &rhs)
Definition variable_block.hpp:528
const_reverse_iterator crbegin() const
Definition variable_block.hpp:899
VariableBlock< Mat > & operator-=(const ScalarLike auto &rhs)
Definition variable_block.hpp:616
VariableBlock(Mat &mat, int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:113
VariableBlock< Mat > & operator/=(const ScalarLike auto &rhs)
Definition variable_block.hpp:546
VariableBlock< Mat > segment(int offset, int length)
Definition variable_block.hpp:419
const_reverse_iterator rend() const
Definition variable_block.hpp:890
VariableBlock< Mat > & operator-=(const MatrixLike auto &rhs)
Definition variable_block.hpp:598
VariableBlock< Mat > & operator+=(const ScalarLike auto &rhs)
Definition variable_block.hpp:580
Variable & operator[](int index)
Definition variable_block.hpp:275
iterator end()
Definition variable_block.hpp:832
double value(int index)
Definition variable_block.hpp:682
VariableBlock(Mat &mat)
Definition variable_block.hpp:101
VariableBlock< const Mat > row(int row) const
Definition variable_block.hpp:457
const VariableBlock< Mat > segment(int offset, int length) const
Definition variable_block.hpp:433
int cols() const
Definition variable_block.hpp:665
VariableBlock< Mat > block(int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:302
reverse_iterator rend()
Definition variable_block.hpp:874
size_t size() const
Definition variable_block.hpp:917
std::remove_cv_t< Mat > T() const
Definition variable_block.hpp:641
VariableBlock(const VariableBlock< Mat > &)=default
const Variable & operator[](int index) const
Definition variable_block.hpp:288
VariableBlock< Mat > & operator=(VariableBlock< Mat > &&values)
Definition variable_block.hpp:72
VariableBlock(Mat &mat, Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:132
const VariableBlock< const Mat > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const
Definition variable_block.hpp:396
const_reverse_iterator crend() const
Definition variable_block.hpp:908
const_reverse_iterator rbegin() const
Definition variable_block.hpp:881
const VariableBlock< const Mat > block(int row_offset, int col_offset, int block_rows, int block_cols) const
Definition variable_block.hpp:321
void set_value(double value)
Definition variable_block.hpp:163
Definition variable.hpp:40
Definition function_ref.hpp:13
Definition concepts.hpp:40
Definition concepts.hpp:13