Sleipnir C++ API
Loading...
Searching...
No Matches
variable_block.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <concepts>
6#include <type_traits>
7#include <utility>
8
9#include <Eigen/Core>
10
11#include "sleipnir/autodiff/sleipnir_base.hpp"
12#include "sleipnir/autodiff/slice.hpp"
13#include "sleipnir/autodiff/variable.hpp"
14#include "sleipnir/util/assert.hpp"
15#include "sleipnir/util/empty.hpp"
16#include "sleipnir/util/function_ref.hpp"
17
18namespace slp {
19
25template <typename Mat>
27 public:
31 using Scalar = typename Mat::Scalar;
32
37
45 if (this == &values) {
46 return *this;
47 }
48
49 if (m_mat == nullptr) {
50 m_mat = values.m_mat;
51 m_row_slice = values.m_row_slice;
52 m_row_slice_length = values.m_row_slice_length;
53 m_col_slice = values.m_col_slice;
54 m_col_slice_length = values.m_col_slice_length;
55 } else {
56 slp_assert(rows() == values.rows() && cols() == values.cols());
57
58 for (int row = 0; row < rows(); ++row) {
59 for (int col = 0; col < cols(); ++col) {
60 (*this)[row, col] = values[row, col];
61 }
62 }
63 }
64
65 return *this;
66 }
67
72
80 if (this == &values) {
81 return *this;
82 }
83
84 if (m_mat == nullptr) {
85 m_mat = values.m_mat;
86 m_row_slice = values.m_row_slice;
87 m_row_slice_length = values.m_row_slice_length;
88 m_col_slice = values.m_col_slice;
89 m_col_slice_length = values.m_col_slice_length;
90 } else {
91 slp_assert(rows() == values.rows() && cols() == values.cols());
92
93 for (int row = 0; row < rows(); ++row) {
94 for (int col = 0; col < cols(); ++col) {
95 (*this)[row, col] = values[row, col];
96 }
97 }
98 }
99
100 return *this;
101 }
102
108 // NOLINTNEXTLINE (google-explicit-constructor)
110
121 int block_cols)
122 : m_mat{&mat},
123 m_row_slice{row_offset, row_offset + block_rows, 1},
124 m_row_slice_length{m_row_slice.adjust(mat.rows())},
125 m_col_slice{col_offset, col_offset + block_cols, 1},
126 m_col_slice_length{m_col_slice.adjust(mat.cols())} {}
127
141 : m_mat{&mat},
142 m_row_slice{std::move(row_slice)},
143 m_row_slice_length{row_slice_length},
144 m_col_slice{std::move(col_slice)},
145 m_col_slice_length{col_slice_length} {}
146
156 slp_assert(rows() == 1 && cols() == 1);
157
158 (*this)[0, 0] = value;
159
160 return *this;
161 }
162
171 slp_assert(rows() == 1 && cols() == 1);
172
173 (*this)[0, 0].set_value(value);
174 }
175
182 template <typename Derived>
183 VariableBlock<Mat>& operator=(const Eigen::MatrixBase<Derived>& values) {
184 slp_assert(rows() == values.rows() && cols() == values.cols());
185
186 for (int row = 0; row < rows(); ++row) {
187 for (int col = 0; col < cols(); ++col) {
188 (*this)[row, col] = values[row, col];
189 }
190 }
191
192 return *this;
193 }
194
200 template <typename Derived>
201 requires std::same_as<typename Derived::Scalar, Scalar>
202 void set_value(const Eigen::MatrixBase<Derived>& values) {
203 slp_assert(rows() == values.rows() && cols() == values.cols());
204
205 for (int row = 0; row < rows(); ++row) {
206 for (int col = 0; col < cols(); ++col) {
207 (*this)[row, col].set_value(values[row, col]);
208 }
209 }
210 }
211
219 slp_assert(rows() == values.rows() && cols() == values.cols());
220
221 for (int row = 0; row < rows(); ++row) {
222 for (int col = 0; col < cols(); ++col) {
223 (*this)[row, col] = values[row, col];
224 }
225 }
226 return *this;
227 }
228
236 slp_assert(rows() == values.rows() && cols() == values.cols());
237
238 for (int row = 0; row < rows(); ++row) {
239 for (int col = 0; col < cols(); ++col) {
240 (*this)[row, col] = std::move(values[row, col]);
241 }
242 }
243 return *this;
244 }
245
254 requires(!std::is_const_v<Mat>)
255 {
256 slp_assert(row >= 0 && row < rows());
257 slp_assert(col >= 0 && col < cols());
258 return (*m_mat)[m_row_slice.start + row * m_row_slice.step,
259 m_col_slice.start + col * m_col_slice.step];
260 }
261
269 const Variable<Scalar>& operator[](int row, int col) const {
270 slp_assert(row >= 0 && row < rows());
271 slp_assert(col >= 0 && col < cols());
272 return (*m_mat)[m_row_slice.start + row * m_row_slice.step,
273 m_col_slice.start + col * m_col_slice.step];
274 }
275
283 requires(!std::is_const_v<Mat>)
284 {
285 slp_assert(index >= 0 && index < rows() * cols());
286 return (*this)[index / cols(), index % cols()];
287 }
288
295 const Variable<Scalar>& operator[](int index) const {
296 slp_assert(index >= 0 && index < rows() * cols());
297 return (*this)[index / cols(), index % cols()];
298 }
299
310 int block_cols) {
311 slp_assert(row_offset >= 0 && row_offset <= rows());
312 slp_assert(col_offset >= 0 && col_offset <= cols());
313 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
314 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
315 return (*this)[Slice{row_offset, row_offset + block_rows, 1},
317 }
318
329 int block_rows, int block_cols) const {
330 slp_assert(row_offset >= 0 && row_offset <= rows());
331 slp_assert(col_offset >= 0 && col_offset <= cols());
332 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
333 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
334 return (*this)[Slice{row_offset, row_offset + block_rows, 1},
336 }
337
346 int row_slice_length = row_slice.adjust(m_row_slice_length);
347 int col_slice_length = col_slice.adjust(m_col_slice_length);
349 }
350
359 Slice col_slice) const {
360 int row_slice_length = row_slice.adjust(m_row_slice_length);
361 int col_slice_length = col_slice.adjust(m_col_slice_length);
363 }
364
379 return VariableBlock{
380 *m_mat,
381 {m_row_slice.start + row_slice.start * m_row_slice.step,
382 m_row_slice.start + row_slice.stop * m_row_slice.step,
383 row_slice.step * m_row_slice.step},
385 {m_col_slice.start + col_slice.start * m_col_slice.step,
386 m_col_slice.start + col_slice.stop * m_col_slice.step,
387 col_slice.step * m_col_slice.step},
389 }
390
406 int col_slice_length) const {
407 return VariableBlock{
408 *m_mat,
409 {m_row_slice.start + row_slice.start * m_row_slice.step,
410 m_row_slice.start + row_slice.stop * m_row_slice.step,
411 row_slice.step * m_row_slice.step},
413 {m_col_slice.start + col_slice.start * m_col_slice.step,
414 m_col_slice.start + col_slice.stop * m_col_slice.step,
415 col_slice.step * m_col_slice.step},
417 }
418
427 slp_assert(cols() == 1);
428 slp_assert(offset >= 0 && offset < rows());
429 slp_assert(length >= 0 && length <= rows() - offset);
430 return block(offset, 0, length, 1);
431 }
432
440 const VariableBlock<Mat> segment(int offset, int length) const {
441 slp_assert(cols() == 1);
442 slp_assert(offset >= 0 && offset < rows());
443 slp_assert(length >= 0 && length <= rows() - offset);
444 return block(offset, 0, length, 1);
445 }
446
454 slp_assert(row >= 0 && row < rows());
455 return block(row, 0, 1, cols());
456 }
457
465 slp_assert(row >= 0 && row < rows());
466 return block(row, 0, 1, cols());
467 }
468
476 slp_assert(col >= 0 && col < cols());
477 return block(0, col, rows(), 1);
478 }
479
487 slp_assert(col >= 0 && col < cols());
488 return block(0, col, rows(), 1);
489 }
490
498 slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
499
500 for (int i = 0; i < rows(); ++i) {
501 for (int j = 0; j < rhs.cols(); ++j) {
502 Variable sum{Scalar(0)};
503 for (int k = 0; k < cols(); ++k) {
504 sum += (*this)(i, k) * rhs(k, j);
505 }
506 (*this)(i, j) = sum;
507 }
508 }
509
510 return *this;
511 }
512
520 for (int row = 0; row < rows(); ++row) {
521 for (int col = 0; col < cols(); ++col) {
522 (*this)[row, col] *= rhs;
523 }
524 }
525
526 return *this;
527 }
528
536 slp_assert(rhs.rows() == 1 && rhs.cols() == 1);
537
538 for (int row = 0; row < rows(); ++row) {
539 for (int col = 0; col < cols(); ++col) {
540 (*this)[row, col] /= rhs[0, 0];
541 }
542 }
543
544 return *this;
545 }
546
554 for (int row = 0; row < rows(); ++row) {
555 for (int col = 0; col < cols(); ++col) {
556 (*this)[row, col] /= rhs;
557 }
558 }
559
560 return *this;
561 }
562
570 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
571
572 for (int row = 0; row < rows(); ++row) {
573 for (int col = 0; col < cols(); ++col) {
574 (*this)[row, col] += rhs[row, col];
575 }
576 }
577
578 return *this;
579 }
580
588 slp_assert(rows() == 1 && cols() == 1);
589
590 for (int row = 0; row < rows(); ++row) {
591 for (int col = 0; col < cols(); ++col) {
592 (*this)[row, col] += rhs;
593 }
594 }
595
596 return *this;
597 }
598
606 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
607
608 for (int row = 0; row < rows(); ++row) {
609 for (int col = 0; col < cols(); ++col) {
610 (*this)[row, col] -= rhs[row, col];
611 }
612 }
613
614 return *this;
615 }
616
624 slp_assert(rows() == 1 && cols() == 1);
625
626 for (int row = 0; row < rows(); ++row) {
627 for (int col = 0; col < cols(); ++col) {
628 (*this)[row, col] -= rhs;
629 }
630 }
631
632 return *this;
633 }
634
638 // NOLINTNEXTLINE (google-explicit-constructor)
639 operator Variable<Scalar>() const {
640 slp_assert(rows() == 1 && cols() == 1);
641 return (*this)(0, 0);
642 }
643
649 std::remove_cv_t<Mat> T() const {
650 std::remove_cv_t<Mat> result{detail::empty, cols(), rows()};
651
652 for (int row = 0; row < rows(); ++row) {
653 for (int col = 0; col < cols(); ++col) {
654 result[col, row] = (*this)[row, col];
655 }
656 }
657
658 return result;
659 }
660
666 int rows() const { return m_row_slice_length; }
667
673 int cols() const { return m_col_slice_length; }
674
682 Scalar value(int row, int col) { return (*this)[row, col].value(); }
683
690 Scalar value(int index) {
691 slp_assert(index >= 0 && index < rows() * cols());
692 return value(index / cols(), index % cols());
693 }
694
700 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> value() {
701 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> result{rows(),
702 cols()};
703
704 for (int row = 0; row < rows(); ++row) {
705 for (int col = 0; col < cols(); ++col) {
706 result[row, col] = value(row, col);
707 }
708 }
709
710 return result;
711 }
712
719 std::remove_cv_t<Mat> cwise_transform(
721 const {
722 std::remove_cv_t<Mat> result{detail::empty, rows(), cols()};
723
724 for (int row = 0; row < rows(); ++row) {
725 for (int col = 0; col < cols(); ++col) {
726 result[row, col] = unary_op((*this)[row, col]);
727 }
728 }
729
730 return result;
731 }
732
733#ifndef DOXYGEN_SHOULD_SKIP_THIS
734
735 class iterator {
736 public:
737 using iterator_category = std::bidirectional_iterator_tag;
738 using value_type = Variable<Scalar>;
739 using difference_type = std::ptrdiff_t;
740 using pointer = Variable<Scalar>*;
742
743 constexpr iterator() noexcept = default;
744
746 : m_mat{mat}, m_index{index} {}
747
748 constexpr iterator& operator++() noexcept {
749 ++m_index;
750 return *this;
751 }
752
753 constexpr iterator operator++(int) noexcept {
754 iterator retval = *this;
755 ++(*this);
756 return retval;
757 }
758
759 constexpr iterator& operator--() noexcept {
760 --m_index;
761 return *this;
762 }
763
764 constexpr iterator operator--(int) noexcept {
765 iterator retval = *this;
766 --(*this);
767 return retval;
768 }
769
770 constexpr bool operator==(const iterator&) const noexcept = default;
771
772 constexpr reference operator*() const noexcept { return (*m_mat)[m_index]; }
773
774 private:
775 VariableBlock<Mat>* m_mat = nullptr;
776 int m_index = 0;
777 };
778
779 class const_iterator {
780 public:
781 using iterator_category = std::bidirectional_iterator_tag;
782 using value_type = Variable<Scalar>;
783 using difference_type = std::ptrdiff_t;
784 using pointer = Variable<Scalar>*;
785 using const_reference = const Variable<Scalar>&;
786
787 constexpr const_iterator() noexcept = default;
788
789 constexpr const_iterator(const VariableBlock<Mat>* mat, int index) noexcept
790 : m_mat{mat}, m_index{index} {}
791
792 constexpr const_iterator& operator++() noexcept {
793 ++m_index;
794 return *this;
795 }
796
797 constexpr const_iterator operator++(int) noexcept {
798 const_iterator retval = *this;
799 ++(*this);
800 return retval;
801 }
802
803 constexpr const_iterator& operator--() noexcept {
804 --m_index;
805 return *this;
806 }
807
808 constexpr const_iterator operator--(int) noexcept {
809 iterator retval = *this;
810 --(*this);
811 return retval;
812 }
813
814 constexpr bool operator==(const const_iterator&) const noexcept = default;
815
816 constexpr const_reference operator*() const noexcept {
817 return (*m_mat)[m_index];
818 }
819
820 private:
821 const VariableBlock<Mat>* m_mat = nullptr;
822 int m_index = 0;
823 };
824
825 using reverse_iterator = std::reverse_iterator<iterator>;
826 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
827
828#endif // DOXYGEN_SHOULD_SKIP_THIS
829
835 iterator begin() { return iterator(this, 0); }
836
842 iterator end() { return iterator(this, rows() * cols()); }
843
849 const_iterator begin() const { return const_iterator(this, 0); }
850
856 const_iterator end() const { return const_iterator(this, rows() * cols()); }
857
863 const_iterator cbegin() const { return const_iterator(this, 0); }
864
870 const_iterator cend() const { return const_iterator(this, rows() * cols()); }
871
878
885
894
903
912
921
927 size_t size() const { return rows() * cols(); }
928
929 private:
930 Mat* m_mat = nullptr;
931
932 Slice m_row_slice;
933 int m_row_slice_length = 0;
934
935 Slice m_col_slice;
936 int m_col_slice_length = 0;
937};
938
939} // namespace slp
Definition intrusive_shared_ptr.hpp:29
Definition sleipnir_base.hpp:11
Definition slice.hpp:31
int step
Step.
Definition slice.hpp:40
int start
Start index (inclusive).
Definition slice.hpp:34
Definition variable_block.hpp:26
VariableBlock< Mat > & operator=(const Mat &values)
Definition variable_block.hpp:218
VariableBlock< Mat > & operator*=(const ScalarLike auto &rhs)
Definition variable_block.hpp:519
const_iterator begin() const
Definition variable_block.hpp:849
const_iterator end() const
Definition variable_block.hpp:856
void set_value(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:202
const_iterator cbegin() const
Definition variable_block.hpp:863
void set_value(Scalar value)
Definition variable_block.hpp:170
VariableBlock< Mat > row(int row)
Definition variable_block.hpp:453
VariableBlock< Mat > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:377
VariableBlock< Mat > & operator=(ScalarLike auto value)
Definition variable_block.hpp:155
Variable< Scalar > & operator[](int index)
Definition variable_block.hpp:282
Variable< Scalar > & operator[](int row, int col)
Definition variable_block.hpp:253
VariableBlock< Mat > & operator+=(const MatrixLike auto &rhs)
Definition variable_block.hpp:569
VariableBlock< Mat > & operator*=(const MatrixLike auto &rhs)
Definition variable_block.hpp:497
iterator begin()
Definition variable_block.hpp:835
VariableBlock(VariableBlock< Mat > &&)=default
VariableBlock< Mat > & operator=(const VariableBlock< Mat > &values)
Definition variable_block.hpp:44
const Variable< Scalar > & operator[](int index) const
Definition variable_block.hpp:295
VariableBlock< const Mat > col(int col) const
Definition variable_block.hpp:486
int rows() const
Definition variable_block.hpp:666
VariableBlock< Mat > operator[](Slice row_slice, Slice col_slice)
Definition variable_block.hpp:345
VariableBlock< Mat > col(int col)
Definition variable_block.hpp:475
const VariableBlock< const Mat > operator[](Slice row_slice, Slice col_slice) const
Definition variable_block.hpp:358
VariableBlock< Mat > & operator=(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:183
Scalar value(int index)
Definition variable_block.hpp:690
reverse_iterator rbegin()
Definition variable_block.hpp:877
VariableBlock< Mat > & operator=(Mat &&values)
Definition variable_block.hpp:235
const_iterator cend() const
Definition variable_block.hpp:870
VariableBlock< Mat > & operator/=(const MatrixLike auto &rhs)
Definition variable_block.hpp:535
const Variable< Scalar > & operator[](int row, int col) const
Definition variable_block.hpp:269
const_reverse_iterator crbegin() const
Definition variable_block.hpp:909
VariableBlock< Mat > & operator-=(const ScalarLike auto &rhs)
Definition variable_block.hpp:623
VariableBlock(Mat &mat, int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:120
VariableBlock< Mat > & operator/=(const ScalarLike auto &rhs)
Definition variable_block.hpp:553
VariableBlock< Mat > segment(int offset, int length)
Definition variable_block.hpp:426
const_reverse_iterator rend() const
Definition variable_block.hpp:900
VariableBlock< Mat > & operator-=(const MatrixLike auto &rhs)
Definition variable_block.hpp:605
VariableBlock< Mat > & operator+=(const ScalarLike auto &rhs)
Definition variable_block.hpp:587
Eigen::Matrix< Scalar, Eigen::Dynamic, Eigen::Dynamic > value()
Definition variable_block.hpp:700
iterator end()
Definition variable_block.hpp:842
VariableBlock(Mat &mat)
Definition variable_block.hpp:109
VariableBlock< const Mat > row(int row) const
Definition variable_block.hpp:464
const VariableBlock< Mat > segment(int offset, int length) const
Definition variable_block.hpp:440
Scalar value(int row, int col)
Definition variable_block.hpp:682
int cols() const
Definition variable_block.hpp:673
VariableBlock< Mat > block(int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:309
reverse_iterator rend()
Definition variable_block.hpp:884
size_t size() const
Definition variable_block.hpp:927
std::remove_cv_t< Mat > T() const
Definition variable_block.hpp:649
VariableBlock(const VariableBlock< Mat > &)=default
std::remove_cv_t< Mat > cwise_transform(function_ref< Variable< Scalar >(const Variable< Scalar > &x)> unary_op) const
Definition variable_block.hpp:719
VariableBlock< Mat > & operator=(VariableBlock< Mat > &&values)
Definition variable_block.hpp:79
VariableBlock(Mat &mat, Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:139
const VariableBlock< const Mat > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const
Definition variable_block.hpp:403
const_reverse_iterator crend() const
Definition variable_block.hpp:918
const_reverse_iterator rbegin() const
Definition variable_block.hpp:891
const VariableBlock< const Mat > block(int row_offset, int col_offset, int block_rows, int block_cols) const
Definition variable_block.hpp:328
typename Mat::Scalar Scalar
Definition variable_block.hpp:31
Definition variable.hpp:49
Definition function_ref.hpp:13
Definition concepts.hpp:18
Definition concepts.hpp:24