Sleipnir C++ API
Loading...
Searching...
No Matches
expression.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <stdint.h>
6
7#include <algorithm>
8#include <array>
9#include <cmath>
10#include <memory>
11#include <numbers>
12#include <string_view>
13#include <utility>
14
15#include <gch/small_vector.hpp>
16
17#include "sleipnir/autodiff/expression_type.hpp"
18#include "sleipnir/util/intrusive_shared_ptr.hpp"
19#include "sleipnir/util/pool.hpp"
20
21namespace slp::detail {
22
23// The global pool allocator uses a thread-local static pool resource, which
24// isn't guaranteed to be initialized properly across DLL boundaries on Windows
25#ifdef _WIN32
26inline constexpr bool USE_POOL_ALLOCATOR = false;
27#else
28inline constexpr bool USE_POOL_ALLOCATOR = true;
29#endif
30
31template <typename Scalar>
32struct Expression;
33
34template <typename Scalar>
35constexpr void inc_ref_count(Expression<Scalar>* expr);
36template <typename Scalar>
37constexpr void dec_ref_count(Expression<Scalar>* expr);
38
42template <typename Scalar>
43using ExpressionPtr = IntrusiveSharedPtr<Expression<Scalar>>;
44
50template <typename T, typename... Args>
51static ExpressionPtr<typename T::Scalar> make_expression_ptr(Args&&... args) {
52 if constexpr (USE_POOL_ALLOCATOR) {
53 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
54 std::forward<Args>(args)...);
55 } else {
56 return make_intrusive_shared<T>(std::forward<Args>(args)...);
57 }
58}
59
60template <typename Scalar, ExpressionType T>
61struct BinaryMinusExpression;
62
63template <typename Scalar, ExpressionType T>
64struct BinaryPlusExpression;
65
66template <typename Scalar>
67struct ConstantExpression;
68
69template <typename Scalar, ExpressionType T>
70struct DivExpression;
71
72template <typename Scalar, ExpressionType T>
73struct MultExpression;
74
75template <typename Scalar, ExpressionType T>
76struct UnaryMinusExpression;
77
82template <typename Scalar>
83ExpressionPtr<Scalar> constant_ptr(Scalar value);
84
88template <typename Scalar_>
89struct Expression {
91 using Scalar = Scalar_;
92
95
98
101
104
108
111
113 std::array<ExpressionPtr<Scalar>, 2> args{nullptr, nullptr};
114
116 constexpr Expression() = default;
117
121 explicit constexpr Expression(Scalar value) : val{value} {}
122
127 : args{std::move(lhs), nullptr} {}
128
135
136 virtual ~Expression() = default;
137
142 constexpr bool is_constant(Scalar constant) const {
143 return type() == ExpressionType::CONSTANT && val == constant;
144 }
145
151 const ExpressionPtr<Scalar>& rhs) {
152 using enum ExpressionType;
153
154 // Prune expression
155 if (lhs->is_constant(Scalar(0))) {
156 // Return zero
157 return lhs;
158 } else if (rhs->is_constant(Scalar(0))) {
159 // Return zero
160 return rhs;
161 } else if (lhs->is_constant(Scalar(1))) {
162 return rhs;
163 } else if (rhs->is_constant(Scalar(1))) {
164 return lhs;
165 }
166
167 // Evaluate constant
168 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
169 return constant_ptr(lhs->val * rhs->val);
170 }
171
172 // Evaluate expression type
173 if (lhs->type() == CONSTANT) {
174 if (rhs->type() == LINEAR) {
176 } else if (rhs->type() == QUADRATIC) {
178 } else {
180 }
181 } else if (rhs->type() == CONSTANT) {
182 if (lhs->type() == LINEAR) {
184 } else if (lhs->type() == QUADRATIC) {
186 } else {
188 }
189 } else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
191 } else {
193 }
194 }
195
201 const ExpressionPtr<Scalar>& rhs) {
202 using enum ExpressionType;
203
204 // Prune expression
205 if (lhs->is_constant(Scalar(0))) {
206 // Return zero
207 return lhs;
208 } else if (rhs->is_constant(Scalar(1))) {
209 return lhs;
210 }
211
212 // Evaluate constant
213 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
214 return constant_ptr(lhs->val / rhs->val);
215 }
216
217 // Evaluate expression type
218 if (rhs->type() == CONSTANT) {
219 if (lhs->type() == LINEAR) {
221 } else if (lhs->type() == QUADRATIC) {
223 } else {
225 }
226 } else {
228 }
229 }
230
236 const ExpressionPtr<Scalar>& rhs) {
237 using enum ExpressionType;
238
239 // Prune expression
240 if (lhs == nullptr || lhs->is_constant(Scalar(0))) {
241 return rhs;
242 } else if (rhs == nullptr || rhs->is_constant(Scalar(0))) {
243 return lhs;
244 }
245
246 // Evaluate constant
247 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
248 return constant_ptr(lhs->val + rhs->val);
249 }
250
251 auto type = std::max(lhs->type(), rhs->type());
252 if (type == LINEAR) {
254 rhs);
255 } else if (type == QUADRATIC) {
257 rhs);
258 } else {
260 rhs);
261 }
262 }
263
272
278 const ExpressionPtr<Scalar>& rhs) {
279 using enum ExpressionType;
280
281 // Prune expression
282 if (lhs->is_constant(Scalar(0))) {
283 if (rhs->is_constant(Scalar(0))) {
284 // Return zero
285 return rhs;
286 } else {
287 return -rhs;
288 }
289 } else if (rhs->is_constant(Scalar(0))) {
290 return lhs;
291 }
292
293 // Evaluate constant
294 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
295 return constant_ptr(lhs->val - rhs->val);
296 }
297
298 auto type = std::max(lhs->type(), rhs->type());
299 if (type == LINEAR) {
301 rhs);
302 } else if (type == QUADRATIC) {
304 rhs);
305 } else {
307 rhs);
308 }
309 }
310
315 using enum ExpressionType;
316
317 // Prune expression
318 if (lhs->is_constant(Scalar(0))) {
319 // Return zero
320 return lhs;
321 }
322
323 // Evaluate constant
324 if (lhs->type() == CONSTANT) {
325 return constant_ptr(-lhs->val);
326 }
327
328 if (lhs->type() == LINEAR) {
330 } else if (lhs->type() == QUADRATIC) {
332 } else {
334 }
335 }
336
341 return lhs;
342 }
343
352 [[maybe_unused]] Scalar rhs) const = 0;
353
358 virtual ExpressionType type() const = 0;
359
363 virtual std::string_view name() const = 0;
364
374 return Scalar(0);
375 }
376
386 return Scalar(0);
387 }
388
399 return constant_ptr(Scalar(0));
400 }
401
412 return constant_ptr(Scalar(0));
413 }
414};
415
416template <typename Scalar>
417ExpressionPtr<Scalar> constant_ptr(Scalar value) {
419}
420
421template <typename Scalar>
422ExpressionPtr<Scalar> cbrt(const ExpressionPtr<Scalar>& x);
423template <typename Scalar>
424ExpressionPtr<Scalar> exp(const ExpressionPtr<Scalar>& x);
425template <typename Scalar>
426ExpressionPtr<Scalar> sin(const ExpressionPtr<Scalar>& x);
427template <typename Scalar>
428ExpressionPtr<Scalar> sinh(const ExpressionPtr<Scalar>& x);
429template <typename Scalar>
430ExpressionPtr<Scalar> sqrt(const ExpressionPtr<Scalar>& x);
431
436template <typename Scalar, ExpressionType T>
445
446 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs - rhs; }
447
448 ExpressionType type() const override { return T; }
449
450 std::string_view name() const override { return "binary minus"; }
451
453 return parent_adjoint;
454 }
455
457 return -parent_adjoint;
458 }
459
465
471};
472
477template <typename Scalar, ExpressionType T>
486
487 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs + rhs; }
488
489 ExpressionType type() const override { return T; }
490
491 std::string_view name() const override { return "binary plus"; }
492
494 return parent_adjoint;
495 }
496
498 return parent_adjoint;
499 }
500
506
512};
513
517template <typename Scalar>
523 : Expression<Scalar>{std::move(lhs)} {}
524
525 Scalar value(Scalar x, Scalar) const override {
526 using std::cbrt;
527 return cbrt(x);
528 }
529
530 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
531
532 std::string_view name() const override { return "cbrt"; }
533
535 using std::cbrt;
536
537 Scalar c = cbrt(x);
538 return parent_adjoint / (Scalar(3) * c * c);
539 }
540
543 const ExpressionPtr<Scalar>& parent_adjoint) const override {
544 auto c = cbrt(x);
545 return parent_adjoint / (constant_ptr(Scalar(3)) * c * c);
546 }
547};
548
553template <typename Scalar>
555 using enum ExpressionType;
556 using std::cbrt;
557
558 // Evaluate constant
559 if (x->type() == CONSTANT) {
560 if (x->val == Scalar(0)) {
561 // Return zero
562 return x;
563 } else if (x->val == Scalar(-1) || x->val == Scalar(1)) {
564 return x;
565 } else {
566 return constant_ptr(cbrt(x->val));
567 }
568 }
569
570 return make_expression_ptr<CbrtExpression<Scalar>>(x);
571}
572
576template <typename Scalar>
581 explicit constexpr ConstantExpression(Scalar value)
582 : Expression<Scalar>{value} {}
583
584 Scalar value(Scalar, Scalar) const override { return this->val; }
585
586 ExpressionType type() const override { return ExpressionType::CONSTANT; }
587
588 std::string_view name() const override { return "constant"; }
589};
590
594template <typename Scalar>
597 constexpr DecisionVariableExpression() = default;
598
604
605 Scalar value(Scalar, Scalar) const override { return this->val; }
606
607 ExpressionType type() const override { return ExpressionType::LINEAR; }
608
609 std::string_view name() const override { return "decision variable"; }
610};
611
616template <typename Scalar, ExpressionType T>
624
625 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs / rhs; }
626
627 ExpressionType type() const override { return T; }
628
629 std::string_view name() const override { return "division"; }
630
632 return parent_adjoint / rhs;
633 };
634
636 return parent_adjoint * -lhs / (rhs * rhs);
637 }
638
644
650};
651
656template <typename Scalar, ExpressionType T>
664
665 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs * rhs; }
666
667 ExpressionType type() const override { return T; }
668
669 std::string_view name() const override { return "multiplication"; }
670
672 Scalar parent_adjoint) const override {
673 return parent_adjoint * rhs;
674 }
675
677 Scalar parent_adjoint) const override {
678 return parent_adjoint * lhs;
679 }
680
687
694};
695
700template <typename Scalar, ExpressionType T>
706 : Expression<Scalar>{std::move(lhs)} {}
707
708 Scalar value(Scalar lhs, Scalar) const override { return -lhs; }
709
710 ExpressionType type() const override { return T; }
711
712 std::string_view name() const override { return "unary minus"; }
713
715 return -parent_adjoint;
716 }
717
723};
724
729template <typename Scalar>
730constexpr void inc_ref_count(Expression<Scalar>* expr) {
731 ++expr->ref_count;
732}
733
738template <typename Scalar>
739constexpr void dec_ref_count(Expression<Scalar>* expr) {
740 // If a deeply nested tree is being deallocated all at once, calling the
741 // Expression destructor when expr's refcount reaches zero can cause a stack
742 // overflow. Instead, we iterate over its children to decrement their
743 // refcounts and deallocate them.
744 gch::small_vector<Expression<Scalar>*> stack;
745 stack.emplace_back(expr);
746
747 while (!stack.empty()) {
748 auto elem = stack.back();
749 stack.pop_back();
750
751 // Decrement the current node's refcount. If it reaches zero, deallocate the
752 // node and enqueue its children so their refcounts are decremented too.
753 if (--elem->ref_count == 0) {
754 if (elem->adjoint_expr != nullptr) {
755 stack.emplace_back(elem->adjoint_expr.get());
756 }
757 for (auto& arg : elem->args) {
758 if (arg != nullptr) {
759 stack.emplace_back(arg.get());
760 }
761 }
762
763 // Not calling the destructor here is safe because it only decrements
764 // refcounts, which was already done above.
765 if constexpr (USE_POOL_ALLOCATOR) {
766 auto alloc = global_pool_allocator<Expression<Scalar>>();
767 std::allocator_traits<decltype(alloc)>::deallocate(
768 alloc, elem, sizeof(Expression<Scalar>));
769 }
770 }
771 }
772}
773
777template <typename Scalar>
783 : Expression<Scalar>{std::move(lhs)} {}
784
785 Scalar value(Scalar x, Scalar) const override {
786 using std::abs;
787 return abs(x);
788 }
789
790 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
791
792 std::string_view name() const override { return "abs"; }
793
795 if (x < Scalar(0)) {
796 return -parent_adjoint;
797 } else if (x > Scalar(0)) {
798 return parent_adjoint;
799 } else {
800 return Scalar(0);
801 }
802 }
803
806 const ExpressionPtr<Scalar>& parent_adjoint) const override {
807 if (x->val < Scalar(0)) {
808 return -parent_adjoint;
809 } else if (x->val > Scalar(0)) {
810 return parent_adjoint;
811 } else {
812 return constant_ptr(Scalar(0));
813 }
814 }
815};
816
821template <typename Scalar>
823 using enum ExpressionType;
824 using std::abs;
825
826 // Prune expression
827 if (x->is_constant(Scalar(0))) {
828 // Return zero
829 return x;
830 }
831
832 // Evaluate constant
833 if (x->type() == CONSTANT) {
834 return constant_ptr(abs(x->val));
835 }
836
837 return make_expression_ptr<AbsExpression<Scalar>>(x);
838}
839
843template <typename Scalar>
849 : Expression<Scalar>{std::move(lhs)} {}
850
851 Scalar value(Scalar x, Scalar) const override {
852 using std::acos;
853 return acos(x);
854 }
855
856 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
857
858 std::string_view name() const override { return "acos"; }
859
861 using std::sqrt;
862 return -parent_adjoint / sqrt(Scalar(1) - x * x);
863 }
864
867 const ExpressionPtr<Scalar>& parent_adjoint) const override {
868 return -parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
869 }
870};
871
876template <typename Scalar>
878 using enum ExpressionType;
879 using std::acos;
880
881 // Prune expression
882 if (x->is_constant(Scalar(0))) {
883 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
884 }
885
886 // Evaluate constant
887 if (x->type() == CONSTANT) {
888 return constant_ptr(acos(x->val));
889 }
890
891 return make_expression_ptr<AcosExpression<Scalar>>(x);
892}
893
897template <typename Scalar>
903 : Expression<Scalar>{std::move(lhs)} {}
904
905 Scalar value(Scalar x, Scalar) const override {
906 using std::asin;
907 return asin(x);
908 }
909
910 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
911
912 std::string_view name() const override { return "asin"; }
913
915 using std::sqrt;
916 return parent_adjoint / sqrt(Scalar(1) - x * x);
917 }
918
921 const ExpressionPtr<Scalar>& parent_adjoint) const override {
922 return parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
923 }
924};
925
930template <typename Scalar>
932 using enum ExpressionType;
933 using std::asin;
934
935 // Prune expression
936 if (x->is_constant(Scalar(0))) {
937 // Return zero
938 return x;
939 }
940
941 // Evaluate constant
942 if (x->type() == CONSTANT) {
943 return constant_ptr(asin(x->val));
944 }
945
946 return make_expression_ptr<AsinExpression<Scalar>>(x);
947}
948
952template <typename Scalar>
958 : Expression<Scalar>{std::move(lhs)} {}
959
960 Scalar value(Scalar x, Scalar) const override {
961 using std::atan;
962 return atan(x);
963 }
964
965 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
966
967 std::string_view name() const override { return "atan"; }
968
970 return parent_adjoint / (Scalar(1) + x * x);
971 }
972
975 const ExpressionPtr<Scalar>& parent_adjoint) const override {
976 return parent_adjoint / (constant_ptr(Scalar(1)) + x * x);
977 }
978};
979
984template <typename Scalar>
986 using enum ExpressionType;
987 using std::atan;
988
989 // Prune expression
990 if (x->is_constant(Scalar(0))) {
991 // Return zero
992 return x;
993 }
994
995 // Evaluate constant
996 if (x->type() == CONSTANT) {
997 return constant_ptr(atan(x->val));
998 }
999
1000 return make_expression_ptr<AtanExpression<Scalar>>(x);
1001}
1002
1006template <typename Scalar>
1014 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
1015
1016 Scalar value(Scalar y, Scalar x) const override {
1017 using std::atan2;
1018 return atan2(y, x);
1019 }
1020
1021 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1022
1023 std::string_view name() const override { return "atan2"; }
1024
1026 return parent_adjoint * x / (y * y + x * x);
1027 }
1028
1030 return parent_adjoint * -y / (y * y + x * x);
1031 }
1032
1035 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1036 return parent_adjoint * x / (y * y + x * x);
1037 }
1038
1041 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1042 return parent_adjoint * -y / (y * y + x * x);
1043 }
1044};
1045
1051template <typename Scalar>
1053 const ExpressionPtr<Scalar>& x) {
1054 using enum ExpressionType;
1055 using std::atan2;
1056
1057 // Prune expression
1058 if (y->is_constant(Scalar(0))) {
1059 // Return zero
1060 return y;
1061 } else if (x->is_constant(Scalar(0))) {
1062 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
1063 }
1064
1065 // Evaluate constant
1066 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1067 return constant_ptr(atan2(y->val, x->val));
1068 }
1069
1070 return make_expression_ptr<Atan2Expression<Scalar>>(y, x);
1071}
1072
1076template <typename Scalar>
1082 : Expression<Scalar>{std::move(lhs)} {}
1083
1084 Scalar value(Scalar x, Scalar) const override {
1085 using std::cos;
1086 return cos(x);
1087 }
1088
1089 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1090
1091 std::string_view name() const override { return "cos"; }
1092
1094 using std::sin;
1095 return parent_adjoint * -sin(x);
1096 }
1097
1100 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1101 return parent_adjoint * -sin(x);
1102 }
1103};
1104
1109template <typename Scalar>
1111 using enum ExpressionType;
1112 using std::cos;
1113
1114 // Prune expression
1115 if (x->is_constant(Scalar(0))) {
1116 return constant_ptr(Scalar(1));
1117 }
1118
1119 // Evaluate constant
1120 if (x->type() == CONSTANT) {
1121 return constant_ptr(cos(x->val));
1122 }
1123
1124 return make_expression_ptr<CosExpression<Scalar>>(x);
1125}
1126
1130template <typename Scalar>
1136 : Expression<Scalar>{std::move(lhs)} {}
1137
1138 Scalar value(Scalar x, Scalar) const override {
1139 using std::cosh;
1140 return cosh(x);
1141 }
1142
1143 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1144
1145 std::string_view name() const override { return "cosh"; }
1146
1148 using std::sinh;
1149 return parent_adjoint * sinh(x);
1150 }
1151
1154 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1155 return parent_adjoint * sinh(x);
1156 }
1157};
1158
1163template <typename Scalar>
1165 using enum ExpressionType;
1166 using std::cosh;
1167
1168 // Prune expression
1169 if (x->is_constant(Scalar(0))) {
1170 return constant_ptr(Scalar(1));
1171 }
1172
1173 // Evaluate constant
1174 if (x->type() == CONSTANT) {
1175 return constant_ptr(cosh(x->val));
1176 }
1177
1178 return make_expression_ptr<CoshExpression<Scalar>>(x);
1179}
1180
1184template <typename Scalar>
1190 : Expression<Scalar>{std::move(lhs)} {}
1191
1192 Scalar value(Scalar x, Scalar) const override {
1193 using std::erf;
1194 return erf(x);
1195 }
1196
1197 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1198
1199 std::string_view name() const override { return "erf"; }
1200
1202 using std::exp;
1203 return parent_adjoint * Scalar(2.0 * std::numbers::inv_sqrtpi) *
1204 exp(-x * x);
1205 }
1206
1209 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1210 return parent_adjoint *
1211 constant_ptr(Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
1212 }
1213};
1214
1219template <typename Scalar>
1221 using enum ExpressionType;
1222 using std::erf;
1223
1224 // Prune expression
1225 if (x->is_constant(Scalar(0))) {
1226 // Return zero
1227 return x;
1228 }
1229
1230 // Evaluate constant
1231 if (x->type() == CONSTANT) {
1232 return constant_ptr(erf(x->val));
1233 }
1234
1235 return make_expression_ptr<ErfExpression<Scalar>>(x);
1236}
1237
1241template <typename Scalar>
1247 : Expression<Scalar>{std::move(lhs)} {}
1248
1249 Scalar value(Scalar x, Scalar) const override {
1250 using std::exp;
1251 return exp(x);
1252 }
1253
1254 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1255
1256 std::string_view name() const override { return "exp"; }
1257
1259 using std::exp;
1260 return parent_adjoint * exp(x);
1261 }
1262
1265 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1266 return parent_adjoint * exp(x);
1267 }
1268};
1269
1274template <typename Scalar>
1276 using enum ExpressionType;
1277 using std::exp;
1278
1279 // Prune expression
1280 if (x->is_constant(Scalar(0))) {
1281 return constant_ptr(Scalar(1));
1282 }
1283
1284 // Evaluate constant
1285 if (x->type() == CONSTANT) {
1286 return constant_ptr(exp(x->val));
1287 }
1288
1289 return make_expression_ptr<ExpExpression<Scalar>>(x);
1290}
1291
1292template <typename Scalar>
1293ExpressionPtr<Scalar> hypot(const ExpressionPtr<Scalar>& x,
1294 const ExpressionPtr<Scalar>& y);
1295
1299template <typename Scalar>
1307 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
1308
1309 Scalar value(Scalar x, Scalar y) const override {
1310 using std::hypot;
1311 return hypot(x, y);
1312 }
1313
1314 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1315
1316 std::string_view name() const override { return "hypot"; }
1317
1319 using std::hypot;
1320 return parent_adjoint * x / hypot(x, y);
1321 }
1322
1324 using std::hypot;
1325 return parent_adjoint * y / hypot(x, y);
1326 }
1327
1330 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1331 return parent_adjoint * x / hypot(x, y);
1332 }
1333
1336 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1337 return parent_adjoint * y / hypot(x, y);
1338 }
1339};
1340
1346template <typename Scalar>
1348 const ExpressionPtr<Scalar>& y) {
1349 using enum ExpressionType;
1350 using std::hypot;
1351
1352 // Prune expression
1353 if (x->is_constant(Scalar(0))) {
1354 return y;
1355 } else if (y->is_constant(Scalar(0))) {
1356 return x;
1357 }
1358
1359 // Evaluate constant
1360 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1361 return constant_ptr(hypot(x->val, y->val));
1362 }
1363
1364 return make_expression_ptr<HypotExpression<Scalar>>(x, y);
1365}
1366
1370template <typename Scalar>
1376 : Expression<Scalar>{std::move(lhs)} {}
1377
1378 Scalar value(Scalar x, Scalar) const override {
1379 using std::log;
1380 return log(x);
1381 }
1382
1383 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1384
1385 std::string_view name() const override { return "log"; }
1386
1388 return parent_adjoint / x;
1389 }
1390
1393 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1394 return parent_adjoint / x;
1395 }
1396};
1397
1402template <typename Scalar>
1404 using enum ExpressionType;
1405 using std::log;
1406
1407 // Prune expression
1408 if (x->is_constant(Scalar(0))) {
1409 // Return zero
1410 return x;
1411 }
1412
1413 // Evaluate constant
1414 if (x->type() == CONSTANT) {
1415 return constant_ptr(log(x->val));
1416 }
1417
1418 return make_expression_ptr<LogExpression<Scalar>>(x);
1419}
1420
1424template <typename Scalar>
1430 : Expression<Scalar>{std::move(lhs)} {}
1431
1432 Scalar value(Scalar x, Scalar) const override {
1433 using std::log10;
1434 return log10(x);
1435 }
1436
1437 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1438
1439 std::string_view name() const override { return "log10"; }
1440
1442 return parent_adjoint / (Scalar(std::numbers::ln10) * x);
1443 }
1444
1447 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1448 return parent_adjoint / (constant_ptr(Scalar(std::numbers::ln10)) * x);
1449 }
1450};
1451
1456template <typename Scalar>
1458 using enum ExpressionType;
1459 using std::log10;
1460
1461 // Prune expression
1462 if (x->is_constant(Scalar(0))) {
1463 // Return zero
1464 return x;
1465 }
1466
1467 // Evaluate constant
1468 if (x->type() == CONSTANT) {
1469 return constant_ptr(log10(x->val));
1470 }
1471
1472 return make_expression_ptr<Log10Expression<Scalar>>(x);
1473}
1474
1480template <typename Scalar>
1488
1489 Scalar value(Scalar a, Scalar b) const override {
1490 using std::max;
1491 return max(a, b);
1492 }
1493
1494 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1495
1496 std::string_view name() const override { return "max"; }
1497
1499 if (a >= b) {
1500 return parent_adjoint;
1501 } else {
1502 return Scalar(0);
1503 }
1504 }
1505
1507 if (b > a) {
1508 return parent_adjoint;
1509 } else {
1510 return Scalar(0);
1511 }
1512 }
1513
1516 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1517 if (a->val >= b->val) {
1518 return parent_adjoint;
1519 } else {
1520 return constant_ptr(Scalar(0));
1521 }
1522 }
1523
1526 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1527 if (b->val > a->val) {
1528 return parent_adjoint;
1529 } else {
1530 return constant_ptr(Scalar(0));
1531 }
1532 }
1533};
1534
1540template <typename Scalar>
1542 const ExpressionPtr<Scalar>& b) {
1543 using enum ExpressionType;
1544 using std::max;
1545
1546 // Evaluate constant
1547 if (a->type() == CONSTANT && b->type() == CONSTANT) {
1548 return constant_ptr(max(a->val, b->val));
1549 }
1550
1551 return make_expression_ptr<MaxExpression<Scalar>>(a, b);
1552}
1553
1559template <typename Scalar>
1567
1568 Scalar value(Scalar a, Scalar b) const override {
1569 using std::min;
1570 return min(a, b);
1571 }
1572
1573 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1574
1575 std::string_view name() const override { return "min"; }
1576
1578 if (a <= b) {
1579 return parent_adjoint;
1580 } else {
1581 return Scalar(0);
1582 }
1583 }
1584
1586 Scalar parent_adjoint) const override {
1587 if (b < a) {
1588 return parent_adjoint;
1589 } else {
1590 return Scalar(0);
1591 }
1592 }
1593
1596 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1597 if (a->val <= b->val) {
1598 return parent_adjoint;
1599 } else {
1600 return constant_ptr(Scalar(0));
1601 }
1602 }
1603
1606 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1607 if (b->val < a->val) {
1608 return parent_adjoint;
1609 } else {
1610 return constant_ptr(Scalar(0));
1611 }
1612 }
1613};
1614
1620template <typename Scalar>
1622 const ExpressionPtr<Scalar>& b) {
1623 using enum ExpressionType;
1624 using std::min;
1625
1626 // Evaluate constant
1627 if (a->type() == CONSTANT && b->type() == CONSTANT) {
1628 return constant_ptr(min(a->val, b->val));
1629 }
1630
1631 return make_expression_ptr<MinExpression<Scalar>>(a, b);
1632}
1633
1634template <typename Scalar>
1635ExpressionPtr<Scalar> pow(const ExpressionPtr<Scalar>& base,
1636 const ExpressionPtr<Scalar>& power);
1637
1642template <typename Scalar, ExpressionType T>
1650
1651 Scalar value(Scalar base, Scalar power) const override {
1652 using std::pow;
1653 return pow(base, power);
1654 }
1655
1656 ExpressionType type() const override { return T; }
1657
1658 std::string_view name() const override { return "pow"; }
1659
1661 Scalar parent_adjoint) const override {
1662 using std::pow;
1663 return parent_adjoint * pow(base, power - Scalar(1)) * power;
1664 }
1665
1667 Scalar parent_adjoint) const override {
1668 using std::log;
1669 using std::pow;
1670
1671 // Since x log(x) -> 0 as x -> 0
1672 if (base == Scalar(0)) {
1673 return Scalar(0);
1674 } else {
1675 return parent_adjoint * pow(base, power) * log(base);
1676 }
1677 }
1678
1681 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1682 return parent_adjoint * pow(base, power - constant_ptr(Scalar(1))) * power;
1683 }
1684
1687 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1688 // Since x log(x) -> 0 as x -> 0
1689 if (base->val == Scalar(0)) {
1690 // Return zero
1691 return base;
1692 } else {
1693 return parent_adjoint * pow(base, power) * log(base);
1694 }
1695 }
1696};
1697
1703template <typename Scalar>
1706 using enum ExpressionType;
1707 using std::pow;
1708
1709 // Prune expression
1710 if (base->is_constant(Scalar(0))) {
1711 // Return zero
1712 return base;
1713 } else if (base->is_constant(Scalar(1))) {
1714 // Return one
1715 return base;
1716 }
1717 if (power->is_constant(Scalar(0))) {
1718 return constant_ptr(Scalar(1));
1719 } else if (power->is_constant(Scalar(1))) {
1720 return base;
1721 }
1722
1723 // Evaluate constant
1724 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1725 return constant_ptr(pow(base->val, power->val));
1726 }
1727
1728 if (power->is_constant(Scalar(2))) {
1729 if (base->type() == LINEAR) {
1730 return make_expression_ptr<MultExpression<Scalar, QUADRATIC>>(base, base);
1731 } else {
1732 return make_expression_ptr<MultExpression<Scalar, NONLINEAR>>(base, base);
1733 }
1734 }
1735
1736 return make_expression_ptr<PowExpression<Scalar, NONLINEAR>>(base, power);
1737}
1738
1742template <typename Scalar>
1748 : Expression<Scalar>{std::move(lhs)} {}
1749
1750 Scalar value(Scalar x, Scalar) const override {
1751 if (x < Scalar(0)) {
1752 return Scalar(-1);
1753 } else if (x == Scalar(0)) {
1754 return Scalar(0);
1755 } else {
1756 return Scalar(1);
1757 }
1758 }
1759
1760 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1761
1762 std::string_view name() const override { return "sign"; }
1763};
1764
1769template <typename Scalar>
1771 using enum ExpressionType;
1772
1773 // Evaluate constant
1774 if (x->type() == CONSTANT) {
1775 if (x->val < Scalar(0)) {
1776 return constant_ptr(Scalar(-1));
1777 } else if (x->val == Scalar(0)) {
1778 // Return zero
1779 return x;
1780 } else {
1781 return constant_ptr(Scalar(1));
1782 }
1783 }
1784
1785 return make_expression_ptr<SignExpression<Scalar>>(x);
1786}
1787
1791template <typename Scalar>
1797 : Expression<Scalar>{std::move(lhs)} {}
1798
1799 Scalar value(Scalar x, Scalar) const override {
1800 using std::sin;
1801 return sin(x);
1802 }
1803
1804 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1805
1806 std::string_view name() const override { return "sin"; }
1807
1809 using std::cos;
1810 return parent_adjoint * cos(x);
1811 }
1812
1815 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1816 return parent_adjoint * cos(x);
1817 }
1818};
1819
1824template <typename Scalar>
1826 using enum ExpressionType;
1827 using std::sin;
1828
1829 // Prune expression
1830 if (x->is_constant(Scalar(0))) {
1831 // Return zero
1832 return x;
1833 }
1834
1835 // Evaluate constant
1836 if (x->type() == CONSTANT) {
1837 return constant_ptr(sin(x->val));
1838 }
1839
1840 return make_expression_ptr<SinExpression<Scalar>>(x);
1841}
1842
1846template <typename Scalar>
1852 : Expression<Scalar>{std::move(lhs)} {}
1853
1854 Scalar value(Scalar x, Scalar) const override {
1855 using std::sinh;
1856 return sinh(x);
1857 }
1858
1859 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1860
1861 std::string_view name() const override { return "sinh"; }
1862
1864 using std::cosh;
1865 return parent_adjoint * cosh(x);
1866 }
1867
1870 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1871 return parent_adjoint * cosh(x);
1872 }
1873};
1874
1879template <typename Scalar>
1881 using enum ExpressionType;
1882 using std::sinh;
1883
1884 // Prune expression
1885 if (x->is_constant(Scalar(0))) {
1886 // Return zero
1887 return x;
1888 }
1889
1890 // Evaluate constant
1891 if (x->type() == CONSTANT) {
1892 return constant_ptr(sinh(x->val));
1893 }
1894
1895 return make_expression_ptr<SinhExpression<Scalar>>(x);
1896}
1897
1901template <typename Scalar>
1907 : Expression<Scalar>{std::move(lhs)} {}
1908
1909 Scalar value(Scalar x, Scalar) const override {
1910 using std::sqrt;
1911 return sqrt(x);
1912 }
1913
1914 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1915
1916 std::string_view name() const override { return "sqrt"; }
1917
1919 using std::sqrt;
1920 return parent_adjoint / (Scalar(2) * sqrt(x));
1921 }
1922
1925 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1926 return parent_adjoint / (constant_ptr(Scalar(2)) * sqrt(x));
1927 }
1928};
1929
1934template <typename Scalar>
1936 using enum ExpressionType;
1937 using std::sqrt;
1938
1939 // Evaluate constant
1940 if (x->type() == CONSTANT) {
1941 if (x->val == Scalar(0)) {
1942 // Return zero
1943 return x;
1944 } else if (x->val == Scalar(1)) {
1945 return x;
1946 } else {
1947 return constant_ptr(sqrt(x->val));
1948 }
1949 }
1950
1951 return make_expression_ptr<SqrtExpression<Scalar>>(x);
1952}
1953
1957template <typename Scalar>
1963 : Expression<Scalar>{std::move(lhs)} {}
1964
1965 Scalar value(Scalar x, Scalar) const override {
1966 using std::tan;
1967 return tan(x);
1968 }
1969
1970 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1971
1972 std::string_view name() const override { return "tan"; }
1973
1975 using std::cos;
1976
1977 auto c = cos(x);
1978 return parent_adjoint / (c * c);
1979 }
1980
1983 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1984 auto c = cos(x);
1985 return parent_adjoint / (c * c);
1986 }
1987};
1988
1993template <typename Scalar>
1995 using enum ExpressionType;
1996 using std::tan;
1997
1998 // Prune expression
1999 if (x->is_constant(Scalar(0))) {
2000 // Return zero
2001 return x;
2002 }
2003
2004 // Evaluate constant
2005 if (x->type() == CONSTANT) {
2006 return constant_ptr(tan(x->val));
2007 }
2008
2009 return make_expression_ptr<TanExpression<Scalar>>(x);
2010}
2011
2015template <typename Scalar>
2021 : Expression<Scalar>{std::move(lhs)} {}
2022
2023 Scalar value(Scalar x, Scalar) const override {
2024 using std::tanh;
2025 return tanh(x);
2026 }
2027
2028 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
2029
2030 std::string_view name() const override { return "tanh"; }
2031
2033 using std::cosh;
2034
2035 auto c = cosh(x);
2036 return parent_adjoint / (c * c);
2037 }
2038
2041 const ExpressionPtr<Scalar>& parent_adjoint) const override {
2042 auto c = cosh(x);
2043 return parent_adjoint / (c * c);
2044 }
2045};
2046
2051template <typename Scalar>
2053 using enum ExpressionType;
2054 using std::tanh;
2055
2056 // Prune expression
2057 if (x->is_constant(Scalar(0))) {
2058 // Return zero
2059 return x;
2060 }
2061
2062 // Evaluate constant
2063 if (x->type() == CONSTANT) {
2064 return constant_ptr(tanh(x->val));
2065 }
2066
2067 return make_expression_ptr<TanhExpression<Scalar>>(x);
2068}
2069
2070} // namespace slp::detail
Definition intrusive_shared_ptr.hpp:27
Definition expression.hpp:778
std::string_view name() const override
Definition expression.hpp:792
constexpr AbsExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:782
ExpressionType type() const override
Definition expression.hpp:790
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:794
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:804
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:785
Definition expression.hpp:844
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:865
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:851
constexpr AcosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:848
ExpressionType type() const override
Definition expression.hpp:856
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:860
std::string_view name() const override
Definition expression.hpp:858
Definition expression.hpp:898
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:914
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:905
std::string_view name() const override
Definition expression.hpp:912
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:919
ExpressionType type() const override
Definition expression.hpp:910
constexpr AsinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:902
Definition expression.hpp:1007
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1033
std::string_view name() const override
Definition expression.hpp:1023
Scalar value(Scalar y, Scalar x) const override
Definition expression.hpp:1016
constexpr Atan2Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1012
Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1029
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1039
ExpressionType type() const override
Definition expression.hpp:1021
Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1025
Definition expression.hpp:953
std::string_view name() const override
Definition expression.hpp:967
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:973
constexpr AtanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:957
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:960
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:969
ExpressionType type() const override
Definition expression.hpp:965
Definition expression.hpp:437
constexpr BinaryMinusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:442
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:460
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:466
std::string_view name() const override
Definition expression.hpp:450
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:456
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:446
ExpressionType type() const override
Definition expression.hpp:448
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:452
Definition expression.hpp:478
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:501
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:487
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:507
constexpr BinaryPlusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:483
ExpressionType type() const override
Definition expression.hpp:489
std::string_view name() const override
Definition expression.hpp:491
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:497
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:493
Definition expression.hpp:518
std::string_view name() const override
Definition expression.hpp:532
constexpr CbrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:522
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:534
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:541
ExpressionType type() const override
Definition expression.hpp:530
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:525
Definition expression.hpp:577
ExpressionType type() const override
Definition expression.hpp:586
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:584
std::string_view name() const override
Definition expression.hpp:588
constexpr ConstantExpression(Scalar value)
Definition expression.hpp:581
Definition expression.hpp:1077
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1084
ExpressionType type() const override
Definition expression.hpp:1089
constexpr CosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1081
std::string_view name() const override
Definition expression.hpp:1091
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1098
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1093
Definition expression.hpp:1131
ExpressionType type() const override
Definition expression.hpp:1143
std::string_view name() const override
Definition expression.hpp:1145
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1138
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1152
constexpr CoshExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1135
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1147
Definition expression.hpp:595
constexpr DecisionVariableExpression()=default
Constructs a decision variable expression with a value of zero.
std::string_view name() const override
Definition expression.hpp:609
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:605
constexpr DecisionVariableExpression(Scalar value)
Definition expression.hpp:602
ExpressionType type() const override
Definition expression.hpp:607
Definition expression.hpp:617
constexpr DivExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:622
ExpressionType type() const override
Definition expression.hpp:627
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:635
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:645
std::string_view name() const override
Definition expression.hpp:629
Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:631
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:625
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:639
Definition expression.hpp:1185
std::string_view name() const override
Definition expression.hpp:1199
constexpr ErfExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1189
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1201
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1192
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1207
ExpressionType type() const override
Definition expression.hpp:1197
Definition expression.hpp:1242
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1258
std::string_view name() const override
Definition expression.hpp:1256
ExpressionType type() const override
Definition expression.hpp:1254
constexpr ExpExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1246
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1263
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1249
Definition expression.hpp:89
Scalar val
The value of the expression node.
Definition expression.hpp:94
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:314
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:103
virtual Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:383
std::array< ExpressionPtr< Scalar >, 2 > args
Expression arguments.
Definition expression.hpp:113
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:110
constexpr Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:126
virtual ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:395
virtual ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:408
Scalar adjoint
The adjoint of the expression node, used during autodiff.
Definition expression.hpp:97
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:235
constexpr bool is_constant(Scalar constant) const
Definition expression.hpp:142
constexpr Expression()=default
Constructs a constant expression with a value of zero.
constexpr Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:133
ExpressionPtr< Scalar > adjoint_expr
Definition expression.hpp:107
Scalar_ Scalar
Scalar type alias.
Definition expression.hpp:91
virtual ExpressionType type() const =0
friend ExpressionPtr< Scalar > operator/(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:200
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:100
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:277
friend ExpressionPtr< Scalar > operator*(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:150
virtual Scalar value(Scalar lhs, Scalar rhs) const =0
virtual std::string_view name() const =0
virtual Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:371
friend ExpressionPtr< Scalar > operator+=(ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:268
constexpr Expression(Scalar value)
Definition expression.hpp:121
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:340
Definition expression.hpp:1300
constexpr HypotExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1305
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1334
ExpressionType type() const override
Definition expression.hpp:1314
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1328
Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1318
std::string_view name() const override
Definition expression.hpp:1316
Scalar value(Scalar x, Scalar y) const override
Definition expression.hpp:1309
Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1323
Definition expression.hpp:1425
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1432
std::string_view name() const override
Definition expression.hpp:1439
constexpr Log10Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1429
ExpressionType type() const override
Definition expression.hpp:1437
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1445
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1441
Definition expression.hpp:1371
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1391
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1387
ExpressionType type() const override
Definition expression.hpp:1383
std::string_view name() const override
Definition expression.hpp:1385
constexpr LogExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1375
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1378
Definition expression.hpp:1481
Scalar grad_r(Scalar a, Scalar b, Scalar parent_adjoint) const override
Definition expression.hpp:1506
Scalar grad_l(Scalar a, Scalar b, Scalar parent_adjoint) const override
Definition expression.hpp:1498
Scalar value(Scalar a, Scalar b) const override
Definition expression.hpp:1489
constexpr MaxExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1486
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1514
ExpressionType type() const override
Definition expression.hpp:1494
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1524
std::string_view name() const override
Definition expression.hpp:1496
Definition expression.hpp:1560
std::string_view name() const override
Definition expression.hpp:1575
Scalar value(Scalar a, Scalar b) const override
Definition expression.hpp:1568
ExpressionType type() const override
Definition expression.hpp:1573
Scalar grad_r(Scalar a, Scalar b, Scalar parent_adjoint) const override
Definition expression.hpp:1585
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1604
constexpr MinExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1565
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1594
Scalar grad_l(Scalar a, Scalar b, Scalar parent_adjoint) const override
Definition expression.hpp:1577
Definition expression.hpp:657
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:665
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:681
Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:671
constexpr MultExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:662
ExpressionType type() const override
Definition expression.hpp:667
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:676
std::string_view name() const override
Definition expression.hpp:669
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:688
Definition expression.hpp:1643
Scalar value(Scalar base, Scalar power) const override
Definition expression.hpp:1651
Scalar grad_l(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1660
ExpressionType type() const override
Definition expression.hpp:1656
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1679
Scalar grad_r(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1666
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1685
std::string_view name() const override
Definition expression.hpp:1658
constexpr PowExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1648
Definition expression.hpp:1743
constexpr SignExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1747
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1750
ExpressionType type() const override
Definition expression.hpp:1760
std::string_view name() const override
Definition expression.hpp:1762
Definition expression.hpp:1792
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1808
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1813
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1799
std::string_view name() const override
Definition expression.hpp:1806
ExpressionType type() const override
Definition expression.hpp:1804
constexpr SinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1796
Definition expression.hpp:1847
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1863
std::string_view name() const override
Definition expression.hpp:1861
ExpressionType type() const override
Definition expression.hpp:1859
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1868
constexpr SinhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1851
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1854
Definition expression.hpp:1902
ExpressionType type() const override
Definition expression.hpp:1914
std::string_view name() const override
Definition expression.hpp:1916
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1923
constexpr SqrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1906
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1918
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1909
Definition expression.hpp:1958
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1981
std::string_view name() const override
Definition expression.hpp:1972
constexpr TanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1962
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1974
ExpressionType type() const override
Definition expression.hpp:1970
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1965
Definition expression.hpp:2016
ExpressionType type() const override
Definition expression.hpp:2028
std::string_view name() const override
Definition expression.hpp:2030
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:2032
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:2023
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:2039
constexpr TanhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:2020
Definition expression.hpp:701
constexpr UnaryMinusExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:705
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:718
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:714
ExpressionType type() const override
Definition expression.hpp:710
Scalar value(Scalar lhs, Scalar) const override
Definition expression.hpp:708
std::string_view name() const override
Definition expression.hpp:712