Sleipnir C++ API
Loading...
Searching...
No Matches
variable.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <algorithm>
6#include <concepts>
7#include <initializer_list>
8#include <source_location>
9#include <type_traits>
10#include <utility>
11#include <vector>
12
13#include <Eigen/Core>
14
15#include "sleipnir/autodiff/expression.hpp"
16#include "sleipnir/autodiff/expression_graph.hpp"
17#include "sleipnir/util/assert.hpp"
18#include "sleipnir/util/concepts.hpp"
19#include "sleipnir/util/small_vector.hpp"
20#include "sleipnir/util/symbol_exports.hpp"
21
22#ifndef SLEIPNIR_DISABLE_DIAGNOSTICS
23#include "sleipnir/util/print.hpp"
24#endif
25
26namespace slp {
27
28// Forward declarations for friend declarations in Variable
29namespace detail {
30class AdjointExpressionGraph;
31} // namespace detail
32template <int UpLo = Eigen::Lower | Eigen::Upper>
33 requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper))
34class SLEIPNIR_DLLEXPORT Hessian;
35class SLEIPNIR_DLLEXPORT Jacobian;
36
40class SLEIPNIR_DLLEXPORT Variable {
41 public:
45 Variable() = default;
46
50 explicit constexpr Variable(std::nullptr_t) : expr{nullptr} {}
51
57 Variable(std::floating_point auto value) // NOLINT
58 : expr{detail::make_expression_ptr<detail::ConstExpression>(value)} {}
59
65 Variable(std::integral auto value) // NOLINT
66 : expr{detail::make_expression_ptr<detail::ConstExpression>(value)} {}
67
73 explicit Variable(const detail::ExpressionPtr& expr) : expr{expr} {}
74
80 explicit constexpr Variable(detail::ExpressionPtr&& expr)
81 : expr{std::move(expr)} {}
82
89 Variable& operator=(double value) {
90 expr = detail::make_expression_ptr<detail::ConstExpression>(value);
91
92 return *this;
93 }
94
100 void set_value(double value) {
101 if (expr->is_constant(0.0)) {
102 expr = detail::make_expression_ptr<detail::ConstExpression>(value);
103 } else {
104#ifndef SLEIPNIR_DISABLE_DIAGNOSTICS
105 // We only need to check the first argument since unary and binary
106 // operators both use it
107 if (expr->args[0] != nullptr) {
108 auto location = std::source_location::current();
109 slp::println(
110 stderr,
111 "WARNING: {}:{}: {}: Modified the value of a dependent variable",
112 location.file_name(), location.line(), location.function_name());
113 }
114#endif
115 expr->val = value;
116 }
117 }
118
126 friend SLEIPNIR_DLLEXPORT Variable operator*(const Variable& lhs,
127 const Variable& rhs) {
128 return Variable{lhs.expr * rhs.expr};
129 }
130
138 *this = *this * rhs;
139 return *this;
140 }
141
149 friend SLEIPNIR_DLLEXPORT Variable operator/(const Variable& lhs,
150 const Variable& rhs) {
151 return Variable{lhs.expr / rhs.expr};
152 }
153
161 *this = *this / rhs;
162 return *this;
163 }
164
172 friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs,
173 const Variable& rhs) {
174 return Variable{lhs.expr + rhs.expr};
175 }
176
184 *this = *this + rhs;
185 return *this;
186 }
187
195 friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs,
196 const Variable& rhs) {
197 return Variable{lhs.expr - rhs.expr};
198 }
199
207 *this = *this - rhs;
208 return *this;
209 }
210
216 friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs) {
217 return Variable{-lhs.expr};
218 }
219
225 friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs) {
226 return Variable{+lhs.expr};
227 }
228
234 double value() {
235 if (!m_graph_initialized) {
236 m_graph = detail::topological_sort(expr);
237 m_graph_initialized = true;
238 }
239 detail::update_values(m_graph);
240
241 return expr->val;
242 }
243
250 ExpressionType type() const { return expr->type(); }
251
252 private:
255 detail::make_expression_ptr<detail::DecisionVariableExpression>();
256
259 small_vector<detail::Expression*> m_graph;
260
262 bool m_graph_initialized = false;
263
264 friend SLEIPNIR_DLLEXPORT Variable abs(const Variable& x);
265 friend SLEIPNIR_DLLEXPORT Variable acos(const Variable& x);
266 friend SLEIPNIR_DLLEXPORT Variable asin(const Variable& x);
267 friend SLEIPNIR_DLLEXPORT Variable atan(const Variable& x);
268 friend SLEIPNIR_DLLEXPORT Variable atan2(const Variable& y,
269 const Variable& x);
270 friend SLEIPNIR_DLLEXPORT Variable cos(const Variable& x);
271 friend SLEIPNIR_DLLEXPORT Variable cosh(const Variable& x);
272 friend SLEIPNIR_DLLEXPORT Variable erf(const Variable& x);
273 friend SLEIPNIR_DLLEXPORT Variable exp(const Variable& x);
274 friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x,
275 const Variable& y);
276 friend SLEIPNIR_DLLEXPORT Variable log(const Variable& x);
277 friend SLEIPNIR_DLLEXPORT Variable log10(const Variable& x);
278 friend SLEIPNIR_DLLEXPORT Variable pow(const Variable& base,
279 const Variable& power);
280 friend SLEIPNIR_DLLEXPORT Variable sign(const Variable& x);
281 friend SLEIPNIR_DLLEXPORT Variable sin(const Variable& x);
282 friend SLEIPNIR_DLLEXPORT Variable sinh(const Variable& x);
283 friend SLEIPNIR_DLLEXPORT Variable sqrt(const Variable& x);
284 friend SLEIPNIR_DLLEXPORT Variable tan(const Variable& x);
285 friend SLEIPNIR_DLLEXPORT Variable tanh(const Variable& x);
286 friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, const Variable& y,
287 const Variable& z);
288
290 template <int UpLo>
291 requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper))
292 friend class SLEIPNIR_DLLEXPORT Hessian;
293 friend class SLEIPNIR_DLLEXPORT Jacobian;
294};
295
301SLEIPNIR_DLLEXPORT inline Variable abs(const Variable& x) {
302 return Variable{detail::abs(x.expr)};
303}
304
310SLEIPNIR_DLLEXPORT inline Variable acos(const Variable& x) {
311 return Variable{detail::acos(x.expr)};
312}
313
319SLEIPNIR_DLLEXPORT inline Variable asin(const Variable& x) {
320 return Variable{detail::asin(x.expr)};
321}
322
328SLEIPNIR_DLLEXPORT inline Variable atan(const Variable& x) {
329 return Variable{detail::atan(x.expr)};
330}
331
338SLEIPNIR_DLLEXPORT inline Variable atan2(const Variable& y, const Variable& x) {
339 return Variable{detail::atan2(y.expr, x.expr)};
340}
341
347SLEIPNIR_DLLEXPORT inline Variable cos(const Variable& x) {
348 return Variable{detail::cos(x.expr)};
349}
350
356SLEIPNIR_DLLEXPORT inline Variable cosh(const Variable& x) {
357 return Variable{detail::cosh(x.expr)};
358}
359
365SLEIPNIR_DLLEXPORT inline Variable erf(const Variable& x) {
366 return Variable{detail::erf(x.expr)};
367}
368
374SLEIPNIR_DLLEXPORT inline Variable exp(const Variable& x) {
375 return Variable{detail::exp(x.expr)};
376}
377
384SLEIPNIR_DLLEXPORT inline Variable hypot(const Variable& x, const Variable& y) {
385 return Variable{detail::hypot(x.expr, y.expr)};
386}
387
394SLEIPNIR_DLLEXPORT inline Variable pow(const Variable& base,
395 const Variable& power) {
396 return Variable{detail::pow(base.expr, power.expr)};
397}
398
404SLEIPNIR_DLLEXPORT inline Variable log(const Variable& x) {
405 return Variable{detail::log(x.expr)};
406}
407
413SLEIPNIR_DLLEXPORT inline Variable log10(const Variable& x) {
414 return Variable{detail::log10(x.expr)};
415}
416
422SLEIPNIR_DLLEXPORT inline Variable sign(const Variable& x) {
423 return Variable{detail::sign(x.expr)};
424}
425
431SLEIPNIR_DLLEXPORT inline Variable sin(const Variable& x) {
432 return Variable{detail::sin(x.expr)};
433}
434
440SLEIPNIR_DLLEXPORT inline Variable sinh(const Variable& x) {
441 return Variable{detail::sinh(x.expr)};
442}
443
449SLEIPNIR_DLLEXPORT inline Variable sqrt(const Variable& x) {
450 return Variable{detail::sqrt(x.expr)};
451}
452
458SLEIPNIR_DLLEXPORT inline Variable tan(const Variable& x) {
459 return Variable{detail::tan(x.expr)};
460}
461
467SLEIPNIR_DLLEXPORT inline Variable tanh(const Variable& x) {
468 return Variable{detail::tanh(x.expr)};
469}
470
478SLEIPNIR_DLLEXPORT inline Variable hypot(const Variable& x, const Variable& y,
479 const Variable& z) {
480 return Variable{slp::sqrt(slp::pow(x, 2) + slp::pow(y, 2) + slp::pow(z, 2))};
481}
482
494template <typename LHS, typename RHS>
495 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
498small_vector<Variable> make_constraints(LHS&& lhs, RHS&& rhs) {
499 small_vector<Variable> constraints;
500
501 if constexpr (ScalarLike<LHS> && ScalarLike<RHS>) {
502 constraints.emplace_back(lhs - rhs);
503 } else if constexpr (ScalarLike<LHS> && MatrixLike<RHS>) {
504 constraints.reserve(rhs.rows() * rhs.cols());
505
506 for (int row = 0; row < rhs.rows(); ++row) {
507 for (int col = 0; col < rhs.cols(); ++col) {
508 // Make right-hand side zero
509 if constexpr (EigenMatrixLike<std::decay_t<RHS>>) {
510 constraints.emplace_back(lhs - rhs(row, col));
511 } else {
512 constraints.emplace_back(lhs - rhs[row, col]);
513 }
514 }
515 }
516 } else if constexpr (MatrixLike<LHS> && ScalarLike<RHS>) {
517 constraints.reserve(lhs.rows() * lhs.cols());
518
519 for (int row = 0; row < lhs.rows(); ++row) {
520 for (int col = 0; col < lhs.cols(); ++col) {
521 // Make right-hand side zero
522 if constexpr (EigenMatrixLike<std::decay_t<LHS>>) {
523 constraints.emplace_back(lhs(row, col) - rhs);
524 } else {
525 constraints.emplace_back(lhs[row, col] - rhs);
526 }
527 }
528 }
529 } else if constexpr (MatrixLike<LHS> && MatrixLike<RHS>) {
530 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
531 constraints.reserve(lhs.rows() * lhs.cols());
532
533 for (int row = 0; row < lhs.rows(); ++row) {
534 for (int col = 0; col < lhs.cols(); ++col) {
535 // Make right-hand side zero
536 if constexpr (EigenMatrixLike<std::decay_t<LHS>> &&
537 EigenMatrixLike<std::decay_t<RHS>>) {
538 constraints.emplace_back(lhs(row, col) - rhs(row, col));
539 } else if constexpr (EigenMatrixLike<std::decay_t<LHS>> &&
540 SleipnirMatrixLike<std::decay_t<RHS>>) {
541 constraints.emplace_back(lhs(row, col) - rhs[row, col]);
542 } else if constexpr (SleipnirMatrixLike<std::decay_t<LHS>> &&
543 EigenMatrixLike<std::decay_t<RHS>>) {
544 constraints.emplace_back(lhs[row, col] - rhs(row, col));
545 } else if constexpr (SleipnirMatrixLike<std::decay_t<LHS>> &&
546 SleipnirMatrixLike<std::decay_t<RHS>>) {
547 constraints.emplace_back(lhs[row, col] - rhs[row, col]);
548 }
549 }
550 }
551 }
552
553 return constraints;
554}
555
559struct SLEIPNIR_DLLEXPORT EqualityConstraints {
561 small_vector<Variable> constraints;
562
569 std::initializer_list<EqualityConstraints> equality_constraints) {
570 for (const auto& elem : equality_constraints) {
571 constraints.insert(constraints.end(), elem.constraints.begin(),
572 elem.constraints.end());
573 }
574 }
575
584 const std::vector<EqualityConstraints>& equality_constraints) {
585 for (const auto& elem : equality_constraints) {
586 constraints.insert(constraints.end(), elem.constraints.begin(),
587 elem.constraints.end());
588 }
589 }
590
600 template <typename LHS, typename RHS>
601 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
604 EqualityConstraints(LHS&& lhs, RHS&& rhs)
605 : constraints{make_constraints(lhs, rhs)} {}
606
610 operator bool() { // NOLINT
611 return std::ranges::all_of(constraints, [](auto& constraint) {
612 return constraint.value() == 0.0;
613 });
614 }
615};
616
620struct SLEIPNIR_DLLEXPORT InequalityConstraints {
622 small_vector<Variable> constraints;
623
631 std::initializer_list<InequalityConstraints> inequality_constraints) {
632 for (const auto& elem : inequality_constraints) {
633 constraints.insert(constraints.end(), elem.constraints.begin(),
634 elem.constraints.end());
635 }
636 }
637
647 const std::vector<InequalityConstraints>& inequality_constraints) {
648 for (const auto& elem : inequality_constraints) {
649 constraints.insert(constraints.end(), elem.constraints.begin(),
650 elem.constraints.end());
651 }
652 }
653
663 template <typename LHS, typename RHS>
664 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
667 InequalityConstraints(LHS&& lhs, RHS&& rhs)
668 : constraints{make_constraints(lhs, rhs)} {}
669
673 operator bool() { // NOLINT
674 return std::ranges::all_of(constraints, [](auto& constraint) {
675 return constraint.value() >= 0.0;
676 });
677 }
678};
679
686template <typename LHS, typename RHS>
687 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
688 (ScalarLike<RHS> || MatrixLike<RHS>) &&
689 (SleipnirType<LHS> || SleipnirType<RHS>)
690EqualityConstraints operator==(LHS&& lhs, RHS&& rhs) {
691 return EqualityConstraints{lhs, rhs};
692}
693
701template <typename LHS, typename RHS>
702 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
703 (ScalarLike<RHS> || MatrixLike<RHS>) &&
704 (SleipnirType<LHS> || SleipnirType<RHS>)
705InequalityConstraints operator<(LHS&& lhs, RHS&& rhs) {
706 return rhs >= lhs;
707}
708
716template <typename LHS, typename RHS>
717 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
718 (ScalarLike<RHS> || MatrixLike<RHS>) &&
719 (SleipnirType<LHS> || SleipnirType<RHS>)
720InequalityConstraints operator<=(LHS&& lhs, RHS&& rhs) {
721 return rhs >= lhs;
722}
723
731template <typename LHS, typename RHS>
732 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
733 (ScalarLike<RHS> || MatrixLike<RHS>) &&
734 (SleipnirType<LHS> || SleipnirType<RHS>)
735InequalityConstraints operator>(LHS&& lhs, RHS&& rhs) {
736 return lhs >= rhs;
737}
738
746template <typename LHS, typename RHS>
747 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
748 (ScalarLike<RHS> || MatrixLike<RHS>) &&
749 (SleipnirType<LHS> || SleipnirType<RHS>)
750InequalityConstraints operator>=(LHS&& lhs, RHS&& rhs) {
751 return InequalityConstraints{lhs, rhs};
752}
753
754} // namespace slp
755
756namespace Eigen {
757
761template <>
762struct NumTraits<slp::Variable> : NumTraits<double> {
769
771 static constexpr int IsComplex = 0;
773 static constexpr int IsInteger = 0;
775 static constexpr int IsSigned = 1;
777 static constexpr int RequireInitialization = 1;
779 static constexpr int ReadCost = 1;
781 static constexpr int AddCost = 3;
783 static constexpr int MulCost = 3;
784};
785
786} // namespace Eigen
Definition hessian.hpp:29
Definition jacobian.hpp:25
Definition variable.hpp:40
friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable &lhs)
Definition variable.hpp:225
constexpr Variable(detail::ExpressionPtr &&expr)
Definition variable.hpp:80
friend SLEIPNIR_DLLEXPORT Variable operator*(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:126
Variable & operator=(double value)
Definition variable.hpp:89
Variable & operator-=(const Variable &rhs)
Definition variable.hpp:206
friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable &lhs)
Definition variable.hpp:216
friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:172
Variable & operator+=(const Variable &rhs)
Definition variable.hpp:183
friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:195
Variable & operator*=(const Variable &rhs)
Definition variable.hpp:137
Variable(std::floating_point auto value)
Definition variable.hpp:57
constexpr Variable(std::nullptr_t)
Definition variable.hpp:50
friend SLEIPNIR_DLLEXPORT Variable operator/(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:149
Variable()=default
Variable(std::integral auto value)
Definition variable.hpp:65
void set_value(double value)
Definition variable.hpp:100
Variable & operator/=(const Variable &rhs)
Definition variable.hpp:160
ExpressionType type() const
Definition variable.hpp:250
double value()
Definition variable.hpp:234
Variable(const detail::ExpressionPtr &expr)
Definition variable.hpp:73
Definition adjoint_expression_graph.hpp:21
Definition concepts.hpp:40
Definition concepts.hpp:13
Definition concepts.hpp:37
Definition variable.hpp:559
EqualityConstraints(std::initializer_list< EqualityConstraints > equality_constraints)
Definition variable.hpp:568
EqualityConstraints(const std::vector< EqualityConstraints > &equality_constraints)
Definition variable.hpp:583
EqualityConstraints(LHS &&lhs, RHS &&rhs)
Definition variable.hpp:604
small_vector< Variable > constraints
A vector of scalar equality constraints.
Definition variable.hpp:561
Definition variable.hpp:620
InequalityConstraints(const std::vector< InequalityConstraints > &inequality_constraints)
Definition variable.hpp:646
InequalityConstraints(std::initializer_list< InequalityConstraints > inequality_constraints)
Definition variable.hpp:630
small_vector< Variable > constraints
A vector of scalar inequality constraints.
Definition variable.hpp:622
InequalityConstraints(LHS &&lhs, RHS &&rhs)
Definition variable.hpp:667