11#include "sleipnir/autodiff/slice.hpp"
12#include "sleipnir/autodiff/variable.hpp"
13#include "sleipnir/util/assert.hpp"
14#include "sleipnir/util/function_ref.hpp"
23template <
typename Mat>
38 if (
this == &values) {
42 if (m_mat ==
nullptr) {
44 m_row_slice = values.m_row_slice;
45 m_row_slice_length = values.m_row_slice_length;
46 m_col_slice = values.m_col_slice;
47 m_col_slice_length = values.m_col_slice_length;
73 if (
this == &values) {
77 if (m_mat ==
nullptr) {
79 m_row_slice = values.m_row_slice;
80 m_row_slice_length = values.m_row_slice_length;
81 m_col_slice = values.m_col_slice;
82 m_col_slice_length = values.m_col_slice_length;
84 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
103 m_row_slice{0, mat.
rows(), 1},
104 m_row_slice_length{m_row_slice.adjust(mat.
rows())},
105 m_col_slice{0, mat.
cols(), 1},
106 m_col_slice_length{m_col_slice.adjust(mat.
cols())} {}
120 m_row_slice{row_offset, row_offset + block_rows, 1},
121 m_row_slice_length{m_row_slice.adjust(mat.
rows())},
122 m_col_slice{col_offset, col_offset + block_cols, 1},
123 m_col_slice_length{m_col_slice.adjust(mat.
cols())} {}
137 Slice col_slice,
int col_slice_length)
139 m_row_slice{std::move(row_slice)},
140 m_row_slice_length{row_slice_length},
141 m_col_slice{std::move(col_slice)},
142 m_col_slice_length{col_slice_length} {}
153 slp_assert(
rows() == 1 &&
cols() == 1);
155 (*this)(0, 0) =
value;
168 slp_assert(
rows() == 1 &&
cols() == 1);
179 template <
typename Derived>
181 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
197 template <
typename Derived>
198 requires std::same_as<typename Derived::Scalar, double>
199 void set_value(
const Eigen::MatrixBase<Derived>& values) {
200 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
216 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
233 slp_assert(
rows() == values.rows() &&
cols() == values.cols());
251 requires(!std::is_const_v<Mat>)
255 return (*m_mat)(m_row_slice.
start +
row * m_row_slice.
step,
269 return (*m_mat)(m_row_slice.
start +
row * m_row_slice.
step,
280 requires(!std::is_const_v<Mat>)
308 slp_assert(row_offset >= 0 && row_offset <=
rows());
309 slp_assert(col_offset >= 0 && col_offset <=
cols());
310 slp_assert(block_rows >= 0 && block_rows <=
rows() - row_offset);
311 slp_assert(block_cols >= 0 && block_cols <=
cols() - col_offset);
312 return (*
this)({row_offset, row_offset + block_rows, 1},
313 {col_offset, col_offset + block_cols, 1});
326 int block_rows,
int block_cols)
const {
327 slp_assert(row_offset >= 0 && row_offset <=
rows());
328 slp_assert(col_offset >= 0 && col_offset <=
cols());
329 slp_assert(block_rows >= 0 && block_rows <=
rows() - row_offset);
330 slp_assert(block_cols >= 0 && block_cols <=
cols() - col_offset);
331 return (*
this)({row_offset, row_offset + block_rows, 1},
332 {col_offset, col_offset + block_cols, 1});
343 int row_slice_length = row_slice.
adjust(m_row_slice_length);
344 int col_slice_length = col_slice.
adjust(m_col_slice_length);
363 Slice col_slice)
const {
364 int row_slice_length = row_slice.
adjust(m_row_slice_length);
365 int col_slice_length = col_slice.
adjust(m_col_slice_length);
389 Slice col_slice,
int col_slice_length) {
413 int row_slice_length,
415 int col_slice_length)
const {
434 slp_assert(offset >= 0 && offset <
rows() *
cols());
435 slp_assert(length >= 0 && length <=
rows() *
cols() - offset);
436 return block(offset, 0, length, 1);
447 slp_assert(offset >= 0 && offset <
rows() *
cols());
448 slp_assert(length >= 0 && length <=
rows() *
cols() - offset);
449 return block(offset, 0, length, 1);
503 slp_assert(
cols() == rhs.rows() &&
cols() == rhs.cols());
505 for (
int i = 0; i <
rows(); ++i) {
506 for (
int j = 0; j < rhs.cols(); ++j) {
508 for (
int k = 0; k <
cols(); ++k) {
509 sum += (*this)(i, k) * rhs(k, j);
541 slp_assert(rhs.rows() == 1 && rhs.cols() == 1);
545 (*this)(
row,
col) /= rhs(0, 0);
575 slp_assert(
rows() == rhs.rows() &&
cols() == rhs.cols());
593 slp_assert(
rows() == 1 &&
cols() == 1);
611 slp_assert(
rows() == rhs.rows() &&
cols() == rhs.cols());
629 slp_assert(
rows() == 1 &&
cols() == 1);
644 slp_assert(
rows() == 1 &&
cols() == 1);
645 return (*
this)(0, 0);
653 std::remove_cv_t<Mat>
T()
const {
654 std::remove_cv_t<Mat> result{Mat::empty,
cols(),
rows()};
670 int rows()
const {
return m_row_slice_length; }
677 int cols()
const {
return m_col_slice_length; }
689 return (*m_mat)(m_row_slice.
start +
row * m_row_slice.
step,
701 slp_assert(index >= 0 && index <
rows() *
cols());
711 Eigen::MatrixXd result{
rows(),
cols()};
730 std::remove_cv_t<Mat> result{Mat::empty,
rows(),
cols()};
741#ifndef DOXYGEN_SHOULD_SKIP_THIS
745 using iterator_category = std::forward_iterator_tag;
747 using difference_type = std::ptrdiff_t;
751 constexpr iterator() noexcept = default;
753 constexpr iterator(
VariableBlock<Mat>* mat,
int index) noexcept
754 : m_mat{mat}, m_index{index} {}
756 constexpr iterator& operator++() noexcept {
761 constexpr iterator operator++(
int)
noexcept {
762 iterator retval = *
this;
767 constexpr bool operator==(
const iterator&)
const noexcept =
default;
769 constexpr reference operator*() const noexcept {
return (*m_mat)[m_index]; }
772 VariableBlock<Mat>* m_mat =
nullptr;
776 class const_iterator {
778 using iterator_category = std::forward_iterator_tag;
779 using value_type = Variable;
780 using difference_type = std::ptrdiff_t;
781 using pointer = Variable*;
782 using const_reference =
const Variable&;
784 constexpr const_iterator() noexcept = default;
786 constexpr const_iterator(const VariableBlock<Mat>* mat,
int index) noexcept
787 : m_mat{mat}, m_index{index} {}
789 constexpr const_iterator& operator++() noexcept {
794 constexpr const_iterator operator++(
int)
noexcept {
795 const_iterator retval = *
this;
800 constexpr bool operator==(
const const_iterator&)
const noexcept =
default;
802 constexpr const_reference operator*() const noexcept {
803 return (*m_mat)[m_index];
807 const VariableBlock<Mat>* m_mat =
nullptr;
818 iterator
begin() {
return iterator(
this, 0); }
832 const_iterator
begin()
const {
return const_iterator(
this, 0); }
839 const_iterator
end()
const {
return const_iterator(
this,
rows() *
cols()); }
846 const_iterator
cbegin()
const {
return const_iterator(
this, 0); }
853 const_iterator
cend()
const {
return const_iterator(
this,
rows() *
cols()); }
863 Mat* m_mat =
nullptr;
866 int m_row_slice_length = 0;
869 int m_col_slice_length = 0;
int step
Step.
Definition slice.hpp:40
int stop
Stop index (exclusive).
Definition slice.hpp:37
constexpr int adjust(int length)
Definition slice.hpp:134
int start
Start index (inclusive).
Definition slice.hpp:34
Definition variable_block.hpp:24
VariableBlock< Mat > & operator=(const Mat &values)
Definition variable_block.hpp:215
VariableBlock< Mat > & operator*=(const ScalarLike auto &rhs)
Definition variable_block.hpp:524
const_iterator begin() const
Definition variable_block.hpp:832
const_iterator end() const
Definition variable_block.hpp:839
const Variable & operator()(int row, int col) const
Definition variable_block.hpp:266
Eigen::MatrixXd value()
Definition variable_block.hpp:710
const_iterator cbegin() const
Definition variable_block.hpp:846
VariableBlock< Mat > row(int row)
Definition variable_block.hpp:458
Variable & operator[](int row)
Definition variable_block.hpp:279
VariableBlock< Mat > & operator=(ScalarLike auto value)
Definition variable_block.hpp:152
VariableBlock< Mat > & operator+=(const MatrixLike auto &rhs)
Definition variable_block.hpp:574
std::remove_cv_t< Mat > cwise_transform(function_ref< Variable(const Variable &x)> unary_op) const
Definition variable_block.hpp:728
VariableBlock< Mat > & operator*=(const MatrixLike auto &rhs)
Definition variable_block.hpp:502
iterator begin()
Definition variable_block.hpp:818
VariableBlock(VariableBlock< Mat > &&)=default
VariableBlock< Mat > & operator=(const VariableBlock< Mat > &values)
Definition variable_block.hpp:37
VariableBlock< const Mat > col(int col) const
Definition variable_block.hpp:491
Variable & operator()(int row, int col)
Definition variable_block.hpp:250
int rows() const
Definition variable_block.hpp:670
double value(int row, int col)
Definition variable_block.hpp:686
const VariableBlock< const Mat > operator()(Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const
Definition variable_block.hpp:412
const Variable & operator[](int row) const
Definition variable_block.hpp:292
VariableBlock< Mat > col(int col)
Definition variable_block.hpp:480
VariableBlock< Mat > & operator=(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:180
VariableBlock< Mat > operator()(Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:388
VariableBlock< Mat > & operator=(Mat &&values)
Definition variable_block.hpp:232
const_iterator cend() const
Definition variable_block.hpp:853
void set_value(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:199
VariableBlock< Mat > & operator/=(const MatrixLike auto &rhs)
Definition variable_block.hpp:540
VariableBlock< Mat > & operator-=(const ScalarLike auto &rhs)
Definition variable_block.hpp:628
VariableBlock(Mat &mat, int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:117
VariableBlock< Mat > & operator/=(const ScalarLike auto &rhs)
Definition variable_block.hpp:558
VariableBlock< Mat > segment(int offset, int length)
Definition variable_block.hpp:433
VariableBlock< Mat > & operator-=(const MatrixLike auto &rhs)
Definition variable_block.hpp:610
VariableBlock< Mat > & operator+=(const ScalarLike auto &rhs)
Definition variable_block.hpp:592
iterator end()
Definition variable_block.hpp:825
double value(int index)
Definition variable_block.hpp:700
VariableBlock(Mat &mat)
Definition variable_block.hpp:101
VariableBlock< const Mat > row(int row) const
Definition variable_block.hpp:469
const VariableBlock< const Mat > operator()(Slice row_slice, Slice col_slice) const
Definition variable_block.hpp:362
const VariableBlock< Mat > segment(int offset, int length) const
Definition variable_block.hpp:446
int cols() const
Definition variable_block.hpp:677
VariableBlock< Mat > block(int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:306
size_t size() const
Definition variable_block.hpp:860
std::remove_cv_t< Mat > T() const
Definition variable_block.hpp:653
VariableBlock(const VariableBlock< Mat > &)=default
VariableBlock< Mat > & operator=(VariableBlock< Mat > &&values)
Definition variable_block.hpp:72
VariableBlock(Mat &mat, Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:136
VariableBlock< Mat > operator()(Slice row_slice, Slice col_slice)
Definition variable_block.hpp:342
const VariableBlock< const Mat > block(int row_offset, int col_offset, int block_rows, int block_cols) const
Definition variable_block.hpp:325
void set_value(double value)
Definition variable_block.hpp:167
Definition variable.hpp:41
Definition function_ref.hpp:13
Definition concepts.hpp:29
Definition concepts.hpp:13