Sleipnir C++ API
Loading...
Searching...
No Matches
expression.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <stdint.h>
6
7#include <algorithm>
8#include <array>
9#include <cmath>
10#include <memory>
11#include <numbers>
12#include <string_view>
13#include <utility>
14
15#include <gch/small_vector.hpp>
16
17#include "sleipnir/autodiff/expression_type.hpp"
18#include "sleipnir/util/intrusive_shared_ptr.hpp"
19#include "sleipnir/util/pool.hpp"
20
21namespace slp::detail {
22
23// The global pool allocator uses a thread-local static pool resource, which
24// isn't guaranteed to be initialized properly across DLL boundaries on Windows
25#ifdef _WIN32
26inline constexpr bool USE_POOL_ALLOCATOR = false;
27#else
28inline constexpr bool USE_POOL_ALLOCATOR = true;
29#endif
30
31template <typename Scalar>
32struct Expression;
33
34template <typename Scalar>
35constexpr void inc_ref_count(Expression<Scalar>* expr);
36template <typename Scalar>
37constexpr void dec_ref_count(Expression<Scalar>* expr);
38
42template <typename Scalar>
43using ExpressionPtr = IntrusiveSharedPtr<Expression<Scalar>>;
44
50template <typename T, typename... Args>
51static ExpressionPtr<typename T::Scalar> make_expression_ptr(Args&&... args) {
52 if constexpr (USE_POOL_ALLOCATOR) {
53 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
54 std::forward<Args>(args)...);
55 } else {
56 return make_intrusive_shared<T>(std::forward<Args>(args)...);
57 }
58}
59
60template <typename Scalar, ExpressionType T>
61struct BinaryMinusExpression;
62
63template <typename Scalar, ExpressionType T>
64struct BinaryPlusExpression;
65
66template <typename Scalar>
67struct ConstantExpression;
68
69template <typename Scalar, ExpressionType T>
70struct DivExpression;
71
72template <typename Scalar, ExpressionType T>
73struct MultExpression;
74
75template <typename Scalar, ExpressionType T>
76struct UnaryMinusExpression;
77
82template <typename Scalar>
83ExpressionPtr<Scalar> constant_ptr(Scalar value);
84
88template <typename Scalar_>
89struct Expression {
91 using Scalar = Scalar_;
92
95
98
101
104
108
111
113 std::array<ExpressionPtr<Scalar>, 2> args{nullptr, nullptr};
114
116 constexpr Expression() = default;
117
121 explicit constexpr Expression(Scalar value) : val{value} {}
122
127 : args{std::move(lhs), nullptr} {}
128
135
136 virtual ~Expression() = default;
137
142 constexpr bool is_constant(Scalar constant) const {
143 return type() == ExpressionType::CONSTANT && val == constant;
144 }
145
151 const ExpressionPtr<Scalar>& rhs) {
152 using enum ExpressionType;
153
154 // Prune expression
155 if (lhs->is_constant(Scalar(0))) {
156 // Return zero
157 return lhs;
158 } else if (rhs->is_constant(Scalar(0))) {
159 // Return zero
160 return rhs;
161 } else if (lhs->is_constant(Scalar(1))) {
162 return rhs;
163 } else if (rhs->is_constant(Scalar(1))) {
164 return lhs;
165 }
166
167 // Evaluate constant
168 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
169 return constant_ptr(lhs->val * rhs->val);
170 }
171
172 // Evaluate expression type
173 if (lhs->type() == CONSTANT) {
174 if (rhs->type() == LINEAR) {
176 } else if (rhs->type() == QUADRATIC) {
178 } else {
180 }
181 } else if (rhs->type() == CONSTANT) {
182 if (lhs->type() == LINEAR) {
184 } else if (lhs->type() == QUADRATIC) {
186 } else {
188 }
189 } else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
191 } else {
193 }
194 }
195
201 const ExpressionPtr<Scalar>& rhs) {
202 using enum ExpressionType;
203
204 // Prune expression
205 if (lhs->is_constant(Scalar(0))) {
206 // Return zero
207 return lhs;
208 } else if (rhs->is_constant(Scalar(1))) {
209 return lhs;
210 }
211
212 // Evaluate constant
213 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
214 return constant_ptr(lhs->val / rhs->val);
215 }
216
217 // Evaluate expression type
218 if (rhs->type() == CONSTANT) {
219 if (lhs->type() == LINEAR) {
221 } else if (lhs->type() == QUADRATIC) {
223 } else {
225 }
226 } else {
228 }
229 }
230
236 const ExpressionPtr<Scalar>& rhs) {
237 using enum ExpressionType;
238
239 // Prune expression
240 if (lhs == nullptr || lhs->is_constant(Scalar(0))) {
241 return rhs;
242 } else if (rhs == nullptr || rhs->is_constant(Scalar(0))) {
243 return lhs;
244 }
245
246 // Evaluate constant
247 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
248 return constant_ptr(lhs->val + rhs->val);
249 }
250
251 auto type = std::max(lhs->type(), rhs->type());
252 if (type == LINEAR) {
254 rhs);
255 } else if (type == QUADRATIC) {
257 rhs);
258 } else {
260 rhs);
261 }
262 }
263
272
278 const ExpressionPtr<Scalar>& rhs) {
279 using enum ExpressionType;
280
281 // Prune expression
282 if (lhs->is_constant(Scalar(0))) {
283 if (rhs->is_constant(Scalar(0))) {
284 // Return zero
285 return rhs;
286 } else {
287 return -rhs;
288 }
289 } else if (rhs->is_constant(Scalar(0))) {
290 return lhs;
291 }
292
293 // Evaluate constant
294 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
295 return constant_ptr(lhs->val - rhs->val);
296 }
297
298 auto type = std::max(lhs->type(), rhs->type());
299 if (type == LINEAR) {
301 rhs);
302 } else if (type == QUADRATIC) {
304 rhs);
305 } else {
307 rhs);
308 }
309 }
310
315 using enum ExpressionType;
316
317 // Prune expression
318 if (lhs->is_constant(Scalar(0))) {
319 // Return zero
320 return lhs;
321 }
322
323 // Evaluate constant
324 if (lhs->type() == CONSTANT) {
325 return constant_ptr(-lhs->val);
326 }
327
328 if (lhs->type() == LINEAR) {
330 } else if (lhs->type() == QUADRATIC) {
332 } else {
334 }
335 }
336
341 return lhs;
342 }
343
352 [[maybe_unused]] Scalar rhs) const = 0;
353
358 virtual ExpressionType type() const = 0;
359
363 virtual std::string_view name() const = 0;
364
374 return Scalar(0);
375 }
376
386 return Scalar(0);
387 }
388
399 return constant_ptr(Scalar(0));
400 }
401
412 return constant_ptr(Scalar(0));
413 }
414};
415
416template <typename Scalar>
417ExpressionPtr<Scalar> constant_ptr(Scalar value) {
419}
420
421template <typename Scalar>
422ExpressionPtr<Scalar> cbrt(const ExpressionPtr<Scalar>& x);
423template <typename Scalar>
424ExpressionPtr<Scalar> exp(const ExpressionPtr<Scalar>& x);
425template <typename Scalar>
426ExpressionPtr<Scalar> sin(const ExpressionPtr<Scalar>& x);
427template <typename Scalar>
428ExpressionPtr<Scalar> sinh(const ExpressionPtr<Scalar>& x);
429template <typename Scalar>
430ExpressionPtr<Scalar> sqrt(const ExpressionPtr<Scalar>& x);
431
436template <typename Scalar, ExpressionType T>
445
446 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs - rhs; }
447
448 ExpressionType type() const override { return T; }
449
450 std::string_view name() const override { return "binary minus"; }
451
453 return parent_adjoint;
454 }
455
457 return -parent_adjoint;
458 }
459
465
471};
472
477template <typename Scalar, ExpressionType T>
486
487 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs + rhs; }
488
489 ExpressionType type() const override { return T; }
490
491 std::string_view name() const override { return "binary plus"; }
492
494 return parent_adjoint;
495 }
496
498 return parent_adjoint;
499 }
500
506
512};
513
517template <typename Scalar>
523 : Expression<Scalar>{std::move(lhs)} {}
524
525 Scalar value(Scalar x, Scalar) const override {
526 using std::cbrt;
527 return cbrt(x);
528 }
529
530 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
531
532 std::string_view name() const override { return "cbrt"; }
533
535 using std::cbrt;
536
537 Scalar c = cbrt(x);
538 return parent_adjoint / (Scalar(3) * c * c);
539 }
540
543 const ExpressionPtr<Scalar>& parent_adjoint) const override {
544 auto c = cbrt(x);
545 return parent_adjoint / (constant_ptr(Scalar(3)) * c * c);
546 }
547};
548
553template <typename Scalar>
555 using enum ExpressionType;
556 using std::cbrt;
557
558 // Evaluate constant
559 if (x->type() == CONSTANT) {
560 if (x->val == Scalar(0)) {
561 // Return zero
562 return x;
563 } else if (x->val == Scalar(-1) || x->val == Scalar(1)) {
564 return x;
565 } else {
566 return constant_ptr(cbrt(x->val));
567 }
568 }
569
570 return make_expression_ptr<CbrtExpression<Scalar>>(x);
571}
572
576template <typename Scalar>
581 explicit constexpr ConstantExpression(Scalar value)
582 : Expression<Scalar>{value} {}
583
584 Scalar value(Scalar, Scalar) const override { return this->val; }
585
586 ExpressionType type() const override { return ExpressionType::CONSTANT; }
587
588 std::string_view name() const override { return "constant"; }
589};
590
594template <typename Scalar>
597 constexpr DecisionVariableExpression() = default;
598
604
605 Scalar value(Scalar, Scalar) const override { return this->val; }
606
607 ExpressionType type() const override { return ExpressionType::LINEAR; }
608
609 std::string_view name() const override { return "decision variable"; }
610};
611
616template <typename Scalar, ExpressionType T>
624
625 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs / rhs; }
626
627 ExpressionType type() const override { return T; }
628
629 std::string_view name() const override { return "division"; }
630
632 return parent_adjoint / rhs;
633 };
634
636 return parent_adjoint * -lhs / (rhs * rhs);
637 }
638
644
650};
651
656template <typename Scalar, ExpressionType T>
664
665 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs * rhs; }
666
667 ExpressionType type() const override { return T; }
668
669 std::string_view name() const override { return "multiplication"; }
670
672 Scalar parent_adjoint) const override {
673 return parent_adjoint * rhs;
674 }
675
677 Scalar parent_adjoint) const override {
678 return parent_adjoint * lhs;
679 }
680
687
694};
695
700template <typename Scalar, ExpressionType T>
706 : Expression<Scalar>{std::move(lhs)} {}
707
708 Scalar value(Scalar lhs, Scalar) const override { return -lhs; }
709
710 ExpressionType type() const override { return T; }
711
712 std::string_view name() const override { return "unary minus"; }
713
715 return -parent_adjoint;
716 }
717
723};
724
729template <typename Scalar>
730constexpr void inc_ref_count(Expression<Scalar>* expr) {
731 ++expr->ref_count;
732}
733
738template <typename Scalar>
739constexpr void dec_ref_count(Expression<Scalar>* expr) {
740 // If a deeply nested tree is being deallocated all at once, calling the
741 // Expression destructor when expr's refcount reaches zero can cause a stack
742 // overflow. Instead, we iterate over its children to decrement their
743 // refcounts and deallocate them.
744 gch::small_vector<Expression<Scalar>*> stack;
745 stack.emplace_back(expr);
746
747 while (!stack.empty()) {
748 auto elem = stack.back();
749 stack.pop_back();
750
751 // Decrement the current node's refcount. If it reaches zero, deallocate the
752 // node and enqueue its children so their refcounts are decremented too.
753 if (--elem->ref_count == 0) {
754 if (elem->adjoint_expr != nullptr) {
755 stack.emplace_back(elem->adjoint_expr.get());
756 }
757 for (auto& arg : elem->args) {
758 if (arg != nullptr) {
759 stack.emplace_back(arg.get());
760 }
761 }
762
763 // Not calling the destructor here is safe because it only decrements
764 // refcounts, which was already done above.
765 if constexpr (USE_POOL_ALLOCATOR) {
766 auto alloc = global_pool_allocator<Expression<Scalar>>();
767 std::allocator_traits<decltype(alloc)>::deallocate(
768 alloc, elem, sizeof(Expression<Scalar>));
769 }
770 }
771 }
772}
773
777template <typename Scalar>
783 : Expression<Scalar>{std::move(lhs)} {}
784
785 Scalar value(Scalar x, Scalar) const override {
786 using std::abs;
787 return abs(x);
788 }
789
790 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
791
792 std::string_view name() const override { return "abs"; }
793
795 if (x < Scalar(0)) {
796 return -parent_adjoint;
797 } else if (x > Scalar(0)) {
798 return parent_adjoint;
799 } else {
800 return Scalar(0);
801 }
802 }
803
806 const ExpressionPtr<Scalar>& parent_adjoint) const override {
807 if (x->val < Scalar(0)) {
808 return -parent_adjoint;
809 } else if (x->val > Scalar(0)) {
810 return parent_adjoint;
811 } else {
812 return constant_ptr(Scalar(0));
813 }
814 }
815};
816
821template <typename Scalar>
823 using enum ExpressionType;
824 using std::abs;
825
826 // Prune expression
827 if (x->is_constant(Scalar(0))) {
828 // Return zero
829 return x;
830 }
831
832 // Evaluate constant
833 if (x->type() == CONSTANT) {
834 return constant_ptr(abs(x->val));
835 }
836
837 return make_expression_ptr<AbsExpression<Scalar>>(x);
838}
839
843template <typename Scalar>
849 : Expression<Scalar>{std::move(lhs)} {}
850
851 Scalar value(Scalar x, Scalar) const override {
852 using std::acos;
853 return acos(x);
854 }
855
856 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
857
858 std::string_view name() const override { return "acos"; }
859
861 using std::sqrt;
862 return -parent_adjoint / sqrt(Scalar(1) - x * x);
863 }
864
867 const ExpressionPtr<Scalar>& parent_adjoint) const override {
868 return -parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
869 }
870};
871
876template <typename Scalar>
878 using enum ExpressionType;
879 using std::acos;
880
881 // Prune expression
882 if (x->is_constant(Scalar(0))) {
883 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
884 }
885
886 // Evaluate constant
887 if (x->type() == CONSTANT) {
888 return constant_ptr(acos(x->val));
889 }
890
891 return make_expression_ptr<AcosExpression<Scalar>>(x);
892}
893
897template <typename Scalar>
903 : Expression<Scalar>{std::move(lhs)} {}
904
905 Scalar value(Scalar x, Scalar) const override {
906 using std::asin;
907 return asin(x);
908 }
909
910 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
911
912 std::string_view name() const override { return "asin"; }
913
915 using std::sqrt;
916 return parent_adjoint / sqrt(Scalar(1) - x * x);
917 }
918
921 const ExpressionPtr<Scalar>& parent_adjoint) const override {
922 return parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
923 }
924};
925
930template <typename Scalar>
932 using enum ExpressionType;
933 using std::asin;
934
935 // Prune expression
936 if (x->is_constant(Scalar(0))) {
937 // Return zero
938 return x;
939 }
940
941 // Evaluate constant
942 if (x->type() == CONSTANT) {
943 return constant_ptr(asin(x->val));
944 }
945
946 return make_expression_ptr<AsinExpression<Scalar>>(x);
947}
948
952template <typename Scalar>
958 : Expression<Scalar>{std::move(lhs)} {}
959
960 Scalar value(Scalar x, Scalar) const override {
961 using std::atan;
962 return atan(x);
963 }
964
965 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
966
967 std::string_view name() const override { return "atan"; }
968
970 return parent_adjoint / (Scalar(1) + x * x);
971 }
972
975 const ExpressionPtr<Scalar>& parent_adjoint) const override {
976 return parent_adjoint / (constant_ptr(Scalar(1)) + x * x);
977 }
978};
979
984template <typename Scalar>
986 using enum ExpressionType;
987 using std::atan;
988
989 // Prune expression
990 if (x->is_constant(Scalar(0))) {
991 // Return zero
992 return x;
993 }
994
995 // Evaluate constant
996 if (x->type() == CONSTANT) {
997 return constant_ptr(atan(x->val));
998 }
999
1000 return make_expression_ptr<AtanExpression<Scalar>>(x);
1001}
1002
1006template <typename Scalar>
1014 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
1015
1016 Scalar value(Scalar y, Scalar x) const override {
1017 using std::atan2;
1018 return atan2(y, x);
1019 }
1020
1021 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1022
1023 std::string_view name() const override { return "atan2"; }
1024
1026 return parent_adjoint * x / (y * y + x * x);
1027 }
1028
1030 return parent_adjoint * -y / (y * y + x * x);
1031 }
1032
1035 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1036 return parent_adjoint * x / (y * y + x * x);
1037 }
1038
1041 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1042 return parent_adjoint * -y / (y * y + x * x);
1043 }
1044};
1045
1051template <typename Scalar>
1053 const ExpressionPtr<Scalar>& x) {
1054 using enum ExpressionType;
1055 using std::atan2;
1056
1057 // Prune expression
1058 if (y->is_constant(Scalar(0))) {
1059 // Return zero
1060 return y;
1061 } else if (x->is_constant(Scalar(0))) {
1062 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
1063 }
1064
1065 // Evaluate constant
1066 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1067 return constant_ptr(atan2(y->val, x->val));
1068 }
1069
1070 return make_expression_ptr<Atan2Expression<Scalar>>(y, x);
1071}
1072
1076template <typename Scalar>
1082 : Expression<Scalar>{std::move(lhs)} {}
1083
1084 Scalar value(Scalar x, Scalar) const override {
1085 using std::cos;
1086 return cos(x);
1087 }
1088
1089 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1090
1091 std::string_view name() const override { return "cos"; }
1092
1094 using std::sin;
1095 return parent_adjoint * -sin(x);
1096 }
1097
1100 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1101 return parent_adjoint * -sin(x);
1102 }
1103};
1104
1109template <typename Scalar>
1111 using enum ExpressionType;
1112 using std::cos;
1113
1114 // Prune expression
1115 if (x->is_constant(Scalar(0))) {
1116 return constant_ptr(Scalar(1));
1117 }
1118
1119 // Evaluate constant
1120 if (x->type() == CONSTANT) {
1121 return constant_ptr(cos(x->val));
1122 }
1123
1124 return make_expression_ptr<CosExpression<Scalar>>(x);
1125}
1126
1130template <typename Scalar>
1136 : Expression<Scalar>{std::move(lhs)} {}
1137
1138 Scalar value(Scalar x, Scalar) const override {
1139 using std::cosh;
1140 return cosh(x);
1141 }
1142
1143 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1144
1145 std::string_view name() const override { return "cosh"; }
1146
1148 using std::sinh;
1149 return parent_adjoint * sinh(x);
1150 }
1151
1154 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1155 return parent_adjoint * sinh(x);
1156 }
1157};
1158
1163template <typename Scalar>
1165 using enum ExpressionType;
1166 using std::cosh;
1167
1168 // Prune expression
1169 if (x->is_constant(Scalar(0))) {
1170 return constant_ptr(Scalar(1));
1171 }
1172
1173 // Evaluate constant
1174 if (x->type() == CONSTANT) {
1175 return constant_ptr(cosh(x->val));
1176 }
1177
1178 return make_expression_ptr<CoshExpression<Scalar>>(x);
1179}
1180
1184template <typename Scalar>
1190 : Expression<Scalar>{std::move(lhs)} {}
1191
1192 Scalar value(Scalar x, Scalar) const override {
1193 using std::erf;
1194 return erf(x);
1195 }
1196
1197 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1198
1199 std::string_view name() const override { return "erf"; }
1200
1202 using std::exp;
1203 return parent_adjoint * Scalar(2.0 * std::numbers::inv_sqrtpi) *
1204 exp(-x * x);
1205 }
1206
1209 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1210 return parent_adjoint *
1211 constant_ptr(Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
1212 }
1213};
1214
1219template <typename Scalar>
1221 using enum ExpressionType;
1222 using std::erf;
1223
1224 // Prune expression
1225 if (x->is_constant(Scalar(0))) {
1226 // Return zero
1227 return x;
1228 }
1229
1230 // Evaluate constant
1231 if (x->type() == CONSTANT) {
1232 return constant_ptr(erf(x->val));
1233 }
1234
1235 return make_expression_ptr<ErfExpression<Scalar>>(x);
1236}
1237
1241template <typename Scalar>
1247 : Expression<Scalar>{std::move(lhs)} {}
1248
1249 Scalar value(Scalar x, Scalar) const override {
1250 using std::exp;
1251 return exp(x);
1252 }
1253
1254 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1255
1256 std::string_view name() const override { return "exp"; }
1257
1259 using std::exp;
1260 return parent_adjoint * exp(x);
1261 }
1262
1265 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1266 return parent_adjoint * exp(x);
1267 }
1268};
1269
1274template <typename Scalar>
1276 using enum ExpressionType;
1277 using std::exp;
1278
1279 // Prune expression
1280 if (x->is_constant(Scalar(0))) {
1281 return constant_ptr(Scalar(1));
1282 }
1283
1284 // Evaluate constant
1285 if (x->type() == CONSTANT) {
1286 return constant_ptr(exp(x->val));
1287 }
1288
1289 return make_expression_ptr<ExpExpression<Scalar>>(x);
1290}
1291
1292template <typename Scalar>
1293ExpressionPtr<Scalar> hypot(const ExpressionPtr<Scalar>& x,
1294 const ExpressionPtr<Scalar>& y);
1295
1299template <typename Scalar>
1307 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
1308
1309 Scalar value(Scalar x, Scalar y) const override {
1310 using std::hypot;
1311 return hypot(x, y);
1312 }
1313
1314 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1315
1316 std::string_view name() const override { return "hypot"; }
1317
1319 using std::hypot;
1320 return parent_adjoint * x / hypot(x, y);
1321 }
1322
1324 using std::hypot;
1325 return parent_adjoint * y / hypot(x, y);
1326 }
1327
1330 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1331 return parent_adjoint * x / hypot(x, y);
1332 }
1333
1336 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1337 return parent_adjoint * y / hypot(x, y);
1338 }
1339};
1340
1346template <typename Scalar>
1348 const ExpressionPtr<Scalar>& y) {
1349 using enum ExpressionType;
1350 using std::hypot;
1351
1352 // Prune expression
1353 if (x->is_constant(Scalar(0))) {
1354 return y;
1355 } else if (y->is_constant(Scalar(0))) {
1356 return x;
1357 }
1358
1359 // Evaluate constant
1360 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1361 return constant_ptr(hypot(x->val, y->val));
1362 }
1363
1364 return make_expression_ptr<HypotExpression<Scalar>>(x, y);
1365}
1366
1370template <typename Scalar>
1376 : Expression<Scalar>{std::move(lhs)} {}
1377
1378 Scalar value(Scalar x, Scalar) const override {
1379 using std::log;
1380 return log(x);
1381 }
1382
1383 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1384
1385 std::string_view name() const override { return "log"; }
1386
1388 return parent_adjoint / x;
1389 }
1390
1393 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1394 return parent_adjoint / x;
1395 }
1396};
1397
1402template <typename Scalar>
1404 using enum ExpressionType;
1405 using std::log;
1406
1407 // Prune expression
1408 if (x->is_constant(Scalar(0))) {
1409 // Return zero
1410 return x;
1411 }
1412
1413 // Evaluate constant
1414 if (x->type() == CONSTANT) {
1415 return constant_ptr(log(x->val));
1416 }
1417
1418 return make_expression_ptr<LogExpression<Scalar>>(x);
1419}
1420
1424template <typename Scalar>
1430 : Expression<Scalar>{std::move(lhs)} {}
1431
1432 Scalar value(Scalar x, Scalar) const override {
1433 using std::log10;
1434 return log10(x);
1435 }
1436
1437 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1438
1439 std::string_view name() const override { return "log10"; }
1440
1442 return parent_adjoint / (Scalar(std::numbers::ln10) * x);
1443 }
1444
1447 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1448 return parent_adjoint / (constant_ptr(Scalar(std::numbers::ln10)) * x);
1449 }
1450};
1451
1456template <typename Scalar>
1458 using enum ExpressionType;
1459 using std::log10;
1460
1461 // Prune expression
1462 if (x->is_constant(Scalar(0))) {
1463 // Return zero
1464 return x;
1465 }
1466
1467 // Evaluate constant
1468 if (x->type() == CONSTANT) {
1469 return constant_ptr(log10(x->val));
1470 }
1471
1472 return make_expression_ptr<Log10Expression<Scalar>>(x);
1473}
1474
1475template <typename Scalar>
1476ExpressionPtr<Scalar> pow(const ExpressionPtr<Scalar>& base,
1477 const ExpressionPtr<Scalar>& power);
1478
1483template <typename Scalar, ExpressionType T>
1491
1492 Scalar value(Scalar base, Scalar power) const override {
1493 using std::pow;
1494 return pow(base, power);
1495 }
1496
1497 ExpressionType type() const override { return T; }
1498
1499 std::string_view name() const override { return "pow"; }
1500
1502 Scalar parent_adjoint) const override {
1503 using std::pow;
1504 return parent_adjoint * pow(base, power - Scalar(1)) * power;
1505 }
1506
1508 Scalar parent_adjoint) const override {
1509 using std::log;
1510 using std::pow;
1511
1512 // Since x log(x) -> 0 as x -> 0
1513 if (base == Scalar(0)) {
1514 return Scalar(0);
1515 } else {
1516 return parent_adjoint * pow(base, power) * log(base);
1517 }
1518 }
1519
1522 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1523 return parent_adjoint * pow(base, power - constant_ptr(Scalar(1))) * power;
1524 }
1525
1528 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1529 // Since x log(x) -> 0 as x -> 0
1530 if (base->val == Scalar(0)) {
1531 // Return zero
1532 return base;
1533 } else {
1534 return parent_adjoint * pow(base, power) * log(base);
1535 }
1536 }
1537};
1538
1544template <typename Scalar>
1547 using enum ExpressionType;
1548 using std::pow;
1549
1550 // Prune expression
1551 if (base->is_constant(Scalar(0))) {
1552 // Return zero
1553 return base;
1554 } else if (base->is_constant(Scalar(1))) {
1555 // Return one
1556 return base;
1557 }
1558 if (power->is_constant(Scalar(0))) {
1559 return constant_ptr(Scalar(1));
1560 } else if (power->is_constant(Scalar(1))) {
1561 return base;
1562 }
1563
1564 // Evaluate constant
1565 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1566 return constant_ptr(pow(base->val, power->val));
1567 }
1568
1569 if (power->is_constant(Scalar(2))) {
1570 if (base->type() == LINEAR) {
1571 return make_expression_ptr<MultExpression<Scalar, QUADRATIC>>(base, base);
1572 } else {
1573 return make_expression_ptr<MultExpression<Scalar, NONLINEAR>>(base, base);
1574 }
1575 }
1576
1577 return make_expression_ptr<PowExpression<Scalar, NONLINEAR>>(base, power);
1578}
1579
1583template <typename Scalar>
1589 : Expression<Scalar>{std::move(lhs)} {}
1590
1591 Scalar value(Scalar x, Scalar) const override {
1592 if (x < Scalar(0)) {
1593 return Scalar(-1);
1594 } else if (x == Scalar(0)) {
1595 return Scalar(0);
1596 } else {
1597 return Scalar(1);
1598 }
1599 }
1600
1601 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1602
1603 std::string_view name() const override { return "sign"; }
1604};
1605
1610template <typename Scalar>
1612 using enum ExpressionType;
1613
1614 // Evaluate constant
1615 if (x->type() == CONSTANT) {
1616 if (x->val < Scalar(0)) {
1617 return constant_ptr(Scalar(-1));
1618 } else if (x->val == Scalar(0)) {
1619 // Return zero
1620 return x;
1621 } else {
1622 return constant_ptr(Scalar(1));
1623 }
1624 }
1625
1626 return make_expression_ptr<SignExpression<Scalar>>(x);
1627}
1628
1632template <typename Scalar>
1638 : Expression<Scalar>{std::move(lhs)} {}
1639
1640 Scalar value(Scalar x, Scalar) const override {
1641 using std::sin;
1642 return sin(x);
1643 }
1644
1645 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1646
1647 std::string_view name() const override { return "sin"; }
1648
1650 using std::cos;
1651 return parent_adjoint * cos(x);
1652 }
1653
1656 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1657 return parent_adjoint * cos(x);
1658 }
1659};
1660
1665template <typename Scalar>
1667 using enum ExpressionType;
1668 using std::sin;
1669
1670 // Prune expression
1671 if (x->is_constant(Scalar(0))) {
1672 // Return zero
1673 return x;
1674 }
1675
1676 // Evaluate constant
1677 if (x->type() == CONSTANT) {
1678 return constant_ptr(sin(x->val));
1679 }
1680
1681 return make_expression_ptr<SinExpression<Scalar>>(x);
1682}
1683
1687template <typename Scalar>
1693 : Expression<Scalar>{std::move(lhs)} {}
1694
1695 Scalar value(Scalar x, Scalar) const override {
1696 using std::sinh;
1697 return sinh(x);
1698 }
1699
1700 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1701
1702 std::string_view name() const override { return "sinh"; }
1703
1705 using std::cosh;
1706 return parent_adjoint * cosh(x);
1707 }
1708
1711 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1712 return parent_adjoint * cosh(x);
1713 }
1714};
1715
1720template <typename Scalar>
1722 using enum ExpressionType;
1723 using std::sinh;
1724
1725 // Prune expression
1726 if (x->is_constant(Scalar(0))) {
1727 // Return zero
1728 return x;
1729 }
1730
1731 // Evaluate constant
1732 if (x->type() == CONSTANT) {
1733 return constant_ptr(sinh(x->val));
1734 }
1735
1736 return make_expression_ptr<SinhExpression<Scalar>>(x);
1737}
1738
1742template <typename Scalar>
1748 : Expression<Scalar>{std::move(lhs)} {}
1749
1750 Scalar value(Scalar x, Scalar) const override {
1751 using std::sqrt;
1752 return sqrt(x);
1753 }
1754
1755 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1756
1757 std::string_view name() const override { return "sqrt"; }
1758
1760 using std::sqrt;
1761 return parent_adjoint / (Scalar(2) * sqrt(x));
1762 }
1763
1766 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1767 return parent_adjoint / (constant_ptr(Scalar(2)) * sqrt(x));
1768 }
1769};
1770
1775template <typename Scalar>
1777 using enum ExpressionType;
1778 using std::sqrt;
1779
1780 // Evaluate constant
1781 if (x->type() == CONSTANT) {
1782 if (x->val == Scalar(0)) {
1783 // Return zero
1784 return x;
1785 } else if (x->val == Scalar(1)) {
1786 return x;
1787 } else {
1788 return constant_ptr(sqrt(x->val));
1789 }
1790 }
1791
1792 return make_expression_ptr<SqrtExpression<Scalar>>(x);
1793}
1794
1798template <typename Scalar>
1804 : Expression<Scalar>{std::move(lhs)} {}
1805
1806 Scalar value(Scalar x, Scalar) const override {
1807 using std::tan;
1808 return tan(x);
1809 }
1810
1811 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1812
1813 std::string_view name() const override { return "tan"; }
1814
1816 using std::cos;
1817
1818 auto c = cos(x);
1819 return parent_adjoint / (c * c);
1820 }
1821
1824 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1825 auto c = cos(x);
1826 return parent_adjoint / (c * c);
1827 }
1828};
1829
1834template <typename Scalar>
1836 using enum ExpressionType;
1837 using std::tan;
1838
1839 // Prune expression
1840 if (x->is_constant(Scalar(0))) {
1841 // Return zero
1842 return x;
1843 }
1844
1845 // Evaluate constant
1846 if (x->type() == CONSTANT) {
1847 return constant_ptr(tan(x->val));
1848 }
1849
1850 return make_expression_ptr<TanExpression<Scalar>>(x);
1851}
1852
1856template <typename Scalar>
1862 : Expression<Scalar>{std::move(lhs)} {}
1863
1864 Scalar value(Scalar x, Scalar) const override {
1865 using std::tanh;
1866 return tanh(x);
1867 }
1868
1869 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1870
1871 std::string_view name() const override { return "tanh"; }
1872
1874 using std::cosh;
1875
1876 auto c = cosh(x);
1877 return parent_adjoint / (c * c);
1878 }
1879
1882 const ExpressionPtr<Scalar>& parent_adjoint) const override {
1883 auto c = cosh(x);
1884 return parent_adjoint / (c * c);
1885 }
1886};
1887
1892template <typename Scalar>
1894 using enum ExpressionType;
1895 using std::tanh;
1896
1897 // Prune expression
1898 if (x->is_constant(Scalar(0))) {
1899 // Return zero
1900 return x;
1901 }
1902
1903 // Evaluate constant
1904 if (x->type() == CONSTANT) {
1905 return constant_ptr(tanh(x->val));
1906 }
1907
1908 return make_expression_ptr<TanhExpression<Scalar>>(x);
1909}
1910
1911} // namespace slp::detail
Definition intrusive_shared_ptr.hpp:27
Definition expression.hpp:778
std::string_view name() const override
Definition expression.hpp:792
constexpr AbsExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:782
ExpressionType type() const override
Definition expression.hpp:790
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:794
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:804
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:785
Definition expression.hpp:844
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:865
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:851
constexpr AcosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:848
ExpressionType type() const override
Definition expression.hpp:856
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:860
std::string_view name() const override
Definition expression.hpp:858
Definition expression.hpp:898
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:914
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:905
std::string_view name() const override
Definition expression.hpp:912
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:919
ExpressionType type() const override
Definition expression.hpp:910
constexpr AsinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:902
Definition expression.hpp:1007
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1033
std::string_view name() const override
Definition expression.hpp:1023
Scalar value(Scalar y, Scalar x) const override
Definition expression.hpp:1016
constexpr Atan2Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1012
Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1029
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1039
ExpressionType type() const override
Definition expression.hpp:1021
Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override
Definition expression.hpp:1025
Definition expression.hpp:953
std::string_view name() const override
Definition expression.hpp:967
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:973
constexpr AtanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:957
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:960
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:969
ExpressionType type() const override
Definition expression.hpp:965
Definition expression.hpp:437
constexpr BinaryMinusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:442
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:460
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:466
std::string_view name() const override
Definition expression.hpp:450
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:456
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:446
ExpressionType type() const override
Definition expression.hpp:448
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:452
Definition expression.hpp:478
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:501
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:487
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:507
constexpr BinaryPlusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:483
ExpressionType type() const override
Definition expression.hpp:489
std::string_view name() const override
Definition expression.hpp:491
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:497
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:493
Definition expression.hpp:518
std::string_view name() const override
Definition expression.hpp:532
constexpr CbrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:522
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:534
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:541
ExpressionType type() const override
Definition expression.hpp:530
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:525
Definition expression.hpp:577
ExpressionType type() const override
Definition expression.hpp:586
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:584
std::string_view name() const override
Definition expression.hpp:588
constexpr ConstantExpression(Scalar value)
Definition expression.hpp:581
Definition expression.hpp:1077
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1084
ExpressionType type() const override
Definition expression.hpp:1089
constexpr CosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1081
std::string_view name() const override
Definition expression.hpp:1091
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1098
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1093
Definition expression.hpp:1131
ExpressionType type() const override
Definition expression.hpp:1143
std::string_view name() const override
Definition expression.hpp:1145
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1138
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1152
constexpr CoshExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1135
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1147
Definition expression.hpp:595
constexpr DecisionVariableExpression()=default
Constructs a decision variable expression with a value of zero.
std::string_view name() const override
Definition expression.hpp:609
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:605
constexpr DecisionVariableExpression(Scalar value)
Definition expression.hpp:602
ExpressionType type() const override
Definition expression.hpp:607
Definition expression.hpp:617
constexpr DivExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:622
ExpressionType type() const override
Definition expression.hpp:627
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:635
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:645
std::string_view name() const override
Definition expression.hpp:629
Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:631
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:625
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:639
Definition expression.hpp:1185
std::string_view name() const override
Definition expression.hpp:1199
constexpr ErfExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1189
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1201
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1192
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1207
ExpressionType type() const override
Definition expression.hpp:1197
Definition expression.hpp:1242
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1258
std::string_view name() const override
Definition expression.hpp:1256
ExpressionType type() const override
Definition expression.hpp:1254
constexpr ExpExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1246
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1263
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1249
Definition expression.hpp:89
Scalar val
The value of the expression node.
Definition expression.hpp:94
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:314
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:103
virtual Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:383
std::array< ExpressionPtr< Scalar >, 2 > args
Expression arguments.
Definition expression.hpp:113
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:110
constexpr Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:126
virtual ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:395
virtual ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const
Definition expression.hpp:408
Scalar adjoint
The adjoint of the expression node, used during autodiff.
Definition expression.hpp:97
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:235
constexpr bool is_constant(Scalar constant) const
Definition expression.hpp:142
constexpr Expression()=default
Constructs a constant expression with a value of zero.
constexpr Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:133
ExpressionPtr< Scalar > adjoint_expr
Definition expression.hpp:107
Scalar_ Scalar
Scalar type alias.
Definition expression.hpp:91
virtual ExpressionType type() const =0
friend ExpressionPtr< Scalar > operator/(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:200
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:100
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:277
friend ExpressionPtr< Scalar > operator*(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:150
virtual Scalar value(Scalar lhs, Scalar rhs) const =0
virtual std::string_view name() const =0
virtual Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const
Definition expression.hpp:371
friend ExpressionPtr< Scalar > operator+=(ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:268
constexpr Expression(Scalar value)
Definition expression.hpp:121
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:340
Definition expression.hpp:1300
constexpr HypotExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1305
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1334
ExpressionType type() const override
Definition expression.hpp:1314
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1328
Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1318
std::string_view name() const override
Definition expression.hpp:1316
Scalar value(Scalar x, Scalar y) const override
Definition expression.hpp:1309
Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override
Definition expression.hpp:1323
Definition expression.hpp:1425
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1432
std::string_view name() const override
Definition expression.hpp:1439
constexpr Log10Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1429
ExpressionType type() const override
Definition expression.hpp:1437
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1445
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1441
Definition expression.hpp:1371
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1391
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1387
ExpressionType type() const override
Definition expression.hpp:1383
std::string_view name() const override
Definition expression.hpp:1385
constexpr LogExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1375
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1378
Definition expression.hpp:657
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:665
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:681
Scalar grad_l(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:671
constexpr MultExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:662
ExpressionType type() const override
Definition expression.hpp:667
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override
Definition expression.hpp:676
std::string_view name() const override
Definition expression.hpp:669
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:688
Definition expression.hpp:1484
Scalar value(Scalar base, Scalar power) const override
Definition expression.hpp:1492
Scalar grad_l(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1501
ExpressionType type() const override
Definition expression.hpp:1497
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1520
Scalar grad_r(Scalar base, Scalar power, Scalar parent_adjoint) const override
Definition expression.hpp:1507
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1526
std::string_view name() const override
Definition expression.hpp:1499
constexpr PowExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1489
Definition expression.hpp:1584
constexpr SignExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1588
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1591
ExpressionType type() const override
Definition expression.hpp:1601
std::string_view name() const override
Definition expression.hpp:1603
Definition expression.hpp:1633
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1649
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1654
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1640
std::string_view name() const override
Definition expression.hpp:1647
ExpressionType type() const override
Definition expression.hpp:1645
constexpr SinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1637
Definition expression.hpp:1688
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1704
std::string_view name() const override
Definition expression.hpp:1702
ExpressionType type() const override
Definition expression.hpp:1700
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1709
constexpr SinhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1692
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1695
Definition expression.hpp:1743
ExpressionType type() const override
Definition expression.hpp:1755
std::string_view name() const override
Definition expression.hpp:1757
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1764
constexpr SqrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1747
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1759
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1750
Definition expression.hpp:1799
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1822
std::string_view name() const override
Definition expression.hpp:1813
constexpr TanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1803
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1815
ExpressionType type() const override
Definition expression.hpp:1811
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1806
Definition expression.hpp:1857
ExpressionType type() const override
Definition expression.hpp:1869
std::string_view name() const override
Definition expression.hpp:1871
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:1873
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1864
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:1880
constexpr TanhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1861
Definition expression.hpp:701
constexpr UnaryMinusExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:705
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &parent_adjoint) const override
Definition expression.hpp:718
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override
Definition expression.hpp:714
ExpressionType type() const override
Definition expression.hpp:710
Scalar value(Scalar lhs, Scalar) const override
Definition expression.hpp:708
std::string_view name() const override
Definition expression.hpp:712