Sleipnir C++ API
Loading...
Searching...
No Matches
VariableBlock.hpp
Go to the documentation of this file.
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <concepts>
6#include <type_traits>
7#include <utility>
8
9#include <Eigen/Core>
10
15
16namespace sleipnir {
17
23template <typename Mat>
25 public:
27
34 if (this == &values) {
35 return *this;
36 }
37
38 if (m_mat == nullptr) {
39 m_mat = values.m_mat;
40 m_rowSlice = values.m_rowSlice;
41 m_rowSliceLength = values.m_rowSliceLength;
42 m_colSlice = values.m_colSlice;
43 m_colSliceLength = values.m_colSliceLength;
44 } else {
45 Assert(Rows() == values.Rows());
46 Assert(Cols() == values.Cols());
47
48 for (int row = 0; row < Rows(); ++row) {
49 for (int col = 0; col < Cols(); ++col) {
50 (*this)(row, col) = values(row, col);
51 }
52 }
53 }
54
55 return *this;
56 }
57
59
66 if (this == &values) {
67 return *this;
68 }
69
70 if (m_mat == nullptr) {
71 m_mat = values.m_mat;
72 m_rowSlice = values.m_rowSlice;
73 m_rowSliceLength = values.m_rowSliceLength;
74 m_colSlice = values.m_colSlice;
75 m_colSliceLength = values.m_colSliceLength;
76 } else {
77 Assert(Rows() == values.Rows());
78 Assert(Cols() == values.Cols());
79
80 for (int row = 0; row < Rows(); ++row) {
81 for (int col = 0; col < Cols(); ++col) {
82 (*this)(row, col) = values(row, col);
83 }
84 }
85 }
86
87 return *this;
88 }
89
95 VariableBlock(Mat& mat) // NOLINT
96 : m_mat{&mat},
97 m_rowSlice{0, mat.Rows(), 1},
98 m_rowSliceLength{m_rowSlice.Adjust(mat.Rows())},
99 m_colSlice{0, mat.Cols(), 1},
100 m_colSliceLength{m_colSlice.Adjust(mat.Cols())} {}
101
112 int blockCols)
113 : m_mat{&mat},
114 m_rowSlice{rowOffset, rowOffset + blockRows, 1},
115 m_rowSliceLength{m_rowSlice.Adjust(mat.Rows())},
116 m_colSlice{colOffset, colOffset + blockCols, 1},
117 m_colSliceLength{m_colSlice.Adjust(mat.Cols())} {}
118
131 int colSliceLength)
132 : m_mat{&mat},
133 m_rowSlice{std::move(rowSlice)},
134 m_rowSliceLength{rowSliceLength},
135 m_colSlice{std::move(colSlice)},
136 m_colSliceLength{colSliceLength} {}
137
144 Assert(Rows() == 1 && Cols() == 1);
145
146 (*this)(0, 0) = value;
147
148 return *this;
149 }
150
158 void SetValue(double value) {
159 Assert(Rows() == 1 && Cols() == 1);
160
161 (*this)(0, 0).SetValue(value);
162 }
163
169 template <typename Derived>
170 VariableBlock<Mat>& operator=(const Eigen::MatrixBase<Derived>& values) {
171 Assert(Rows() == values.rows());
172 Assert(Cols() == values.cols());
173
174 for (int row = 0; row < Rows(); ++row) {
175 for (int col = 0; col < Cols(); ++col) {
176 (*this)(row, col) = values(row, col);
177 }
178 }
179
180 return *this;
181 }
182
188 template <typename Derived>
189 requires std::same_as<typename Derived::Scalar, double>
190 void SetValue(const Eigen::MatrixBase<Derived>& values) {
191 Assert(Rows() == values.rows());
192 Assert(Cols() == values.cols());
193
194 for (int row = 0; row < Rows(); ++row) {
195 for (int col = 0; col < Cols(); ++col) {
196 (*this)(row, col).SetValue(values(row, col));
197 }
198 }
199 }
200
207 Assert(Rows() == values.Rows());
208 Assert(Cols() == values.Cols());
209
210 for (int row = 0; row < Rows(); ++row) {
211 for (int col = 0; col < Cols(); ++col) {
212 (*this)(row, col) = values(row, col);
213 }
214 }
215 return *this;
216 }
217
224 Assert(Rows() == values.Rows());
225 Assert(Cols() == values.Cols());
226
227 for (int row = 0; row < Rows(); ++row) {
228 for (int col = 0; col < Cols(); ++col) {
229 (*this)(row, col) = std::move(values(row, col));
230 }
231 }
232 return *this;
233 }
234
241 Variable& operator()(int row, int col)
242 requires(!std::is_const_v<Mat>)
243 {
244 Assert(row >= 0 && row < Rows());
245 Assert(col >= 0 && col < Cols());
246 return (*m_mat)(m_rowSlice.start + row * m_rowSlice.step,
247 m_colSlice.start + col * m_colSlice.step);
248 }
249
256 const Variable& operator()(int row, int col) const {
257 Assert(row >= 0 && row < Rows());
258 Assert(col >= 0 && col < Cols());
259 return (*m_mat)(m_rowSlice.start + row * m_rowSlice.step,
260 m_colSlice.start + col * m_colSlice.step);
261 }
262
269 requires(!std::is_const_v<Mat>)
270 {
271 Assert(row >= 0 && row < Rows() * Cols());
272 return (*this)(row / Cols(), row % Cols());
273 }
274
280 const Variable& operator()(int row) const {
281 Assert(row >= 0 && row < Rows() * Cols());
282 return (*this)(row / Cols(), row % Cols());
283 }
284
294 int blockCols) {
295 Assert(rowOffset >= 0 && rowOffset <= Rows());
296 Assert(colOffset >= 0 && colOffset <= Cols());
297 Assert(blockRows >= 0 && blockRows <= Rows() - rowOffset);
298 Assert(blockCols >= 0 && blockCols <= Cols() - colOffset);
299 return (*this)({rowOffset, rowOffset + blockRows, 1},
301 }
302
312 int blockRows, int blockCols) const {
313 Assert(rowOffset >= 0 && rowOffset <= Rows());
314 Assert(colOffset >= 0 && colOffset <= Cols());
315 Assert(blockRows >= 0 && blockRows <= Rows() - rowOffset);
316 Assert(blockCols >= 0 && blockCols <= Cols() - colOffset);
317 return (*this)({rowOffset, rowOffset + blockRows, 1},
319 }
320
328 int rowSliceLength = rowSlice.Adjust(m_rowSliceLength);
329 int colSliceLength = colSlice.Adjust(m_colSliceLength);
330 return VariableBlock{
331 *m_mat,
332 {m_rowSlice.start + rowSlice.start * m_rowSlice.step,
333 m_rowSlice.start + rowSlice.stop, m_rowSlice.step * rowSlice.step},
335 {m_colSlice.start + colSlice.start * m_colSlice.step,
336 m_colSlice.start + colSlice.stop, m_colSlice.step * colSlice.step},
338 }
339
347 Slice colSlice) const {
348 int rowSliceLength = rowSlice.Adjust(m_rowSliceLength);
349 int colSliceLength = colSlice.Adjust(m_colSliceLength);
350 return VariableBlock{
351 *m_mat,
352 {m_rowSlice.start + rowSlice.start * m_rowSlice.step,
353 m_rowSlice.start + rowSlice.stop, m_rowSlice.step * rowSlice.step},
355 {m_colSlice.start + colSlice.start * m_colSlice.step,
356 m_colSlice.start + colSlice.stop, m_colSlice.step * colSlice.step},
358 }
359
373 return VariableBlock{
374 *m_mat,
375 {m_rowSlice.start + rowSlice.start * m_rowSlice.step,
376 m_rowSlice.start + rowSlice.stop, m_rowSlice.step * rowSlice.step},
378 {m_colSlice.start + colSlice.start * m_colSlice.step,
379 m_colSlice.start + colSlice.stop, m_colSlice.step * colSlice.step},
381 }
382
396 int colSliceLength) const {
397 return VariableBlock{
398 *m_mat,
399 {m_rowSlice.start + rowSlice.start * m_rowSlice.step,
400 m_rowSlice.start + rowSlice.stop, m_rowSlice.step * rowSlice.step},
402 {m_colSlice.start + colSlice.start * m_colSlice.step,
403 m_colSlice.start + colSlice.stop, m_colSlice.step * colSlice.step},
405 }
406
414 Assert(offset >= 0 && offset < Rows() * Cols());
415 Assert(length >= 0 && length <= Rows() * Cols() - offset);
416 return Block(offset, 0, length, 1);
417 }
418
425 const VariableBlock<Mat> Segment(int offset, int length) const {
426 Assert(offset >= 0 && offset < Rows() * Cols());
427 Assert(length >= 0 && length <= Rows() * Cols() - offset);
428 return Block(offset, 0, length, 1);
429 }
430
437 Assert(row >= 0 && row < Rows());
438 return Block(row, 0, 1, Cols());
439 }
440
447 Assert(row >= 0 && row < Rows());
448 return Block(row, 0, 1, Cols());
449 }
450
457 Assert(col >= 0 && col < Cols());
458 return Block(0, col, Rows(), 1);
459 }
460
467 Assert(col >= 0 && col < Cols());
468 return Block(0, col, Rows(), 1);
469 }
470
477 Assert(Cols() == rhs.Rows() && Cols() == rhs.Cols());
478
479 for (int i = 0; i < Rows(); ++i) {
480 for (int j = 0; j < rhs.Cols(); ++j) {
482 for (int k = 0; k < Cols(); ++k) {
483 sum += (*this)(i, k) * rhs(k, j);
484 }
485 (*this)(i, j) = sum;
486 }
487 }
488
489 return *this;
490 }
491
499 for (int row = 0; row < Rows(); ++row) {
500 for (int col = 0; col < Cols(); ++col) {
501 (*this)(row, col) *= rhs;
502 }
503 }
504
505 return *this;
506 }
507
515 Assert(rhs.Rows() == 1 && rhs.Cols() == 1);
516
517 for (int row = 0; row < Rows(); ++row) {
518 for (int col = 0; col < Cols(); ++col) {
519 (*this)(row, col) /= rhs(0, 0);
520 }
521 }
522
523 return *this;
524 }
525
533 for (int row = 0; row < Rows(); ++row) {
534 for (int col = 0; col < Cols(); ++col) {
535 (*this)(row, col) /= rhs;
536 }
537 }
538
539 return *this;
540 }
541
548 for (int row = 0; row < Rows(); ++row) {
549 for (int col = 0; col < Cols(); ++col) {
550 (*this)(row, col) += rhs(row, col);
551 }
552 }
553
554 return *this;
555 }
556
563 for (int row = 0; row < Rows(); ++row) {
564 for (int col = 0; col < Cols(); ++col) {
565 (*this)(row, col) -= rhs(row, col);
566 }
567 }
568
569 return *this;
570 }
571
575 std::remove_cv_t<Mat> T() const {
576 std::remove_cv_t<Mat> result{Cols(), Rows()};
577
578 for (int row = 0; row < Rows(); ++row) {
579 for (int col = 0; col < Cols(); ++col) {
580 result(col, row) = (*this)(row, col);
581 }
582 }
583
584 return result;
585 }
586
590 int Rows() const { return m_rowSliceLength; }
591
595 int Cols() const { return m_colSliceLength; }
596
603 double Value(int row, int col) {
604 Assert(row >= 0 && row < Rows());
605 Assert(col >= 0 && col < Cols());
606 return (*m_mat)(m_rowSlice.start + row * m_rowSlice.step,
607 m_colSlice.start + col * m_colSlice.step)
608 .Value();
609 }
610
616 double Value(int index) {
617 Assert(index >= 0 && index < Rows() * Cols());
618 return Value(index / Cols(), index % Cols());
619 }
620
624 Eigen::MatrixXd Value() {
625 Eigen::MatrixXd result{Rows(), Cols()};
626
627 for (int row = 0; row < Rows(); ++row) {
628 for (int col = 0; col < Cols(); ++col) {
629 result(row, col) = Value(row, col);
630 }
631 }
632
633 return result;
634 }
635
641 std::remove_cv_t<Mat> CwiseTransform(
642 function_ref<Variable(const Variable& x)> unaryOp) const {
643 std::remove_cv_t<Mat> result{Rows(), Cols()};
644
645 for (int row = 0; row < Rows(); ++row) {
646 for (int col = 0; col < Cols(); ++col) {
647 result(row, col) = unaryOp((*this)(row, col));
648 }
649 }
650
651 return result;
652 }
653
654 class iterator {
655 public:
656 using iterator_category = std::forward_iterator_tag;
658 using difference_type = std::ptrdiff_t;
661
662 iterator(VariableBlock<Mat>* mat, int index) : m_mat{mat}, m_index{index} {}
663
665 ++m_index;
666 return *this;
667 }
669 iterator retval = *this;
670 ++(*this);
671 return retval;
672 }
673 bool operator==(const iterator&) const = default;
674 reference operator*() { return (*m_mat)(m_index); }
675
676 private:
677 VariableBlock<Mat>* m_mat;
678 int m_index;
679 };
680
682 public:
683 using iterator_category = std::forward_iterator_tag;
685 using difference_type = std::ptrdiff_t;
687 using const_reference = const Variable&;
688
690 : m_mat{mat}, m_index{index} {}
691
693 ++m_index;
694 return *this;
695 }
697 const_iterator retval = *this;
698 ++(*this);
699 return retval;
700 }
701 bool operator==(const const_iterator&) const = default;
702 const_reference operator*() const { return (*m_mat)(m_index); }
703
704 private:
705 const VariableBlock<Mat>* m_mat;
706 int m_index;
707 };
708
712 iterator begin() { return iterator(this, 0); }
713
717 iterator end() { return iterator(this, Rows() * Cols()); }
718
722 const_iterator begin() const { return const_iterator(this, 0); }
723
727 const_iterator end() const { return const_iterator(this, Rows() * Cols()); }
728
732 const_iterator cbegin() const { return const_iterator(this, 0); }
733
737 const_iterator cend() const { return const_iterator(this, Rows() * Cols()); }
738
742 size_t size() const { return Rows() * Cols(); }
743
744 private:
745 Mat* m_mat = nullptr;
746
747 Slice m_rowSlice;
748 int m_rowSliceLength = 0;
749
750 Slice m_colSlice;
751 int m_colSliceLength = 0;
752};
753
754} // namespace sleipnir
#define Assert(condition)
Definition Assert.hpp:24
Definition Slice.hpp:21
int step
Step.
Definition Slice.hpp:30
int start
Start index (inclusive).
Definition Slice.hpp:24
Definition VariableBlock.hpp:681
const_iterator(const VariableBlock< Mat > *mat, int index)
Definition VariableBlock.hpp:689
const_reference operator*() const
Definition VariableBlock.hpp:702
const_iterator operator++(int)
Definition VariableBlock.hpp:696
bool operator==(const const_iterator &) const =default
std::ptrdiff_t difference_type
Definition VariableBlock.hpp:685
const_iterator & operator++()
Definition VariableBlock.hpp:692
std::forward_iterator_tag iterator_category
Definition VariableBlock.hpp:683
Definition VariableBlock.hpp:654
reference operator*()
Definition VariableBlock.hpp:674
bool operator==(const iterator &) const =default
std::forward_iterator_tag iterator_category
Definition VariableBlock.hpp:656
std::ptrdiff_t difference_type
Definition VariableBlock.hpp:658
iterator(VariableBlock< Mat > *mat, int index)
Definition VariableBlock.hpp:662
iterator & operator++()
Definition VariableBlock.hpp:664
iterator operator++(int)
Definition VariableBlock.hpp:668
Definition VariableBlock.hpp:24
int Rows() const
Definition VariableBlock.hpp:590
VariableBlock< const Mat > Col(int col) const
Definition VariableBlock.hpp:466
const VariableBlock< const Mat > Block(int rowOffset, int colOffset, int blockRows, int blockCols) const
Definition VariableBlock.hpp:311
const_iterator cend() const
Definition VariableBlock.hpp:737
const VariableBlock< const Mat > operator()(Slice rowSlice, Slice colSlice) const
Definition VariableBlock.hpp:346
const_iterator end() const
Definition VariableBlock.hpp:727
iterator begin()
Definition VariableBlock.hpp:712
void SetValue(double value)
Definition VariableBlock.hpp:158
VariableBlock(const VariableBlock< Mat > &values)=default
VariableBlock< const Mat > Row(int row) const
Definition VariableBlock.hpp:446
VariableBlock< Mat > & operator/=(const VariableBlock< Mat > &rhs)
Definition VariableBlock.hpp:514
VariableBlock< Mat > Row(int row)
Definition VariableBlock.hpp:436
iterator end()
Definition VariableBlock.hpp:717
size_t size() const
Definition VariableBlock.hpp:742
VariableBlock(Mat &mat)
Definition VariableBlock.hpp:95
VariableBlock & operator*=(double rhs)
Definition VariableBlock.hpp:498
VariableBlock< Mat > & operator-=(const VariableBlock< Mat > &rhs)
Definition VariableBlock.hpp:562
void SetValue(const Eigen::MatrixBase< Derived > &values)
Definition VariableBlock.hpp:190
const VariableBlock< Mat > Segment(int offset, int length) const
Definition VariableBlock.hpp:425
int Cols() const
Definition VariableBlock.hpp:595
VariableBlock< Mat > & operator=(const Eigen::MatrixBase< Derived > &values)
Definition VariableBlock.hpp:170
const Variable & operator()(int row) const
Definition VariableBlock.hpp:280
VariableBlock< Mat > & operator=(VariableBlock< Mat > &&values)
Definition VariableBlock.hpp:65
const Variable & operator()(int row, int col) const
Definition VariableBlock.hpp:256
VariableBlock< Mat > & operator=(double value)
Definition VariableBlock.hpp:143
VariableBlock< Mat > & operator/=(double rhs)
Definition VariableBlock.hpp:532
std::remove_cv_t< Mat > T() const
Definition VariableBlock.hpp:575
Variable & operator()(int row)
Definition VariableBlock.hpp:268
const_iterator begin() const
Definition VariableBlock.hpp:722
VariableBlock(Mat &mat, Slice rowSlice, int rowSliceLength, Slice colSlice, int colSliceLength)
Definition VariableBlock.hpp:130
const_iterator cbegin() const
Definition VariableBlock.hpp:732
VariableBlock< Mat > operator()(Slice rowSlice, Slice colSlice)
Definition VariableBlock.hpp:327
Variable & operator()(int row, int col)
Definition VariableBlock.hpp:241
VariableBlock< Mat > & operator=(const VariableBlock< Mat > &values)
Definition VariableBlock.hpp:33
double Value(int row, int col)
Definition VariableBlock.hpp:603
VariableBlock< Mat > Col(int col)
Definition VariableBlock.hpp:456
VariableBlock< Mat > & operator=(Mat &&values)
Definition VariableBlock.hpp:223
VariableBlock< Mat > Block(int rowOffset, int colOffset, int blockRows, int blockCols)
Definition VariableBlock.hpp:293
Eigen::MatrixXd Value()
Definition VariableBlock.hpp:624
VariableBlock< Mat > & operator*=(const VariableBlock< Mat > &rhs)
Definition VariableBlock.hpp:476
VariableBlock< Mat > Segment(int offset, int length)
Definition VariableBlock.hpp:413
VariableBlock< Mat > & operator+=(const VariableBlock< Mat > &rhs)
Definition VariableBlock.hpp:547
const VariableBlock< const Mat > operator()(Slice rowSlice, int rowSliceLength, Slice colSlice, int colSliceLength) const
Definition VariableBlock.hpp:394
std::remove_cv_t< Mat > CwiseTransform(function_ref< Variable(const Variable &x)> unaryOp) const
Definition VariableBlock.hpp:641
double Value(int index)
Definition VariableBlock.hpp:616
VariableBlock< Mat > & operator=(const Mat &values)
Definition VariableBlock.hpp:206
VariableBlock(VariableBlock< Mat > &&)=default
VariableBlock(Mat &mat, int rowOffset, int colOffset, int blockRows, int blockCols)
Definition VariableBlock.hpp:111
VariableBlock< Mat > operator()(Slice rowSlice, int rowSliceLength, Slice colSlice, int colSliceLength)
Definition VariableBlock.hpp:371
Definition Variable.hpp:33
Definition FunctionRef.hpp:17
Definition Expression.hpp:18
IntrusiveSharedPtr< T > AllocateIntrusiveShared(Alloc alloc, Args &&... args)
Definition IntrusiveSharedPtr.hpp:275