Sleipnir C++ API
Loading...
Searching...
No Matches
variable_block.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <concepts>
6#include <type_traits>
7#include <utility>
8
9#include <Eigen/Core>
10
11#include "sleipnir/autodiff/sleipnir_base.hpp"
12#include "sleipnir/autodiff/slice.hpp"
13#include "sleipnir/autodiff/variable.hpp"
14#include "sleipnir/util/assert.hpp"
15#include "sleipnir/util/empty.hpp"
16#include "sleipnir/util/function_ref.hpp"
17
18namespace slp {
19
23template <typename Mat>
25 public:
27 using Scalar = typename Mat::Scalar;
28
31
37 if (this == &values) {
38 return *this;
39 }
40
41 if (m_mat == nullptr) {
42 m_mat = values.m_mat;
43 m_row_slice = values.m_row_slice;
44 m_row_slice_length = values.m_row_slice_length;
45 m_col_slice = values.m_col_slice;
46 m_col_slice_length = values.m_col_slice_length;
47 } else {
48 slp_assert(rows() == values.rows() && cols() == values.cols());
49
50 for (int row = 0; row < rows(); ++row) {
51 for (int col = 0; col < cols(); ++col) {
52 (*this)[row, col] = values[row, col];
53 }
54 }
55 }
56
57 return *this;
58 }
59
62
68 if (this == &values) {
69 return *this;
70 }
71
72 if (m_mat == nullptr) {
73 m_mat = values.m_mat;
74 m_row_slice = values.m_row_slice;
75 m_row_slice_length = values.m_row_slice_length;
76 m_col_slice = values.m_col_slice;
77 m_col_slice_length = values.m_col_slice_length;
78 } else {
79 slp_assert(rows() == values.rows() && cols() == values.cols());
80
81 for (int row = 0; row < rows(); ++row) {
82 for (int col = 0; col < cols(); ++col) {
83 (*this)[row, col] = values[row, col];
84 }
85 }
86 }
87
88 return *this;
89 }
90
94 // NOLINTNEXTLINE (google-explicit-constructor)
96
105 int block_cols)
106 : m_mat{&mat},
107 m_row_slice{row_offset, row_offset + block_rows, 1},
108 m_row_slice_length{m_row_slice.adjust(mat.rows())},
109 m_col_slice{col_offset, col_offset + block_cols, 1},
110 m_col_slice_length{m_col_slice.adjust(mat.cols())} {}
111
123 : m_mat{&mat},
124 m_row_slice{std::move(row_slice)},
125 m_row_slice_length{row_slice_length},
126 m_col_slice{std::move(col_slice)},
127 m_col_slice_length{col_slice_length} {}
128
136 slp_assert(rows() == 1 && cols() == 1);
137
138 (*this)[0, 0] = value;
139
140 return *this;
141 }
142
149 slp_assert(rows() == 1 && cols() == 1);
150
151 (*this)[0, 0].set_value(value);
152 }
153
158 template <typename Derived>
159 VariableBlock<Mat>& operator=(const Eigen::MatrixBase<Derived>& values) {
160 slp_assert(rows() == values.rows() && cols() == values.cols());
161
162 for (int row = 0; row < rows(); ++row) {
163 for (int col = 0; col < cols(); ++col) {
164 (*this)[row, col] = values[row, col];
165 }
166 }
167
168 return *this;
169 }
170
174 template <typename Derived>
175 requires std::same_as<typename Derived::Scalar, Scalar>
176 void set_value(const Eigen::MatrixBase<Derived>& values) {
177 slp_assert(rows() == values.rows() && cols() == values.cols());
178
179 for (int row = 0; row < rows(); ++row) {
180 for (int col = 0; col < cols(); ++col) {
181 (*this)[row, col].set_value(values[row, col]);
182 }
183 }
184 }
185
191 slp_assert(rows() == values.rows() && cols() == values.cols());
192
193 for (int row = 0; row < rows(); ++row) {
194 for (int col = 0; col < cols(); ++col) {
195 (*this)[row, col] = values[row, col];
196 }
197 }
198 return *this;
199 }
200
206 slp_assert(rows() == values.rows() && cols() == values.cols());
207
208 for (int row = 0; row < rows(); ++row) {
209 for (int col = 0; col < cols(); ++col) {
210 (*this)[row, col] = std::move(values[row, col]);
211 }
212 }
213 return *this;
214 }
215
222 requires(!std::is_const_v<Mat>)
223 {
224 slp_assert(row >= 0 && row < rows());
225 slp_assert(col >= 0 && col < cols());
226 return (*m_mat)[m_row_slice.start + row * m_row_slice.step,
227 m_col_slice.start + col * m_col_slice.step];
228 }
229
235 const Variable<Scalar>& operator[](int row, int col) const {
236 slp_assert(row >= 0 && row < rows());
237 slp_assert(col >= 0 && col < cols());
238 return (*m_mat)[m_row_slice.start + row * m_row_slice.step,
239 m_col_slice.start + col * m_col_slice.step];
240 }
241
247 requires(!std::is_const_v<Mat>)
248 {
249 slp_assert(index >= 0 && index < rows() * cols());
250 return (*this)[index / cols(), index % cols()];
251 }
252
257 const Variable<Scalar>& operator[](int index) const {
258 slp_assert(index >= 0 && index < rows() * cols());
259 return (*this)[index / cols(), index % cols()];
260 }
261
270 int block_cols) {
271 slp_assert(row_offset >= 0 && row_offset <= rows());
272 slp_assert(col_offset >= 0 && col_offset <= cols());
273 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
274 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
275 return (*this)[Slice{row_offset, row_offset + block_rows, 1},
277 }
278
287 int block_rows, int block_cols) const {
288 slp_assert(row_offset >= 0 && row_offset <= rows());
289 slp_assert(col_offset >= 0 && col_offset <= cols());
290 slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset);
291 slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset);
292 return (*this)[Slice{row_offset, row_offset + block_rows, 1},
294 }
295
302 int row_slice_length = row_slice.adjust(m_row_slice_length);
303 int col_slice_length = col_slice.adjust(m_col_slice_length);
305 }
306
313 Slice col_slice) const {
314 int row_slice_length = row_slice.adjust(m_row_slice_length);
315 int col_slice_length = col_slice.adjust(m_col_slice_length);
317 }
318
331 return VariableBlock{
332 *m_mat,
333 {m_row_slice.start + row_slice.start * m_row_slice.step,
334 m_row_slice.start + row_slice.stop * m_row_slice.step,
335 row_slice.step * m_row_slice.step},
337 {m_col_slice.start + col_slice.start * m_col_slice.step,
338 m_col_slice.start + col_slice.stop * m_col_slice.step,
339 col_slice.step * m_col_slice.step},
341 }
342
356 int col_slice_length) const {
357 return VariableBlock{
358 *m_mat,
359 {m_row_slice.start + row_slice.start * m_row_slice.step,
360 m_row_slice.start + row_slice.stop * m_row_slice.step,
361 row_slice.step * m_row_slice.step},
363 {m_col_slice.start + col_slice.start * m_col_slice.step,
364 m_col_slice.start + col_slice.stop * m_col_slice.step,
365 col_slice.step * m_col_slice.step},
367 }
368
375 slp_assert(cols() == 1);
376 slp_assert(offset >= 0 && offset < rows());
377 slp_assert(length >= 0 && length <= rows() - offset);
378 return block(offset, 0, length, 1);
379 }
380
386 const VariableBlock<Mat> segment(int offset, int length) const {
387 slp_assert(cols() == 1);
388 slp_assert(offset >= 0 && offset < rows());
389 slp_assert(length >= 0 && length <= rows() - offset);
390 return block(offset, 0, length, 1);
391 }
392
398 slp_assert(row >= 0 && row < rows());
399 return block(row, 0, 1, cols());
400 }
401
407 slp_assert(row >= 0 && row < rows());
408 return block(row, 0, 1, cols());
409 }
410
416 slp_assert(col >= 0 && col < cols());
417 return block(0, col, rows(), 1);
418 }
419
425 slp_assert(col >= 0 && col < cols());
426 return block(0, col, rows(), 1);
427 }
428
434 slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
435
436 for (int i = 0; i < rows(); ++i) {
437 for (int j = 0; j < rhs.cols(); ++j) {
438 Variable sum{Scalar(0)};
439 for (int k = 0; k < cols(); ++k) {
440 sum += (*this)(i, k) * rhs(k, j);
441 }
442 (*this)(i, j) = sum;
443 }
444 }
445
446 return *this;
447 }
448
454 for (int row = 0; row < rows(); ++row) {
455 for (int col = 0; col < cols(); ++col) {
456 (*this)[row, col] *= rhs;
457 }
458 }
459
460 return *this;
461 }
462
468 slp_assert(rhs.rows() == 1 && rhs.cols() == 1);
469
470 for (int row = 0; row < rows(); ++row) {
471 for (int col = 0; col < cols(); ++col) {
472 (*this)[row, col] /= rhs[0, 0];
473 }
474 }
475
476 return *this;
477 }
478
484 for (int row = 0; row < rows(); ++row) {
485 for (int col = 0; col < cols(); ++col) {
486 (*this)[row, col] /= rhs;
487 }
488 }
489
490 return *this;
491 }
492
498 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
499
500 for (int row = 0; row < rows(); ++row) {
501 for (int col = 0; col < cols(); ++col) {
502 (*this)[row, col] += rhs[row, col];
503 }
504 }
505
506 return *this;
507 }
508
514 slp_assert(rows() == 1 && cols() == 1);
515
516 for (int row = 0; row < rows(); ++row) {
517 for (int col = 0; col < cols(); ++col) {
518 (*this)[row, col] += rhs;
519 }
520 }
521
522 return *this;
523 }
524
530 slp_assert(rows() == rhs.rows() && cols() == rhs.cols());
531
532 for (int row = 0; row < rows(); ++row) {
533 for (int col = 0; col < cols(); ++col) {
534 (*this)[row, col] -= rhs[row, col];
535 }
536 }
537
538 return *this;
539 }
540
546 slp_assert(rows() == 1 && cols() == 1);
547
548 for (int row = 0; row < rows(); ++row) {
549 for (int col = 0; col < cols(); ++col) {
550 (*this)[row, col] -= rhs;
551 }
552 }
553
554 return *this;
555 }
556
558 // NOLINTNEXTLINE (google-explicit-constructor)
559 operator Variable<Scalar>() const {
560 slp_assert(rows() == 1 && cols() == 1);
561 return (*this)[0, 0];
562 }
563
567 std::remove_cv_t<Mat> T() const {
568 std::remove_cv_t<Mat> result{detail::empty, cols(), rows()};
569
570 for (int row = 0; row < rows(); ++row) {
571 for (int col = 0; col < cols(); ++col) {
572 result[col, row] = (*this)[row, col];
573 }
574 }
575
576 return result;
577 }
578
582 int rows() const { return m_row_slice_length; }
583
587 int cols() const { return m_col_slice_length; }
588
594 Scalar value(int row, int col) { return (*this)[row, col].value(); }
595
600 Scalar value(int index) {
601 slp_assert(index >= 0 && index < rows() * cols());
602 return value(index / cols(), index % cols());
603 }
604
608 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> value() {
609 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> result{rows(),
610 cols()};
611
612 for (int row = 0; row < rows(); ++row) {
613 for (int col = 0; col < cols(); ++col) {
614 result[row, col] = value(row, col);
615 }
616 }
617
618 return result;
619 }
620
625 std::remove_cv_t<Mat> cwise_transform(
627 const {
628 std::remove_cv_t<Mat> result{detail::empty, rows(), cols()};
629
630 for (int row = 0; row < rows(); ++row) {
631 for (int col = 0; col < cols(); ++col) {
632 result[row, col] = unary_op((*this)[row, col]);
633 }
634 }
635
636 return result;
637 }
638
639#ifndef DOXYGEN_SHOULD_SKIP_THIS
640
641 class iterator {
642 public:
643 using iterator_category = std::bidirectional_iterator_tag;
644 using value_type = Variable<Scalar>;
645 using difference_type = std::ptrdiff_t;
646 using pointer = Variable<Scalar>*;
648
649 constexpr iterator() noexcept = default;
650
651 constexpr iterator(VariableBlock<Mat>* mat, int index) noexcept
652 : m_mat{mat}, m_index{index} {}
653
654 constexpr iterator& operator++() noexcept {
655 ++m_index;
656 return *this;
657 }
658
659 constexpr iterator operator++(int) noexcept {
660 iterator retval = *this;
661 ++(*this);
662 return retval;
663 }
664
665 constexpr iterator& operator--() noexcept {
666 --m_index;
667 return *this;
668 }
669
670 constexpr iterator operator--(int) noexcept {
671 iterator retval = *this;
672 --(*this);
673 return retval;
674 }
675
676 constexpr bool operator==(const iterator&) const noexcept = default;
677
678 constexpr reference operator*() const noexcept { return (*m_mat)[m_index]; }
679
680 private:
681 VariableBlock<Mat>* m_mat = nullptr;
682 int m_index = 0;
683 };
684
685 class const_iterator {
686 public:
687 using iterator_category = std::bidirectional_iterator_tag;
688 using value_type = Variable<Scalar>;
689 using difference_type = std::ptrdiff_t;
690 using pointer = Variable<Scalar>*;
691 using const_reference = const Variable<Scalar>&;
692
693 constexpr const_iterator() noexcept = default;
694
695 constexpr const_iterator(const VariableBlock<Mat>* mat, int index) noexcept
696 : m_mat{mat}, m_index{index} {}
697
698 constexpr const_iterator& operator++() noexcept {
699 ++m_index;
700 return *this;
701 }
702
703 constexpr const_iterator operator++(int) noexcept {
704 const_iterator retval = *this;
705 ++(*this);
706 return retval;
707 }
708
709 constexpr const_iterator& operator--() noexcept {
710 --m_index;
711 return *this;
712 }
713
714 constexpr const_iterator operator--(int) noexcept {
715 iterator retval = *this;
716 --(*this);
717 return retval;
718 }
719
720 constexpr bool operator==(const const_iterator&) const noexcept = default;
721
722 constexpr const_reference operator*() const noexcept {
723 return (*m_mat)[m_index];
724 }
725
726 private:
727 const VariableBlock<Mat>* m_mat = nullptr;
728 int m_index = 0;
729 };
730
731 using reverse_iterator = std::reverse_iterator<iterator>;
732 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
733
734#endif // DOXYGEN_SHOULD_SKIP_THIS
735
739 iterator begin() { return iterator(this, 0); }
740
744 iterator end() { return iterator(this, rows() * cols()); }
745
749 const_iterator begin() const { return const_iterator(this, 0); }
750
754 const_iterator end() const { return const_iterator(this, rows() * cols()); }
755
759 const_iterator cbegin() const { return const_iterator(this, 0); }
760
764 const_iterator cend() const { return const_iterator(this, rows() * cols()); }
765
769 reverse_iterator rbegin() { return reverse_iterator{end()}; }
770
774 reverse_iterator rend() { return reverse_iterator{begin()}; }
775
779 const_reverse_iterator rbegin() const {
780 return const_reverse_iterator{end()};
781 }
782
786 const_reverse_iterator rend() const {
787 return const_reverse_iterator{begin()};
788 }
789
793 const_reverse_iterator crbegin() const {
794 return const_reverse_iterator{cend()};
795 }
796
800 const_reverse_iterator crend() const {
801 return const_reverse_iterator{cbegin()};
802 }
803
807 size_t size() const { return rows() * cols(); }
808
809 private:
810 Mat* m_mat = nullptr;
811
812 Slice m_row_slice;
813 int m_row_slice_length = 0;
814
815 Slice m_col_slice;
816 int m_col_slice_length = 0;
817};
818
819} // namespace slp
Definition intrusive_shared_ptr.hpp:27
Definition sleipnir_base.hpp:9
Represents a sequence of elements in an iterable object.
Definition slice.hpp:25
int step
Step.
Definition slice.hpp:34
int start
Start index (inclusive).
Definition slice.hpp:28
Definition variable_block.hpp:24
VariableBlock< Mat > & operator=(const Mat &values)
Definition variable_block.hpp:190
VariableBlock< Mat > & operator*=(const ScalarLike auto &rhs)
Definition variable_block.hpp:453
const_iterator begin() const
Definition variable_block.hpp:749
const_iterator end() const
Definition variable_block.hpp:754
void set_value(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:176
const_iterator cbegin() const
Definition variable_block.hpp:759
void set_value(Scalar value)
Definition variable_block.hpp:148
VariableBlock< Mat > row(int row)
Definition variable_block.hpp:397
VariableBlock< Mat > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:329
VariableBlock< Mat > & operator=(ScalarLike auto value)
Definition variable_block.hpp:135
Variable< Scalar > & operator[](int index)
Definition variable_block.hpp:246
Variable< Scalar > & operator[](int row, int col)
Definition variable_block.hpp:221
VariableBlock< Mat > & operator+=(const MatrixLike auto &rhs)
Definition variable_block.hpp:497
VariableBlock< Mat > & operator*=(const MatrixLike auto &rhs)
Definition variable_block.hpp:433
iterator begin()
Definition variable_block.hpp:739
VariableBlock(VariableBlock< Mat > &&)=default
Move constructor.
VariableBlock< Mat > & operator=(const VariableBlock< Mat > &values)
Definition variable_block.hpp:36
const Variable< Scalar > & operator[](int index) const
Definition variable_block.hpp:257
VariableBlock< const Mat > col(int col) const
Definition variable_block.hpp:424
int rows() const
Definition variable_block.hpp:582
VariableBlock< Mat > operator[](Slice row_slice, Slice col_slice)
Definition variable_block.hpp:301
VariableBlock< Mat > col(int col)
Definition variable_block.hpp:415
const VariableBlock< const Mat > operator[](Slice row_slice, Slice col_slice) const
Definition variable_block.hpp:312
VariableBlock< Mat > & operator=(const Eigen::MatrixBase< Derived > &values)
Definition variable_block.hpp:159
Scalar value(int index)
Definition variable_block.hpp:600
reverse_iterator rbegin()
Definition variable_block.hpp:769
VariableBlock< Mat > & operator=(Mat &&values)
Definition variable_block.hpp:205
const_iterator cend() const
Definition variable_block.hpp:764
VariableBlock< Mat > & operator/=(const MatrixLike auto &rhs)
Definition variable_block.hpp:467
const Variable< Scalar > & operator[](int row, int col) const
Definition variable_block.hpp:235
const_reverse_iterator crbegin() const
Definition variable_block.hpp:793
VariableBlock< Mat > & operator-=(const ScalarLike auto &rhs)
Definition variable_block.hpp:545
VariableBlock(Mat &mat, int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:104
VariableBlock< Mat > & operator/=(const ScalarLike auto &rhs)
Definition variable_block.hpp:483
VariableBlock< Mat > segment(int offset, int length)
Definition variable_block.hpp:374
const_reverse_iterator rend() const
Definition variable_block.hpp:786
VariableBlock< Mat > & operator-=(const MatrixLike auto &rhs)
Definition variable_block.hpp:529
VariableBlock< Mat > & operator+=(const ScalarLike auto &rhs)
Definition variable_block.hpp:513
Eigen::Matrix< Scalar, Eigen::Dynamic, Eigen::Dynamic > value()
Definition variable_block.hpp:608
iterator end()
Definition variable_block.hpp:744
VariableBlock(Mat &mat)
Definition variable_block.hpp:95
VariableBlock< const Mat > row(int row) const
Definition variable_block.hpp:406
const VariableBlock< Mat > segment(int offset, int length) const
Definition variable_block.hpp:386
Scalar value(int row, int col)
Definition variable_block.hpp:594
int cols() const
Definition variable_block.hpp:587
VariableBlock< Mat > block(int row_offset, int col_offset, int block_rows, int block_cols)
Definition variable_block.hpp:269
reverse_iterator rend()
Definition variable_block.hpp:774
size_t size() const
Definition variable_block.hpp:807
std::remove_cv_t< Mat > T() const
Definition variable_block.hpp:567
VariableBlock(const VariableBlock< Mat > &)=default
Copy constructor.
std::remove_cv_t< Mat > cwise_transform(function_ref< Variable< Scalar >(const Variable< Scalar > &x)> unary_op) const
Definition variable_block.hpp:625
VariableBlock< Mat > & operator=(VariableBlock< Mat > &&values)
Definition variable_block.hpp:67
VariableBlock(Mat &mat, Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length)
Definition variable_block.hpp:121
const VariableBlock< const Mat > operator[](Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const
Definition variable_block.hpp:353
const_reverse_iterator crend() const
Definition variable_block.hpp:800
const_reverse_iterator rbegin() const
Definition variable_block.hpp:779
const VariableBlock< const Mat > block(int row_offset, int col_offset, int block_rows, int block_cols) const
Definition variable_block.hpp:286
typename Mat::Scalar Scalar
Scalar type alias.
Definition variable_block.hpp:27
Definition variable.hpp:47
Definition function_ref.hpp:13
Definition concepts.hpp:18
Definition concepts.hpp:24