Sleipnir C++ API
Loading...
Searching...
No Matches
variable_block.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <concepts>
6#include <type_traits>
7#include <utility>
8
9#include <Eigen/Core>
10
11#include "sleipnir/autodiff/slice.hpp"
12#include "sleipnir/autodiff/variable.hpp"
13#include "sleipnir/util/assert.hpp"
14#include "sleipnir/util/function_ref.hpp"
15
16namespace slp {
17
23template <typename Mat>
25 public:
30
38 if (this == &values) {
39 return *this;
40 }
41
42 if (m_mat == nullptr) {
43 m_mat = values.m_mat;
44 m_row_slice = values.m_row_slice;
45 m_row_slice_length = values.m_row_slice_length;
46 m_col_slice = values.m_col_slice;
47 m_col_slice_length = values.m_col_slice_length;
48 } else {
49 slp_assert(rows() == values.rows() && cols() == values.cols());
50
51 for (int row = 0; row < rows(); ++row) {
52 for (int col = 0; col < cols(); ++col) {
53 (*this)(row, col) = values(row, col);
54 }
55 }
56 }
57
58 return *this;
59 }
60
65
73 if (this == &values) {
74 return *this;
75 }
76
77 if (m_mat == nullptr) {
78 m_mat = values.m_mat;
79 m_row_slice = values.m_row_slice;
80 m_row_slice_length = values.m_row_slice_length;
81 m_col_slice = values.m_col_slice;
82 m_col_slice_length = values.m_col_slice_length;
83 } else {
84 slp_assert(rows() == values.rows() && cols() == values.cols());
85
86 for (int row = 0; row < rows(); ++row) {
87 for (int col = 0; col < cols(); ++col) {
88 (*this)(row, col) = values(row, col);
89 }
90 }
91 }
92
93 return *this;
94 }
95
101 VariableBlock(Mat& mat) // NOLINT
102 : m_mat{&mat},
103 m_row_slice{0, mat.rows(), 1},
104 m_row_slice_length{m_row_slice.adjust(mat.rows())},
105 m_col_slice{0, mat.cols(), 1},
106 m_col_slice_length{m_col_slice.adjust(mat.cols())} {}
107
117 VariableBlock(Mat& mat, int row_offset, int col_offset, int block_rows,
118 int block_cols)
119 : m_mat{&mat},
120 m_row_slice{row_offset, row_offset + block_rows, 1},
121 m_row_slice_length{m_row_slice.adjust(mat.rows())},
122 m_col_slice{col_offset, col_offset + block_cols, 1},
123 m_col_slice_length{m_col_slice.adjust(mat.cols())} {}
124
136 VariableBlock(Mat& mat, Slice row_slice, int row_slice_length,
137 Slice col_slice, int col_slice_length)
138 : m_mat{&mat},
139 m_row_slice{std::move(row_slice)},
140 m_row_slice_length{row_slice_length},
141 m_col_slice{std::move(col_slice)},
142 m_col_slice_length{col_slice_length} {}
143
153 slp_assert(rows() == 1 && cols() == 1);
154
155 (*this)(0, 0) = value;
156
157 return *this;
158 }
159
167 void set_value(double value) {
168 slp_assert(rows() == 1 && cols() == 1);
169
170 (*this)(0, 0).set_value(value);
171 }
172
179 template <typename Derived>
180 VariableBlock<Mat>& operator=(const Eigen::MatrixBase<Derived>& values) {
181 slp_assert(rows() == values.rows() && cols() == values.cols());
182
183 for (int row = 0; row < rows(); ++row) {
184 for (int col = 0; col < cols(); ++col) {
185 (*this)(row, col) = values(row, col);
186 }
187 }
188
189 return *this;
190 }
191
197 template <typename Derived>
198 requires std::same_as<typename Derived::Scalar, double>
199 void set_value(const Eigen::MatrixBase<Derived>& values) {
200 slp_assert(rows() == values.rows() && cols() == values.cols());
201
202 for (int row = 0; row < rows(); ++row) {
203 for (int col = 0; col < cols(); ++col) {
204 (*this)(row, col).set_value(values(row, col));
205 }
206 }
207 }
208
215 VariableBlock<Mat>& operator=(const Mat& values) {
216 slp_assert(rows() == values.rows() && cols() == values.cols());
217
218 for (int row = 0; row < rows(); ++row) {
219 for (int col = 0; col < cols(); ++col) {
220 (*this)(row, col) = values(row, col);
221 }
222 }
223 return *this;
224 }
225
233 slp_assert(rows() == values.rows() && cols() == values.cols());
234
235 for (int row = 0; row < rows(); ++row) {
236 for (int col = 0; col < cols(); ++col) {
237 (*this)(row, col) = std::move(values(row, col));
238 }
239 }
240 return *this;
241 }
242
251 requires(!std::is_const_v<Mat>)
252 {
253 slp_assert(row >= 0 && row < rows());
254 slp_assert(col >= 0 && col < cols());
255 return (*m_mat)(m_row_slice.start + row * m_row_slice.step,
256 m_col_slice.start + col * m_col_slice.step);
257 }
258
266 const Variable& operator()(int row, int col) const {
267 slp_assert(row >= 0 && row < rows());
268 slp_assert(col >= 0 && col < cols());
269 return (*m_mat)(m_row_slice.start + row * m_row_slice.step,
270 m_col_slice.start + col * m_col_slice.step);
271 }
272
280 requires(!std::is_const_v<Mat>)
281 {
282 slp_assert(row >= 0 && row < rows() * cols());
283 return (*this)(row / cols(), row % cols());
284 }
285
292 const Variable& operator[](int row) const {
293 slp_assert(row >= 0 && row < rows() * cols());
294 return (*this)(row / cols(), row % cols());
295 }
296
306 VariableBlock<Mat> block(int row_offset, int col_offset, int block_rows,
307 int block_cols) {
308 slp_assert(row_offset >= 0 && row_offset <= rows());
309 slp_assert(col_offset >= 0 && col_offset <= cols());
310 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
311 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
312 return (*this)({row_offset, row_offset + block_rows, 1},
313 {col_offset, col_offset + block_cols, 1});
314 }
315
325 const VariableBlock<const Mat> block(int row_offset, int col_offset,
326 int block_rows, int block_cols) const {
327 slp_assert(row_offset >= 0 && row_offset <= rows());
328 slp_assert(col_offset >= 0 && col_offset <= cols());
329 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
330 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
331 return (*this)({row_offset, row_offset + block_rows, 1},
332 {col_offset, col_offset + block_cols, 1});
333 }
334
342 VariableBlock<Mat> operator()(Slice row_slice, Slice col_slice) {
343 int row_slice_length = row_slice.adjust(m_row_slice_length);
344 int col_slice_length = col_slice.adjust(m_col_slice_length);
345 return VariableBlock{
346 *m_mat,
347 {m_row_slice.start + row_slice.start * m_row_slice.step,
348 m_row_slice.start + row_slice.stop, m_row_slice.step * row_slice.step},
349 row_slice_length,
350 {m_col_slice.start + col_slice.start * m_col_slice.step,
351 m_col_slice.start + col_slice.stop, m_col_slice.step * col_slice.step},
352 col_slice_length};
353 }
354
363 Slice col_slice) const {
364 int row_slice_length = row_slice.adjust(m_row_slice_length);
365 int col_slice_length = col_slice.adjust(m_col_slice_length);
366 return VariableBlock{
367 *m_mat,
368 {m_row_slice.start + row_slice.start * m_row_slice.step,
369 m_row_slice.start + row_slice.stop, m_row_slice.step * row_slice.step},
370 row_slice_length,
371 {m_col_slice.start + col_slice.start * m_col_slice.step,
372 m_col_slice.start + col_slice.stop, m_col_slice.step * col_slice.step},
373 col_slice_length};
374 }
375
388 VariableBlock<Mat> operator()(Slice row_slice, int row_slice_length,
389 Slice col_slice, int col_slice_length) {
390 return VariableBlock{
391 *m_mat,
392 {m_row_slice.start + row_slice.start * m_row_slice.step,
393 m_row_slice.start + row_slice.stop, m_row_slice.step * row_slice.step},
394 row_slice_length,
395 {m_col_slice.start + col_slice.start * m_col_slice.step,
396 m_col_slice.start + col_slice.stop, m_col_slice.step * col_slice.step},
397 col_slice_length};
398 }
399
413 int row_slice_length,
414 Slice col_slice,
415 int col_slice_length) const {
416 return VariableBlock{
417 *m_mat,
418 {m_row_slice.start + row_slice.start * m_row_slice.step,
419 m_row_slice.start + row_slice.stop, m_row_slice.step * row_slice.step},
420 row_slice_length,
421 {m_col_slice.start + col_slice.start * m_col_slice.step,
422 m_col_slice.start + col_slice.stop, m_col_slice.step * col_slice.step},
423 col_slice_length};
424 }
425
433 VariableBlock<Mat> segment(int offset, int length) {
434 slp_assert(offset >= 0 && offset < rows() * cols());
435 slp_assert(length >= 0 && length <= rows() * cols() - offset);
436 return block(offset, 0, length, 1);
437 }
438
446 const VariableBlock<Mat> segment(int offset, int length) const {
447 slp_assert(offset >= 0 && offset < rows() * cols());
448 slp_assert(length >= 0 && length <= rows() * cols() - offset);
449 return block(offset, 0, length, 1);
450 }
451
459 slp_assert(row >= 0 && row < rows());
460 return block(row, 0, 1, cols());
461 }
462
470 slp_assert(row >= 0 && row < rows());
471 return block(row, 0, 1, cols());
472 }
473
481 slp_assert(col >= 0 && col < cols());
482 return block(0, col, rows(), 1);
483 }
484
492 slp_assert(col >= 0 && col < cols());
493 return block(0, col, rows(), 1);
494 }
495
503 slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
504
505 for (int i = 0; i < rows(); ++i) {
506 for (int j = 0; j < rhs.cols(); ++j) {
507 Variable sum;
508 for (int k = 0; k < cols(); ++k) {
509 sum += (*this)(i, k) * rhs(k, j);
510 }
511 (*this)(i, j) = sum;
512 }
513 }
514
515 return *this;
516 }
517
525 for (int row = 0; row < rows(); ++row) {
526 for (int col = 0; col < cols(); ++col) {
527 (*this)(row, col) *= rhs;
528 }
529 }
530
531 return *this;
532 }
533
541 slp_assert(rhs.rows() == 1 && rhs.cols() == 1);
542
543 for (int row = 0; row < rows(); ++row) {
544 for (int col = 0; col < cols(); ++col) {
545 (*this)(row, col) /= rhs(0, 0);
546 }
547 }
548
549 return *this;
550 }
551
559 for (int row = 0; row < rows(); ++row) {
560 for (int col = 0; col < cols(); ++col) {
561 (*this)(row, col) /= rhs;
562 }
563 }
564
565 return *this;
566 }
567
575 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
576
577 for (int row = 0; row < rows(); ++row) {
578 for (int col = 0; col < cols(); ++col) {
579 (*this)(row, col) += rhs(row, col);
580 }
581 }
582
583 return *this;
584 }
585
593 slp_assert(rows() == 1 && cols() == 1);
594
595 for (int row = 0; row < rows(); ++row) {
596 for (int col = 0; col < cols(); ++col) {
597 (*this)(row, col) += rhs;
598 }
599 }
600
601 return *this;
602 }
603
611 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
612
613 for (int row = 0; row < rows(); ++row) {
614 for (int col = 0; col < cols(); ++col) {
615 (*this)(row, col) -= rhs(row, col);
616 }
617 }
618
619 return *this;
620 }
621
629 slp_assert(rows() == 1 && cols() == 1);
630
631 for (int row = 0; row < rows(); ++row) {
632 for (int col = 0; col < cols(); ++col) {
633 (*this)(row, col) -= rhs;
634 }
635 }
636
637 return *this;
638 }
639
643 operator Variable() const { // NOLINT
644 slp_assert(rows() == 1 && cols() == 1);
645 return (*this)(0, 0);
646 }
647
653 std::remove_cv_t<Mat> T() const {
654 std::remove_cv_t<Mat> result{Mat::empty, cols(), rows()};
655
656 for (int row = 0; row < rows(); ++row) {
657 for (int col = 0; col < cols(); ++col) {
658 result(col, row) = (*this)(row, col);
659 }
660 }
661
662 return result;
663 }
664
670 int rows() const { return m_row_slice_length; }
671
677 int cols() const { return m_col_slice_length; }
678
686 double value(int row, int col) {
687 slp_assert(row >= 0 && row < rows());
688 slp_assert(col >= 0 && col < cols());
689 return (*m_mat)(m_row_slice.start + row * m_row_slice.step,
690 m_col_slice.start + col * m_col_slice.step)
691 .value();
692 }
693
700 double value(int index) {
701 slp_assert(index >= 0 && index < rows() * cols());
702 return value(index / cols(), index % cols());
703 }
704
710 Eigen::MatrixXd value() {
711 Eigen::MatrixXd result{rows(), cols()};
712
713 for (int row = 0; row < rows(); ++row) {
714 for (int col = 0; col < cols(); ++col) {
715 result(row, col) = value(row, col);
716 }
717 }
718
719 return result;
720 }
721
728 std::remove_cv_t<Mat> cwise_transform(
729 function_ref<Variable(const Variable& x)> unary_op) const {
730 std::remove_cv_t<Mat> result{Mat::empty, rows(), cols()};
731
732 for (int row = 0; row < rows(); ++row) {
733 for (int col = 0; col < cols(); ++col) {
734 result(row, col) = unary_op((*this)(row, col));
735 }
736 }
737
738 return result;
739 }
740
741#ifndef DOXYGEN_SHOULD_SKIP_THIS
742
743 class iterator {
744 public:
745 using iterator_category = std::forward_iterator_tag;
746 using value_type = Variable;
747 using difference_type = std::ptrdiff_t;
748 using pointer = Variable*;
749 using reference = Variable&;
750
751 constexpr iterator() noexcept = default;
752
753 constexpr iterator(VariableBlock<Mat>* mat, int index) noexcept
754 : m_mat{mat}, m_index{index} {}
755
756 constexpr iterator& operator++() noexcept {
757 ++m_index;
758 return *this;
759 }
760
761 constexpr iterator operator++(int) noexcept {
762 iterator retval = *this;
763 ++(*this);
764 return retval;
765 }
766
767 constexpr bool operator==(const iterator&) const noexcept = default;
768
769 constexpr reference operator*() const noexcept { return (*m_mat)[m_index]; }
770
771 private:
772 VariableBlock<Mat>* m_mat = nullptr;
773 int m_index = 0;
774 };
775
776 class const_iterator {
777 public:
778 using iterator_category = std::forward_iterator_tag;
779 using value_type = Variable;
780 using difference_type = std::ptrdiff_t;
781 using pointer = Variable*;
782 using const_reference = const Variable&;
783
784 constexpr const_iterator() noexcept = default;
785
786 constexpr const_iterator(const VariableBlock<Mat>* mat, int index) noexcept
787 : m_mat{mat}, m_index{index} {}
788
789 constexpr const_iterator& operator++() noexcept {
790 ++m_index;
791 return *this;
792 }
793
794 constexpr const_iterator operator++(int) noexcept {
795 const_iterator retval = *this;
796 ++(*this);
797 return retval;
798 }
799
800 constexpr bool operator==(const const_iterator&) const noexcept = default;
801
802 constexpr const_reference operator*() const noexcept {
803 return (*m_mat)[m_index];
804 }
805
806 private:
807 const VariableBlock<Mat>* m_mat = nullptr;
808 int m_index = 0;
809 };
810
811#endif // DOXYGEN_SHOULD_SKIP_THIS
812
818 iterator begin() { return iterator(this, 0); }
819
825 iterator end() { return iterator(this, rows() * cols()); }
826
832 const_iterator begin() const { return const_iterator(this, 0); }
833
839 const_iterator end() const { return const_iterator(this, rows() * cols()); }
840
846 const_iterator cbegin() const { return const_iterator(this, 0); }
847
853 const_iterator cend() const { return const_iterator(this, rows() * cols()); }
854
860 size_t size() const { return rows() * cols(); }
861
862 private:
863 Mat* m_mat = nullptr;
864
865 Slice m_row_slice;
866 int m_row_slice_length = 0;
867
868 Slice m_col_slice;
869 int m_col_slice_length = 0;
870};
871
872} // namespace slp
Definition slice.hpp:31
int step
Step.
Definition slice.hpp:40
int stop
Stop index (exclusive).
Definition slice.hpp:37
constexpr int adjust(int length)
Definition slice.hpp:134
int start
Start index (inclusive).
Definition slice.hpp:34
Definition variable_block.hpp:24
VariableBlock< Mat > & operator=(const Mat &values)
Definition variable_block.hpp:215
VariableBlock< Mat > & operator*=(const ScalarLike auto &rhs)
Definition variable_block.hpp:524
const_iterator begin() const
Definition variable_block.hpp:832
const_iterator end() const
Definition variable_block.hpp:839
const Variable & operator()(int row, int col) const
Definition variable_block.hpp:266
Eigen::MatrixXd value()
Definition variable_block.hpp:710
const_iterator cbegin() const
Definition variable_block.hpp:846
VariableBlock< Mat > row(int row)
Definition variable_block.hpp:458
Variable & operator[](int row)
Definition variable_block.hpp:279
VariableBlock< Mat > & operator=(ScalarLike auto value)
Definition variable_block.hpp:152
VariableBlock< Mat > & operator+=(const MatrixLike auto &rhs)
Definition variable_block.hpp:574
std::remove_cv_t< Mat > cwise_transform(function_ref< Variable(const Variable &x)> unary_op) const
Definition variable_block.hpp:728
VariableBlock< Mat > & operator*=(const MatrixLike auto &rhs)
Definition variable_block.hpp:502
iterator begin()
Definition variable_block.hpp:818
VariableBlock(VariableBlock< Mat > &&)=default
VariableBlock< Mat > & operator=(const VariableBlock< Mat > &values)
Definition variable_block.hpp:37
VariableBlock< const Mat > col(int col) const
Definition variable_block.hpp:491
Variable & operator()(int row, int col)
Definition variable_block.hpp:250
int rows() const
Definition variable_block.hpp:670
double value(int row, int col)
Definition variable_block.hpp:686
const VariableBlock< const Mat > operator()(Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const
Definition variable_block.hpp:412
const Variable & operator[](int row) const
Definition variable_block.hpp:292
VariableBlock< Mat > col(int col)
Definition variable_block.hpp:480
VariableBlock< Mat > & operator=(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:180
VariableBlock< Mat > operator()(Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:388
VariableBlock< Mat > & operator=(Mat &&values)
Definition variable_block.hpp:232
const_iterator cend() const
Definition variable_block.hpp:853
void set_value(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:199
VariableBlock< Mat > & operator/=(const MatrixLike auto &rhs)
Definition variable_block.hpp:540
VariableBlock< Mat > & operator-=(const ScalarLike auto &rhs)
Definition variable_block.hpp:628
VariableBlock(Mat &mat, int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:117
VariableBlock< Mat > & operator/=(const ScalarLike auto &rhs)
Definition variable_block.hpp:558
VariableBlock< Mat > segment(int offset, int length)
Definition variable_block.hpp:433
VariableBlock< Mat > & operator-=(const MatrixLike auto &rhs)
Definition variable_block.hpp:610
VariableBlock< Mat > & operator+=(const ScalarLike auto &rhs)
Definition variable_block.hpp:592
iterator end()
Definition variable_block.hpp:825
double value(int index)
Definition variable_block.hpp:700
VariableBlock(Mat &mat)
Definition variable_block.hpp:101
VariableBlock< const Mat > row(int row) const
Definition variable_block.hpp:469
const VariableBlock< const Mat > operator()(Slice row_slice, Slice col_slice) const
Definition variable_block.hpp:362
const VariableBlock< Mat > segment(int offset, int length) const
Definition variable_block.hpp:446
int cols() const
Definition variable_block.hpp:677
VariableBlock< Mat > block(int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:306
size_t size() const
Definition variable_block.hpp:860
std::remove_cv_t< Mat > T() const
Definition variable_block.hpp:653
VariableBlock(const VariableBlock< Mat > &)=default
VariableBlock< Mat > & operator=(VariableBlock< Mat > &&values)
Definition variable_block.hpp:72
VariableBlock(Mat &mat, Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:136
VariableBlock< Mat > operator()(Slice row_slice, Slice col_slice)
Definition variable_block.hpp:342
const VariableBlock< const Mat > block(int row_offset, int col_offset, int block_rows, int block_cols) const
Definition variable_block.hpp:325
void set_value(double value)
Definition variable_block.hpp:167
Definition variable.hpp:41
Definition function_ref.hpp:13
Definition concepts.hpp:29
Definition concepts.hpp:13