Sleipnir C++ API
Loading...
Searching...
No Matches
variable_block.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <concepts>
6#include <type_traits>
7#include <utility>
8
9#include <Eigen/Core>
10
11#include "sleipnir/autodiff/slice.hpp"
12#include "sleipnir/autodiff/variable.hpp"
13#include "sleipnir/util/assert.hpp"
14#include "sleipnir/util/function_ref.hpp"
15
16namespace slp {
17
23template <typename Mat>
25 public:
30
38 if (this == &values) {
39 return *this;
40 }
41
42 if (m_mat == nullptr) {
43 m_mat = values.m_mat;
44 m_row_slice = values.m_row_slice;
45 m_row_slice_length = values.m_row_slice_length;
46 m_col_slice = values.m_col_slice;
47 m_col_slice_length = values.m_col_slice_length;
48 } else {
49 slp_assert(rows() == values.rows() && cols() == values.cols());
50
51 for (int row = 0; row < rows(); ++row) {
52 for (int col = 0; col < cols(); ++col) {
53 (*this)[row, col] = values[row, col];
54 }
55 }
56 }
57
58 return *this;
59 }
60
65
73 if (this == &values) {
74 return *this;
75 }
76
77 if (m_mat == nullptr) {
78 m_mat = values.m_mat;
79 m_row_slice = values.m_row_slice;
80 m_row_slice_length = values.m_row_slice_length;
81 m_col_slice = values.m_col_slice;
82 m_col_slice_length = values.m_col_slice_length;
83 } else {
84 slp_assert(rows() == values.rows() && cols() == values.cols());
85
86 for (int row = 0; row < rows(); ++row) {
87 for (int col = 0; col < cols(); ++col) {
88 (*this)[row, col] = values[row, col];
89 }
90 }
91 }
92
93 return *this;
94 }
95
101 VariableBlock(Mat& mat) // NOLINT
102 : VariableBlock{mat, 0, 0, mat.rows(), mat.cols()} {}
103
113 VariableBlock(Mat& mat, int row_offset, int col_offset, int block_rows,
114 int block_cols)
115 : m_mat{&mat},
116 m_row_slice{row_offset, row_offset + block_rows, 1},
117 m_row_slice_length{m_row_slice.adjust(mat.rows())},
118 m_col_slice{col_offset, col_offset + block_cols, 1},
119 m_col_slice_length{m_col_slice.adjust(mat.cols())} {}
120
132 VariableBlock(Mat& mat, Slice row_slice, int row_slice_length,
133 Slice col_slice, int col_slice_length)
134 : m_mat{&mat},
135 m_row_slice{std::move(row_slice)},
136 m_row_slice_length{row_slice_length},
137 m_col_slice{std::move(col_slice)},
138 m_col_slice_length{col_slice_length} {}
139
149 slp_assert(rows() == 1 && cols() == 1);
150
151 (*this)[0, 0] = value;
152
153 return *this;
154 }
155
163 void set_value(double value) {
164 slp_assert(rows() == 1 && cols() == 1);
165
166 (*this)[0, 0].set_value(value);
167 }
168
175 template <typename Derived>
176 VariableBlock<Mat>& operator=(const Eigen::MatrixBase<Derived>& values) {
177 slp_assert(rows() == values.rows() && cols() == values.cols());
178
179 for (int row = 0; row < rows(); ++row) {
180 for (int col = 0; col < cols(); ++col) {
181 (*this)[row, col] = values(row, col);
182 }
183 }
184
185 return *this;
186 }
187
193 template <typename Derived>
194 requires std::same_as<typename Derived::Scalar, double>
195 void set_value(const Eigen::MatrixBase<Derived>& values) {
196 slp_assert(rows() == values.rows() && cols() == values.cols());
197
198 for (int row = 0; row < rows(); ++row) {
199 for (int col = 0; col < cols(); ++col) {
200 (*this)[row, col].set_value(values(row, col));
201 }
202 }
203 }
204
211 VariableBlock<Mat>& operator=(const Mat& values) {
212 slp_assert(rows() == values.rows() && cols() == values.cols());
213
214 for (int row = 0; row < rows(); ++row) {
215 for (int col = 0; col < cols(); ++col) {
216 (*this)[row, col] = values[row, col];
217 }
218 }
219 return *this;
220 }
221
229 slp_assert(rows() == values.rows() && cols() == values.cols());
230
231 for (int row = 0; row < rows(); ++row) {
232 for (int col = 0; col < cols(); ++col) {
233 (*this)[row, col] = std::move(values[row, col]);
234 }
235 }
236 return *this;
237 }
238
247 requires(!std::is_const_v<Mat>)
248 {
249 slp_assert(row >= 0 && row < rows());
250 slp_assert(col >= 0 && col < cols());
251 return (*m_mat)[m_row_slice.start + row * m_row_slice.step,
252 m_col_slice.start + col * m_col_slice.step];
253 }
254
262 const Variable& operator[](int row, int col) const {
263 slp_assert(row >= 0 && row < rows());
264 slp_assert(col >= 0 && col < cols());
265 return (*m_mat)[m_row_slice.start + row * m_row_slice.step,
266 m_col_slice.start + col * m_col_slice.step];
267 }
268
276 requires(!std::is_const_v<Mat>)
277 {
278 slp_assert(index >= 0 && index < rows() * cols());
279 return (*this)[index / cols(), index % cols()];
280 }
281
288 const Variable& operator[](int index) const {
289 slp_assert(index >= 0 && index < rows() * cols());
290 return (*this)[index / cols(), index % cols()];
291 }
292
302 VariableBlock<Mat> block(int row_offset, int col_offset, int block_rows,
303 int block_cols) {
304 slp_assert(row_offset >= 0 && row_offset <= rows());
305 slp_assert(col_offset >= 0 && col_offset <= cols());
306 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
307 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
308 return (*this)[Slice{row_offset, row_offset + block_rows, 1},
309 Slice{col_offset, col_offset + block_cols, 1}];
310 }
311
321 const VariableBlock<const Mat> block(int row_offset, int col_offset,
322 int block_rows, int block_cols) const {
323 slp_assert(row_offset >= 0 && row_offset <= rows());
324 slp_assert(col_offset >= 0 && col_offset <= cols());
325 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
326 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
327 return (*this)[Slice{row_offset, row_offset + block_rows, 1},
328 Slice{col_offset, col_offset + block_cols, 1}];
329 }
330
338 VariableBlock<Mat> operator[](Slice row_slice, Slice col_slice) {
339 int row_slice_length = row_slice.adjust(m_row_slice_length);
340 int col_slice_length = col_slice.adjust(m_col_slice_length);
341 return (*this)[row_slice, row_slice_length, col_slice, col_slice_length];
342 }
343
352 Slice col_slice) const {
353 int row_slice_length = row_slice.adjust(m_row_slice_length);
354 int col_slice_length = col_slice.adjust(m_col_slice_length);
355 return (*this)[row_slice, row_slice_length, col_slice, col_slice_length];
356 }
357
370 VariableBlock<Mat> operator[](Slice row_slice, int row_slice_length,
371 Slice col_slice, int col_slice_length) {
372 return VariableBlock{
373 *m_mat,
374 {m_row_slice.start + row_slice.start * m_row_slice.step,
375 m_row_slice.start + row_slice.stop * m_row_slice.step,
376 row_slice.step * m_row_slice.step},
377 row_slice_length,
378 {m_col_slice.start + col_slice.start * m_col_slice.step,
379 m_col_slice.start + col_slice.stop * m_col_slice.step,
380 col_slice.step * m_col_slice.step},
381 col_slice_length};
382 }
383
397 int row_slice_length,
398 Slice col_slice,
399 int col_slice_length) const {
400 return VariableBlock{
401 *m_mat,
402 {m_row_slice.start + row_slice.start * m_row_slice.step,
403 m_row_slice.start + row_slice.stop * m_row_slice.step,
404 row_slice.step * m_row_slice.step},
405 row_slice_length,
406 {m_col_slice.start + col_slice.start * m_col_slice.step,
407 m_col_slice.start + col_slice.stop * m_col_slice.step,
408 col_slice.step * m_col_slice.step},
409 col_slice_length};
410 }
411
419 VariableBlock<Mat> segment(int offset, int length) {
420 slp_assert(cols() == 1);
421 slp_assert(offset >= 0 && offset < rows());
422 slp_assert(length >= 0 && length <= rows() - offset);
423 return block(offset, 0, length, 1);
424 }
425
433 const VariableBlock<Mat> segment(int offset, int length) const {
434 slp_assert(cols() == 1);
435 slp_assert(offset >= 0 && offset < rows());
436 slp_assert(length >= 0 && length <= rows() - offset);
437 return block(offset, 0, length, 1);
438 }
439
447 slp_assert(row >= 0 && row < rows());
448 return block(row, 0, 1, cols());
449 }
450
458 slp_assert(row >= 0 && row < rows());
459 return block(row, 0, 1, cols());
460 }
461
469 slp_assert(col >= 0 && col < cols());
470 return block(0, col, rows(), 1);
471 }
472
480 slp_assert(col >= 0 && col < cols());
481 return block(0, col, rows(), 1);
482 }
483
491 slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
492
493 for (int i = 0; i < rows(); ++i) {
494 for (int j = 0; j < rhs.cols(); ++j) {
495 Variable sum{0.0};
496 for (int k = 0; k < cols(); ++k) {
497 sum += (*this)(i, k) * rhs(k, j);
498 }
499 (*this)(i, j) = sum;
500 }
501 }
502
503 return *this;
504 }
505
513 for (int row = 0; row < rows(); ++row) {
514 for (int col = 0; col < cols(); ++col) {
515 (*this)[row, col] *= rhs;
516 }
517 }
518
519 return *this;
520 }
521
529 slp_assert(rhs.rows() == 1 && rhs.cols() == 1);
530
531 for (int row = 0; row < rows(); ++row) {
532 for (int col = 0; col < cols(); ++col) {
533 (*this)[row, col] /= rhs[0, 0];
534 }
535 }
536
537 return *this;
538 }
539
547 for (int row = 0; row < rows(); ++row) {
548 for (int col = 0; col < cols(); ++col) {
549 (*this)[row, col] /= rhs;
550 }
551 }
552
553 return *this;
554 }
555
563 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
564
565 for (int row = 0; row < rows(); ++row) {
566 for (int col = 0; col < cols(); ++col) {
567 (*this)[row, col] += rhs[row, col];
568 }
569 }
570
571 return *this;
572 }
573
581 slp_assert(rows() == 1 && cols() == 1);
582
583 for (int row = 0; row < rows(); ++row) {
584 for (int col = 0; col < cols(); ++col) {
585 (*this)[row, col] += rhs;
586 }
587 }
588
589 return *this;
590 }
591
599 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
600
601 for (int row = 0; row < rows(); ++row) {
602 for (int col = 0; col < cols(); ++col) {
603 (*this)[row, col] -= rhs[row, col];
604 }
605 }
606
607 return *this;
608 }
609
617 slp_assert(rows() == 1 && cols() == 1);
618
619 for (int row = 0; row < rows(); ++row) {
620 for (int col = 0; col < cols(); ++col) {
621 (*this)[row, col] -= rhs;
622 }
623 }
624
625 return *this;
626 }
627
631 operator Variable() const { // NOLINT
632 slp_assert(rows() == 1 && cols() == 1);
633 return (*this)(0, 0);
634 }
635
641 std::remove_cv_t<Mat> T() const {
642 std::remove_cv_t<Mat> result{Mat::empty, cols(), rows()};
643
644 for (int row = 0; row < rows(); ++row) {
645 for (int col = 0; col < cols(); ++col) {
646 result[col, row] = (*this)[row, col];
647 }
648 }
649
650 return result;
651 }
652
658 int rows() const { return m_row_slice_length; }
659
665 int cols() const { return m_col_slice_length; }
666
674 double value(int row, int col) { return (*this)[row, col].value(); }
675
682 double value(int index) {
683 slp_assert(index >= 0 && index < rows() * cols());
684 return value(index / cols(), index % cols());
685 }
686
692 Eigen::MatrixXd value() {
693 Eigen::MatrixXd result{rows(), cols()};
694
695 for (int row = 0; row < rows(); ++row) {
696 for (int col = 0; col < cols(); ++col) {
697 result(row, col) = value(row, col);
698 }
699 }
700
701 return result;
702 }
703
710 std::remove_cv_t<Mat> cwise_transform(
711 function_ref<Variable(const Variable& x)> unary_op) const {
712 std::remove_cv_t<Mat> result{Mat::empty, rows(), cols()};
713
714 for (int row = 0; row < rows(); ++row) {
715 for (int col = 0; col < cols(); ++col) {
716 result[row, col] = unary_op((*this)[row, col]);
717 }
718 }
719
720 return result;
721 }
722
723#ifndef DOXYGEN_SHOULD_SKIP_THIS
724
725 class iterator {
726 public:
727 using iterator_category = std::bidirectional_iterator_tag;
728 using value_type = Variable;
729 using difference_type = std::ptrdiff_t;
730 using pointer = Variable*;
731 using reference = Variable&;
732
733 constexpr iterator() noexcept = default;
734
735 constexpr iterator(VariableBlock<Mat>* mat, int index) noexcept
736 : m_mat{mat}, m_index{index} {}
737
738 constexpr iterator& operator++() noexcept {
739 ++m_index;
740 return *this;
741 }
742
743 constexpr iterator operator++(int) noexcept {
744 iterator retval = *this;
745 ++(*this);
746 return retval;
747 }
748
749 constexpr iterator& operator--() noexcept {
750 --m_index;
751 return *this;
752 }
753
754 constexpr iterator operator--(int) noexcept {
755 iterator retval = *this;
756 --(*this);
757 return retval;
758 }
759
760 constexpr bool operator==(const iterator&) const noexcept = default;
761
762 constexpr reference operator*() const noexcept { return (*m_mat)[m_index]; }
763
764 private:
765 VariableBlock<Mat>* m_mat = nullptr;
766 int m_index = 0;
767 };
768
769 class const_iterator {
770 public:
771 using iterator_category = std::bidirectional_iterator_tag;
772 using value_type = Variable;
773 using difference_type = std::ptrdiff_t;
774 using pointer = Variable*;
775 using const_reference = const Variable&;
776
777 constexpr const_iterator() noexcept = default;
778
779 constexpr const_iterator(const VariableBlock<Mat>* mat, int index) noexcept
780 : m_mat{mat}, m_index{index} {}
781
782 constexpr const_iterator& operator++() noexcept {
783 ++m_index;
784 return *this;
785 }
786
787 constexpr const_iterator operator++(int) noexcept {
788 const_iterator retval = *this;
789 ++(*this);
790 return retval;
791 }
792
793 constexpr const_iterator& operator--() noexcept {
794 --m_index;
795 return *this;
796 }
797
798 constexpr const_iterator operator--(int) noexcept {
799 iterator retval = *this;
800 --(*this);
801 return retval;
802 }
803
804 constexpr bool operator==(const const_iterator&) const noexcept = default;
805
806 constexpr const_reference operator*() const noexcept {
807 return (*m_mat)[m_index];
808 }
809
810 private:
811 const VariableBlock<Mat>* m_mat = nullptr;
812 int m_index = 0;
813 };
814
815 using reverse_iterator = std::reverse_iterator<iterator>;
816 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
817
818#endif // DOXYGEN_SHOULD_SKIP_THIS
819
825 iterator begin() { return iterator(this, 0); }
826
832 iterator end() { return iterator(this, rows() * cols()); }
833
839 const_iterator begin() const { return const_iterator(this, 0); }
840
846 const_iterator end() const { return const_iterator(this, rows() * cols()); }
847
853 const_iterator cbegin() const { return const_iterator(this, 0); }
854
860 const_iterator cend() const { return const_iterator(this, rows() * cols()); }
861
867 reverse_iterator rbegin() { return reverse_iterator{end()}; }
868
874 reverse_iterator rend() { return reverse_iterator{begin()}; }
875
881 const_reverse_iterator rbegin() const {
882 return const_reverse_iterator{end()};
883 }
884
890 const_reverse_iterator rend() const {
891 return const_reverse_iterator{begin()};
892 }
893
899 const_reverse_iterator crbegin() const {
900 return const_reverse_iterator{cend()};
901 }
902
908 const_reverse_iterator crend() const {
909 return const_reverse_iterator{cbegin()};
910 }
911
917 size_t size() const { return rows() * cols(); }
918
919 private:
920 Mat* m_mat = nullptr;
921
922 Slice m_row_slice;
923 int m_row_slice_length = 0;
924
925 Slice m_col_slice;
926 int m_col_slice_length = 0;
927};
928
929} // namespace slp
Definition slice.hpp:31
int step
Step.
Definition slice.hpp:40
int stop
Stop index (exclusive).
Definition slice.hpp:37
constexpr int adjust(int length)
Definition slice.hpp:134
int start
Start index (inclusive).
Definition slice.hpp:34
Definition variable_block.hpp:24
VariableBlock< Mat > & operator=(const Mat &values)
Definition variable_block.hpp:211
VariableBlock< Mat > & operator*=(const ScalarLike auto &rhs)
Definition variable_block.hpp:512
const Variable & operator[](int row, int col) const
Definition variable_block.hpp:262
const_iterator begin() const
Definition variable_block.hpp:839
const_iterator end() const
Definition variable_block.hpp:846
Eigen::MatrixXd value()
Definition variable_block.hpp:692
const_iterator cbegin() const
Definition variable_block.hpp:853
VariableBlock< Mat > row(int row)
Definition variable_block.hpp:446
VariableBlock< Mat > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:370
VariableBlock< Mat > & operator=(ScalarLike auto value)
Definition variable_block.hpp:148
VariableBlock< Mat > & operator+=(const MatrixLike auto &rhs)
Definition variable_block.hpp:562
std::remove_cv_t< Mat > cwise_transform(function_ref< Variable(const Variable &x)> unary_op) const
Definition variable_block.hpp:710
VariableBlock< Mat > & operator*=(const MatrixLike auto &rhs)
Definition variable_block.hpp:490
iterator begin()
Definition variable_block.hpp:825
VariableBlock(VariableBlock< Mat > &&)=default
VariableBlock< Mat > & operator=(const VariableBlock< Mat > &values)
Definition variable_block.hpp:37
VariableBlock< const Mat > col(int col) const
Definition variable_block.hpp:479
int rows() const
Definition variable_block.hpp:658
VariableBlock< Mat > operator[](Slice row_slice, Slice col_slice)
Definition variable_block.hpp:338
double value(int row, int col)
Definition variable_block.hpp:674
VariableBlock< Mat > col(int col)
Definition variable_block.hpp:468
const VariableBlock< const Mat > operator[](Slice row_slice, Slice col_slice) const
Definition variable_block.hpp:351
VariableBlock< Mat > & operator=(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:176
reverse_iterator rbegin()
Definition variable_block.hpp:867
VariableBlock< Mat > & operator=(Mat &&values)
Definition variable_block.hpp:228
const_iterator cend() const
Definition variable_block.hpp:860
void set_value(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:195
Variable & operator[](int row, int col)
Definition variable_block.hpp:246
VariableBlock< Mat > & operator/=(const MatrixLike auto &rhs)
Definition variable_block.hpp:528
const_reverse_iterator crbegin() const
Definition variable_block.hpp:899
VariableBlock< Mat > & operator-=(const ScalarLike auto &rhs)
Definition variable_block.hpp:616
VariableBlock(Mat &mat, int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:113
VariableBlock< Mat > & operator/=(const ScalarLike auto &rhs)
Definition variable_block.hpp:546
VariableBlock< Mat > segment(int offset, int length)
Definition variable_block.hpp:419
const_reverse_iterator rend() const
Definition variable_block.hpp:890
VariableBlock< Mat > & operator-=(const MatrixLike auto &rhs)
Definition variable_block.hpp:598
VariableBlock< Mat > & operator+=(const ScalarLike auto &rhs)
Definition variable_block.hpp:580
Variable & operator[](int index)
Definition variable_block.hpp:275
iterator end()
Definition variable_block.hpp:832
double value(int index)
Definition variable_block.hpp:682
VariableBlock(Mat &mat)
Definition variable_block.hpp:101
VariableBlock< const Mat > row(int row) const
Definition variable_block.hpp:457
const VariableBlock< Mat > segment(int offset, int length) const
Definition variable_block.hpp:433
int cols() const
Definition variable_block.hpp:665
VariableBlock< Mat > block(int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:302
reverse_iterator rend()
Definition variable_block.hpp:874
size_t size() const
Definition variable_block.hpp:917
std::remove_cv_t< Mat > T() const
Definition variable_block.hpp:641
VariableBlock(const VariableBlock< Mat > &)=default
const Variable & operator[](int index) const
Definition variable_block.hpp:288
VariableBlock< Mat > & operator=(VariableBlock< Mat > &&values)
Definition variable_block.hpp:72
VariableBlock(Mat &mat, Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:132
const VariableBlock< const Mat > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const
Definition variable_block.hpp:396
const_reverse_iterator crend() const
Definition variable_block.hpp:908
const_reverse_iterator rbegin() const
Definition variable_block.hpp:881
const VariableBlock< const Mat > block(int row_offset, int col_offset, int block_rows, int block_cols) const
Definition variable_block.hpp:321
void set_value(double value)
Definition variable_block.hpp:163
Definition variable.hpp:40
Definition function_ref.hpp:13
Definition concepts.hpp:40
Definition concepts.hpp:13