15#include <gch/small_vector.hpp>
17#include "sleipnir/autodiff/expression_type.hpp"
18#include "sleipnir/util/intrusive_shared_ptr.hpp"
19#include "sleipnir/util/pool.hpp"
21namespace slp::detail {
26inline constexpr bool USE_POOL_ALLOCATOR =
false;
28inline constexpr bool USE_POOL_ALLOCATOR =
true;
31template <
typename Scalar>
34template <
typename Scalar>
35constexpr void inc_ref_count(Expression<Scalar>* expr);
36template <
typename Scalar>
37constexpr void dec_ref_count(Expression<Scalar>* expr);
42template <
typename Scalar>
43using ExpressionPtr = IntrusiveSharedPtr<Expression<Scalar>>;
50template <
typename T,
typename... Args>
51static ExpressionPtr<typename T::Scalar> make_expression_ptr(Args&&... args) {
52 if constexpr (USE_POOL_ALLOCATOR) {
53 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
54 std::forward<Args>(args)...);
56 return make_intrusive_shared<T>(std::forward<Args>(args)...);
60template <
typename Scalar, ExpressionType T>
61struct BinaryMinusExpression;
63template <
typename Scalar, ExpressionType T>
64struct BinaryPlusExpression;
66template <
typename Scalar>
67struct ConstantExpression;
69template <
typename Scalar, ExpressionType T>
72template <
typename Scalar, ExpressionType T>
75template <
typename Scalar, ExpressionType T>
76struct UnaryMinusExpression;
82template <
typename Scalar>
83ExpressionPtr<Scalar> constant_ptr(Scalar value);
88template <
typename Scalar_>
113 std::array<ExpressionPtr<Scalar>, 2>
args{
nullptr,
nullptr};
152 using enum ExpressionType;
158 }
else if (
rhs->is_constant(
Scalar(0))) {
161 }
else if (
lhs->is_constant(
Scalar(1))) {
163 }
else if (
rhs->is_constant(
Scalar(1))) {
168 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
169 return constant_ptr(
lhs->val *
rhs->val);
173 if (
lhs->type() == CONSTANT) {
174 if (
rhs->type() == LINEAR) {
176 }
else if (
rhs->type() == QUADRATIC) {
181 }
else if (
rhs->type() == CONSTANT) {
182 if (
lhs->type() == LINEAR) {
184 }
else if (
lhs->type() == QUADRATIC) {
189 }
else if (
lhs->type() == LINEAR &&
rhs->type() == LINEAR) {
202 using enum ExpressionType;
208 }
else if (
rhs->is_constant(
Scalar(1))) {
213 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
214 return constant_ptr(
lhs->val /
rhs->val);
218 if (
rhs->type() == CONSTANT) {
219 if (
lhs->type() == LINEAR) {
221 }
else if (
lhs->type() == QUADRATIC) {
237 using enum ExpressionType;
242 }
else if (
rhs ==
nullptr ||
rhs->is_constant(
Scalar(0))) {
247 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
248 return constant_ptr(
lhs->val +
rhs->val);
251 auto type = std::max(
lhs->type(),
rhs->type());
252 if (
type == LINEAR) {
255 }
else if (
type == QUADRATIC) {
279 using enum ExpressionType;
289 }
else if (
rhs->is_constant(
Scalar(0))) {
294 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
295 return constant_ptr(
lhs->val -
rhs->val);
298 auto type = std::max(
lhs->type(),
rhs->type());
299 if (
type == LINEAR) {
302 }
else if (
type == QUADRATIC) {
315 using enum ExpressionType;
324 if (
lhs->type() == CONSTANT) {
325 return constant_ptr(-
lhs->val);
328 if (
lhs->type() == LINEAR) {
330 }
else if (
lhs->type() == QUADRATIC) {
358 virtual ExpressionType
type()
const = 0;
363 virtual std::string_view
name()
const = 0;
399 return constant_ptr(
Scalar(0));
412 return constant_ptr(
Scalar(0));
416template <
typename Scalar>
421template <
typename Scalar>
422ExpressionPtr<Scalar> cbrt(
const ExpressionPtr<Scalar>& x);
423template <
typename Scalar>
424ExpressionPtr<Scalar> exp(
const ExpressionPtr<Scalar>& x);
425template <
typename Scalar>
426ExpressionPtr<Scalar> sin(
const ExpressionPtr<Scalar>& x);
427template <
typename Scalar>
428ExpressionPtr<Scalar> sinh(
const ExpressionPtr<Scalar>& x);
429template <
typename Scalar>
430ExpressionPtr<Scalar> sqrt(
const ExpressionPtr<Scalar>& x);
436template <
typename Scalar, ExpressionType T>
448 ExpressionType
type()
const override {
return T; }
450 std::string_view
name()
const override {
return "binary minus"; }
477template <
typename Scalar, ExpressionType T>
489 ExpressionType
type()
const override {
return T; }
491 std::string_view
name()
const override {
return "binary plus"; }
517template <
typename Scalar>
530 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
532 std::string_view
name()
const override {
return "cbrt"; }
553template <
typename Scalar>
555 using enum ExpressionType;
559 if (x->type() == CONSTANT) {
560 if (x->val == Scalar(0)) {
563 }
else if (x->val == Scalar(-1) || x->val == Scalar(1)) {
566 return constant_ptr(cbrt(x->val));
570 return make_expression_ptr<CbrtExpression<Scalar>>(x);
576template <
typename Scalar>
586 ExpressionType
type()
const override {
return ExpressionType::CONSTANT; }
588 std::string_view
name()
const override {
return "constant"; }
594template <
typename Scalar>
607 ExpressionType
type()
const override {
return ExpressionType::LINEAR; }
609 std::string_view
name()
const override {
return "decision variable"; }
616template <
typename Scalar, ExpressionType T>
627 ExpressionType
type()
const override {
return T; }
629 std::string_view
name()
const override {
return "division"; }
656template <
typename Scalar, ExpressionType T>
667 ExpressionType
type()
const override {
return T; }
669 std::string_view
name()
const override {
return "multiplication"; }
700template <
typename Scalar, ExpressionType T>
710 ExpressionType
type()
const override {
return T; }
712 std::string_view
name()
const override {
return "unary minus"; }
729template <
typename Scalar>
738template <
typename Scalar>
739constexpr void dec_ref_count(Expression<Scalar>* expr) {
744 gch::small_vector<Expression<Scalar>*> stack;
745 stack.emplace_back(expr);
747 while (!stack.empty()) {
748 auto elem = stack.back();
753 if (--elem->ref_count == 0) {
754 if (elem->adjoint_expr !=
nullptr) {
755 stack.emplace_back(elem->adjoint_expr.get());
757 for (
auto& arg : elem->args) {
758 if (arg !=
nullptr) {
759 stack.emplace_back(arg.get());
765 if constexpr (USE_POOL_ALLOCATOR) {
766 auto alloc = global_pool_allocator<Expression<Scalar>>();
767 std::allocator_traits<
decltype(alloc)>::deallocate(
768 alloc, elem,
sizeof(Expression<Scalar>));
777template <
typename Scalar>
790 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
792 std::string_view
name()
const override {
return "abs"; }
797 }
else if (x >
Scalar(0)) {
809 }
else if (x->val >
Scalar(0)) {
812 return constant_ptr(
Scalar(0));
821template <
typename Scalar>
823 using enum ExpressionType;
827 if (x->is_constant(Scalar(0))) {
833 if (x->type() == CONSTANT) {
834 return constant_ptr(abs(x->val));
837 return make_expression_ptr<AbsExpression<Scalar>>(x);
843template <
typename Scalar>
856 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
858 std::string_view
name()
const override {
return "acos"; }
876template <
typename Scalar>
878 using enum ExpressionType;
882 if (x->is_constant(Scalar(0))) {
883 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
887 if (x->type() == CONSTANT) {
888 return constant_ptr(acos(x->val));
891 return make_expression_ptr<AcosExpression<Scalar>>(x);
897template <
typename Scalar>
910 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
912 std::string_view
name()
const override {
return "asin"; }
930template <
typename Scalar>
932 using enum ExpressionType;
936 if (x->is_constant(Scalar(0))) {
942 if (x->type() == CONSTANT) {
943 return constant_ptr(asin(x->val));
946 return make_expression_ptr<AsinExpression<Scalar>>(x);
952template <
typename Scalar>
965 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
967 std::string_view
name()
const override {
return "atan"; }
984template <
typename Scalar>
986 using enum ExpressionType;
990 if (x->is_constant(Scalar(0))) {
996 if (x->type() == CONSTANT) {
997 return constant_ptr(atan(x->val));
1000 return make_expression_ptr<AtanExpression<Scalar>>(x);
1006template <
typename Scalar>
1021 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1023 std::string_view
name()
const override {
return "atan2"; }
1051template <
typename Scalar>
1054 using enum ExpressionType;
1058 if (
y->is_constant(Scalar(0))) {
1061 }
else if (x->is_constant(Scalar(0))) {
1062 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
1066 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1067 return constant_ptr(atan2(y->val, x->val));
1070 return make_expression_ptr<Atan2Expression<Scalar>>(y, x);
1076template <
typename Scalar>
1089 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1091 std::string_view
name()
const override {
return "cos"; }
1109template <
typename Scalar>
1111 using enum ExpressionType;
1115 if (x->is_constant(Scalar(0))) {
1116 return constant_ptr(Scalar(1));
1120 if (x->type() == CONSTANT) {
1121 return constant_ptr(cos(x->val));
1124 return make_expression_ptr<CosExpression<Scalar>>(x);
1130template <
typename Scalar>
1143 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1145 std::string_view
name()
const override {
return "cosh"; }
1163template <
typename Scalar>
1165 using enum ExpressionType;
1169 if (x->is_constant(Scalar(0))) {
1170 return constant_ptr(Scalar(1));
1174 if (x->type() == CONSTANT) {
1175 return constant_ptr(cosh(x->val));
1178 return make_expression_ptr<CoshExpression<Scalar>>(x);
1184template <
typename Scalar>
1197 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1199 std::string_view
name()
const override {
return "erf"; }
1211 constant_ptr(
Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
1219template <
typename Scalar>
1221 using enum ExpressionType;
1225 if (x->is_constant(Scalar(0))) {
1231 if (x->type() == CONSTANT) {
1232 return constant_ptr(erf(x->val));
1235 return make_expression_ptr<ErfExpression<Scalar>>(x);
1241template <
typename Scalar>
1254 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1256 std::string_view
name()
const override {
return "exp"; }
1274template <
typename Scalar>
1276 using enum ExpressionType;
1280 if (x->is_constant(Scalar(0))) {
1281 return constant_ptr(Scalar(1));
1285 if (x->type() == CONSTANT) {
1286 return constant_ptr(exp(x->val));
1289 return make_expression_ptr<ExpExpression<Scalar>>(x);
1292template <
typename Scalar>
1293ExpressionPtr<Scalar> hypot(
const ExpressionPtr<Scalar>& x,
1294 const ExpressionPtr<Scalar>& y);
1299template <
typename Scalar>
1314 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1316 std::string_view
name()
const override {
return "hypot"; }
1346template <
typename Scalar>
1349 using enum ExpressionType;
1353 if (x->is_constant(Scalar(0))) {
1355 }
else if (y->is_constant(Scalar(0))) {
1360 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1361 return constant_ptr(hypot(x->val, y->val));
1364 return make_expression_ptr<HypotExpression<Scalar>>(x, y);
1370template <
typename Scalar>
1383 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1385 std::string_view
name()
const override {
return "log"; }
1402template <
typename Scalar>
1404 using enum ExpressionType;
1408 if (x->is_constant(Scalar(0))) {
1414 if (x->type() == CONSTANT) {
1415 return constant_ptr(log(x->val));
1418 return make_expression_ptr<LogExpression<Scalar>>(x);
1424template <
typename Scalar>
1437 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1439 std::string_view
name()
const override {
return "log10"; }
1456template <
typename Scalar>
1458 using enum ExpressionType;
1462 if (x->is_constant(Scalar(0))) {
1468 if (x->type() == CONSTANT) {
1469 return constant_ptr(log10(x->val));
1472 return make_expression_ptr<Log10Expression<Scalar>>(x);
1480template <
typename Scalar>
1494 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1496 std::string_view
name()
const override {
return "max"; }
1517 if (
a->val >=
b->val) {
1520 return constant_ptr(
Scalar(0));
1527 if (
b->val >
a->val) {
1530 return constant_ptr(
Scalar(0));
1540template <
typename Scalar>
1543 using enum ExpressionType;
1547 if (
a->type() == CONSTANT &&
b->type() == CONSTANT) {
1548 return constant_ptr(max(
a->val,
b->val));
1551 return make_expression_ptr<MaxExpression<Scalar>>(a, b);
1559template <
typename Scalar>
1573 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1575 std::string_view
name()
const override {
return "min"; }
1597 if (
a->val <=
b->val) {
1600 return constant_ptr(
Scalar(0));
1607 if (
b->val <
a->val) {
1610 return constant_ptr(
Scalar(0));
1620template <
typename Scalar>
1623 using enum ExpressionType;
1627 if (
a->type() == CONSTANT &&
b->type() == CONSTANT) {
1628 return constant_ptr(min(
a->val,
b->val));
1631 return make_expression_ptr<MinExpression<Scalar>>(a, b);
1634template <
typename Scalar>
1635ExpressionPtr<Scalar> pow(
const ExpressionPtr<Scalar>& base,
1636 const ExpressionPtr<Scalar>& power);
1642template <
typename Scalar, ExpressionType T>
1656 ExpressionType
type()
const override {
return T; }
1658 std::string_view
name()
const override {
return "pow"; }
1703template <
typename Scalar>
1706 using enum ExpressionType;
1710 if (
base->is_constant(Scalar(0))) {
1713 }
else if (base->is_constant(Scalar(1))) {
1717 if (power->is_constant(Scalar(0))) {
1718 return constant_ptr(Scalar(1));
1719 }
else if (power->is_constant(Scalar(1))) {
1724 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1725 return constant_ptr(pow(base->val, power->val));
1728 if (power->is_constant(Scalar(2))) {
1729 if (base->type() == LINEAR) {
1730 return make_expression_ptr<MultExpression<Scalar, QUADRATIC>>(base, base);
1732 return make_expression_ptr<MultExpression<Scalar, NONLINEAR>>(base, base);
1736 return make_expression_ptr<PowExpression<Scalar, NONLINEAR>>(base, power);
1742template <
typename Scalar>
1753 }
else if (x ==
Scalar(0)) {
1760 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1762 std::string_view
name()
const override {
return "sign"; }
1769template <
typename Scalar>
1771 using enum ExpressionType;
1774 if (x->type() == CONSTANT) {
1775 if (x->val < Scalar(0)) {
1776 return constant_ptr(Scalar(-1));
1777 }
else if (x->val == Scalar(0)) {
1781 return constant_ptr(Scalar(1));
1785 return make_expression_ptr<SignExpression<Scalar>>(x);
1791template <
typename Scalar>
1804 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1806 std::string_view
name()
const override {
return "sin"; }
1824template <
typename Scalar>
1826 using enum ExpressionType;
1830 if (x->is_constant(Scalar(0))) {
1836 if (x->type() == CONSTANT) {
1837 return constant_ptr(sin(x->val));
1840 return make_expression_ptr<SinExpression<Scalar>>(x);
1846template <
typename Scalar>
1859 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1861 std::string_view
name()
const override {
return "sinh"; }
1879template <
typename Scalar>
1881 using enum ExpressionType;
1885 if (x->is_constant(Scalar(0))) {
1891 if (x->type() == CONSTANT) {
1892 return constant_ptr(sinh(x->val));
1895 return make_expression_ptr<SinhExpression<Scalar>>(x);
1901template <
typename Scalar>
1914 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1916 std::string_view
name()
const override {
return "sqrt"; }
1934template <
typename Scalar>
1936 using enum ExpressionType;
1940 if (x->type() == CONSTANT) {
1941 if (x->val == Scalar(0)) {
1944 }
else if (x->val == Scalar(1)) {
1947 return constant_ptr(sqrt(x->val));
1951 return make_expression_ptr<SqrtExpression<Scalar>>(x);
1957template <
typename Scalar>
1970 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1972 std::string_view
name()
const override {
return "tan"; }
1993template <
typename Scalar>
1995 using enum ExpressionType;
1999 if (x->is_constant(Scalar(0))) {
2005 if (x->type() == CONSTANT) {
2006 return constant_ptr(tan(x->val));
2009 return make_expression_ptr<TanExpression<Scalar>>(x);
2015template <
typename Scalar>
2028 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
2030 std::string_view
name()
const override {
return "tanh"; }
2051template <
typename Scalar>
2053 using enum ExpressionType;
2057 if (x->is_constant(Scalar(0))) {
2063 if (x->type() == CONSTANT) {
2064 return constant_ptr(tanh(x->val));
2067 return make_expression_ptr<TanhExpression<Scalar>>(x);
Definition intrusive_shared_ptr.hpp:27
Definition expression.hpp:778
std::string_view name() const override
Definition expression.hpp:792
constexpr AbsExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:782
ExpressionType type() const override
Definition expression.hpp:790
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:794
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:804
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:785
Definition expression.hpp:844
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:865
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:851
constexpr AcosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:848
ExpressionType type() const override
Definition expression.hpp:856
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:860
std::string_view name() const override
Definition expression.hpp:858
Definition expression.hpp:898
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:914
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:905
std::string_view name() const override
Definition expression.hpp:912
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:919
ExpressionType type() const override
Definition expression.hpp:910
constexpr AsinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:902
Definition expression.hpp:1007
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1033
std::string_view name() const override
Definition expression.hpp:1023
Scalar value(Scalar y, Scalar x) const override
Definition expression.hpp:1016
constexpr Atan2Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1012
Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1029
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1039
ExpressionType type() const override
Definition expression.hpp:1021
Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1025
Definition expression.hpp:953
std::string_view name() const override
Definition expression.hpp:967
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:973
constexpr AtanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:957
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:960
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:969
ExpressionType type() const override
Definition expression.hpp:965
Definition expression.hpp:437
constexpr BinaryMinusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:442
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:460
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:466
std::string_view name() const override
Definition expression.hpp:450
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:456
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:446
ExpressionType type() const override
Definition expression.hpp:448
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:452
Definition expression.hpp:478
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:501
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:487
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:507
constexpr BinaryPlusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:483
ExpressionType type() const override
Definition expression.hpp:489
std::string_view name() const override
Definition expression.hpp:491
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:497
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:493
Definition expression.hpp:518
std::string_view name() const override
Definition expression.hpp:532
constexpr CbrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:522
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:534
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:541
ExpressionType type() const override
Definition expression.hpp:530
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:525
Definition expression.hpp:577
ExpressionType type() const override
Definition expression.hpp:586
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:584
std::string_view name() const override
Definition expression.hpp:588
constexpr ConstantExpression(Scalar value)
Definition expression.hpp:581
Definition expression.hpp:1077
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1084
ExpressionType type() const override
Definition expression.hpp:1089
constexpr CosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1081
std::string_view name() const override
Definition expression.hpp:1091
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1098
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1093
Definition expression.hpp:1131
ExpressionType type() const override
Definition expression.hpp:1143
std::string_view name() const override
Definition expression.hpp:1145
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1138
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1152
constexpr CoshExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1135
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1147
Definition expression.hpp:595
constexpr DecisionVariableExpression()=default
Constructs a decision variable expression with a value of zero.
std::string_view name() const override
Definition expression.hpp:609
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:605
constexpr DecisionVariableExpression(Scalar value)
Definition expression.hpp:602
ExpressionType type() const override
Definition expression.hpp:607
Definition expression.hpp:617
constexpr DivExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:622
ExpressionType type() const override
Definition expression.hpp:627
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:635
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:645
std::string_view name() const override
Definition expression.hpp:629
Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:631
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:625
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:639
Definition expression.hpp:1185
std::string_view name() const override
Definition expression.hpp:1199
constexpr ErfExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1189
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1201
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1192
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1207
ExpressionType type() const override
Definition expression.hpp:1197
Definition expression.hpp:1242
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1258
std::string_view name() const override
Definition expression.hpp:1256
ExpressionType type() const override
Definition expression.hpp:1254
constexpr ExpExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1246
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1263
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1249
Definition expression.hpp:89
Scalar val
The value of the expression node.
Definition expression.hpp:94
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:314
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:103
virtual Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:383
std::array< ExpressionPtr< Scalar >, 2 > args
Expression arguments.
Definition expression.hpp:113
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:110
constexpr Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:126
virtual ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:395
virtual ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:408
Scalar adjoint
The adjoint of the expression node, used during autodiff.
Definition expression.hpp:97
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:235
constexpr bool is_constant(Scalar constant) const
Definition expression.hpp:142
constexpr Expression()=default
Constructs a constant expression with a value of zero.
constexpr Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:133
ExpressionPtr< Scalar > adjoint_expr
Definition expression.hpp:107
Scalar_ Scalar
Scalar type alias.
Definition expression.hpp:91
virtual ExpressionType type() const =0
friend ExpressionPtr< Scalar > operator/(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:200
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:100
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:277
friend ExpressionPtr< Scalar > operator*(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:150
virtual Scalar value(Scalar lhs, Scalar rhs) const =0
virtual std::string_view name() const =0
virtual Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:371
friend ExpressionPtr< Scalar > operator+=(ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:268
constexpr Expression(Scalar value)
Definition expression.hpp:121
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:340
Definition expression.hpp:1300
constexpr HypotExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1305
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1334
ExpressionType type() const override
Definition expression.hpp:1314
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1328
Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1318
std::string_view name() const override
Definition expression.hpp:1316
Scalar value(Scalar x, Scalar y) const override
Definition expression.hpp:1309
Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1323
Definition expression.hpp:1425
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1432
std::string_view name() const override
Definition expression.hpp:1439
constexpr Log10Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1429
ExpressionType type() const override
Definition expression.hpp:1437
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1445
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1441
Definition expression.hpp:1371
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1391
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1387
ExpressionType type() const override
Definition expression.hpp:1383
std::string_view name() const override
Definition expression.hpp:1385
constexpr LogExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1375
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1378
Definition expression.hpp:1481
Scalar grad_r(Scalar a, Scalar b, Scalar parent_adjoint) const override
Definition expression.hpp:1506
Scalar grad_l(Scalar a, Scalar b, Scalar parent_adjoint) const override
Definition expression.hpp:1498
Scalar value(Scalar a, Scalar b) const override
Definition expression.hpp:1489
constexpr MaxExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1486
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1514
ExpressionType type() const override
Definition expression.hpp:1494
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1524
std::string_view name() const override
Definition expression.hpp:1496
Definition expression.hpp:1560
std::string_view name() const override
Definition expression.hpp:1575
Scalar value(Scalar a, Scalar b) const override
Definition expression.hpp:1568
ExpressionType type() const override
Definition expression.hpp:1573
Scalar grad_r(Scalar a, Scalar b, Scalar parent_adjoint) const override
Definition expression.hpp:1585
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1604
constexpr MinExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1565
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1594
Scalar grad_l(Scalar a, Scalar b, Scalar parent_adjoint) const override
Definition expression.hpp:1577
Definition expression.hpp:657
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:665
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:681
Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:671
constexpr MultExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:662
ExpressionType type() const override
Definition expression.hpp:667
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:676
std::string_view name() const override
Definition expression.hpp:669
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:688
Definition expression.hpp:1643
Scalar value(Scalar base, Scalar power) const override
Definition expression.hpp:1651
Scalar grad_l(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1660
ExpressionType type() const override
Definition expression.hpp:1656
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1679
Scalar grad_r(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1666
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1685
std::string_view name() const override
Definition expression.hpp:1658
constexpr PowExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1648
Definition expression.hpp:1743
constexpr SignExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1747
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1750
ExpressionType type() const override
Definition expression.hpp:1760
std::string_view name() const override
Definition expression.hpp:1762
Definition expression.hpp:1792
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1808
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1813
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1799
std::string_view name() const override
Definition expression.hpp:1806
ExpressionType type() const override
Definition expression.hpp:1804
constexpr SinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1796
Definition expression.hpp:1847
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1863
std::string_view name() const override
Definition expression.hpp:1861
ExpressionType type() const override
Definition expression.hpp:1859
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1868
constexpr SinhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1851
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1854
Definition expression.hpp:1902
ExpressionType type() const override
Definition expression.hpp:1914
std::string_view name() const override
Definition expression.hpp:1916
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1923
constexpr SqrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1906
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1918
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1909
Definition expression.hpp:1958
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1981
std::string_view name() const override
Definition expression.hpp:1972
constexpr TanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1962
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1974
ExpressionType type() const override
Definition expression.hpp:1970
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1965
Definition expression.hpp:2016
ExpressionType type() const override
Definition expression.hpp:2028
std::string_view name() const override
Definition expression.hpp:2030
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:2032
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:2023
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:2039
constexpr TanhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:2020
Definition expression.hpp:701
constexpr UnaryMinusExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:705
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:718
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:714
ExpressionType type() const override
Definition expression.hpp:710
Scalar value(Scalar lhs, Scalar) const override
Definition expression.hpp:708
std::string_view name() const override
Definition expression.hpp:712