7#include <initializer_list>
8#include <source_location>
13#include <gch/small_vector.hpp>
15#include "sleipnir/autodiff/expression.hpp"
16#include "sleipnir/autodiff/expression_graph.hpp"
17#include "sleipnir/autodiff/sleipnir_base.hpp"
18#include "sleipnir/util/assert.hpp"
19#include "sleipnir/util/concepts.hpp"
21#ifndef SLEIPNIR_DISABLE_DIAGNOSTICS
22#include "sleipnir/util/print.hpp"
29template <
typename Scalar>
34template <
typename Scalar>
35VariableMatrix<Scalar> gradient_tree(
36 const detail::ExpressionGraph<Scalar>& top_list,
37 const VariableMatrix<Scalar>& wrt);
41template <
typename Scalar,
int UpLo = Eigen::Lower | Eigen::Upper>
42 requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper))
45template <
typename Scalar>
51template <
typename Scalar_>
69 : expr{detail::make_expression_ptr<detail::ConstantExpression<
Scalar>>(
77 slp_assert(
value.rows() == 1 &&
value.cols() == 1);
85 : expr{detail::make_expression_ptr<detail::ConstantExpression<
Scalar>>(
93 : expr{detail::make_expression_ptr<detail::ConstantExpression<
Scalar>>(
105 : expr{std::
move(expr)} {}
113 detail::make_expression_ptr<detail::ConstantExpression<Scalar>>(
value);
114 m_graph_initialized =
false;
123#ifndef SLEIPNIR_DISABLE_DIAGNOSTICS
126 if (expr->args[0] !=
nullptr) {
127 auto location = std::source_location::current();
130 "WARNING: {}:{}: {}: Modified the value of a dependent variable",
141 if (!m_graph_initialized) {
142 m_graph = detail::topological_sort(expr);
143 m_graph_initialized =
true;
145 detail::update_values(m_graph);
154 ExpressionType
type()
const {
return expr->type(); }
161 template <ScalarLike LHS, SleipnirScalarLike<Scalar> RHS>
171 template <SleipnirScalarLike<Scalar> LHS, ScalarLike RHS>
269 detail::make_expression_ptr<detail::DecisionVariableExpression<Scalar>>();
273 detail::ExpressionGraph<Scalar> m_graph;
276 bool m_graph_initialized =
false;
278 template <
typename Scalar>
280 template <
typename Scalar>
282 template <
typename Scalar>
284 template <
typename Scalar>
286 template <
typename Scalar>
289 template <
typename Scalar>
292 template <
typename Scalar>
295 template <
typename Scalar>
297 template <
typename Scalar>
299 template <
typename Scalar>
301 template <
typename Scalar>
303 template <
typename Scalar>
305 template <
typename Scalar>
308 template <
typename Scalar>
311 template <
typename Scalar>
314 template <
typename Scalar>
316 template <
typename Scalar>
318 template <
typename Scalar>
321 template <
typename Scalar>
324 template <
typename Scalar>
327 template <
typename Scalar>
330 template <
typename Scalar>
333 template <
typename Scalar>
336 template <
typename Scalar>
339 template <
typename Scalar>
342 template <
typename Scalar>
345 template <
typename Scalar>
347 template <
typename Scalar>
349 template <
typename Scalar>
351 template <
typename Scalar>
353 template <
typename Scalar>
355 template <
typename Scalar>
357 template <
typename Scalar>
362 template <
typename Scalar>
364 const detail::ExpressionGraph<Scalar>&
top_list,
366 template <
typename Scalar,
int UpLo>
367 requires(
UpLo == Eigen::Lower) || (
UpLo == (Eigen::Lower | Eigen::Upper))
369 template <
typename Scalar>
373template <
template <
typename>
typename T,
typename Scalar>
374 requires SleipnirMatrixLike<T<Scalar>, Scalar>
375Variable(T<Scalar>) -> Variable<Scalar>;
377template <std::
floating_po
int T>
378Variable(T) -> Variable<T>;
380template <std::
integral T>
381Variable(T) -> Variable<T>;
387template <
typename Scalar>
388Variable<Scalar> abs(
const Variable<Scalar>& x) {
389 return Variable{detail::abs(x.expr)};
396template <
typename Scalar>
397Variable<Scalar> acos(
const Variable<Scalar>& x) {
398 return Variable{detail::acos(x.expr)};
405template <
typename Scalar>
406Variable<Scalar> asin(
const Variable<Scalar>& x) {
407 return Variable{detail::asin(x.expr)};
414template <
typename Scalar>
415Variable<Scalar> atan(
const Variable<Scalar>& x) {
416 return Variable{detail::atan(x.expr)};
424template <
typename Scalar>
425Variable<Scalar> atan2(
const ScalarLike
auto& y,
const Variable<Scalar>& x) {
426 return Variable{detail::atan2(Variable<Scalar>(y).expr, x.expr)};
434template <
typename Scalar>
435Variable<Scalar> atan2(
const Variable<Scalar>& y,
const ScalarLike
auto& x) {
436 return Variable{detail::atan2(y.expr, Variable<Scalar>(x).expr)};
444template <
typename Scalar>
445Variable<Scalar> atan2(
const Variable<Scalar>& y,
const Variable<Scalar>& x) {
446 return Variable{detail::atan2(y.expr, x.expr)};
453template <
typename Scalar>
454Variable<Scalar> cbrt(
const Variable<Scalar>& x) {
455 return Variable{detail::cbrt(x.expr)};
462template <
typename Scalar>
463Variable<Scalar> cos(
const Variable<Scalar>& x) {
464 return Variable{detail::cos(x.expr)};
471template <
typename Scalar>
472Variable<Scalar> cosh(
const Variable<Scalar>& x) {
473 return Variable{detail::cosh(x.expr)};
480template <
typename Scalar>
481Variable<Scalar> erf(
const Variable<Scalar>& x) {
482 return Variable{detail::erf(x.expr)};
489template <
typename Scalar>
490Variable<Scalar> exp(
const Variable<Scalar>& x) {
491 return Variable{detail::exp(x.expr)};
499template <
typename Scalar>
500Variable<Scalar> hypot(
const ScalarLike
auto& x,
const Variable<Scalar>& y) {
501 return Variable{detail::hypot(Variable<Scalar>(x).expr, y.expr)};
509template <
typename Scalar>
510Variable<Scalar> hypot(
const Variable<Scalar>& x,
const ScalarLike
auto& y) {
511 return Variable{detail::hypot(x.expr, Variable<Scalar>(y).expr)};
519template <
typename Scalar>
520Variable<Scalar> hypot(
const Variable<Scalar>& x,
const Variable<Scalar>& y) {
521 return Variable{detail::hypot(x.expr, y.expr)};
528template <
typename Scalar>
529Variable<Scalar> log(
const Variable<Scalar>& x) {
530 return Variable{detail::log(x.expr)};
537template <
typename Scalar>
538Variable<Scalar> log10(
const Variable<Scalar>& x) {
539 return Variable{detail::log10(x.expr)};
549template <
typename Scalar>
550Variable<Scalar> max(
const ScalarLike
auto& a,
const Variable<Scalar>& b) {
551 return Variable{detail::max(Variable<Scalar>(a).expr, b.expr)};
561template <
typename Scalar>
562Variable<Scalar> max(
const Variable<Scalar>& a,
const ScalarLike
auto& b) {
563 return Variable{detail::max(a.expr, Variable<Scalar>(b).expr)};
573template <
typename Scalar>
574Variable<Scalar> max(
const Variable<Scalar>& a,
const Variable<Scalar>& b) {
575 return Variable{detail::max(a.expr, b.expr)};
585template <
typename Scalar>
586Variable<Scalar> min(
const ScalarLike
auto& a,
const Variable<Scalar>& b) {
587 return Variable{detail::min(Variable<Scalar>(a).expr, b.expr)};
597template <
typename Scalar>
598Variable<Scalar> min(
const Variable<Scalar>& a,
const ScalarLike
auto& b) {
599 return Variable{detail::min(a.expr, Variable<Scalar>(b).expr)};
609template <
typename Scalar>
610Variable<Scalar> min(
const Variable<Scalar>& a,
const Variable<Scalar>& b) {
611 return Variable{detail::min(a.expr, b.expr)};
619template <
typename Scalar>
620Variable<Scalar> pow(
const ScalarLike
auto& base,
621 const Variable<Scalar>& power) {
622 return Variable{detail::pow(Variable<Scalar>(base).expr, power.expr)};
630template <
typename Scalar>
631Variable<Scalar> pow(
const Variable<Scalar>& base,
632 const ScalarLike
auto& power) {
633 return Variable{detail::pow(base.expr, Variable<Scalar>(power).expr)};
641template <
typename Scalar>
642Variable<Scalar> pow(
const Variable<Scalar>& base,
643 const Variable<Scalar>& power) {
644 return Variable{detail::pow(base.expr, power.expr)};
651template <
typename Scalar>
652Variable<Scalar> sign(
const Variable<Scalar>& x) {
653 return Variable{detail::sign(x.expr)};
660template <
typename Scalar>
661Variable<Scalar> sin(
const Variable<Scalar>& x) {
662 return Variable{detail::sin(x.expr)};
669template <
typename Scalar>
670Variable<Scalar> sinh(
const Variable<Scalar>& x) {
671 return Variable{detail::sinh(x.expr)};
678template <
typename Scalar>
679Variable<Scalar> sqrt(
const Variable<Scalar>& x) {
680 return Variable{detail::sqrt(x.expr)};
687template <
typename Scalar>
688Variable<Scalar> tan(
const Variable<Scalar>& x) {
689 return Variable{detail::tan(x.expr)};
696template <
typename Scalar>
697Variable<Scalar> tanh(
const Variable<Scalar>& x) {
698 return Variable{detail::tanh(x.expr)};
707template <
typename Scalar>
708Variable<Scalar> hypot(
const Variable<Scalar>& x,
const Variable<Scalar>& y,
709 const Variable<Scalar>& z) {
710 return Variable{sqrt(pow(x, 2) + pow(y, 2) + pow(z, 2))};
718template <
typename Scalar, ScalarLike LHS, ScalarLike RHS>
719 requires SleipnirScalarLike<LHS, Scalar> || SleipnirScalarLike<RHS, Scalar>
720auto make_constraints(LHS&& lhs, RHS&& rhs) {
721 gch::small_vector<Variable<Scalar>> constraints;
722 constraints.emplace_back(lhs - rhs);
727template <
typename Scalar, ScalarLike LHS, MatrixLike RHS>
728 requires SleipnirScalarLike<LHS, Scalar> || SleipnirMatrixLike<RHS, Scalar>
729auto make_constraints(LHS&& lhs, RHS&& rhs) {
730 gch::small_vector<Variable<Scalar>> constraints;
731 constraints.reserve(rhs.rows() * rhs.cols());
733 for (
int row = 0; row < rhs.rows(); ++row) {
734 for (
int col = 0; col < rhs.cols(); ++col) {
736 constraints.emplace_back(lhs - rhs[row, col]);
743template <
typename Scalar, MatrixLike LHS, ScalarLike RHS>
744 requires SleipnirMatrixLike<LHS, Scalar> || SleipnirScalarLike<RHS, Scalar>
745auto make_constraints(LHS&& lhs, RHS&& rhs) {
746 gch::small_vector<Variable<Scalar>> constraints;
747 constraints.reserve(lhs.rows() * lhs.cols());
749 for (
int row = 0; row < lhs.rows(); ++row) {
750 for (
int col = 0; col < lhs.cols(); ++col) {
752 constraints.emplace_back(lhs[row, col] - rhs);
759template <
typename Scalar, MatrixLike LHS, MatrixLike RHS>
760 requires SleipnirMatrixLike<LHS, Scalar> || SleipnirMatrixLike<RHS, Scalar>
761auto make_constraints(LHS&& lhs, RHS&& rhs) {
762 slp_assert(lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols());
764 gch::small_vector<Variable<Scalar>> constraints;
765 constraints.reserve(lhs.rows() * lhs.cols());
767 for (
int row = 0; row < lhs.rows(); ++row) {
768 for (
int col = 0; col < lhs.cols(); ++col) {
770 constraints.emplace_back(lhs[row, col] - rhs[row, col]);
780template <
typename Scalar>
793 elem.constraints.end());
807 elem.constraints.end());
818 template <
typename LHS,
typename RHS>
837template <
typename Scalar>
850 elem.constraints.end());
864 elem.constraints.end());
876 template <
typename LHS,
typename RHS>
896template <
typename LHS,
typename RHS>
897 requires(ScalarLike<LHS> || MatrixLike<LHS>) && SleipnirType<LHS> &&
898 (ScalarLike<RHS> || MatrixLike<RHS>) && (!SleipnirType<RHS>)
899auto operator==(LHS&& lhs, RHS&& rhs) {
900 return EqualityConstraints<typename std::decay_t<LHS>::Scalar>{lhs, rhs};
907template <
typename LHS,
typename RHS>
908 requires(ScalarLike<LHS> || MatrixLike<LHS>) && (!SleipnirType<LHS>) &&
909 (ScalarLike<RHS> || MatrixLike<RHS>) && SleipnirType<RHS>
910auto operator==(LHS&& lhs, RHS&& rhs) {
911 return EqualityConstraints<typename std::decay_t<RHS>::Scalar>{lhs, rhs};
918template <
typename LHS,
typename RHS>
919 requires(ScalarLike<LHS> || MatrixLike<LHS>) && SleipnirType<LHS> &&
920 (ScalarLike<RHS> || MatrixLike<RHS>) && SleipnirType<RHS>
921auto operator==(LHS&& lhs, RHS&& rhs) {
922 return EqualityConstraints<typename std::decay_t<LHS>::Scalar>{lhs, rhs};
930template <
typename LHS,
typename RHS>
931 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
932 (ScalarLike<RHS> || MatrixLike<RHS>) &&
933 (SleipnirType<LHS> || SleipnirType<RHS>)
934auto operator<(LHS&& lhs, RHS&& rhs) {
943template <
typename LHS,
typename RHS>
944 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
945 (ScalarLike<RHS> || MatrixLike<RHS>) &&
946 (SleipnirType<LHS> || SleipnirType<RHS>)
947auto operator<=(LHS&& lhs, RHS&& rhs) {
956template <
typename LHS,
typename RHS>
957 requires(ScalarLike<LHS> || MatrixLike<LHS>) &&
958 (ScalarLike<RHS> || MatrixLike<RHS>) &&
959 (SleipnirType<LHS> || SleipnirType<RHS>)
960auto operator>(LHS&& lhs, RHS&& rhs) {
969template <
typename LHS,
typename RHS>
970 requires(ScalarLike<LHS> || MatrixLike<LHS>) && SleipnirType<LHS> &&
971 (ScalarLike<RHS> || MatrixLike<RHS>) && (!SleipnirType<RHS>)
972auto operator>=(LHS&& lhs, RHS&& rhs) {
973 return InequalityConstraints<typename std::decay_t<LHS>::Scalar>{lhs, rhs};
981template <
typename LHS,
typename RHS>
982 requires(ScalarLike<LHS> || MatrixLike<LHS>) && (!SleipnirType<LHS>) &&
983 (ScalarLike<RHS> || MatrixLike<RHS>) && SleipnirType<RHS>
984auto operator>=(LHS&& lhs, RHS&& rhs) {
985 return InequalityConstraints<typename std::decay_t<RHS>::Scalar>{lhs, rhs};
993template <
typename LHS,
typename RHS>
994 requires(ScalarLike<LHS> || MatrixLike<LHS>) && SleipnirType<LHS> &&
995 (ScalarLike<RHS> || MatrixLike<RHS>) && SleipnirType<RHS>
996auto operator>=(LHS&& lhs, RHS&& rhs) {
997 return InequalityConstraints<typename std::decay_t<LHS>::Scalar>{lhs, rhs};
1005template <
typename L,
typename X,
typename U>
1006 requires(ScalarLike<L> || MatrixLike<L>) && SleipnirType<X> &&
1007 (ScalarLike<U> || MatrixLike<U>)
1008auto bounds(L&& l, X&& x, U&& u) {
1009 return InequalityConstraints{l <= x, x <= u};
1020template <
typename Scalar>
1021struct NumTraits<slp::Variable<Scalar>> : NumTraits<Scalar> {
1030 static constexpr int IsComplex = 0;
1032 static constexpr int IsInteger = 0;
1034 static constexpr int IsSigned = 1;
1036 static constexpr int RequireInitialization = 1;
1038 static constexpr int ReadCost = 1;
1040 static constexpr int AddCost = 3;
1042 static constexpr int MulCost = 3;
Definition hessian.hpp:30
Definition intrusive_shared_ptr.hpp:27
Definition jacobian.hpp:27
Definition sleipnir_base.hpp:9
Definition variable.hpp:52
ExpressionType type() const
Definition variable.hpp:154
friend Variable< Scalar > operator*(const LHS &lhs, const RHS &rhs)
Definition variable.hpp:162
Variable(const detail::ExpressionPtr< Scalar > &expr)
Definition variable.hpp:99
Variable< Scalar > & operator/=(const Variable< Scalar > &rhs)
Definition variable.hpp:209
Variable< Scalar > & operator*=(const Variable< Scalar > &rhs)
Definition variable.hpp:190
friend Variable< Scalar > operator+(const Variable< Scalar > &lhs, const Variable< Scalar > &rhs)
Definition variable.hpp:219
void set_value(Scalar value)
Definition variable.hpp:122
friend Variable< Scalar > operator/(const Variable< Scalar > &lhs, const Variable< Scalar > &rhs)
Definition variable.hpp:200
Variable(std::floating_point auto value)
Definition variable.hpp:84
constexpr Variable(detail::ExpressionPtr< Scalar > &&expr)
Definition variable.hpp:104
friend Variable< Scalar > operator+(const Variable< Scalar > &lhs)
Definition variable.hpp:262
Variable(Scalar value)
Definition variable.hpp:67
constexpr Variable(std::nullptr_t)
Constructs an empty Variable.
Definition variable.hpp:61
Variable< Scalar > & operator=(ScalarLike auto value)
Definition variable.hpp:111
Variable()=default
Constructs a linear Variable with a value of zero.
friend Variable< Scalar > operator-(const Variable< Scalar > &lhs)
Definition variable.hpp:255
friend Variable< Scalar > operator*(const Variable< Scalar > &lhs, const Variable< Scalar > &rhs)
Definition variable.hpp:181
Variable(std::integral auto value)
Definition variable.hpp:92
Variable(SleipnirMatrixLike< Scalar > auto value)
Definition variable.hpp:76
Variable< Scalar > & operator+=(const Variable< Scalar > &rhs)
Definition variable.hpp:228
Scalar_ Scalar
Scalar type alias.
Definition variable.hpp:55
Scalar value()
Definition variable.hpp:140
Variable< Scalar > & operator-=(const Variable< Scalar > &rhs)
Definition variable.hpp:247
friend Variable< Scalar > operator-(const Variable< Scalar > &lhs, const Variable< Scalar > &rhs)
Definition variable.hpp:238
Definition concepts.hpp:18
Definition concepts.hpp:24
Definition concepts.hpp:33
Definition concepts.hpp:15
Definition variable.hpp:781
EqualityConstraints(LHS &&lhs, RHS &&rhs)
Definition variable.hpp:822
EqualityConstraints(std::initializer_list< EqualityConstraints > equality_constraints)
Definition variable.hpp:789
gch::small_vector< Variable< Scalar > > constraints
A vector of scalar equality constraints.
Definition variable.hpp:783
EqualityConstraints(const std::vector< EqualityConstraints > &equality_constraints)
Definition variable.hpp:803
Definition variable.hpp:838
gch::small_vector< Variable< Scalar > > constraints
A vector of scalar inequality constraints.
Definition variable.hpp:840
InequalityConstraints(LHS &&lhs, RHS &&rhs)
Definition variable.hpp:880
InequalityConstraints(const std::vector< InequalityConstraints > &inequality_constraints)
Definition variable.hpp:860
InequalityConstraints(std::initializer_list< InequalityConstraints > inequality_constraints)
Definition variable.hpp:846