14#include <gch/small_vector.hpp>
16#include "sleipnir/autodiff/expression_type.hpp"
17#include "sleipnir/util/intrusive_shared_ptr.hpp"
18#include "sleipnir/util/pool.hpp"
20namespace slp::detail {
25inline constexpr bool USE_POOL_ALLOCATOR =
false;
27inline constexpr bool USE_POOL_ALLOCATOR =
true;
30template <
typename Scalar>
33template <
typename Scalar>
34constexpr void inc_ref_count(Expression<Scalar>* expr);
35template <
typename Scalar>
36constexpr void dec_ref_count(Expression<Scalar>* expr);
43template <
typename Scalar>
44using ExpressionPtr = IntrusiveSharedPtr<Expression<Scalar>>;
53template <
typename T,
typename... Args>
54static ExpressionPtr<typename T::Scalar> make_expression_ptr(Args&&... args) {
55 if constexpr (USE_POOL_ALLOCATOR) {
56 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
57 std::forward<Args>(args)...);
59 return make_intrusive_shared<T>(std::forward<Args>(args)...);
63template <
typename Scalar, ExpressionType T>
64struct BinaryMinusExpression;
66template <
typename Scalar, ExpressionType T>
67struct BinaryPlusExpression;
69template <
typename Scalar>
70struct ConstExpression;
72template <
typename Scalar, ExpressionType T>
75template <
typename Scalar, ExpressionType T>
78template <
typename Scalar, ExpressionType T>
79struct UnaryMinusExpression;
87template <
typename Scalar>
88ExpressionPtr<Scalar> constant_ptr(Scalar value);
95template <
typename Scalar_>
122 std::array<ExpressionPtr<Scalar>, 2>
args{
nullptr,
nullptr};
174 using enum ExpressionType;
180 }
else if (
rhs->is_constant(
Scalar(0))) {
183 }
else if (
lhs->is_constant(
Scalar(1))) {
185 }
else if (
rhs->is_constant(
Scalar(1))) {
190 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
191 return constant_ptr(
lhs->val *
rhs->val);
195 if (
lhs->type() == CONSTANT) {
196 if (
rhs->type() == LINEAR) {
198 }
else if (
rhs->type() == QUADRATIC) {
203 }
else if (
rhs->type() == CONSTANT) {
204 if (
lhs->type() == LINEAR) {
206 }
else if (
lhs->type() == QUADRATIC) {
211 }
else if (
lhs->type() == LINEAR &&
rhs->type() == LINEAR) {
226 using enum ExpressionType;
232 }
else if (
rhs->is_constant(
Scalar(1))) {
237 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
238 return constant_ptr(
lhs->val /
rhs->val);
242 if (
rhs->type() == CONSTANT) {
243 if (
lhs->type() == LINEAR) {
245 }
else if (
lhs->type() == QUADRATIC) {
263 using enum ExpressionType;
268 }
else if (
rhs ==
nullptr ||
rhs->is_constant(
Scalar(0))) {
273 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
274 return constant_ptr(
lhs->val +
rhs->val);
277 auto type = std::max(
lhs->type(),
rhs->type());
278 if (
type == LINEAR) {
281 }
else if (
type == QUADRATIC) {
309 using enum ExpressionType;
319 }
else if (
rhs->is_constant(
Scalar(0))) {
324 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
325 return constant_ptr(
lhs->val -
rhs->val);
328 auto type = std::max(
lhs->type(),
rhs->type());
329 if (
type == LINEAR) {
332 }
else if (
type == QUADRATIC) {
347 using enum ExpressionType;
356 if (
lhs->type() == CONSTANT) {
357 return constant_ptr(-
lhs->val);
360 if (
lhs->type() == LINEAR) {
362 }
else if (
lhs->type() == QUADRATIC) {
396 virtual ExpressionType
type()
const = 0;
438 return constant_ptr(
Scalar(0));
453 return constant_ptr(
Scalar(0));
457template <
typename Scalar>
462template <
typename Scalar>
463ExpressionPtr<Scalar> cbrt(
const ExpressionPtr<Scalar>& x);
464template <
typename Scalar>
465ExpressionPtr<Scalar> exp(
const ExpressionPtr<Scalar>& x);
466template <
typename Scalar>
467ExpressionPtr<Scalar> sin(
const ExpressionPtr<Scalar>& x);
468template <
typename Scalar>
469ExpressionPtr<Scalar> sinh(
const ExpressionPtr<Scalar>& x);
470template <
typename Scalar>
471ExpressionPtr<Scalar> sqrt(
const ExpressionPtr<Scalar>& x);
479template <
typename Scalar, ExpressionType T>
493 ExpressionType
type()
const override {
return T; }
522template <
typename Scalar, ExpressionType T>
536 ExpressionType
type()
const override {
return T; }
564template <
typename Scalar>
579 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
602template <
typename Scalar>
604 using enum ExpressionType;
608 if (x->type() == CONSTANT) {
609 if (x->val == Scalar(0)) {
612 }
else if (x->val == Scalar(-1) || x->val == Scalar(1)) {
615 return constant_ptr(cbrt(x->val));
619 return make_expression_ptr<CbrtExpression<Scalar>>(x);
627template <
typename Scalar>
639 ExpressionType
type()
const override {
return ExpressionType::CONSTANT; }
647template <
typename Scalar>
664 ExpressionType
type()
const override {
return ExpressionType::LINEAR; }
673template <
typename Scalar, ExpressionType T>
686 ExpressionType
type()
const override {
return T; }
715template <
typename Scalar, ExpressionType T>
728 ExpressionType
type()
const override {
return T; }
761template <
typename Scalar, ExpressionType T>
773 ExpressionType
type()
const override {
return T; }
792template <
typename Scalar>
803template <
typename Scalar>
804constexpr void dec_ref_count(Expression<Scalar>* expr) {
809 gch::small_vector<Expression<Scalar>*> stack;
810 stack.emplace_back(expr);
812 while (!stack.empty()) {
813 auto elem = stack.back();
818 if (--elem->ref_count == 0) {
819 if (elem->adjoint_expr !=
nullptr) {
820 stack.emplace_back(elem->adjoint_expr.get());
822 for (
auto& arg : elem->args) {
823 if (arg !=
nullptr) {
824 stack.emplace_back(arg.get());
830 if constexpr (USE_POOL_ALLOCATOR) {
831 auto alloc = global_pool_allocator<Expression<Scalar>>();
832 std::allocator_traits<
decltype(alloc)>::deallocate(
833 alloc, elem,
sizeof(Expression<Scalar>));
844template <
typename Scalar>
859 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
864 }
else if (x >
Scalar(0)) {
876 }
else if (x->val >
Scalar(0)) {
879 return constant_ptr(
Scalar(0));
890template <
typename Scalar>
892 using enum ExpressionType;
896 if (x->is_constant(Scalar(0))) {
902 if (x->type() == CONSTANT) {
903 return constant_ptr(abs(x->val));
906 return make_expression_ptr<AbsExpression<Scalar>>(x);
914template <
typename Scalar>
929 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
949template <
typename Scalar>
951 using enum ExpressionType;
955 if (x->is_constant(Scalar(0))) {
956 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
960 if (x->type() == CONSTANT) {
961 return constant_ptr(acos(x->val));
964 return make_expression_ptr<AcosExpression<Scalar>>(x);
972template <
typename Scalar>
987 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1007template <
typename Scalar>
1009 using enum ExpressionType;
1013 if (x->is_constant(Scalar(0))) {
1019 if (x->type() == CONSTANT) {
1020 return constant_ptr(asin(x->val));
1023 return make_expression_ptr<AsinExpression<Scalar>>(x);
1031template <
typename Scalar>
1046 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1065template <
typename Scalar>
1067 using enum ExpressionType;
1071 if (x->is_constant(Scalar(0))) {
1077 if (x->type() == CONSTANT) {
1078 return constant_ptr(atan(x->val));
1081 return make_expression_ptr<AtanExpression<Scalar>>(x);
1089template <
typename Scalar>
1106 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1136template <
typename Scalar>
1139 using enum ExpressionType;
1143 if (
y->is_constant(Scalar(0))) {
1146 }
else if (x->is_constant(Scalar(0))) {
1147 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
1151 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1152 return constant_ptr(atan2(y->val, x->val));
1155 return make_expression_ptr<Atan2Expression<Scalar>>(y, x);
1163template <
typename Scalar>
1178 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1198template <
typename Scalar>
1200 using enum ExpressionType;
1204 if (x->is_constant(Scalar(0))) {
1205 return constant_ptr(Scalar(1));
1209 if (x->type() == CONSTANT) {
1210 return constant_ptr(cos(x->val));
1213 return make_expression_ptr<CosExpression<Scalar>>(x);
1221template <
typename Scalar>
1236 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1256template <
typename Scalar>
1258 using enum ExpressionType;
1262 if (x->is_constant(Scalar(0))) {
1263 return constant_ptr(Scalar(1));
1267 if (x->type() == CONSTANT) {
1268 return constant_ptr(cosh(x->val));
1271 return make_expression_ptr<CoshExpression<Scalar>>(x);
1279template <
typename Scalar>
1294 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1306 constant_ptr(
Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
1316template <
typename Scalar>
1318 using enum ExpressionType;
1322 if (x->is_constant(Scalar(0))) {
1328 if (x->type() == CONSTANT) {
1329 return constant_ptr(erf(x->val));
1332 return make_expression_ptr<ErfExpression<Scalar>>(x);
1340template <
typename Scalar>
1355 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1375template <
typename Scalar>
1377 using enum ExpressionType;
1381 if (x->is_constant(Scalar(0))) {
1382 return constant_ptr(Scalar(1));
1386 if (x->type() == CONSTANT) {
1387 return constant_ptr(exp(x->val));
1390 return make_expression_ptr<ExpExpression<Scalar>>(x);
1393template <
typename Scalar>
1394ExpressionPtr<Scalar> hypot(
const ExpressionPtr<Scalar>& x,
1395 const ExpressionPtr<Scalar>& y);
1402template <
typename Scalar>
1419 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1451template <
typename Scalar>
1454 using enum ExpressionType;
1458 if (x->is_constant(Scalar(0))) {
1460 }
else if (y->is_constant(Scalar(0))) {
1465 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1466 return constant_ptr(hypot(x->val, y->val));
1469 return make_expression_ptr<HypotExpression<Scalar>>(x, y);
1477template <
typename Scalar>
1492 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1511template <
typename Scalar>
1513 using enum ExpressionType;
1517 if (x->is_constant(Scalar(0))) {
1523 if (x->type() == CONSTANT) {
1524 return constant_ptr(log(x->val));
1527 return make_expression_ptr<LogExpression<Scalar>>(x);
1535template <
typename Scalar>
1550 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1569template <
typename Scalar>
1571 using enum ExpressionType;
1575 if (x->is_constant(Scalar(0))) {
1581 if (x->type() == CONSTANT) {
1582 return constant_ptr(log10(x->val));
1585 return make_expression_ptr<Log10Expression<Scalar>>(x);
1588template <
typename Scalar>
1589ExpressionPtr<Scalar> pow(
const ExpressionPtr<Scalar>& base,
1590 const ExpressionPtr<Scalar>& power);
1598template <
typename Scalar, ExpressionType T>
1614 ExpressionType
type()
const override {
return T; }
1661template <
typename Scalar>
1664 using enum ExpressionType;
1668 if (
base->is_constant(Scalar(0))) {
1671 }
else if (base->is_constant(Scalar(1))) {
1675 if (power->is_constant(Scalar(0))) {
1676 return constant_ptr(Scalar(1));
1677 }
else if (power->is_constant(Scalar(1))) {
1682 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1683 return constant_ptr(pow(base->val, power->val));
1686 if (power->is_constant(Scalar(2))) {
1687 if (base->type() == LINEAR) {
1688 return make_expression_ptr<MultExpression<Scalar, QUADRATIC>>(base, base);
1690 return make_expression_ptr<MultExpression<Scalar, NONLINEAR>>(base, base);
1694 return make_expression_ptr<PowExpression<Scalar, NONLINEAR>>(base, power);
1702template <
typename Scalar>
1715 }
else if (x ==
Scalar(0)) {
1722 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1731template <
typename Scalar>
1733 using enum ExpressionType;
1736 if (x->type() == CONSTANT) {
1737 if (x->val < Scalar(0)) {
1738 return constant_ptr(Scalar(-1));
1739 }
else if (x->val == Scalar(0)) {
1743 return constant_ptr(Scalar(1));
1747 return make_expression_ptr<SignExpression<Scalar>>(x);
1755template <
typename Scalar>
1770 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1790template <
typename Scalar>
1792 using enum ExpressionType;
1796 if (x->is_constant(Scalar(0))) {
1802 if (x->type() == CONSTANT) {
1803 return constant_ptr(sin(x->val));
1806 return make_expression_ptr<SinExpression<Scalar>>(x);
1814template <
typename Scalar>
1829 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1849template <
typename Scalar>
1851 using enum ExpressionType;
1855 if (x->is_constant(Scalar(0))) {
1861 if (x->type() == CONSTANT) {
1862 return constant_ptr(sinh(x->val));
1865 return make_expression_ptr<SinhExpression<Scalar>>(x);
1873template <
typename Scalar>
1888 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1908template <
typename Scalar>
1910 using enum ExpressionType;
1914 if (x->type() == CONSTANT) {
1915 if (x->val == Scalar(0)) {
1918 }
else if (x->val == Scalar(1)) {
1921 return constant_ptr(sqrt(x->val));
1925 return make_expression_ptr<SqrtExpression<Scalar>>(x);
1933template <
typename Scalar>
1948 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1971template <
typename Scalar>
1973 using enum ExpressionType;
1977 if (x->is_constant(Scalar(0))) {
1983 if (x->type() == CONSTANT) {
1984 return constant_ptr(tan(x->val));
1987 return make_expression_ptr<TanExpression<Scalar>>(x);
1995template <
typename Scalar>
2010 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
2033template <
typename Scalar>
2035 using enum ExpressionType;
2039 if (x->is_constant(Scalar(0))) {
2045 if (x->type() == CONSTANT) {
2046 return constant_ptr(tanh(x->val));
2049 return make_expression_ptr<TanhExpression<Scalar>>(x);
Definition intrusive_shared_ptr.hpp:29
Definition expression.hpp:845
constexpr AbsExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:851
ExpressionType type() const override
Definition expression.hpp:859
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:861
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:871
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:854
Definition expression.hpp:915
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:936
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:924
constexpr AcosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:921
ExpressionType type() const override
Definition expression.hpp:929
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:931
Definition expression.hpp:973
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:989
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:982
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:994
ExpressionType type() const override
Definition expression.hpp:987
constexpr AsinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:979
Definition expression.hpp:1090
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1116
Scalar value(Scalar y, Scalar x) const override
Definition expression.hpp:1101
constexpr Atan2Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1097
Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1112
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1122
ExpressionType type() const override
Definition expression.hpp:1106
Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1108
Definition expression.hpp:1032
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1052
constexpr AtanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1038
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1041
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1048
ExpressionType type() const override
Definition expression.hpp:1046
Definition expression.hpp:480
constexpr BinaryMinusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:487
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:503
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:509
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:499
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:491
ExpressionType type() const override
Definition expression.hpp:493
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:495
Definition expression.hpp:523
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:546
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:534
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:552
constexpr BinaryPlusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:530
ExpressionType type() const override
Definition expression.hpp:536
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:542
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:538
Definition expression.hpp:565
constexpr CbrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:571
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:581
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:588
ExpressionType type() const override
Definition expression.hpp:579
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:574
Definition expression.hpp:628
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:637
ExpressionType type() const override
Definition expression.hpp:639
constexpr ConstExpression(Scalar value)
Definition expression.hpp:634
Definition expression.hpp:1164
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1173
ExpressionType type() const override
Definition expression.hpp:1178
constexpr CosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1170
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1185
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1180
Definition expression.hpp:1222
ExpressionType type() const override
Definition expression.hpp:1236
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1231
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1243
constexpr CoshExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1228
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1238
Definition expression.hpp:648
constexpr DecisionVariableExpression()=default
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:662
constexpr DecisionVariableExpression(Scalar value)
Definition expression.hpp:659
ExpressionType type() const override
Definition expression.hpp:664
Definition expression.hpp:674
constexpr DivExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:681
ExpressionType type() const override
Definition expression.hpp:686
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:692
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:702
Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:688
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:684
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:696
Definition expression.hpp:1280
constexpr ErfExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1286
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1296
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1289
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1302
ExpressionType type() const override
Definition expression.hpp:1294
Definition expression.hpp:1341
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1357
ExpressionType type() const override
Definition expression.hpp:1355
constexpr ExpExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1347
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1362
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1350
Definition expression.hpp:96
Scalar val
The value of the expression node.
Definition expression.hpp:103
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:346
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:112
virtual Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:420
std::array< ExpressionPtr< Scalar >, 2 > args
Expression arguments.
Definition expression.hpp:122
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:119
constexpr Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:141
virtual ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:434
virtual ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:449
Scalar adjoint
The adjoint of the expression node, used during autodiff.
Definition expression.hpp:106
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:261
constexpr bool is_constant(Scalar constant) const
Definition expression.hpp:162
constexpr Expression()=default
constexpr Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:150
ExpressionPtr< Scalar > adjoint_expr
Definition expression.hpp:116
Scalar_ Scalar
Definition expression.hpp:100
virtual ExpressionType type() const =0
friend ExpressionPtr< Scalar > operator/(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:224
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:109
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:307
friend ExpressionPtr< Scalar > operator*(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:172
virtual Scalar value(Scalar lhs, Scalar rhs) const =0
virtual Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:406
friend ExpressionPtr< Scalar > operator+=(ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:296
constexpr Expression(Scalar value)
Definition expression.hpp:134
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:374
Definition expression.hpp:1403
constexpr HypotExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1410
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1437
ExpressionType type() const override
Definition expression.hpp:1419
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1431
Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1421
Scalar value(Scalar x, Scalar y) const override
Definition expression.hpp:1414
Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1426
Definition expression.hpp:1536
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1545
constexpr Log10Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1542
ExpressionType type() const override
Definition expression.hpp:1550
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1556
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1552
Definition expression.hpp:1478
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1498
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1494
ExpressionType type() const override
Definition expression.hpp:1492
constexpr LogExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1484
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1487
Definition expression.hpp:716
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:726
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:740
Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:730
constexpr MultExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:723
ExpressionType type() const override
Definition expression.hpp:728
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:735
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:747
Definition expression.hpp:1599
Scalar value(Scalar base, Scalar power) const override
Definition expression.hpp:1609
Scalar grad_l(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1616
ExpressionType type() const override
Definition expression.hpp:1614
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1635
Scalar grad_r(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1622
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1641
constexpr PowExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1606
Definition expression.hpp:1703
constexpr SignExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1709
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1712
ExpressionType type() const override
Definition expression.hpp:1722
Definition expression.hpp:1756
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1772
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1777
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1765
ExpressionType type() const override
Definition expression.hpp:1770
constexpr SinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1762
Definition expression.hpp:1815
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1831
ExpressionType type() const override
Definition expression.hpp:1829
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1836
constexpr SinhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1821
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1824
Definition expression.hpp:1874
ExpressionType type() const override
Definition expression.hpp:1888
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1895
constexpr SqrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1880
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1890
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1883
Definition expression.hpp:1934
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1957
constexpr TanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1940
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1950
ExpressionType type() const override
Definition expression.hpp:1948
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1943
Definition expression.hpp:1996
ExpressionType type() const override
Definition expression.hpp:2010
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:2012
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:2005
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:2019
constexpr TanhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:2002
Definition expression.hpp:762
constexpr UnaryMinusExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:768
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:779
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:775
ExpressionType type() const override
Definition expression.hpp:773
Scalar value(Scalar lhs, Scalar) const override
Definition expression.hpp:771