45template <
typename T,
typename... Args>
49 std::forward<Args>(args)...);
55template <ExpressionType Type>
56struct BinaryMinusExpression;
58template <ExpressionType Type>
59struct BinaryPlusExpression;
61struct ConstExpression;
63template <ExpressionType Type>
66template <ExpressionType Type>
69template <ExpressionType Type>
70struct UnaryMinusExpression;
98 std::array<ExpressionPtr, 2>
args{
nullptr,
nullptr};
151 if (
lhs->IsConstant(0.0)) {
154 }
else if (
rhs->IsConstant(0.0)) {
157 }
else if (
lhs->IsConstant(1.0)) {
159 }
else if (
rhs->IsConstant(1.0)) {
203 if (
lhs->IsConstant(0.0)) {
206 }
else if (
rhs->IsConstant(1.0)) {
240 if (
lhs ==
nullptr ||
lhs->IsConstant(0.0)) {
242 }
else if (
rhs ==
nullptr ||
rhs->IsConstant(0.0)) {
252 if (
lhs->Type() <
rhs->Type()) {
282 if (
lhs->IsConstant(0.0)) {
283 if (
rhs->IsConstant(0.0)) {
289 }
else if (
rhs->IsConstant(0.0)) {
299 if (
lhs->Type() <
rhs->Type()) {
327 if (
lhs->IsConstant(0.0)) {
425template <ExpressionType T>
461template <ExpressionType T>
510 double Value(
double,
double)
const override {
return value; }
529 double Value(
double,
double)
const override {
return value; }
534template <ExpressionType T>
572template <ExpressionType T>
612template <ExpressionType T>
665 while (!
stack.empty()) {
671 if (--
elem->refCount == 0) {
672 if (
elem->adjointExpr !=
nullptr) {
673 stack.emplace_back(
elem->adjointExpr.Get());
676 if (
arg !=
nullptr) {
703 double Value(
double x,
double)
const override {
return std::abs(x); }
711 }
else if (x > 0.0) {
720 if (x->value < 0.0) {
722 }
else if (x->value > 0.0) {
763 double Value(
double x,
double)
const override {
return std::acos(x); }
811 double Value(
double x,
double)
const override {
return std::asin(x); }
859 double Value(
double x,
double)
const override {
return std::atan(x); }
908 double Value(
double y,
double x)
const override {
return std::atan2(
y, x); }
943 if (
y->IsConstant(0.0)) {
968 double Value(
double x,
double)
const override {
return std::cos(x); }
1014 double Value(
double x,
double)
const override {
return std::cosh(x); }
1060 double Value(
double x,
double)
const override {
return std::erf(x); }
1066 return parentAdjoint * 2.0 * std::numbers::inv_sqrtpi * std::exp(-x * x);
1109 double Value(
double x,
double)
const override {
return std::exp(x); }
1159 double Value(
double x,
double y)
const override {
return std::hypot(x,
y); }
1196 }
else if (
y->IsConstant(0.0)) {
1218 double Value(
double x,
double)
const override {
return std::log(x); }
1265 double Value(
double x,
double)
const override {
return std::log10(x); }
1305template <ExpressionType T>
1319 return std::pow(base,
power);
1352 if (base->value == 0.0) {
1381 if (
power->IsConstant(0.0)) {
1383 }
else if (
power->IsConstant(1.0)) {
1417 double Value(
double x,
double)
const override {
1420 }
else if (x == 0.0) {
1448 if (x->
value < 0.0) {
1450 }
else if (x->
value == 0.0) {
1471 double Value(
double x,
double)
const override {
return std::sin(x); }
1518 double Value(
double x,
double)
const override {
return std::sinh(x); }
1565 double Value(
double x,
double)
const override {
return std::sqrt(x); }
1591 if (x->
value == 0.0) {
1594 }
else if (x->
value == 1.0) {
1614 double Value(
double x,
double)
const override {
return std::tan(x); }
1662 double Value(
double x,
double)
const override {
return std::tanh(x); }
Definition small_vector.hpp:3616
::value &&MoveInsertable constexpr reference emplace_back(Args &&... args)
Definition small_vector.hpp:4071
Definition Expression.hpp:18
ExpressionPtr abs(const ExpressionPtr &x)
Definition Expression.hpp:736
ExpressionPtr log(const ExpressionPtr &x)
Definition Expression.hpp:1238
constexpr void IntrusiveSharedPtrIncRefCount(Expression *expr)
Definition Expression.hpp:648
static ExpressionPtr MakeExpressionPtr(Args &&... args)
Definition Expression.hpp:46
ExpressionPtr exp(const ExpressionPtr &x)
Definition Expression.hpp:1129
ExpressionPtr tanh(const ExpressionPtr &x)
Definition Expression.hpp:1683
IntrusiveSharedPtr< Expression > ExpressionPtr
Definition Expression.hpp:36
ExpressionPtr atan(const ExpressionPtr &x)
Definition Expression.hpp:879
ExpressionPtr sinh(const ExpressionPtr &x)
Definition Expression.hpp:1538
ExpressionPtr log10(const ExpressionPtr &x)
Definition Expression.hpp:1286
ExpressionPtr pow(const ExpressionPtr &base, const ExpressionPtr &power)
Definition Expression.hpp:1370
ExpressionPtr sign(const ExpressionPtr &x)
Definition Expression.hpp:1443
ExpressionPtr atan2(const ExpressionPtr &y, const ExpressionPtr &x)
Definition Expression.hpp:939
ExpressionPtr asin(const ExpressionPtr &x)
Definition Expression.hpp:832
ExpressionPtr erf(const ExpressionPtr &x)
Definition Expression.hpp:1082
constexpr void IntrusiveSharedPtrDecRefCount(Expression *expr)
Definition Expression.hpp:657
constexpr bool kUsePoolAllocator
Definition Expression.hpp:25
ExpressionPtr sin(const ExpressionPtr &x)
Definition Expression.hpp:1491
ExpressionPtr tan(const ExpressionPtr &x)
Definition Expression.hpp:1635
ExpressionPtr acos(const ExpressionPtr &x)
Definition Expression.hpp:785
ExpressionPtr cosh(const ExpressionPtr &x)
Definition Expression.hpp:1034
ExpressionPtr cos(const ExpressionPtr &x)
Definition Expression.hpp:988
ExpressionPtr sqrt(const ExpressionPtr &x)
Definition Expression.hpp:1586
ExpressionPtr hypot(const ExpressionPtr &x, const ExpressionPtr &y)
Definition Expression.hpp:1190
IntrusiveSharedPtr< T > AllocateIntrusiveShared(Alloc alloc, Args &&... args)
Definition IntrusiveSharedPtr.hpp:275
ExpressionType
Definition ExpressionType.hpp:14
@ kConstant
The expression is a constant.
@ kLinear
The expression is composed of linear and lower-order operators.
@ kNonlinear
The expression is composed of nonlinear and lower-order operators.
@ kQuadratic
The expression is composed of quadratic and lower-order operators.
Definition Expression.hpp:692
constexpr AbsExpression(ExpressionPtr lhs)
Definition Expression.hpp:698
ExpressionType Type() const override
Definition Expression.hpp:705
double Value(double x, double) const override
Definition Expression.hpp:703
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:718
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:707
Definition Expression.hpp:753
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:772
AcosExpression(ExpressionPtr lhs)
Definition Expression.hpp:759
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:767
ExpressionType Type() const override
Definition Expression.hpp:765
double Value(double x, double) const override
Definition Expression.hpp:763
Definition Expression.hpp:801
ExpressionType Type() const override
Definition Expression.hpp:813
AsinExpression(ExpressionPtr lhs)
Definition Expression.hpp:807
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:820
double Value(double x, double) const override
Definition Expression.hpp:811
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:815
Definition Expression.hpp:896
ExpressionType Type() const override
Definition Expression.hpp:910
double GradientValueRhs(double y, double x, double parentAdjoint) const override
Definition Expression.hpp:917
Atan2Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition Expression.hpp:903
ExpressionPtr GradientRhs(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:927
ExpressionPtr GradientLhs(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:922
double GradientValueLhs(double y, double x, double parentAdjoint) const override
Definition Expression.hpp:912
double Value(double y, double x) const override
Definition Expression.hpp:908
Definition Expression.hpp:849
ExpressionType Type() const override
Definition Expression.hpp:861
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:868
AtanExpression(ExpressionPtr lhs)
Definition Expression.hpp:855
double Value(double x, double) const override
Definition Expression.hpp:859
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:863
Definition Expression.hpp:426
ExpressionPtr GradientLhs(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:450
ExpressionType Type() const override
Definition Expression.hpp:440
ExpressionPtr GradientRhs(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:455
double GradientValueRhs(double, double, double parentAdjoint) const override
Definition Expression.hpp:446
constexpr BinaryMinusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition Expression.hpp:433
double GradientValueLhs(double, double, double parentAdjoint) const override
Definition Expression.hpp:442
double Value(double lhs, double rhs) const override
Definition Expression.hpp:438
Definition Expression.hpp:462
double GradientValueLhs(double, double, double parentAdjoint) const override
Definition Expression.hpp:478
ExpressionType Type() const override
Definition Expression.hpp:476
ExpressionPtr GradientLhs(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:486
double Value(double lhs, double rhs) const override
Definition Expression.hpp:474
double GradientValueRhs(double, double, double parentAdjoint) const override
Definition Expression.hpp:482
constexpr BinaryPlusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition Expression.hpp:469
ExpressionPtr GradientRhs(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:491
Definition Expression.hpp:497
constexpr ConstExpression()=default
constexpr ConstExpression(double value)
Definition Expression.hpp:508
double Value(double, double) const override
Definition Expression.hpp:510
ExpressionType Type() const override
Definition Expression.hpp:512
Definition Expression.hpp:958
CosExpression(ExpressionPtr lhs)
Definition Expression.hpp:964
double Value(double x, double) const override
Definition Expression.hpp:968
ExpressionType Type() const override
Definition Expression.hpp:970
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:977
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:972
Definition Expression.hpp:1004
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1018
ExpressionType Type() const override
Definition Expression.hpp:1016
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1023
CoshExpression(ExpressionPtr lhs)
Definition Expression.hpp:1010
double Value(double x, double) const override
Definition Expression.hpp:1014
Definition Expression.hpp:515
constexpr DecisionVariableExpression(double value)
Definition Expression.hpp:526
constexpr DecisionVariableExpression()=default
double Value(double, double) const override
Definition Expression.hpp:529
ExpressionType Type() const override
Definition Expression.hpp:531
Definition Expression.hpp:535
ExpressionPtr GradientRhs(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:566
ExpressionType Type() const override
Definition Expression.hpp:549
double GradientValueLhs(double, double rhs, double parentAdjoint) const override
Definition Expression.hpp:551
ExpressionPtr GradientLhs(const ExpressionPtr &, const ExpressionPtr &rhs, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:561
double GradientValueRhs(double lhs, double rhs, double parentAdjoint) const override
Definition Expression.hpp:556
double Value(double lhs, double rhs) const override
Definition Expression.hpp:547
constexpr DivExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition Expression.hpp:542
Definition Expression.hpp:1050
ErfExpression(ExpressionPtr lhs)
Definition Expression.hpp:1056
double Value(double x, double) const override
Definition Expression.hpp:1060
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1069
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1064
ExpressionType Type() const override
Definition Expression.hpp:1062
Definition Expression.hpp:1099
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1118
ExpExpression(ExpressionPtr lhs)
Definition Expression.hpp:1105
double Value(double x, double) const override
Definition Expression.hpp:1109
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1113
ExpressionType Type() const override
Definition Expression.hpp:1111
Definition Expression.hpp:75
virtual ExpressionType Type() const =0
double adjoint
The adjoint of the expression node used during autodiff.
Definition Expression.hpp:80
constexpr Expression(double value)
Definition Expression.hpp:110
virtual ExpressionPtr GradientRhs(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parentAdjoint) const
Definition Expression.hpp:417
uint32_t duplications
Definition Expression.hpp:84
std::array< ExpressionPtr, 2 > args
Expression arguments.
Definition Expression.hpp:98
constexpr Expression()=default
constexpr Expression(ExpressionPtr lhs)
Definition Expression.hpp:117
virtual double GradientValueLhs(double lhs, double rhs, double parentAdjoint) const
Definition Expression.hpp:377
friend ExpressionPtr operator+(const ExpressionPtr &lhs)
Definition Expression.hpp:351
constexpr bool IsConstant(double constant) const
Definition Expression.hpp:136
friend ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition Expression.hpp:146
virtual double GradientValueRhs(double lhs, double rhs, double parentAdjoint) const
Definition Expression.hpp:390
friend ExpressionPtr operator-(const ExpressionPtr &lhs)
Definition Expression.hpp:323
friend ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition Expression.hpp:277
friend ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition Expression.hpp:198
virtual ~Expression()=default
virtual double Value(double lhs, double rhs) const =0
ExpressionPtr adjointExpr
Definition Expression.hpp:92
uint32_t refCount
Reference count for intrusive shared pointer.
Definition Expression.hpp:95
virtual ExpressionPtr GradientLhs(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parentAdjoint) const
Definition Expression.hpp:403
constexpr Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition Expression.hpp:126
double value
The value of the expression node.
Definition Expression.hpp:77
friend ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition Expression.hpp:235
int32_t row
Definition Expression.hpp:88
Definition Expression.hpp:1147
ExpressionPtr GradientRhs(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1178
HypotExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition Expression.hpp:1154
double GradientValueRhs(double x, double y, double parentAdjoint) const override
Definition Expression.hpp:1168
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1173
double GradientValueLhs(double x, double y, double parentAdjoint) const override
Definition Expression.hpp:1163
double Value(double x, double y) const override
Definition Expression.hpp:1159
ExpressionType Type() const override
Definition Expression.hpp:1161
Definition Expression.hpp:1255
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1269
Log10Expression(ExpressionPtr lhs)
Definition Expression.hpp:1261
ExpressionType Type() const override
Definition Expression.hpp:1267
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1274
double Value(double x, double) const override
Definition Expression.hpp:1265
Definition Expression.hpp:1208
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1222
ExpressionType Type() const override
Definition Expression.hpp:1220
LogExpression(ExpressionPtr lhs)
Definition Expression.hpp:1214
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1227
double Value(double x, double) const override
Definition Expression.hpp:1218
Definition Expression.hpp:573
double GradientValueLhs(double lhs, double rhs, double parentAdjoint) const override
Definition Expression.hpp:589
double GradientValueRhs(double lhs, double rhs, double parentAdjoint) const override
Definition Expression.hpp:594
ExpressionPtr GradientRhs(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:605
ExpressionType Type() const override
Definition Expression.hpp:587
ExpressionPtr GradientLhs(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:599
double Value(double lhs, double rhs) const override
Definition Expression.hpp:585
constexpr MultExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition Expression.hpp:580
Definition Expression.hpp:1306
double GradientValueLhs(double base, double power, double parentAdjoint) const override
Definition Expression.hpp:1324
double Value(double base, double power) const override
Definition Expression.hpp:1318
ExpressionType Type() const override
Definition Expression.hpp:1322
ExpressionPtr GradientLhs(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1339
ExpressionPtr GradientRhs(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1348
PowExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition Expression.hpp:1313
double GradientValueRhs(double base, double power, double parentAdjoint) const override
Definition Expression.hpp:1329
Definition Expression.hpp:1400
ExpressionType Type() const override
Definition Expression.hpp:1427
ExpressionPtr GradientLhs(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) const override
Definition Expression.hpp:1431
double Value(double x, double) const override
Definition Expression.hpp:1417
constexpr SignExpression(ExpressionPtr lhs)
Definition Expression.hpp:1406
double GradientValueLhs(double, double, double) const override
Definition Expression.hpp:1429
Definition Expression.hpp:1461
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1480
double Value(double x, double) const override
Definition Expression.hpp:1471
SinExpression(ExpressionPtr lhs)
Definition Expression.hpp:1467
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1475
ExpressionType Type() const override
Definition Expression.hpp:1473
Definition Expression.hpp:1508
SinhExpression(ExpressionPtr lhs)
Definition Expression.hpp:1514
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1522
double Value(double x, double) const override
Definition Expression.hpp:1518
ExpressionType Type() const override
Definition Expression.hpp:1520
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1527
Definition Expression.hpp:1555
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1574
ExpressionType Type() const override
Definition Expression.hpp:1567
double Value(double x, double) const override
Definition Expression.hpp:1565
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1569
SqrtExpression(ExpressionPtr lhs)
Definition Expression.hpp:1561
Definition Expression.hpp:1604
ExpressionType Type() const override
Definition Expression.hpp:1616
TanExpression(ExpressionPtr lhs)
Definition Expression.hpp:1610
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1623
double Value(double x, double) const override
Definition Expression.hpp:1614
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1618
Definition Expression.hpp:1652
TanhExpression(ExpressionPtr lhs)
Definition Expression.hpp:1658
ExpressionType Type() const override
Definition Expression.hpp:1664
ExpressionPtr GradientLhs(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:1671
double Value(double x, double) const override
Definition Expression.hpp:1662
double GradientValueLhs(double x, double, double parentAdjoint) const override
Definition Expression.hpp:1666
Definition Expression.hpp:613
ExpressionPtr GradientLhs(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parentAdjoint) const override
Definition Expression.hpp:632
constexpr UnaryMinusExpression(ExpressionPtr lhs)
Definition Expression.hpp:619
ExpressionType Type() const override
Definition Expression.hpp:626
double Value(double lhs, double) const override
Definition Expression.hpp:624
double GradientValueLhs(double, double, double parentAdjoint) const override
Definition Expression.hpp:628