15#include <gch/small_vector.hpp>
17#include "sleipnir/autodiff/expression_type.hpp"
18#include "sleipnir/util/intrusive_shared_ptr.hpp"
19#include "sleipnir/util/pool.hpp"
21namespace slp::detail {
26inline constexpr bool USE_POOL_ALLOCATOR =
false;
28inline constexpr bool USE_POOL_ALLOCATOR =
true;
31template <
typename Scalar>
34template <
typename Scalar>
35constexpr void inc_ref_count(Expression<Scalar>* expr);
36template <
typename Scalar>
37constexpr void dec_ref_count(Expression<Scalar>* expr);
42template <
typename Scalar>
43using ExpressionPtr = IntrusiveSharedPtr<Expression<Scalar>>;
50template <
typename T,
typename... Args>
51static ExpressionPtr<typename T::Scalar> make_expression_ptr(Args&&... args) {
52 if constexpr (USE_POOL_ALLOCATOR) {
53 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
54 std::forward<Args>(args)...);
56 return make_intrusive_shared<T>(std::forward<Args>(args)...);
60template <
typename Scalar, ExpressionType T>
61struct BinaryMinusExpression;
63template <
typename Scalar, ExpressionType T>
64struct BinaryPlusExpression;
66template <
typename Scalar>
67struct ConstantExpression;
69template <
typename Scalar, ExpressionType T>
72template <
typename Scalar, ExpressionType T>
75template <
typename Scalar, ExpressionType T>
76struct UnaryMinusExpression;
82template <
typename Scalar>
83ExpressionPtr<Scalar> constant_ptr(Scalar value);
88template <
typename Scalar_>
113 std::array<ExpressionPtr<Scalar>, 2>
args{
nullptr,
nullptr};
152 using enum ExpressionType;
158 }
else if (
rhs->is_constant(
Scalar(0))) {
161 }
else if (
lhs->is_constant(
Scalar(1))) {
163 }
else if (
rhs->is_constant(
Scalar(1))) {
168 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
169 return constant_ptr(
lhs->val *
rhs->val);
173 if (
lhs->type() == CONSTANT) {
174 if (
rhs->type() == LINEAR) {
176 }
else if (
rhs->type() == QUADRATIC) {
181 }
else if (
rhs->type() == CONSTANT) {
182 if (
lhs->type() == LINEAR) {
184 }
else if (
lhs->type() == QUADRATIC) {
189 }
else if (
lhs->type() == LINEAR &&
rhs->type() == LINEAR) {
202 using enum ExpressionType;
208 }
else if (
rhs->is_constant(
Scalar(1))) {
213 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
214 return constant_ptr(
lhs->val /
rhs->val);
218 if (
rhs->type() == CONSTANT) {
219 if (
lhs->type() == LINEAR) {
221 }
else if (
lhs->type() == QUADRATIC) {
237 using enum ExpressionType;
242 }
else if (
rhs ==
nullptr ||
rhs->is_constant(
Scalar(0))) {
247 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
248 return constant_ptr(
lhs->val +
rhs->val);
251 auto type = std::max(
lhs->type(),
rhs->type());
252 if (
type == LINEAR) {
255 }
else if (
type == QUADRATIC) {
279 using enum ExpressionType;
289 }
else if (
rhs->is_constant(
Scalar(0))) {
294 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
295 return constant_ptr(
lhs->val -
rhs->val);
298 auto type = std::max(
lhs->type(),
rhs->type());
299 if (
type == LINEAR) {
302 }
else if (
type == QUADRATIC) {
315 using enum ExpressionType;
324 if (
lhs->type() == CONSTANT) {
325 return constant_ptr(-
lhs->val);
328 if (
lhs->type() == LINEAR) {
330 }
else if (
lhs->type() == QUADRATIC) {
358 virtual ExpressionType
type()
const = 0;
363 virtual std::string_view
name()
const = 0;
399 return constant_ptr(
Scalar(0));
412 return constant_ptr(
Scalar(0));
416template <
typename Scalar>
421template <
typename Scalar>
422ExpressionPtr<Scalar> cbrt(
const ExpressionPtr<Scalar>& x);
423template <
typename Scalar>
424ExpressionPtr<Scalar> exp(
const ExpressionPtr<Scalar>& x);
425template <
typename Scalar>
426ExpressionPtr<Scalar> sin(
const ExpressionPtr<Scalar>& x);
427template <
typename Scalar>
428ExpressionPtr<Scalar> sinh(
const ExpressionPtr<Scalar>& x);
429template <
typename Scalar>
430ExpressionPtr<Scalar> sqrt(
const ExpressionPtr<Scalar>& x);
436template <
typename Scalar, ExpressionType T>
448 ExpressionType
type()
const override {
return T; }
450 std::string_view
name()
const override {
return "binary minus"; }
477template <
typename Scalar, ExpressionType T>
489 ExpressionType
type()
const override {
return T; }
491 std::string_view
name()
const override {
return "binary plus"; }
517template <
typename Scalar>
530 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
532 std::string_view
name()
const override {
return "cbrt"; }
553template <
typename Scalar>
555 using enum ExpressionType;
559 if (x->type() == CONSTANT) {
560 if (x->val == Scalar(0)) {
563 }
else if (x->val == Scalar(-1) || x->val == Scalar(1)) {
566 return constant_ptr(cbrt(x->val));
570 return make_expression_ptr<CbrtExpression<Scalar>>(x);
576template <
typename Scalar>
586 ExpressionType
type()
const override {
return ExpressionType::CONSTANT; }
588 std::string_view
name()
const override {
return "constant"; }
594template <
typename Scalar>
607 ExpressionType
type()
const override {
return ExpressionType::LINEAR; }
609 std::string_view
name()
const override {
return "decision variable"; }
616template <
typename Scalar, ExpressionType T>
627 ExpressionType
type()
const override {
return T; }
629 std::string_view
name()
const override {
return "division"; }
656template <
typename Scalar, ExpressionType T>
667 ExpressionType
type()
const override {
return T; }
669 std::string_view
name()
const override {
return "multiplication"; }
700template <
typename Scalar, ExpressionType T>
710 ExpressionType
type()
const override {
return T; }
712 std::string_view
name()
const override {
return "unary minus"; }
729template <
typename Scalar>
738template <
typename Scalar>
739constexpr void dec_ref_count(Expression<Scalar>* expr) {
744 gch::small_vector<Expression<Scalar>*> stack;
745 stack.emplace_back(expr);
747 while (!stack.empty()) {
748 auto elem = stack.back();
753 if (--elem->ref_count == 0) {
754 if (elem->adjoint_expr !=
nullptr) {
755 stack.emplace_back(elem->adjoint_expr.get());
757 for (
auto& arg : elem->args) {
758 if (arg !=
nullptr) {
759 stack.emplace_back(arg.get());
765 if constexpr (USE_POOL_ALLOCATOR) {
766 auto alloc = global_pool_allocator<Expression<Scalar>>();
767 std::allocator_traits<
decltype(alloc)>::deallocate(
768 alloc, elem,
sizeof(Expression<Scalar>));
777template <
typename Scalar>
790 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
792 std::string_view
name()
const override {
return "abs"; }
797 }
else if (x >
Scalar(0)) {
809 }
else if (x->val >
Scalar(0)) {
812 return constant_ptr(
Scalar(0));
821template <
typename Scalar>
823 using enum ExpressionType;
827 if (x->is_constant(Scalar(0))) {
833 if (x->type() == CONSTANT) {
834 return constant_ptr(abs(x->val));
837 return make_expression_ptr<AbsExpression<Scalar>>(x);
843template <
typename Scalar>
856 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
858 std::string_view
name()
const override {
return "acos"; }
876template <
typename Scalar>
878 using enum ExpressionType;
882 if (x->is_constant(Scalar(0))) {
883 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
887 if (x->type() == CONSTANT) {
888 return constant_ptr(acos(x->val));
891 return make_expression_ptr<AcosExpression<Scalar>>(x);
897template <
typename Scalar>
910 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
912 std::string_view
name()
const override {
return "asin"; }
930template <
typename Scalar>
932 using enum ExpressionType;
936 if (x->is_constant(Scalar(0))) {
942 if (x->type() == CONSTANT) {
943 return constant_ptr(asin(x->val));
946 return make_expression_ptr<AsinExpression<Scalar>>(x);
952template <
typename Scalar>
965 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
967 std::string_view
name()
const override {
return "atan"; }
984template <
typename Scalar>
986 using enum ExpressionType;
990 if (x->is_constant(Scalar(0))) {
996 if (x->type() == CONSTANT) {
997 return constant_ptr(atan(x->val));
1000 return make_expression_ptr<AtanExpression<Scalar>>(x);
1006template <
typename Scalar>
1021 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1023 std::string_view
name()
const override {
return "atan2"; }
1051template <
typename Scalar>
1054 using enum ExpressionType;
1058 if (
y->is_constant(Scalar(0))) {
1061 }
else if (x->is_constant(Scalar(0))) {
1062 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
1066 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1067 return constant_ptr(atan2(y->val, x->val));
1070 return make_expression_ptr<Atan2Expression<Scalar>>(y, x);
1076template <
typename Scalar>
1089 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1091 std::string_view
name()
const override {
return "cos"; }
1109template <
typename Scalar>
1111 using enum ExpressionType;
1115 if (x->is_constant(Scalar(0))) {
1116 return constant_ptr(Scalar(1));
1120 if (x->type() == CONSTANT) {
1121 return constant_ptr(cos(x->val));
1124 return make_expression_ptr<CosExpression<Scalar>>(x);
1130template <
typename Scalar>
1143 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1145 std::string_view
name()
const override {
return "cosh"; }
1163template <
typename Scalar>
1165 using enum ExpressionType;
1169 if (x->is_constant(Scalar(0))) {
1170 return constant_ptr(Scalar(1));
1174 if (x->type() == CONSTANT) {
1175 return constant_ptr(cosh(x->val));
1178 return make_expression_ptr<CoshExpression<Scalar>>(x);
1184template <
typename Scalar>
1197 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1199 std::string_view
name()
const override {
return "erf"; }
1211 constant_ptr(
Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
1219template <
typename Scalar>
1221 using enum ExpressionType;
1225 if (x->is_constant(Scalar(0))) {
1231 if (x->type() == CONSTANT) {
1232 return constant_ptr(erf(x->val));
1235 return make_expression_ptr<ErfExpression<Scalar>>(x);
1241template <
typename Scalar>
1254 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1256 std::string_view
name()
const override {
return "exp"; }
1274template <
typename Scalar>
1276 using enum ExpressionType;
1280 if (x->is_constant(Scalar(0))) {
1281 return constant_ptr(Scalar(1));
1285 if (x->type() == CONSTANT) {
1286 return constant_ptr(exp(x->val));
1289 return make_expression_ptr<ExpExpression<Scalar>>(x);
1292template <
typename Scalar>
1293ExpressionPtr<Scalar> hypot(
const ExpressionPtr<Scalar>& x,
1294 const ExpressionPtr<Scalar>& y);
1299template <
typename Scalar>
1314 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1316 std::string_view
name()
const override {
return "hypot"; }
1346template <
typename Scalar>
1349 using enum ExpressionType;
1353 if (x->is_constant(Scalar(0))) {
1355 }
else if (y->is_constant(Scalar(0))) {
1360 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1361 return constant_ptr(hypot(x->val, y->val));
1364 return make_expression_ptr<HypotExpression<Scalar>>(x, y);
1370template <
typename Scalar>
1383 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1385 std::string_view
name()
const override {
return "log"; }
1402template <
typename Scalar>
1404 using enum ExpressionType;
1408 if (x->is_constant(Scalar(0))) {
1414 if (x->type() == CONSTANT) {
1415 return constant_ptr(log(x->val));
1418 return make_expression_ptr<LogExpression<Scalar>>(x);
1424template <
typename Scalar>
1437 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1439 std::string_view
name()
const override {
return "log10"; }
1456template <
typename Scalar>
1458 using enum ExpressionType;
1462 if (x->is_constant(Scalar(0))) {
1468 if (x->type() == CONSTANT) {
1469 return constant_ptr(log10(x->val));
1472 return make_expression_ptr<Log10Expression<Scalar>>(x);
1475template <
typename Scalar>
1476ExpressionPtr<Scalar> pow(
const ExpressionPtr<Scalar>& base,
1477 const ExpressionPtr<Scalar>& power);
1483template <
typename Scalar, ExpressionType T>
1497 ExpressionType
type()
const override {
return T; }
1499 std::string_view
name()
const override {
return "pow"; }
1544template <
typename Scalar>
1547 using enum ExpressionType;
1551 if (
base->is_constant(Scalar(0))) {
1554 }
else if (base->is_constant(Scalar(1))) {
1558 if (power->is_constant(Scalar(0))) {
1559 return constant_ptr(Scalar(1));
1560 }
else if (power->is_constant(Scalar(1))) {
1565 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1566 return constant_ptr(pow(base->val, power->val));
1569 if (power->is_constant(Scalar(2))) {
1570 if (base->type() == LINEAR) {
1571 return make_expression_ptr<MultExpression<Scalar, QUADRATIC>>(base, base);
1573 return make_expression_ptr<MultExpression<Scalar, NONLINEAR>>(base, base);
1577 return make_expression_ptr<PowExpression<Scalar, NONLINEAR>>(base, power);
1583template <
typename Scalar>
1594 }
else if (x ==
Scalar(0)) {
1601 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1603 std::string_view
name()
const override {
return "sign"; }
1610template <
typename Scalar>
1612 using enum ExpressionType;
1615 if (x->type() == CONSTANT) {
1616 if (x->val < Scalar(0)) {
1617 return constant_ptr(Scalar(-1));
1618 }
else if (x->val == Scalar(0)) {
1622 return constant_ptr(Scalar(1));
1626 return make_expression_ptr<SignExpression<Scalar>>(x);
1632template <
typename Scalar>
1645 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1647 std::string_view
name()
const override {
return "sin"; }
1665template <
typename Scalar>
1667 using enum ExpressionType;
1671 if (x->is_constant(Scalar(0))) {
1677 if (x->type() == CONSTANT) {
1678 return constant_ptr(sin(x->val));
1681 return make_expression_ptr<SinExpression<Scalar>>(x);
1687template <
typename Scalar>
1700 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1702 std::string_view
name()
const override {
return "sinh"; }
1720template <
typename Scalar>
1722 using enum ExpressionType;
1726 if (x->is_constant(Scalar(0))) {
1732 if (x->type() == CONSTANT) {
1733 return constant_ptr(sinh(x->val));
1736 return make_expression_ptr<SinhExpression<Scalar>>(x);
1742template <
typename Scalar>
1755 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1757 std::string_view
name()
const override {
return "sqrt"; }
1775template <
typename Scalar>
1777 using enum ExpressionType;
1781 if (x->type() == CONSTANT) {
1782 if (x->val == Scalar(0)) {
1785 }
else if (x->val == Scalar(1)) {
1788 return constant_ptr(sqrt(x->val));
1792 return make_expression_ptr<SqrtExpression<Scalar>>(x);
1798template <
typename Scalar>
1811 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1813 std::string_view
name()
const override {
return "tan"; }
1834template <
typename Scalar>
1836 using enum ExpressionType;
1840 if (x->is_constant(Scalar(0))) {
1846 if (x->type() == CONSTANT) {
1847 return constant_ptr(tan(x->val));
1850 return make_expression_ptr<TanExpression<Scalar>>(x);
1856template <
typename Scalar>
1869 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1871 std::string_view
name()
const override {
return "tanh"; }
1892template <
typename Scalar>
1894 using enum ExpressionType;
1898 if (x->is_constant(Scalar(0))) {
1904 if (x->type() == CONSTANT) {
1905 return constant_ptr(tanh(x->val));
1908 return make_expression_ptr<TanhExpression<Scalar>>(x);
Definition intrusive_shared_ptr.hpp:27
Definition expression.hpp:778
std::string_view name() const override
Definition expression.hpp:792
constexpr AbsExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:782
ExpressionType type() const override
Definition expression.hpp:790
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:794
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:804
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:785
Definition expression.hpp:844
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:865
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:851
constexpr AcosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:848
ExpressionType type() const override
Definition expression.hpp:856
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:860
std::string_view name() const override
Definition expression.hpp:858
Definition expression.hpp:898
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:914
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:905
std::string_view name() const override
Definition expression.hpp:912
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:919
ExpressionType type() const override
Definition expression.hpp:910
constexpr AsinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:902
Definition expression.hpp:1007
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1033
std::string_view name() const override
Definition expression.hpp:1023
Scalar value(Scalar y, Scalar x) const override
Definition expression.hpp:1016
constexpr Atan2Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1012
Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1029
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1039
ExpressionType type() const override
Definition expression.hpp:1021
Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1025
Definition expression.hpp:953
std::string_view name() const override
Definition expression.hpp:967
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:973
constexpr AtanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:957
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:960
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:969
ExpressionType type() const override
Definition expression.hpp:965
Definition expression.hpp:437
constexpr BinaryMinusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:442
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:460
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:466
std::string_view name() const override
Definition expression.hpp:450
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:456
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:446
ExpressionType type() const override
Definition expression.hpp:448
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:452
Definition expression.hpp:478
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:501
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:487
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:507
constexpr BinaryPlusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:483
ExpressionType type() const override
Definition expression.hpp:489
std::string_view name() const override
Definition expression.hpp:491
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:497
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:493
Definition expression.hpp:518
std::string_view name() const override
Definition expression.hpp:532
constexpr CbrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:522
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:534
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:541
ExpressionType type() const override
Definition expression.hpp:530
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:525
Definition expression.hpp:577
ExpressionType type() const override
Definition expression.hpp:586
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:584
std::string_view name() const override
Definition expression.hpp:588
constexpr ConstantExpression(Scalar value)
Definition expression.hpp:581
Definition expression.hpp:1077
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1084
ExpressionType type() const override
Definition expression.hpp:1089
constexpr CosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1081
std::string_view name() const override
Definition expression.hpp:1091
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1098
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1093
Definition expression.hpp:1131
ExpressionType type() const override
Definition expression.hpp:1143
std::string_view name() const override
Definition expression.hpp:1145
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1138
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1152
constexpr CoshExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1135
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1147
Definition expression.hpp:595
constexpr DecisionVariableExpression()=default
Constructs a decision variable expression with a value of zero.
std::string_view name() const override
Definition expression.hpp:609
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:605
constexpr DecisionVariableExpression(Scalar value)
Definition expression.hpp:602
ExpressionType type() const override
Definition expression.hpp:607
Definition expression.hpp:617
constexpr DivExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:622
ExpressionType type() const override
Definition expression.hpp:627
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:635
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:645
std::string_view name() const override
Definition expression.hpp:629
Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:631
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:625
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:639
Definition expression.hpp:1185
std::string_view name() const override
Definition expression.hpp:1199
constexpr ErfExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1189
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1201
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1192
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1207
ExpressionType type() const override
Definition expression.hpp:1197
Definition expression.hpp:1242
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1258
std::string_view name() const override
Definition expression.hpp:1256
ExpressionType type() const override
Definition expression.hpp:1254
constexpr ExpExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1246
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1263
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1249
Definition expression.hpp:89
Scalar val
The value of the expression node.
Definition expression.hpp:94
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:314
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:103
virtual Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:383
std::array< ExpressionPtr< Scalar >, 2 > args
Expression arguments.
Definition expression.hpp:113
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:110
constexpr Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:126
virtual ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:395
virtual ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:408
Scalar adjoint
The adjoint of the expression node, used during autodiff.
Definition expression.hpp:97
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:235
constexpr bool is_constant(Scalar constant) const
Definition expression.hpp:142
constexpr Expression()=default
Constructs a constant expression with a value of zero.
constexpr Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:133
ExpressionPtr< Scalar > adjoint_expr
Definition expression.hpp:107
Scalar_ Scalar
Scalar type alias.
Definition expression.hpp:91
virtual ExpressionType type() const =0
friend ExpressionPtr< Scalar > operator/(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:200
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:100
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:277
friend ExpressionPtr< Scalar > operator*(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:150
virtual Scalar value(Scalar lhs, Scalar rhs) const =0
virtual std::string_view name() const =0
virtual Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:371
friend ExpressionPtr< Scalar > operator+=(ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:268
constexpr Expression(Scalar value)
Definition expression.hpp:121
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:340
Definition expression.hpp:1300
constexpr HypotExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1305
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1334
ExpressionType type() const override
Definition expression.hpp:1314
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1328
Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1318
std::string_view name() const override
Definition expression.hpp:1316
Scalar value(Scalar x, Scalar y) const override
Definition expression.hpp:1309
Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1323
Definition expression.hpp:1425
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1432
std::string_view name() const override
Definition expression.hpp:1439
constexpr Log10Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1429
ExpressionType type() const override
Definition expression.hpp:1437
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1445
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1441
Definition expression.hpp:1371
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1391
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1387
ExpressionType type() const override
Definition expression.hpp:1383
std::string_view name() const override
Definition expression.hpp:1385
constexpr LogExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1375
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1378
Definition expression.hpp:657
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:665
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:681
Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:671
constexpr MultExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:662
ExpressionType type() const override
Definition expression.hpp:667
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:676
std::string_view name() const override
Definition expression.hpp:669
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:688
Definition expression.hpp:1484
Scalar value(Scalar base, Scalar power) const override
Definition expression.hpp:1492
Scalar grad_l(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1501
ExpressionType type() const override
Definition expression.hpp:1497
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1520
Scalar grad_r(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1507
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1526
std::string_view name() const override
Definition expression.hpp:1499
constexpr PowExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1489
Definition expression.hpp:1584
constexpr SignExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1588
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1591
ExpressionType type() const override
Definition expression.hpp:1601
std::string_view name() const override
Definition expression.hpp:1603
Definition expression.hpp:1633
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1649
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1654
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1640
std::string_view name() const override
Definition expression.hpp:1647
ExpressionType type() const override
Definition expression.hpp:1645
constexpr SinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1637
Definition expression.hpp:1688
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1704
std::string_view name() const override
Definition expression.hpp:1702
ExpressionType type() const override
Definition expression.hpp:1700
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1709
constexpr SinhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1692
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1695
Definition expression.hpp:1743
ExpressionType type() const override
Definition expression.hpp:1755
std::string_view name() const override
Definition expression.hpp:1757
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1764
constexpr SqrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1747
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1759
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1750
Definition expression.hpp:1799
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1822
std::string_view name() const override
Definition expression.hpp:1813
constexpr TanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1803
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1815
ExpressionType type() const override
Definition expression.hpp:1811
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1806
Definition expression.hpp:1857
ExpressionType type() const override
Definition expression.hpp:1869
std::string_view name() const override
Definition expression.hpp:1871
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1873
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1864
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1880
constexpr TanhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1861
Definition expression.hpp:701
constexpr UnaryMinusExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:705
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:718
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:714
ExpressionType type() const override
Definition expression.hpp:710
Scalar value(Scalar lhs, Scalar) const override
Definition expression.hpp:708
std::string_view name() const override
Definition expression.hpp:712