Sleipnir C++ API
Loading...
Searching...
No Matches
variable.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <algorithm>
6#include <concepts>
7#include <initializer_list>
8#include <optional>
9#include <source_location>
10#include <type_traits>
11#include <utility>
12#include <vector>
13
14#include <Eigen/Core>
15
16#include "sleipnir/autodiff/expression.hpp"
17#include "sleipnir/autodiff/expression_graph.hpp"
18#include "sleipnir/util/assert.hpp"
19#include "sleipnir/util/concepts.hpp"
20#include "sleipnir/util/small_vector.hpp"
21#include "sleipnir/util/symbol_exports.hpp"
22
23#ifndef SLEIPNIR_DISABLE_DIAGNOSTICS
24#include "sleipnir/util/print.hpp"
25#endif
26
27namespace slp {
28
29// Forward declarations for friend declarations in Variable
30namespace detail {
31class AdjointExpressionGraph;
32} // namespace detail
33template <int UpLo = Eigen::Lower | Eigen::Upper>
34 requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper))
35class SLEIPNIR_DLLEXPORT Hessian;
36class SLEIPNIR_DLLEXPORT Jacobian;
37
41class SLEIPNIR_DLLEXPORT Variable {
42 public:
46 Variable() = default;
47
51 explicit constexpr Variable(std::nullptr_t) : expr{nullptr} {}
52
58 Variable(std::floating_point auto value) // NOLINT
59 : expr{detail::make_expression_ptr<detail::ConstExpression>(value)} {}
60
66 Variable(std::integral auto value) // NOLINT
67 : expr{detail::make_expression_ptr<detail::ConstExpression>(value)} {}
68
74 explicit Variable(const detail::ExpressionPtr& expr) : expr{expr} {}
75
81 explicit constexpr Variable(detail::ExpressionPtr&& expr)
82 : expr{std::move(expr)} {}
83
90 Variable& operator=(double value) {
91 expr = detail::make_expression_ptr<detail::ConstExpression>(value);
92
93 return *this;
94 }
95
101 void set_value(double value) {
102 if (expr->is_constant(0.0)) {
103 expr = detail::make_expression_ptr<detail::ConstExpression>(value);
104 } else {
105#ifndef SLEIPNIR_DISABLE_DIAGNOSTICS
106 // We only need to check the first argument since unary and binary
107 // operators both use it
108 if (expr->args[0] != nullptr) {
109 auto location = std::source_location::current();
110 slp::println(
111 stderr,
112 "WARNING: {}:{}: {}: Modified the value of a dependent variable",
113 location.file_name(), location.line(), location.function_name());
114 }
115#endif
116 expr->val = value;
117 }
118 }
119
127 friend SLEIPNIR_DLLEXPORT Variable operator*(const Variable& lhs,
128 const Variable& rhs) {
129 return Variable{lhs.expr * rhs.expr};
130 }
131
139 *this = *this * rhs;
140 return *this;
141 }
142
150 friend SLEIPNIR_DLLEXPORT Variable operator/(const Variable& lhs,
151 const Variable& rhs) {
152 return Variable{lhs.expr / rhs.expr};
153 }
154
162 *this = *this / rhs;
163 return *this;
164 }
165
173 friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs,
174 const Variable& rhs) {
175 return Variable{lhs.expr + rhs.expr};
176 }
177
185 *this = *this + rhs;
186 return *this;
187 }
188
196 friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs,
197 const Variable& rhs) {
198 return Variable{lhs.expr - rhs.expr};
199 }
200
208 *this = *this - rhs;
209 return *this;
210 }
211
217 friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs) {
218 return Variable{-lhs.expr};
219 }
220
226 friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs) {
227 return Variable{+lhs.expr};
228 }
229
235 double value() {
236 if (!m_graph) {
237 m_graph = detail::topological_sort(expr);
238 }
239 detail::update_values(m_graph.value());
240
241 return expr->val;
242 }
243
250 ExpressionType type() const { return expr->type(); }
251
252 private:
255 detail::make_expression_ptr<detail::DecisionVariableExpression>();
256
259 std::optional<small_vector<detail::Expression*>> m_graph;
260
261 friend SLEIPNIR_DLLEXPORT Variable abs(const Variable& x);
262 friend SLEIPNIR_DLLEXPORT Variable acos(const Variable& x);
263 friend SLEIPNIR_DLLEXPORT Variable asin(const Variable& x);
264 friend SLEIPNIR_DLLEXPORT Variable atan(const Variable& x);
265 friend SLEIPNIR_DLLEXPORT Variable atan2(const Variable& y,
266 const Variable& x);
267 friend SLEIPNIR_DLLEXPORT Variable cos(const Variable& x);
268 friend SLEIPNIR_DLLEXPORT Variable cosh(const Variable& x);
269 friend SLEIPNIR_DLLEXPORT Variable erf(const Variable& x);
270 friend SLEIPNIR_DLLEXPORT Variable exp(const Variable& x);
271 friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x,
272 const Variable& y);
273 friend SLEIPNIR_DLLEXPORT Variable log(const Variable& x);
274 friend SLEIPNIR_DLLEXPORT Variable log10(const Variable& x);
275 friend SLEIPNIR_DLLEXPORT Variable pow(const Variable& base,
276 const Variable& power);
277 friend SLEIPNIR_DLLEXPORT Variable sign(const Variable& x);
278 friend SLEIPNIR_DLLEXPORT Variable sin(const Variable& x);
279 friend SLEIPNIR_DLLEXPORT Variable sinh(const Variable& x);
280 friend SLEIPNIR_DLLEXPORT Variable sqrt(const Variable& x);
281 friend SLEIPNIR_DLLEXPORT Variable tan(const Variable& x);
282 friend SLEIPNIR_DLLEXPORT Variable tanh(const Variable& x);
283 friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, const Variable& y,
284 const Variable& z);
285
287 template <int UpLo>
288 requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper))
289 friend class SLEIPNIR_DLLEXPORT Hessian;
290 friend class SLEIPNIR_DLLEXPORT Jacobian;
291};
292
298SLEIPNIR_DLLEXPORT inline Variable abs(const Variable& x) {
299 return Variable{detail::abs(x.expr)};
300}
301
307SLEIPNIR_DLLEXPORT inline Variable acos(const Variable& x) {
308 return Variable{detail::acos(x.expr)};
309}
310
316SLEIPNIR_DLLEXPORT inline Variable asin(const Variable& x) {
317 return Variable{detail::asin(x.expr)};
318}
319
325SLEIPNIR_DLLEXPORT inline Variable atan(const Variable& x) {
326 return Variable{detail::atan(x.expr)};
327}
328
335SLEIPNIR_DLLEXPORT inline Variable atan2(const Variable& y, const Variable& x) {
336 return Variable{detail::atan2(y.expr, x.expr)};
337}
338
344SLEIPNIR_DLLEXPORT inline Variable cos(const Variable& x) {
345 return Variable{detail::cos(x.expr)};
346}
347
353SLEIPNIR_DLLEXPORT inline Variable cosh(const Variable& x) {
354 return Variable{detail::cosh(x.expr)};
355}
356
362SLEIPNIR_DLLEXPORT inline Variable erf(const Variable& x) {
363 return Variable{detail::erf(x.expr)};
364}
365
371SLEIPNIR_DLLEXPORT inline Variable exp(const Variable& x) {
372 return Variable{detail::exp(x.expr)};
373}
374
381SLEIPNIR_DLLEXPORT inline Variable hypot(const Variable& x, const Variable& y) {
382 return Variable{detail::hypot(x.expr, y.expr)};
383}
384
391SLEIPNIR_DLLEXPORT inline Variable pow(const Variable& base,
392 const Variable& power) {
393 return Variable{detail::pow(base.expr, power.expr)};
394}
395
401SLEIPNIR_DLLEXPORT inline Variable log(const Variable& x) {
402 return Variable{detail::log(x.expr)};
403}
404
410SLEIPNIR_DLLEXPORT inline Variable log10(const Variable& x) {
411 return Variable{detail::log10(x.expr)};
412}
413
419SLEIPNIR_DLLEXPORT inline Variable sign(const Variable& x) {
420 return Variable{detail::sign(x.expr)};
421}
422
428SLEIPNIR_DLLEXPORT inline Variable sin(const Variable& x) {
429 return Variable{detail::sin(x.expr)};
430}
431
437SLEIPNIR_DLLEXPORT inline Variable sinh(const Variable& x) {
438 return Variable{detail::sinh(x.expr)};
439}
440
446SLEIPNIR_DLLEXPORT inline Variable sqrt(const Variable& x) {
447 return Variable{detail::sqrt(x.expr)};
448}
449
455SLEIPNIR_DLLEXPORT inline Variable tan(const Variable& x) {
456 return Variable{detail::tan(x.expr)};
457}
458
464SLEIPNIR_DLLEXPORT inline Variable tanh(const Variable& x) {
465 return Variable{detail::tanh(x.expr)};
466}
467
475SLEIPNIR_DLLEXPORT inline Variable hypot(const Variable& x, const Variable& y,
476 const Variable& z) {
477 return Variable{slp::sqrt(slp::pow(x, 2) + slp::pow(y, 2) + slp::pow(z, 2))};
478}
479
491template <typename LHS, typename RHS>
492 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
495small_vector<Variable> make_constraints(LHS&& lhs, RHS&& rhs) {
496 small_vector<Variable> constraints;
497
498 if constexpr (ScalarLike<LHS> && ScalarLike<RHS>) {
499 constraints.emplace_back(lhs - rhs);
500 } else if constexpr (ScalarLike<LHS> && MatrixLike<RHS>) {
501 constraints.reserve(rhs.rows() * rhs.cols());
502
503 for (int row = 0; row < rhs.rows(); ++row) {
504 for (int col = 0; col < rhs.cols(); ++col) {
505 // Make right-hand side zero
506 if constexpr (EigenMatrixLike<std::decay_t<RHS>>) {
507 constraints.emplace_back(lhs - rhs(row, col));
508 } else {
509 constraints.emplace_back(lhs - rhs[row, col]);
510 }
511 }
512 }
513 } else if constexpr (MatrixLike<LHS> && ScalarLike<RHS>) {
514 constraints.reserve(lhs.rows() * lhs.cols());
515
516 for (int row = 0; row < lhs.rows(); ++row) {
517 for (int col = 0; col < lhs.cols(); ++col) {
518 // Make right-hand side zero
519 if constexpr (EigenMatrixLike<std::decay_t<LHS>>) {
520 constraints.emplace_back(lhs(row, col) - rhs);
521 } else {
522 constraints.emplace_back(lhs[row, col] - rhs);
523 }
524 }
525 }
526 } else if constexpr (MatrixLike<LHS> && MatrixLike<RHS>) {
527 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
528 constraints.reserve(lhs.rows() * lhs.cols());
529
530 for (int row = 0; row < lhs.rows(); ++row) {
531 for (int col = 0; col < lhs.cols(); ++col) {
532 // Make right-hand side zero
533 if constexpr (EigenMatrixLike<std::decay_t<LHS>> &&
534 EigenMatrixLike<std::decay_t<RHS>>) {
535 constraints.emplace_back(lhs(row, col) - rhs(row, col));
536 } else if constexpr (EigenMatrixLike<std::decay_t<LHS>> &&
537 SleipnirMatrixLike<std::decay_t<RHS>>) {
538 constraints.emplace_back(lhs(row, col) - rhs[row, col]);
539 } else if constexpr (SleipnirMatrixLike<std::decay_t<LHS>> &&
540 EigenMatrixLike<std::decay_t<RHS>>) {
541 constraints.emplace_back(lhs[row, col] - rhs(row, col));
542 } else if constexpr (SleipnirMatrixLike<std::decay_t<LHS>> &&
543 SleipnirMatrixLike<std::decay_t<RHS>>) {
544 constraints.emplace_back(lhs[row, col] - rhs[row, col]);
545 }
546 }
547 }
548 }
549
550 return constraints;
551}
552
556struct SLEIPNIR_DLLEXPORT EqualityConstraints {
558 small_vector<Variable> constraints;
559
566 std::initializer_list<EqualityConstraints> equality_constraints) {
567 for (const auto& elem : equality_constraints) {
568 constraints.insert(constraints.end(), elem.constraints.begin(),
569 elem.constraints.end());
570 }
571 }
572
581 const std::vector<EqualityConstraints>& equality_constraints) {
582 for (const auto& elem : equality_constraints) {
583 constraints.insert(constraints.end(), elem.constraints.begin(),
584 elem.constraints.end());
585 }
586 }
587
597 template <typename LHS, typename RHS>
598 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
601 EqualityConstraints(LHS&& lhs, RHS&& rhs)
602 : constraints{make_constraints(lhs, rhs)} {}
603
607 operator bool() { // NOLINT
608 return std::ranges::all_of(constraints, [](auto& constraint) {
609 return constraint.value() == 0.0;
610 });
611 }
612};
613
617struct SLEIPNIR_DLLEXPORT InequalityConstraints {
619 small_vector<Variable> constraints;
620
628 std::initializer_list<InequalityConstraints> inequality_constraints) {
629 for (const auto& elem : inequality_constraints) {
630 constraints.insert(constraints.end(), elem.constraints.begin(),
631 elem.constraints.end());
632 }
633 }
634
644 const std::vector<InequalityConstraints>& inequality_constraints) {
645 for (const auto& elem : inequality_constraints) {
646 constraints.insert(constraints.end(), elem.constraints.begin(),
647 elem.constraints.end());
648 }
649 }
650
660 template <typename LHS, typename RHS>
661 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
664 InequalityConstraints(LHS&& lhs, RHS&& rhs)
665 : constraints{make_constraints(lhs, rhs)} {}
666
670 operator bool() { // NOLINT
671 return std::ranges::all_of(constraints, [](auto& constraint) {
672 return constraint.value() >= 0.0;
673 });
674 }
675};
676
683template <typename LHS, typename RHS>
684 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
685 (ScalarLike<RHS> || MatrixLike<RHS>) &&
686 (SleipnirType<LHS> || SleipnirType<RHS>)
687EqualityConstraints operator==(LHS&& lhs, RHS&& rhs) {
688 return EqualityConstraints{lhs, rhs};
689}
690
698template <typename LHS, typename RHS>
699 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
700 (ScalarLike<RHS> || MatrixLike<RHS>) &&
701 (SleipnirType<LHS> || SleipnirType<RHS>)
702InequalityConstraints operator<(LHS&& lhs, RHS&& rhs) {
703 return rhs >= lhs;
704}
705
713template <typename LHS, typename RHS>
714 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
715 (ScalarLike<RHS> || MatrixLike<RHS>) &&
716 (SleipnirType<LHS> || SleipnirType<RHS>)
717InequalityConstraints operator<=(LHS&& lhs, RHS&& rhs) {
718 return rhs >= lhs;
719}
720
728template <typename LHS, typename RHS>
729 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
730 (ScalarLike<RHS> || MatrixLike<RHS>) &&
731 (SleipnirType<LHS> || SleipnirType<RHS>)
732InequalityConstraints operator>(LHS&& lhs, RHS&& rhs) {
733 return lhs >= rhs;
734}
735
743template <typename LHS, typename RHS>
744 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
745 (ScalarLike<RHS> || MatrixLike<RHS>) &&
746 (SleipnirType<LHS> || SleipnirType<RHS>)
747InequalityConstraints operator>=(LHS&& lhs, RHS&& rhs) {
748 return InequalityConstraints{lhs, rhs};
749}
750
751} // namespace slp
752
753namespace Eigen {
754
758template <>
759struct NumTraits<slp::Variable> : NumTraits<double> {
766
768 static constexpr int IsComplex = 0;
770 static constexpr int IsInteger = 0;
772 static constexpr int IsSigned = 1;
774 static constexpr int RequireInitialization = 1;
776 static constexpr int ReadCost = 1;
778 static constexpr int AddCost = 3;
780 static constexpr int MulCost = 3;
781};
782
783} // namespace Eigen
Definition hessian.hpp:31
Definition jacobian.hpp:27
Definition variable.hpp:41
friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable &lhs)
Definition variable.hpp:226
constexpr Variable(detail::ExpressionPtr &&expr)
Definition variable.hpp:81
friend SLEIPNIR_DLLEXPORT Variable operator*(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:127
Variable & operator=(double value)
Definition variable.hpp:90
Variable & operator-=(const Variable &rhs)
Definition variable.hpp:207
friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable &lhs)
Definition variable.hpp:217
friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:173
Variable & operator+=(const Variable &rhs)
Definition variable.hpp:184
friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:196
Variable & operator*=(const Variable &rhs)
Definition variable.hpp:138
Variable(std::floating_point auto value)
Definition variable.hpp:58
constexpr Variable(std::nullptr_t)
Definition variable.hpp:51
friend SLEIPNIR_DLLEXPORT Variable operator/(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:150
Variable()=default
Variable(std::integral auto value)
Definition variable.hpp:66
void set_value(double value)
Definition variable.hpp:101
Variable & operator/=(const Variable &rhs)
Definition variable.hpp:161
ExpressionType type() const
Definition variable.hpp:250
double value()
Definition variable.hpp:235
Variable(const detail::ExpressionPtr &expr)
Definition variable.hpp:74
Definition adjoint_expression_graph.hpp:21
Definition concepts.hpp:40
Definition concepts.hpp:13
Definition concepts.hpp:37
Definition variable.hpp:556
EqualityConstraints(std::initializer_list< EqualityConstraints > equality_constraints)
Definition variable.hpp:565
EqualityConstraints(const std::vector< EqualityConstraints > &equality_constraints)
Definition variable.hpp:580
EqualityConstraints(LHS &&lhs, RHS &&rhs)
Definition variable.hpp:601
small_vector< Variable > constraints
A vector of scalar equality constraints.
Definition variable.hpp:558
Definition variable.hpp:617
InequalityConstraints(const std::vector< InequalityConstraints > &inequality_constraints)
Definition variable.hpp:643
InequalityConstraints(std::initializer_list< InequalityConstraints > inequality_constraints)
Definition variable.hpp:627
small_vector< Variable > constraints
A vector of scalar inequality constraints.
Definition variable.hpp:619
InequalityConstraints(LHS &&lhs, RHS &&rhs)
Definition variable.hpp:664