Sleipnir C++ API
Loading...
Searching...
No Matches
variable.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <algorithm>
6#include <concepts>
7#include <initializer_list>
8#include <source_location>
9#include <type_traits>
10#include <utility>
11#include <vector>
12
13#include <Eigen/Core>
14#include <gch/small_vector.hpp>
15
16#include "sleipnir/autodiff/expression.hpp"
17#include "sleipnir/autodiff/expression_graph.hpp"
18#include "sleipnir/util/assert.hpp"
19#include "sleipnir/util/concepts.hpp"
20#include "sleipnir/util/symbol_exports.hpp"
21
22#ifndef SLEIPNIR_DISABLE_DIAGNOSTICS
23#include "sleipnir/util/print.hpp"
24#endif
25
26namespace slp {
27
28// Forward declarations for friend declarations in Variable
29namespace detail {
30class AdjointExpressionGraph;
31} // namespace detail
32template <int UpLo = Eigen::Lower | Eigen::Upper>
33 requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper))
34class SLEIPNIR_DLLEXPORT Hessian;
35class SLEIPNIR_DLLEXPORT Jacobian;
36
40class SLEIPNIR_DLLEXPORT Variable {
41 public:
45 Variable() = default;
46
50 explicit constexpr Variable(std::nullptr_t) : expr{nullptr} {}
51
57 Variable(std::floating_point auto value) // NOLINT
58 : expr{detail::make_expression_ptr<detail::ConstExpression>(value)} {}
59
65 Variable(std::integral auto value) // NOLINT
66 : expr{detail::make_expression_ptr<detail::ConstExpression>(value)} {}
67
73 explicit Variable(const detail::ExpressionPtr& expr) : expr{expr} {}
74
80 explicit constexpr Variable(detail::ExpressionPtr&& expr)
81 : expr{std::move(expr)} {}
82
89 Variable& operator=(double value) {
90 expr = detail::make_expression_ptr<detail::ConstExpression>(value);
91 m_graph_initialized = false;
92
93 return *this;
94 }
95
101 void set_value(double value) {
102#ifndef SLEIPNIR_DISABLE_DIAGNOSTICS
103 // We only need to check the first argument since unary and binary operators
104 // both use it
105 if (expr->args[0] != nullptr) {
106 auto location = std::source_location::current();
107 slp::println(
108 stderr,
109 "WARNING: {}:{}: {}: Modified the value of a dependent variable",
110 location.file_name(), location.line(), location.function_name());
111 }
112#endif
113 expr->val = value;
114 }
115
121 double value() {
122 if (!m_graph_initialized) {
123 m_graph = detail::topological_sort(expr);
124 m_graph_initialized = true;
125 }
126 detail::update_values(m_graph);
127
128 return expr->val;
129 }
130
137 ExpressionType type() const { return expr->type(); }
138
146 friend SLEIPNIR_DLLEXPORT Variable operator*(const Variable& lhs,
147 const Variable& rhs) {
148 return Variable{lhs.expr * rhs.expr};
149 }
150
158 *this = *this * rhs;
159 return *this;
160 }
161
169 friend SLEIPNIR_DLLEXPORT Variable operator/(const Variable& lhs,
170 const Variable& rhs) {
171 return Variable{lhs.expr / rhs.expr};
172 }
173
181 *this = *this / rhs;
182 return *this;
183 }
184
192 friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs,
193 const Variable& rhs) {
194 return Variable{lhs.expr + rhs.expr};
195 }
196
204 *this = *this + rhs;
205 return *this;
206 }
207
215 friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs,
216 const Variable& rhs) {
217 return Variable{lhs.expr - rhs.expr};
218 }
219
227 *this = *this - rhs;
228 return *this;
229 }
230
236 friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs) {
237 return Variable{-lhs.expr};
238 }
239
245 friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs) {
246 return Variable{+lhs.expr};
247 }
248
249 private:
252 detail::make_expression_ptr<detail::DecisionVariableExpression>();
253
256 gch::small_vector<detail::Expression*> m_graph;
257
259 bool m_graph_initialized = false;
260
261 friend SLEIPNIR_DLLEXPORT Variable abs(const Variable& x);
262 friend SLEIPNIR_DLLEXPORT Variable acos(const Variable& x);
263 friend SLEIPNIR_DLLEXPORT Variable asin(const Variable& x);
264 friend SLEIPNIR_DLLEXPORT Variable atan(const Variable& x);
265 friend SLEIPNIR_DLLEXPORT Variable atan2(const Variable& y,
266 const Variable& x);
267 friend SLEIPNIR_DLLEXPORT Variable cbrt(const Variable& x);
268 friend SLEIPNIR_DLLEXPORT Variable cos(const Variable& x);
269 friend SLEIPNIR_DLLEXPORT Variable cosh(const Variable& x);
270 friend SLEIPNIR_DLLEXPORT Variable erf(const Variable& x);
271 friend SLEIPNIR_DLLEXPORT Variable exp(const Variable& x);
272 friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x,
273 const Variable& y);
274 friend SLEIPNIR_DLLEXPORT Variable log(const Variable& x);
275 friend SLEIPNIR_DLLEXPORT Variable log10(const Variable& x);
276 friend SLEIPNIR_DLLEXPORT Variable pow(const Variable& base,
277 const Variable& power);
278 friend SLEIPNIR_DLLEXPORT Variable sign(const Variable& x);
279 friend SLEIPNIR_DLLEXPORT Variable sin(const Variable& x);
280 friend SLEIPNIR_DLLEXPORT Variable sinh(const Variable& x);
281 friend SLEIPNIR_DLLEXPORT Variable sqrt(const Variable& x);
282 friend SLEIPNIR_DLLEXPORT Variable tan(const Variable& x);
283 friend SLEIPNIR_DLLEXPORT Variable tanh(const Variable& x);
284 friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, const Variable& y,
285 const Variable& z);
286
288 template <int UpLo>
289 requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper))
290 friend class SLEIPNIR_DLLEXPORT Hessian;
291 friend class SLEIPNIR_DLLEXPORT Jacobian;
292};
293
299SLEIPNIR_DLLEXPORT inline Variable abs(const Variable& x) {
300 return Variable{detail::abs(x.expr)};
301}
302
308SLEIPNIR_DLLEXPORT inline Variable acos(const Variable& x) {
309 return Variable{detail::acos(x.expr)};
310}
311
317SLEIPNIR_DLLEXPORT inline Variable asin(const Variable& x) {
318 return Variable{detail::asin(x.expr)};
319}
320
326SLEIPNIR_DLLEXPORT inline Variable atan(const Variable& x) {
327 return Variable{detail::atan(x.expr)};
328}
329
336SLEIPNIR_DLLEXPORT inline Variable atan2(const Variable& y, const Variable& x) {
337 return Variable{detail::atan2(y.expr, x.expr)};
338}
339
345SLEIPNIR_DLLEXPORT inline Variable cbrt(const Variable& x) {
346 return Variable{detail::cbrt(x.expr)};
347}
348
354SLEIPNIR_DLLEXPORT inline Variable cos(const Variable& x) {
355 return Variable{detail::cos(x.expr)};
356}
357
363SLEIPNIR_DLLEXPORT inline Variable cosh(const Variable& x) {
364 return Variable{detail::cosh(x.expr)};
365}
366
372SLEIPNIR_DLLEXPORT inline Variable erf(const Variable& x) {
373 return Variable{detail::erf(x.expr)};
374}
375
381SLEIPNIR_DLLEXPORT inline Variable exp(const Variable& x) {
382 return Variable{detail::exp(x.expr)};
383}
384
391SLEIPNIR_DLLEXPORT inline Variable hypot(const Variable& x, const Variable& y) {
392 return Variable{detail::hypot(x.expr, y.expr)};
393}
394
401SLEIPNIR_DLLEXPORT inline Variable pow(const Variable& base,
402 const Variable& power) {
403 return Variable{detail::pow(base.expr, power.expr)};
404}
405
411SLEIPNIR_DLLEXPORT inline Variable log(const Variable& x) {
412 return Variable{detail::log(x.expr)};
413}
414
420SLEIPNIR_DLLEXPORT inline Variable log10(const Variable& x) {
421 return Variable{detail::log10(x.expr)};
422}
423
429SLEIPNIR_DLLEXPORT inline Variable sign(const Variable& x) {
430 return Variable{detail::sign(x.expr)};
431}
432
438SLEIPNIR_DLLEXPORT inline Variable sin(const Variable& x) {
439 return Variable{detail::sin(x.expr)};
440}
441
447SLEIPNIR_DLLEXPORT inline Variable sinh(const Variable& x) {
448 return Variable{detail::sinh(x.expr)};
449}
450
456SLEIPNIR_DLLEXPORT inline Variable sqrt(const Variable& x) {
457 return Variable{detail::sqrt(x.expr)};
458}
459
465SLEIPNIR_DLLEXPORT inline Variable tan(const Variable& x) {
466 return Variable{detail::tan(x.expr)};
467}
468
474SLEIPNIR_DLLEXPORT inline Variable tanh(const Variable& x) {
475 return Variable{detail::tanh(x.expr)};
476}
477
485SLEIPNIR_DLLEXPORT inline Variable hypot(const Variable& x, const Variable& y,
486 const Variable& z) {
487 return Variable{slp::sqrt(slp::pow(x, 2) + slp::pow(y, 2) + slp::pow(z, 2))};
488}
489
501template <typename LHS, typename RHS>
502 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
505gch::small_vector<Variable> make_constraints(LHS&& lhs, RHS&& rhs) {
506 gch::small_vector<Variable> constraints;
507
508 if constexpr (ScalarLike<LHS> && ScalarLike<RHS>) {
509 constraints.emplace_back(lhs - rhs);
510 } else if constexpr (ScalarLike<LHS> && MatrixLike<RHS>) {
511 constraints.reserve(rhs.rows() * rhs.cols());
512
513 for (int row = 0; row < rhs.rows(); ++row) {
514 for (int col = 0; col < rhs.cols(); ++col) {
515 // Make right-hand side zero
516 if constexpr (EigenMatrixLike<std::decay_t<RHS>>) {
517 constraints.emplace_back(lhs - rhs(row, col));
518 } else {
519 constraints.emplace_back(lhs - rhs[row, col]);
520 }
521 }
522 }
523 } else if constexpr (MatrixLike<LHS> && ScalarLike<RHS>) {
524 constraints.reserve(lhs.rows() * lhs.cols());
525
526 for (int row = 0; row < lhs.rows(); ++row) {
527 for (int col = 0; col < lhs.cols(); ++col) {
528 // Make right-hand side zero
529 if constexpr (EigenMatrixLike<std::decay_t<LHS>>) {
530 constraints.emplace_back(lhs(row, col) - rhs);
531 } else {
532 constraints.emplace_back(lhs[row, col] - rhs);
533 }
534 }
535 }
536 } else if constexpr (MatrixLike<LHS> && MatrixLike<RHS>) {
537 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
538 constraints.reserve(lhs.rows() * lhs.cols());
539
540 for (int row = 0; row < lhs.rows(); ++row) {
541 for (int col = 0; col < lhs.cols(); ++col) {
542 // Make right-hand side zero
543 if constexpr (EigenMatrixLike<std::decay_t<LHS>> &&
544 EigenMatrixLike<std::decay_t<RHS>>) {
545 constraints.emplace_back(lhs(row, col) - rhs(row, col));
546 } else if constexpr (EigenMatrixLike<std::decay_t<LHS>> &&
547 SleipnirMatrixLike<std::decay_t<RHS>>) {
548 constraints.emplace_back(lhs(row, col) - rhs[row, col]);
549 } else if constexpr (SleipnirMatrixLike<std::decay_t<LHS>> &&
550 EigenMatrixLike<std::decay_t<RHS>>) {
551 constraints.emplace_back(lhs[row, col] - rhs(row, col));
552 } else if constexpr (SleipnirMatrixLike<std::decay_t<LHS>> &&
553 SleipnirMatrixLike<std::decay_t<RHS>>) {
554 constraints.emplace_back(lhs[row, col] - rhs[row, col]);
555 }
556 }
557 }
558 }
559
560 return constraints;
561}
562
566struct SLEIPNIR_DLLEXPORT EqualityConstraints {
568 gch::small_vector<Variable> constraints;
569
576 std::initializer_list<EqualityConstraints> equality_constraints) {
577 for (const auto& elem : equality_constraints) {
578 constraints.insert(constraints.end(), elem.constraints.begin(),
579 elem.constraints.end());
580 }
581 }
582
591 const std::vector<EqualityConstraints>& equality_constraints) {
592 for (const auto& elem : equality_constraints) {
593 constraints.insert(constraints.end(), elem.constraints.begin(),
594 elem.constraints.end());
595 }
596 }
597
607 template <typename LHS, typename RHS>
608 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
611 EqualityConstraints(LHS&& lhs, RHS&& rhs)
612 : constraints{make_constraints(lhs, rhs)} {}
613
617 operator bool() { // NOLINT
618 return std::ranges::all_of(constraints, [](auto& constraint) {
619 return constraint.value() == 0.0;
620 });
621 }
622};
623
627struct SLEIPNIR_DLLEXPORT InequalityConstraints {
629 gch::small_vector<Variable> constraints;
630
638 std::initializer_list<InequalityConstraints> inequality_constraints) {
639 for (const auto& elem : inequality_constraints) {
640 constraints.insert(constraints.end(), elem.constraints.begin(),
641 elem.constraints.end());
642 }
643 }
644
654 const std::vector<InequalityConstraints>& inequality_constraints) {
655 for (const auto& elem : inequality_constraints) {
656 constraints.insert(constraints.end(), elem.constraints.begin(),
657 elem.constraints.end());
658 }
659 }
660
670 template <typename LHS, typename RHS>
671 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
674 InequalityConstraints(LHS&& lhs, RHS&& rhs)
675 : constraints{make_constraints(lhs, rhs)} {}
676
680 operator bool() { // NOLINT
681 return std::ranges::all_of(constraints, [](auto& constraint) {
682 return constraint.value() >= 0.0;
683 });
684 }
685};
686
693template <typename LHS, typename RHS>
694 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
695 (ScalarLike<RHS> || MatrixLike<RHS>) &&
696 (SleipnirType<LHS> || SleipnirType<RHS>)
697EqualityConstraints operator==(LHS&& lhs, RHS&& rhs) {
698 return EqualityConstraints{lhs, rhs};
699}
700
708template <typename LHS, typename RHS>
709 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
710 (ScalarLike<RHS> || MatrixLike<RHS>) &&
711 (SleipnirType<LHS> || SleipnirType<RHS>)
712InequalityConstraints operator<(LHS&& lhs, RHS&& rhs) {
713 return rhs >= lhs;
714}
715
723template <typename LHS, typename RHS>
724 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
725 (ScalarLike<RHS> || MatrixLike<RHS>) &&
726 (SleipnirType<LHS> || SleipnirType<RHS>)
727InequalityConstraints operator<=(LHS&& lhs, RHS&& rhs) {
728 return rhs >= lhs;
729}
730
738template <typename LHS, typename RHS>
739 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
740 (ScalarLike<RHS> || MatrixLike<RHS>) &&
741 (SleipnirType<LHS> || SleipnirType<RHS>)
742InequalityConstraints operator>(LHS&& lhs, RHS&& rhs) {
743 return lhs >= rhs;
744}
745
753template <typename LHS, typename RHS>
754 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
755 (ScalarLike<RHS> || MatrixLike<RHS>) &&
756 (SleipnirType<LHS> || SleipnirType<RHS>)
757InequalityConstraints operator>=(LHS&& lhs, RHS&& rhs) {
758 return InequalityConstraints{lhs, rhs};
759}
760
761} // namespace slp
762
763namespace Eigen {
764
768template <>
769struct NumTraits<slp::Variable> : NumTraits<double> {
776
778 static constexpr int IsComplex = 0;
780 static constexpr int IsInteger = 0;
782 static constexpr int IsSigned = 1;
784 static constexpr int RequireInitialization = 1;
786 static constexpr int ReadCost = 1;
788 static constexpr int AddCost = 3;
790 static constexpr int MulCost = 3;
791};
792
793} // namespace Eigen
Definition hessian.hpp:30
Definition jacobian.hpp:26
Definition variable.hpp:40
friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable &lhs)
Definition variable.hpp:245
constexpr Variable(detail::ExpressionPtr &&expr)
Definition variable.hpp:80
friend SLEIPNIR_DLLEXPORT Variable operator*(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:146
Variable & operator=(double value)
Definition variable.hpp:89
Variable & operator-=(const Variable &rhs)
Definition variable.hpp:226
friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable &lhs)
Definition variable.hpp:236
friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:192
Variable & operator+=(const Variable &rhs)
Definition variable.hpp:203
friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:215
Variable & operator*=(const Variable &rhs)
Definition variable.hpp:157
Variable(std::floating_point auto value)
Definition variable.hpp:57
constexpr Variable(std::nullptr_t)
Definition variable.hpp:50
friend SLEIPNIR_DLLEXPORT Variable operator/(const Variable &lhs, const Variable &rhs)
Definition variable.hpp:169
Variable()=default
Variable(std::integral auto value)
Definition variable.hpp:65
void set_value(double value)
Definition variable.hpp:101
Variable & operator/=(const Variable &rhs)
Definition variable.hpp:180
ExpressionType type() const
Definition variable.hpp:137
double value()
Definition variable.hpp:121
Variable(const detail::ExpressionPtr &expr)
Definition variable.hpp:73
Definition adjoint_expression_graph.hpp:22
Definition concepts.hpp:40
Definition concepts.hpp:13
Definition concepts.hpp:37
Definition variable.hpp:566
gch::small_vector< Variable > constraints
A vector of scalar equality constraints.
Definition variable.hpp:568
EqualityConstraints(std::initializer_list< EqualityConstraints > equality_constraints)
Definition variable.hpp:575
EqualityConstraints(const std::vector< EqualityConstraints > &equality_constraints)
Definition variable.hpp:590
EqualityConstraints(LHS &&lhs, RHS &&rhs)
Definition variable.hpp:611
Definition variable.hpp:627
InequalityConstraints(const std::vector< InequalityConstraints > &inequality_constraints)
Definition variable.hpp:653
InequalityConstraints(std::initializer_list< InequalityConstraints > inequality_constraints)
Definition variable.hpp:637
InequalityConstraints(LHS &&lhs, RHS &&rhs)
Definition variable.hpp:674
gch::small_vector< Variable > constraints
A vector of scalar inequality constraints.
Definition variable.hpp:629