Sleipnir C++ API
Loading...
Searching...
No Matches
expression.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <stdint.h>
6
7#include <algorithm>
8#include <array>
9#include <cmath>
10#include <memory>
11#include <numbers>
12#include <string_view>
13#include <utility>
14
15#include <gch/small_vector.hpp>
16
17#include "sleipnir/autodiff/expression_type.hpp"
18#include "sleipnir/util/intrusive_shared_ptr.hpp"
19#include "sleipnir/util/pool.hpp"
20
21namespace slp::detail {
22
23// The global pool allocator uses a thread-local static pool resource, which
24// isn't guaranteed to be initialized properly across DLL boundaries on Windows
25#ifdef _WIN32
26inline constexpr bool USE_POOL_ALLOCATOR = false;
27#else
28inline constexpr bool USE_POOL_ALLOCATOR = true;
29#endif
30
31template <typename Scalar>
32struct Expression;
33
34template <typename Scalar>
35constexpr void inc_ref_count(Expression<Scalar>* expr);
36template <typename Scalar>
37constexpr void dec_ref_count(Expression<Scalar>* expr);
38
42template <typename Scalar>
43using ExpressionPtr = IntrusiveSharedPtr<Expression<Scalar>>;
44
50template <typename T, typename... Args>
51static ExpressionPtr<typename T::Scalar> make_expression_ptr(Args&&... args) {
52 if constexpr (USE_POOL_ALLOCATOR) {
53 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
54 std::forward<Args>(args)...);
55 } else {
56 return make_intrusive_shared<T>(std::forward<Args>(args)...);
57 }
58}
59
60template <typename Scalar, ExpressionType T>
61struct BinaryMinusExpression;
62
63template <typename Scalar, ExpressionType T>
64struct BinaryPlusExpression;
65
66template <typename Scalar>
67struct ConstantExpression;
68
69template <typename Scalar, ExpressionType T>
70struct DivExpression;
71
72template <typename Scalar, ExpressionType T>
73struct MultExpression;
74
75template <typename Scalar, ExpressionType T>
76struct UnaryMinusExpression;
77
82template <typename Scalar>
83ExpressionPtr<Scalar> constant_ptr(Scalar value);
84
88template <typename Scalar_>
89struct Expression {
91 using Scalar = Scalar_;
92
95
98
102
104 std::array<ExpressionPtr<Scalar>, 2> args{nullptr, nullptr};
105
116
119
121 constexpr Expression() = default;
122
126 explicit constexpr Expression(Scalar value) : val{value} {}
127
132 : args{std::move(lhs), nullptr} {}
133
140
141 virtual ~Expression() = default;
142
147 constexpr bool is_constant(Scalar constant) const {
148 return type() == ExpressionType::CONSTANT && val == constant;
149 }
150
156 const ExpressionPtr<Scalar>& rhs) {
157 using enum ExpressionType;
158
159 // Prune expression
160 if (lhs->is_constant(Scalar(0))) {
161 // Return zero, which lhs currently is
162 return lhs;
163 } else if (rhs->is_constant(Scalar(0))) {
164 // Return zero, which rhs currently is
165 return rhs;
166 } else if (lhs->is_constant(Scalar(1))) {
167 // Return rhs unmodified
168 return rhs;
169 } else if (rhs->is_constant(Scalar(1))) {
170 // Return lhs unmodified
171 return lhs;
172 }
173
174 // Evaluate constant
175 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
176 return constant_ptr(lhs->val * rhs->val);
177 }
178
179 // Evaluate expression type
180 if (lhs->type() == CONSTANT) {
181 if (rhs->type() == LINEAR) {
183 } else if (rhs->type() == QUADRATIC) {
185 } else {
187 }
188 } else if (rhs->type() == CONSTANT) {
189 if (lhs->type() == LINEAR) {
191 } else if (lhs->type() == QUADRATIC) {
193 } else {
195 }
196 } else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
198 } else {
200 }
201 }
202
208 const ExpressionPtr<Scalar>& rhs) {
209 using enum ExpressionType;
210
211 // Prune expression
212 if (lhs->is_constant(Scalar(0))) {
213 // Return zero, which lhs currently is
214 return lhs;
215 } else if (rhs->is_constant(Scalar(1))) {
216 // Return lhs unmodified
217 return lhs;
218 }
219
220 // Evaluate constant
221 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
222 return constant_ptr(lhs->val / rhs->val);
223 }
224
225 // Evaluate expression type
226 if (rhs->type() == CONSTANT) {
227 if (lhs->type() == LINEAR) {
229 } else if (lhs->type() == QUADRATIC) {
231 } else {
233 }
234 } else {
236 }
237 }
238
244 const ExpressionPtr<Scalar>& rhs) {
245 using enum ExpressionType;
246
247 // Prune expression. We check for nullptr because operator+ is used in
248 // adjoint accumulation, and child nodes can be null.
249 if (lhs == nullptr || lhs->is_constant(Scalar(0))) {
250 // Return rhs unmodified
251 return rhs;
252 } else if (rhs == nullptr || rhs->is_constant(Scalar(0))) {
253 // Return lhs unmodified
254 return lhs;
255 }
256
257 // Evaluate constant
258 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
259 return constant_ptr(lhs->val + rhs->val);
260 }
261
262 auto type = std::max(lhs->type(), rhs->type());
263 if (type == LINEAR) {
265 rhs);
266 } else if (type == QUADRATIC) {
268 rhs);
269 } else {
271 rhs);
272 }
273 }
274
283
289 const ExpressionPtr<Scalar>& rhs) {
290 using enum ExpressionType;
291
292 // Prune expression
293 if (lhs->is_constant(Scalar(0))) {
294 if (rhs->is_constant(Scalar(0))) {
295 // Return zero, which rhs currently is
296 return rhs;
297 } else {
298 // Return rhs negated
299 return -rhs;
300 }
301 } else if (rhs->is_constant(Scalar(0))) {
302 // Return lhs unmodified
303 return lhs;
304 }
305
306 // Evaluate constant
307 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
308 return constant_ptr(lhs->val - rhs->val);
309 }
310
311 auto type = std::max(lhs->type(), rhs->type());
312 if (type == LINEAR) {
314 rhs);
315 } else if (type == QUADRATIC) {
317 rhs);
318 } else {
320 rhs);
321 }
322 }
323
328 using enum ExpressionType;
329
330 // Prune expression
331 if (lhs->is_constant(Scalar(0))) {
332 // Return zero, which lhs currently is
333 return lhs;
334 }
335
336 // Evaluate constant
337 if (lhs->type() == CONSTANT) {
338 return constant_ptr(-lhs->val);
339 }
340
341 if (lhs->type() == LINEAR) {
343 } else if (lhs->type() == QUADRATIC) {
345 } else {
347 }
348 }
349
354 return lhs;
355 }
356
365 [[maybe_unused]] Scalar rhs) const = 0;
366
371 virtual ExpressionType type() const = 0;
372
376 virtual std::string_view name() const = 0;
377
384 [[maybe_unused]] Scalar rhs) const {
385 return Scalar(0);
386 }
387
394 [[maybe_unused]] Scalar rhs) const {
395 return Scalar(0);
396 }
397
405 [[maybe_unused]] const ExpressionPtr<Scalar>& rhs) const {
406 return constant_ptr(Scalar(0));
407 }
408
416 [[maybe_unused]] const ExpressionPtr<Scalar>& rhs) const {
417 return constant_ptr(Scalar(0));
418 }
419};
420
421template <typename Scalar>
422ExpressionPtr<Scalar> constant_ptr(Scalar value) {
424}
425
426template <typename Scalar>
427ExpressionPtr<Scalar> cbrt(const ExpressionPtr<Scalar>& x);
428template <typename Scalar>
429ExpressionPtr<Scalar> exp(const ExpressionPtr<Scalar>& x);
430template <typename Scalar>
431ExpressionPtr<Scalar> sin(const ExpressionPtr<Scalar>& x);
432template <typename Scalar>
433ExpressionPtr<Scalar> sinh(const ExpressionPtr<Scalar>& x);
434template <typename Scalar>
435ExpressionPtr<Scalar> sqrt(const ExpressionPtr<Scalar>& x);
436
441template <typename Scalar, ExpressionType T>
450
451 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs - rhs; }
452
453 ExpressionType type() const override { return T; }
454
455 std::string_view name() const override { return "binary minus"; }
456
457 Scalar grad_l(Scalar, Scalar) const override { return this->adjoint; }
458
459 Scalar grad_r(Scalar, Scalar) const override { return -this->adjoint; }
460
463 const ExpressionPtr<Scalar>&) const override {
464 return this->adjoint_expr;
465 }
466
469 const ExpressionPtr<Scalar>&) const override {
470 return -this->adjoint_expr;
471 }
472};
473
478template <typename Scalar, ExpressionType T>
487
488 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs + rhs; }
489
490 ExpressionType type() const override { return T; }
491
492 std::string_view name() const override { return "binary plus"; }
493
494 Scalar grad_l(Scalar, Scalar) const override { return this->adjoint; }
495
496 Scalar grad_r(Scalar, Scalar) const override { return this->adjoint; }
497
500 const ExpressionPtr<Scalar>&) const override {
501 return this->adjoint_expr;
502 }
503
506 const ExpressionPtr<Scalar>&) const override {
507 return this->adjoint_expr;
508 }
509};
510
514template <typename Scalar>
520 : Expression<Scalar>{std::move(lhs)} {}
521
522 Scalar value(Scalar x, Scalar) const override {
523 using std::cbrt;
524 return cbrt(x);
525 }
526
527 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
528
529 std::string_view name() const override { return "cbrt"; }
530
531 Scalar grad_l(Scalar x, Scalar) const override {
532 using std::cbrt;
533
534 Scalar c = cbrt(x);
535 return this->adjoint / (Scalar(3) * c * c);
536 }
537
539 const ExpressionPtr<Scalar>& x,
540 const ExpressionPtr<Scalar>&) const override {
541 auto c = cbrt(x);
542 return this->adjoint_expr / (constant_ptr(Scalar(3)) * c * c);
543 }
544};
545
550template <typename Scalar>
552 using enum ExpressionType;
553 using std::cbrt;
554
555 // Evaluate constant
556 if (x->type() == CONSTANT) {
557 if (x->val == Scalar(0)) {
558 // Return zero
559 return x;
560 } else if (x->val == Scalar(-1) || x->val == Scalar(1)) {
561 return x;
562 } else {
563 return constant_ptr(cbrt(x->val));
564 }
565 }
566
567 return make_expression_ptr<CbrtExpression<Scalar>>(x);
568}
569
573template <typename Scalar>
578 explicit constexpr ConstantExpression(Scalar value)
579 : Expression<Scalar>{value} {}
580
581 Scalar value(Scalar, Scalar) const override { return this->val; }
582
583 ExpressionType type() const override { return ExpressionType::CONSTANT; }
584
585 std::string_view name() const override { return "constant"; }
586};
587
591template <typename Scalar>
594 constexpr DecisionVariableExpression() = default;
595
601
602 Scalar value(Scalar, Scalar) const override { return this->val; }
603
604 ExpressionType type() const override { return ExpressionType::LINEAR; }
605
606 std::string_view name() const override { return "decision variable"; }
607};
608
613template <typename Scalar, ExpressionType T>
621
622 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs / rhs; }
623
624 ExpressionType type() const override { return T; }
625
626 std::string_view name() const override { return "division"; }
627
628 Scalar grad_l(Scalar, Scalar rhs) const override {
629 return this->adjoint / rhs;
630 };
631
632 Scalar grad_r(Scalar lhs, Scalar rhs) const override {
633 return this->adjoint * -lhs / (rhs * rhs);
634 }
635
638 const ExpressionPtr<Scalar>& rhs) const override {
639 return this->adjoint_expr / rhs;
640 }
641
644 const ExpressionPtr<Scalar>& rhs) const override {
645 return this->adjoint_expr * -lhs / (rhs * rhs);
646 }
647};
648
653template <typename Scalar, ExpressionType T>
661
662 Scalar value(Scalar lhs, Scalar rhs) const override { return lhs * rhs; }
663
664 ExpressionType type() const override { return T; }
665
666 std::string_view name() const override { return "multiplication"; }
667
669 return this->adjoint * rhs;
670 }
671
673 return this->adjoint * lhs;
674 }
675
678 const ExpressionPtr<Scalar>& rhs) const override {
679 return this->adjoint_expr * rhs;
680 }
681
684 [[maybe_unused]] const ExpressionPtr<Scalar>& rhs) const override {
685 return this->adjoint_expr * lhs;
686 }
687};
688
693template <typename Scalar, ExpressionType T>
699 : Expression<Scalar>{std::move(lhs)} {}
700
701 Scalar value(Scalar lhs, Scalar) const override { return -lhs; }
702
703 ExpressionType type() const override { return T; }
704
705 std::string_view name() const override { return "unary minus"; }
706
707 Scalar grad_l(Scalar, Scalar) const override { return -this->adjoint; }
708
711 const ExpressionPtr<Scalar>&) const override {
712 return -this->adjoint_expr;
713 }
714};
715
720template <typename Scalar>
721constexpr void inc_ref_count(Expression<Scalar>* expr) {
722 ++expr->ref_count;
723}
724
729template <typename Scalar>
730constexpr void dec_ref_count(Expression<Scalar>* expr) {
731 // If a deeply nested tree is being deallocated all at once, calling the
732 // Expression destructor when expr's refcount reaches zero can cause a stack
733 // overflow. Instead, we iterate over its children to decrement their
734 // refcounts and deallocate them.
735 gch::small_vector<Expression<Scalar>*> stack;
736 stack.emplace_back(expr);
737
738 while (!stack.empty()) {
739 auto elem = stack.back();
740 stack.pop_back();
741
742 // Decrement the current node's refcount. If it reaches zero, deallocate the
743 // node and enqueue its children so their refcounts are decremented too.
744 if (--elem->ref_count == 0) {
745 if (elem->adjoint_expr != nullptr) {
746 stack.emplace_back(elem->adjoint_expr.get());
747 }
748 for (auto& arg : elem->args) {
749 if (arg != nullptr) {
750 stack.emplace_back(arg.get());
751 }
752 }
753
754 // Not calling the destructor here is safe because it only decrements
755 // refcounts, which was already done above.
756 if constexpr (USE_POOL_ALLOCATOR) {
757 auto alloc = global_pool_allocator<Expression<Scalar>>();
758 std::allocator_traits<decltype(alloc)>::deallocate(
759 alloc, elem, sizeof(Expression<Scalar>));
760 } else {
761 operator delete(elem);
762 }
763 }
764 }
765}
766
770template <typename Scalar>
776 : Expression<Scalar>{std::move(lhs)} {}
777
778 Scalar value(Scalar x, Scalar) const override {
779 using std::abs;
780 return abs(x);
781 }
782
783 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
784
785 std::string_view name() const override { return "abs"; }
786
787 Scalar grad_l(Scalar x, Scalar) const override {
788 if (x < Scalar(0)) {
789 return -this->adjoint;
790 } else if (x > Scalar(0)) {
791 return this->adjoint;
792 } else {
793 return Scalar(0);
794 }
795 }
796
798 const ExpressionPtr<Scalar>& x,
799 const ExpressionPtr<Scalar>&) const override {
800 if (x->val < Scalar(0)) {
801 return -this->adjoint_expr;
802 } else if (x->val > Scalar(0)) {
803 return this->adjoint_expr;
804 } else {
805 return constant_ptr(Scalar(0));
806 }
807 }
808};
809
814template <typename Scalar>
816 using enum ExpressionType;
817 using std::abs;
818
819 // Prune expression
820 if (x->is_constant(Scalar(0))) {
821 // Return zero, which x currently is
822 return x;
823 }
824
825 // Evaluate constant
826 if (x->type() == CONSTANT) {
827 return constant_ptr(abs(x->val));
828 }
829
830 return make_expression_ptr<AbsExpression<Scalar>>(x);
831}
832
836template <typename Scalar>
842 : Expression<Scalar>{std::move(lhs)} {}
843
844 Scalar value(Scalar x, Scalar) const override {
845 using std::acos;
846 return acos(x);
847 }
848
849 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
850
851 std::string_view name() const override { return "acos"; }
852
853 Scalar grad_l(Scalar x, Scalar) const override {
854 using std::sqrt;
855 return -this->adjoint / sqrt(Scalar(1) - x * x);
856 }
857
859 const ExpressionPtr<Scalar>& x,
860 const ExpressionPtr<Scalar>&) const override {
861 return -this->adjoint_expr / sqrt(constant_ptr(Scalar(1)) - x * x);
862 }
863};
864
869template <typename Scalar>
871 using enum ExpressionType;
872 using std::acos;
873
874 // Prune expression
875 if (x->is_constant(Scalar(0))) {
876 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
877 }
878
879 // Evaluate constant
880 if (x->type() == CONSTANT) {
881 return constant_ptr(acos(x->val));
882 }
883
884 return make_expression_ptr<AcosExpression<Scalar>>(x);
885}
886
890template <typename Scalar>
896 : Expression<Scalar>{std::move(lhs)} {}
897
898 Scalar value(Scalar x, Scalar) const override {
899 using std::asin;
900 return asin(x);
901 }
902
903 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
904
905 std::string_view name() const override { return "asin"; }
906
907 Scalar grad_l(Scalar x, Scalar) const override {
908 using std::sqrt;
909 return this->adjoint / sqrt(Scalar(1) - x * x);
910 }
911
913 const ExpressionPtr<Scalar>& x,
914 const ExpressionPtr<Scalar>&) const override {
915 return this->adjoint_expr / sqrt(constant_ptr(Scalar(1)) - x * x);
916 }
917};
918
923template <typename Scalar>
925 using enum ExpressionType;
926 using std::asin;
927
928 // Prune expression
929 if (x->is_constant(Scalar(0))) {
930 // Return zero, which x currently is
931 return x;
932 }
933
934 // Evaluate constant
935 if (x->type() == CONSTANT) {
936 return constant_ptr(asin(x->val));
937 }
938
939 return make_expression_ptr<AsinExpression<Scalar>>(x);
940}
941
945template <typename Scalar>
951 : Expression<Scalar>{std::move(lhs)} {}
952
953 Scalar value(Scalar x, Scalar) const override {
954 using std::atan;
955 return atan(x);
956 }
957
958 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
959
960 std::string_view name() const override { return "atan"; }
961
962 Scalar grad_l(Scalar x, Scalar) const override {
963 return this->adjoint / (Scalar(1) + x * x);
964 }
965
967 const ExpressionPtr<Scalar>& x,
968 const ExpressionPtr<Scalar>&) const override {
969 return this->adjoint_expr / (constant_ptr(Scalar(1)) + x * x);
970 }
971};
972
977template <typename Scalar>
979 using enum ExpressionType;
980 using std::atan;
981
982 // Prune expression
983 if (x->is_constant(Scalar(0))) {
984 // Return zero, which x currently is
985 return x;
986 }
987
988 // Evaluate constant
989 if (x->type() == CONSTANT) {
990 return constant_ptr(atan(x->val));
991 }
992
993 return make_expression_ptr<AtanExpression<Scalar>>(x);
994}
995
999template <typename Scalar>
1007 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
1008
1009 Scalar value(Scalar y, Scalar x) const override {
1010 using std::atan2;
1011 return atan2(y, x);
1012 }
1013
1014 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1015
1016 std::string_view name() const override { return "atan2"; }
1017
1018 Scalar grad_l(Scalar y, Scalar x) const override {
1019 return this->adjoint * x / (y * y + x * x);
1020 }
1021
1022 Scalar grad_r(Scalar y, Scalar x) const override {
1023 return this->adjoint * -y / (y * y + x * x);
1024 }
1025
1027 const ExpressionPtr<Scalar>& y,
1028 const ExpressionPtr<Scalar>& x) const override {
1029 return this->adjoint_expr * x / (y * y + x * x);
1030 }
1031
1033 const ExpressionPtr<Scalar>& y,
1034 const ExpressionPtr<Scalar>& x) const override {
1035 return this->adjoint_expr * -y / (y * y + x * x);
1036 }
1037};
1038
1044template <typename Scalar>
1046 const ExpressionPtr<Scalar>& x) {
1047 using enum ExpressionType;
1048 using std::atan2;
1049
1050 // Prune expression
1051 if (y->is_constant(Scalar(0))) {
1052 // Return zero, which y currently is
1053 return y;
1054 } else if (x->is_constant(Scalar(0))) {
1055 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
1056 }
1057
1058 // Evaluate constant
1059 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1060 return constant_ptr(atan2(y->val, x->val));
1061 }
1062
1063 return make_expression_ptr<Atan2Expression<Scalar>>(y, x);
1064}
1065
1069template <typename Scalar>
1075 : Expression<Scalar>{std::move(lhs)} {}
1076
1077 Scalar value(Scalar x, Scalar) const override {
1078 using std::cos;
1079 return cos(x);
1080 }
1081
1082 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1083
1084 std::string_view name() const override { return "cos"; }
1085
1086 Scalar grad_l(Scalar x, Scalar) const override {
1087 using std::sin;
1088 return this->adjoint * -sin(x);
1089 }
1090
1092 const ExpressionPtr<Scalar>& x,
1093 const ExpressionPtr<Scalar>&) const override {
1094 return this->adjoint_expr * -sin(x);
1095 }
1096};
1097
1102template <typename Scalar>
1104 using enum ExpressionType;
1105 using std::cos;
1106
1107 // Prune expression
1108 if (x->is_constant(Scalar(0))) {
1109 return constant_ptr(Scalar(1));
1110 }
1111
1112 // Evaluate constant
1113 if (x->type() == CONSTANT) {
1114 return constant_ptr(cos(x->val));
1115 }
1116
1117 return make_expression_ptr<CosExpression<Scalar>>(x);
1118}
1119
1123template <typename Scalar>
1129 : Expression<Scalar>{std::move(lhs)} {}
1130
1131 Scalar value(Scalar x, Scalar) const override {
1132 using std::cosh;
1133 return cosh(x);
1134 }
1135
1136 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1137
1138 std::string_view name() const override { return "cosh"; }
1139
1140 Scalar grad_l(Scalar x, Scalar) const override {
1141 using std::sinh;
1142 return this->adjoint * sinh(x);
1143 }
1144
1146 const ExpressionPtr<Scalar>& x,
1147 const ExpressionPtr<Scalar>&) const override {
1148 return this->adjoint_expr * sinh(x);
1149 }
1150};
1151
1156template <typename Scalar>
1158 using enum ExpressionType;
1159 using std::cosh;
1160
1161 // Prune expression
1162 if (x->is_constant(Scalar(0))) {
1163 return constant_ptr(Scalar(1));
1164 }
1165
1166 // Evaluate constant
1167 if (x->type() == CONSTANT) {
1168 return constant_ptr(cosh(x->val));
1169 }
1170
1171 return make_expression_ptr<CoshExpression<Scalar>>(x);
1172}
1173
1177template <typename Scalar>
1183 : Expression<Scalar>{std::move(lhs)} {}
1184
1185 Scalar value(Scalar x, Scalar) const override {
1186 using std::erf;
1187 return erf(x);
1188 }
1189
1190 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1191
1192 std::string_view name() const override { return "erf"; }
1193
1194 Scalar grad_l(Scalar x, Scalar) const override {
1195 using std::exp;
1196 return this->adjoint * Scalar(2.0 * std::numbers::inv_sqrtpi) * exp(-x * x);
1197 }
1198
1200 const ExpressionPtr<Scalar>& x,
1201 const ExpressionPtr<Scalar>&) const override {
1202 return this->adjoint_expr *
1203 constant_ptr(Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
1204 }
1205};
1206
1211template <typename Scalar>
1213 using enum ExpressionType;
1214 using std::erf;
1215
1216 // Prune expression
1217 if (x->is_constant(Scalar(0))) {
1218 // Return zero, which x currently is
1219 return x;
1220 }
1221
1222 // Evaluate constant
1223 if (x->type() == CONSTANT) {
1224 return constant_ptr(erf(x->val));
1225 }
1226
1227 return make_expression_ptr<ErfExpression<Scalar>>(x);
1228}
1229
1233template <typename Scalar>
1239 : Expression<Scalar>{std::move(lhs)} {}
1240
1241 Scalar value(Scalar x, Scalar) const override {
1242 using std::exp;
1243 return exp(x);
1244 }
1245
1246 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1247
1248 std::string_view name() const override { return "exp"; }
1249
1250 Scalar grad_l(Scalar x, Scalar) const override {
1251 using std::exp;
1252 return this->adjoint * exp(x);
1253 }
1254
1256 const ExpressionPtr<Scalar>& x,
1257 const ExpressionPtr<Scalar>&) const override {
1258 return this->adjoint_expr * exp(x);
1259 }
1260};
1261
1266template <typename Scalar>
1268 using enum ExpressionType;
1269 using std::exp;
1270
1271 // Prune expression
1272 if (x->is_constant(Scalar(0))) {
1273 return constant_ptr(Scalar(1));
1274 }
1275
1276 // Evaluate constant
1277 if (x->type() == CONSTANT) {
1278 return constant_ptr(exp(x->val));
1279 }
1280
1281 return make_expression_ptr<ExpExpression<Scalar>>(x);
1282}
1283
1284template <typename Scalar>
1285ExpressionPtr<Scalar> hypot(const ExpressionPtr<Scalar>& x,
1286 const ExpressionPtr<Scalar>& y);
1287
1291template <typename Scalar>
1299 : Expression<Scalar>{std::move(lhs), std::move(rhs)} {}
1300
1301 Scalar value(Scalar x, Scalar y) const override {
1302 using std::hypot;
1303 return hypot(x, y);
1304 }
1305
1306 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1307
1308 std::string_view name() const override { return "hypot"; }
1309
1310 Scalar grad_l(Scalar x, Scalar y) const override {
1311 using std::hypot;
1312 return this->adjoint * x / hypot(x, y);
1313 }
1314
1315 Scalar grad_r(Scalar x, Scalar y) const override {
1316 using std::hypot;
1317 return this->adjoint * y / hypot(x, y);
1318 }
1319
1321 const ExpressionPtr<Scalar>& x,
1322 const ExpressionPtr<Scalar>& y) const override {
1323 return this->adjoint_expr * x / hypot(x, y);
1324 }
1325
1327 const ExpressionPtr<Scalar>& x,
1328 const ExpressionPtr<Scalar>& y) const override {
1329 return this->adjoint_expr * y / hypot(x, y);
1330 }
1331};
1332
1338template <typename Scalar>
1340 const ExpressionPtr<Scalar>& y) {
1341 using enum ExpressionType;
1342 using std::hypot;
1343
1344 // Prune expression
1345 if (x->is_constant(Scalar(0))) {
1346 return y;
1347 } else if (y->is_constant(Scalar(0))) {
1348 return x;
1349 }
1350
1351 // Evaluate constant
1352 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1353 return constant_ptr(hypot(x->val, y->val));
1354 }
1355
1356 return make_expression_ptr<HypotExpression<Scalar>>(x, y);
1357}
1358
1362template <typename Scalar>
1368 : Expression<Scalar>{std::move(lhs)} {}
1369
1370 Scalar value(Scalar x, Scalar) const override {
1371 using std::log;
1372 return log(x);
1373 }
1374
1375 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1376
1377 std::string_view name() const override { return "log"; }
1378
1379 Scalar grad_l(Scalar x, Scalar) const override { return this->adjoint / x; }
1380
1382 const ExpressionPtr<Scalar>& x,
1383 const ExpressionPtr<Scalar>&) const override {
1384 return this->adjoint_expr / x;
1385 }
1386};
1387
1392template <typename Scalar>
1394 using enum ExpressionType;
1395 using std::log;
1396
1397 // Prune expression
1398 if (x->is_constant(Scalar(0))) {
1399 // Return zero, which x currently is
1400 return x;
1401 }
1402
1403 // Evaluate constant
1404 if (x->type() == CONSTANT) {
1405 return constant_ptr(log(x->val));
1406 }
1407
1408 return make_expression_ptr<LogExpression<Scalar>>(x);
1409}
1410
1414template <typename Scalar>
1420 : Expression<Scalar>{std::move(lhs)} {}
1421
1422 Scalar value(Scalar x, Scalar) const override {
1423 using std::log10;
1424 return log10(x);
1425 }
1426
1427 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1428
1429 std::string_view name() const override { return "log10"; }
1430
1431 Scalar grad_l(Scalar x, Scalar) const override {
1432 return this->adjoint / (Scalar(std::numbers::ln10) * x);
1433 }
1434
1436 const ExpressionPtr<Scalar>& x,
1437 const ExpressionPtr<Scalar>&) const override {
1438 return this->adjoint_expr / (constant_ptr(Scalar(std::numbers::ln10)) * x);
1439 }
1440};
1441
1446template <typename Scalar>
1448 using enum ExpressionType;
1449 using std::log10;
1450
1451 // Prune expression
1452 if (x->is_constant(Scalar(0))) {
1453 // Return zero, which x currently is
1454 return x;
1455 }
1456
1457 // Evaluate constant
1458 if (x->type() == CONSTANT) {
1459 return constant_ptr(log10(x->val));
1460 }
1461
1462 return make_expression_ptr<Log10Expression<Scalar>>(x);
1463}
1464
1470template <typename Scalar>
1478
1479 Scalar value(Scalar a, Scalar b) const override {
1480 using std::max;
1481 return max(a, b);
1482 }
1483
1484 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1485
1486 std::string_view name() const override { return "max"; }
1487
1488 Scalar grad_l(Scalar a, Scalar b) const override {
1489 if (a >= b) {
1490 return this->adjoint;
1491 } else {
1492 return Scalar(0);
1493 }
1494 }
1495
1496 Scalar grad_r(Scalar a, Scalar b) const override {
1497 if (b > a) {
1498 return this->adjoint;
1499 } else {
1500 return Scalar(0);
1501 }
1502 }
1503
1505 const ExpressionPtr<Scalar>& a,
1506 const ExpressionPtr<Scalar>& b) const override {
1507 if (a->val >= b->val) {
1508 return this->adjoint_expr;
1509 } else {
1510 return constant_ptr(Scalar(0));
1511 }
1512 }
1513
1515 const ExpressionPtr<Scalar>& a,
1516 const ExpressionPtr<Scalar>& b) const override {
1517 if (b->val > a->val) {
1518 return this->adjoint_expr;
1519 } else {
1520 return constant_ptr(Scalar(0));
1521 }
1522 }
1523};
1524
1530template <typename Scalar>
1532 const ExpressionPtr<Scalar>& b) {
1533 using enum ExpressionType;
1534 using std::max;
1535
1536 // Evaluate constant
1537 if (a->type() == CONSTANT && b->type() == CONSTANT) {
1538 return constant_ptr(max(a->val, b->val));
1539 }
1540
1541 return make_expression_ptr<MaxExpression<Scalar>>(a, b);
1542}
1543
1549template <typename Scalar>
1557
1558 Scalar value(Scalar a, Scalar b) const override {
1559 using std::min;
1560 return min(a, b);
1561 }
1562
1563 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1564
1565 std::string_view name() const override { return "min"; }
1566
1567 Scalar grad_l(Scalar a, Scalar b) const override {
1568 if (a <= b) {
1569 return this->adjoint;
1570 } else {
1571 return Scalar(0);
1572 }
1573 }
1574
1576 [[maybe_unused]] Scalar b) const override {
1577 if (b < a) {
1578 return this->adjoint;
1579 } else {
1580 return Scalar(0);
1581 }
1582 }
1583
1585 const ExpressionPtr<Scalar>& a,
1586 const ExpressionPtr<Scalar>& b) const override {
1587 if (a->val <= b->val) {
1588 return this->adjoint_expr;
1589 } else {
1590 return constant_ptr(Scalar(0));
1591 }
1592 }
1593
1595 const ExpressionPtr<Scalar>& a,
1596 const ExpressionPtr<Scalar>& b) const override {
1597 if (b->val < a->val) {
1598 return this->adjoint_expr;
1599 } else {
1600 return constant_ptr(Scalar(0));
1601 }
1602 }
1603};
1604
1610template <typename Scalar>
1612 const ExpressionPtr<Scalar>& b) {
1613 using enum ExpressionType;
1614 using std::min;
1615
1616 // Evaluate constant
1617 if (a->type() == CONSTANT && b->type() == CONSTANT) {
1618 return constant_ptr(min(a->val, b->val));
1619 }
1620
1621 return make_expression_ptr<MinExpression<Scalar>>(a, b);
1622}
1623
1624template <typename Scalar>
1625ExpressionPtr<Scalar> pow(const ExpressionPtr<Scalar>& base,
1626 const ExpressionPtr<Scalar>& power);
1627
1632template <typename Scalar, ExpressionType T>
1640
1641 Scalar value(Scalar base, Scalar power) const override {
1642 using std::pow;
1643 return pow(base, power);
1644 }
1645
1646 ExpressionType type() const override { return T; }
1647
1648 std::string_view name() const override { return "pow"; }
1649
1651 using std::pow;
1652 return this->adjoint * pow(base, power - Scalar(1)) * power;
1653 }
1654
1656 using std::log;
1657 using std::pow;
1658
1659 // Since x log(x) -> 0 as x -> 0
1660 if (base == Scalar(0)) {
1661 return Scalar(0);
1662 } else {
1663 return this->adjoint * pow(base, power) * log(base);
1664 }
1665 }
1666
1669 const ExpressionPtr<Scalar>& power) const override {
1670 return this->adjoint_expr * pow(base, power - constant_ptr(Scalar(1))) *
1671 power;
1672 }
1673
1676 const ExpressionPtr<Scalar>& power) const override {
1677 // Since x log(x) -> 0 as x -> 0
1678 if (base->val == Scalar(0)) {
1679 // Return zero
1680 return base;
1681 } else {
1682 return this->adjoint_expr * pow(base, power) * log(base);
1683 }
1684 }
1685};
1686
1692template <typename Scalar>
1695 using enum ExpressionType;
1696 using std::pow;
1697
1698 // Prune expression
1699 if (base->is_constant(Scalar(0))) {
1700 // Return zero, which base currently is
1701 return base;
1702 } else if (base->is_constant(Scalar(1))) {
1703 // Return one, which base currently is
1704 return base;
1705 }
1706 if (power->is_constant(Scalar(0))) {
1707 return constant_ptr(Scalar(1));
1708 } else if (power->is_constant(Scalar(1))) {
1709 // Return base unmodified
1710 return base;
1711 }
1712
1713 // Evaluate constant
1714 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1715 return constant_ptr(pow(base->val, power->val));
1716 }
1717
1718 if (power->is_constant(Scalar(2))) {
1719 if (base->type() == LINEAR) {
1720 return make_expression_ptr<MultExpression<Scalar, QUADRATIC>>(base, base);
1721 } else {
1722 return make_expression_ptr<MultExpression<Scalar, NONLINEAR>>(base, base);
1723 }
1724 }
1725
1726 return make_expression_ptr<PowExpression<Scalar, NONLINEAR>>(base, power);
1727}
1728
1732template <typename Scalar>
1738 : Expression<Scalar>{std::move(lhs)} {}
1739
1740 Scalar value(Scalar x, Scalar) const override {
1741 if (x < Scalar(0)) {
1742 return Scalar(-1);
1743 } else if (x == Scalar(0)) {
1744 return Scalar(0);
1745 } else {
1746 return Scalar(1);
1747 }
1748 }
1749
1750 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1751
1752 std::string_view name() const override { return "sign"; }
1753};
1754
1759template <typename Scalar>
1761 using enum ExpressionType;
1762
1763 // Evaluate constant
1764 if (x->type() == CONSTANT) {
1765 if (x->val < Scalar(0)) {
1766 return constant_ptr(Scalar(-1));
1767 } else if (x->val == Scalar(0)) {
1768 // Return zero
1769 return x;
1770 } else {
1771 return constant_ptr(Scalar(1));
1772 }
1773 }
1774
1775 return make_expression_ptr<SignExpression<Scalar>>(x);
1776}
1777
1781template <typename Scalar>
1787 : Expression<Scalar>{std::move(lhs)} {}
1788
1789 Scalar value(Scalar x, Scalar) const override {
1790 using std::sin;
1791 return sin(x);
1792 }
1793
1794 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1795
1796 std::string_view name() const override { return "sin"; }
1797
1798 Scalar grad_l(Scalar x, Scalar) const override {
1799 using std::cos;
1800 return this->adjoint * cos(x);
1801 }
1802
1804 const ExpressionPtr<Scalar>& x,
1805 const ExpressionPtr<Scalar>&) const override {
1806 return this->adjoint_expr * cos(x);
1807 }
1808};
1809
1814template <typename Scalar>
1816 using enum ExpressionType;
1817 using std::sin;
1818
1819 // Prune expression
1820 if (x->is_constant(Scalar(0))) {
1821 // Return zero, which x currently is
1822 return x;
1823 }
1824
1825 // Evaluate constant
1826 if (x->type() == CONSTANT) {
1827 return constant_ptr(sin(x->val));
1828 }
1829
1830 return make_expression_ptr<SinExpression<Scalar>>(x);
1831}
1832
1836template <typename Scalar>
1842 : Expression<Scalar>{std::move(lhs)} {}
1843
1844 Scalar value(Scalar x, Scalar) const override {
1845 using std::sinh;
1846 return sinh(x);
1847 }
1848
1849 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1850
1851 std::string_view name() const override { return "sinh"; }
1852
1853 Scalar grad_l(Scalar x, Scalar) const override {
1854 using std::cosh;
1855 return this->adjoint * cosh(x);
1856 }
1857
1859 const ExpressionPtr<Scalar>& x,
1860 const ExpressionPtr<Scalar>&) const override {
1861 return this->adjoint_expr * cosh(x);
1862 }
1863};
1864
1869template <typename Scalar>
1871 using enum ExpressionType;
1872 using std::sinh;
1873
1874 // Prune expression
1875 if (x->is_constant(Scalar(0))) {
1876 // Return zero, which x currently is
1877 return x;
1878 }
1879
1880 // Evaluate constant
1881 if (x->type() == CONSTANT) {
1882 return constant_ptr(sinh(x->val));
1883 }
1884
1885 return make_expression_ptr<SinhExpression<Scalar>>(x);
1886}
1887
1891template <typename Scalar>
1897 : Expression<Scalar>{std::move(lhs)} {}
1898
1899 Scalar value(Scalar x, Scalar) const override {
1900 using std::sqrt;
1901 return sqrt(x);
1902 }
1903
1904 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1905
1906 std::string_view name() const override { return "sqrt"; }
1907
1908 Scalar grad_l(Scalar x, Scalar) const override {
1909 using std::sqrt;
1910 return this->adjoint / (Scalar(2) * sqrt(x));
1911 }
1912
1914 const ExpressionPtr<Scalar>& x,
1915 const ExpressionPtr<Scalar>&) const override {
1916 return this->adjoint_expr / (constant_ptr(Scalar(2)) * sqrt(x));
1917 }
1918};
1919
1924template <typename Scalar>
1926 using enum ExpressionType;
1927 using std::sqrt;
1928
1929 // Evaluate constant
1930 if (x->type() == CONSTANT) {
1931 if (x->val == Scalar(0)) {
1932 // Return zero
1933 return x;
1934 } else if (x->val == Scalar(1)) {
1935 return x;
1936 } else {
1937 return constant_ptr(sqrt(x->val));
1938 }
1939 }
1940
1941 return make_expression_ptr<SqrtExpression<Scalar>>(x);
1942}
1943
1947template <typename Scalar>
1953 : Expression<Scalar>{std::move(lhs)} {}
1954
1955 Scalar value(Scalar x, Scalar) const override {
1956 using std::tan;
1957 return tan(x);
1958 }
1959
1960 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1961
1962 std::string_view name() const override { return "tan"; }
1963
1964 Scalar grad_l(Scalar x, Scalar) const override {
1965 using std::cos;
1966
1967 auto c = cos(x);
1968 return this->adjoint / (c * c);
1969 }
1970
1972 const ExpressionPtr<Scalar>& x,
1973 const ExpressionPtr<Scalar>&) const override {
1974 auto c = cos(x);
1975 return this->adjoint_expr / (c * c);
1976 }
1977};
1978
1983template <typename Scalar>
1985 using enum ExpressionType;
1986 using std::tan;
1987
1988 // Prune expression
1989 if (x->is_constant(Scalar(0))) {
1990 // Return zero, which x currently is
1991 return x;
1992 }
1993
1994 // Evaluate constant
1995 if (x->type() == CONSTANT) {
1996 return constant_ptr(tan(x->val));
1997 }
1998
1999 return make_expression_ptr<TanExpression<Scalar>>(x);
2000}
2001
2005template <typename Scalar>
2011 : Expression<Scalar>{std::move(lhs)} {}
2012
2013 Scalar value(Scalar x, Scalar) const override {
2014 using std::tanh;
2015 return tanh(x);
2016 }
2017
2018 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
2019
2020 std::string_view name() const override { return "tanh"; }
2021
2022 Scalar grad_l(Scalar x, Scalar) const override {
2023 using std::cosh;
2024
2025 auto c = cosh(x);
2026 return this->adjoint / (c * c);
2027 }
2028
2030 const ExpressionPtr<Scalar>& x,
2031 const ExpressionPtr<Scalar>&) const override {
2032 auto c = cosh(x);
2033 return this->adjoint_expr / (c * c);
2034 }
2035};
2036
2041template <typename Scalar>
2043 using enum ExpressionType;
2044 using std::tanh;
2045
2046 // Prune expression
2047 if (x->is_constant(Scalar(0))) {
2048 // Return zero, which x currently is
2049 return x;
2050 }
2051
2052 // Evaluate constant
2053 if (x->type() == CONSTANT) {
2054 return constant_ptr(tanh(x->val));
2055 }
2056
2057 return make_expression_ptr<TanhExpression<Scalar>>(x);
2058}
2059
2060} // namespace slp::detail
Definition intrusive_shared_ptr.hpp:27
Definition expression.hpp:771
std::string_view name() const override
Definition expression.hpp:785
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:797
constexpr AbsExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:775
ExpressionType type() const override
Definition expression.hpp:783
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:787
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:778
Definition expression.hpp:837
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:844
constexpr AcosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:841
ExpressionType type() const override
Definition expression.hpp:849
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:853
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:858
std::string_view name() const override
Definition expression.hpp:851
Definition expression.hpp:891
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:898
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:907
std::string_view name() const override
Definition expression.hpp:905
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:912
ExpressionType type() const override
Definition expression.hpp:903
constexpr AsinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:895
Definition expression.hpp:1000
std::string_view name() const override
Definition expression.hpp:1016
Scalar value(Scalar y, Scalar x) const override
Definition expression.hpp:1009
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x) const override
Definition expression.hpp:1026
constexpr Atan2Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1005
Scalar grad_l(Scalar y, Scalar x) const override
Definition expression.hpp:1018
Scalar grad_r(Scalar y, Scalar x) const override
Definition expression.hpp:1022
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x) const override
Definition expression.hpp:1032
ExpressionType type() const override
Definition expression.hpp:1014
Definition expression.hpp:946
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:966
std::string_view name() const override
Definition expression.hpp:960
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:962
constexpr AtanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:950
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:953
ExpressionType type() const override
Definition expression.hpp:958
Definition expression.hpp:442
constexpr BinaryMinusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:447
std::string_view name() const override
Definition expression.hpp:455
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:467
Scalar grad_r(Scalar, Scalar) const override
Definition expression.hpp:459
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:461
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:451
ExpressionType type() const override
Definition expression.hpp:453
Scalar grad_l(Scalar, Scalar) const override
Definition expression.hpp:457
Definition expression.hpp:479
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:488
Scalar grad_r(Scalar, Scalar) const override
Definition expression.hpp:496
Scalar grad_l(Scalar, Scalar) const override
Definition expression.hpp:494
constexpr BinaryPlusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:484
ExpressionType type() const override
Definition expression.hpp:490
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:504
std::string_view name() const override
Definition expression.hpp:492
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:498
Definition expression.hpp:515
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:538
std::string_view name() const override
Definition expression.hpp:529
constexpr CbrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:519
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:531
ExpressionType type() const override
Definition expression.hpp:527
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:522
Definition expression.hpp:574
ExpressionType type() const override
Definition expression.hpp:583
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:581
std::string_view name() const override
Definition expression.hpp:585
constexpr ConstantExpression(Scalar value)
Definition expression.hpp:578
Definition expression.hpp:1070
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1077
ExpressionType type() const override
Definition expression.hpp:1082
constexpr CosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1074
std::string_view name() const override
Definition expression.hpp:1084
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1086
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1091
Definition expression.hpp:1124
ExpressionType type() const override
Definition expression.hpp:1136
std::string_view name() const override
Definition expression.hpp:1138
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1131
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1145
constexpr CoshExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1128
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1140
Definition expression.hpp:592
constexpr DecisionVariableExpression()=default
Constructs a decision variable expression with a value of zero.
std::string_view name() const override
Definition expression.hpp:606
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:602
constexpr DecisionVariableExpression(Scalar value)
Definition expression.hpp:599
ExpressionType type() const override
Definition expression.hpp:604
Definition expression.hpp:614
constexpr DivExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:619
ExpressionType type() const override
Definition expression.hpp:624
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const override
Definition expression.hpp:642
std::string_view name() const override
Definition expression.hpp:626
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:622
Scalar grad_l(Scalar, Scalar rhs) const override
Definition expression.hpp:628
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &rhs) const override
Definition expression.hpp:636
Scalar grad_r(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:632
Definition expression.hpp:1178
std::string_view name() const override
Definition expression.hpp:1192
constexpr ErfExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1182
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1194
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1199
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1185
ExpressionType type() const override
Definition expression.hpp:1190
Definition expression.hpp:1234
std::string_view name() const override
Definition expression.hpp:1248
ExpressionType type() const override
Definition expression.hpp:1246
constexpr ExpExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1238
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1250
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1255
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1241
Definition expression.hpp:89
Scalar val
The value of the expression node.
Definition expression.hpp:94
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:327
virtual Scalar grad_r(Scalar lhs, Scalar rhs) const
Definition expression.hpp:393
std::array< ExpressionPtr< Scalar >, 2 > args
Expression arguments.
Definition expression.hpp:104
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:118
constexpr Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:131
Scalar adjoint
The adjoint of the expression node, used during autodiff.
Definition expression.hpp:97
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:243
int32_t scratch
Definition expression.hpp:115
virtual Scalar grad_l(Scalar lhs, Scalar rhs) const
Definition expression.hpp:383
constexpr bool is_constant(Scalar constant) const
Definition expression.hpp:147
constexpr Expression()=default
Constructs a constant expression with a value of zero.
constexpr Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:138
virtual ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const
Definition expression.hpp:403
ExpressionPtr< Scalar > adjoint_expr
Definition expression.hpp:101
Scalar_ Scalar
Scalar type alias.
Definition expression.hpp:91
virtual ExpressionType type() const =0
friend ExpressionPtr< Scalar > operator/(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:207
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:288
friend ExpressionPtr< Scalar > operator*(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:155
virtual Scalar value(Scalar lhs, Scalar rhs) const =0
virtual std::string_view name() const =0
virtual ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const
Definition expression.hpp:414
friend ExpressionPtr< Scalar > operator+=(ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:279
constexpr Expression(Scalar value)
Definition expression.hpp:126
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:353
Definition expression.hpp:1292
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y) const override
Definition expression.hpp:1320
constexpr HypotExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1297
ExpressionType type() const override
Definition expression.hpp:1306
std::string_view name() const override
Definition expression.hpp:1308
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y) const override
Definition expression.hpp:1326
Scalar value(Scalar x, Scalar y) const override
Definition expression.hpp:1301
Scalar grad_r(Scalar x, Scalar y) const override
Definition expression.hpp:1315
Scalar grad_l(Scalar x, Scalar y) const override
Definition expression.hpp:1310
Definition expression.hpp:1415
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1422
std::string_view name() const override
Definition expression.hpp:1429
constexpr Log10Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1419
ExpressionType type() const override
Definition expression.hpp:1427
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1435
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1431
Definition expression.hpp:1363
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1381
ExpressionType type() const override
Definition expression.hpp:1375
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1379
std::string_view name() const override
Definition expression.hpp:1377
constexpr LogExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1367
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1370
Definition expression.hpp:1471
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b) const override
Definition expression.hpp:1504
Scalar value(Scalar a, Scalar b) const override
Definition expression.hpp:1479
Scalar grad_r(Scalar a, Scalar b) const override
Definition expression.hpp:1496
constexpr MaxExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1476
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b) const override
Definition expression.hpp:1514
Scalar grad_l(Scalar a, Scalar b) const override
Definition expression.hpp:1488
ExpressionType type() const override
Definition expression.hpp:1484
std::string_view name() const override
Definition expression.hpp:1486
Definition expression.hpp:1550
std::string_view name() const override
Definition expression.hpp:1565
Scalar value(Scalar a, Scalar b) const override
Definition expression.hpp:1558
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b) const override
Definition expression.hpp:1584
ExpressionType type() const override
Definition expression.hpp:1563
constexpr MinExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1555
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b) const override
Definition expression.hpp:1594
Scalar grad_l(Scalar a, Scalar b) const override
Definition expression.hpp:1567
Scalar grad_r(Scalar a, Scalar b) const override
Definition expression.hpp:1575
Definition expression.hpp:654
Scalar grad_l(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:668
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:662
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const override
Definition expression.hpp:676
constexpr MultExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:659
ExpressionType type() const override
Definition expression.hpp:664
std::string_view name() const override
Definition expression.hpp:666
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const override
Definition expression.hpp:682
Scalar grad_r(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:672
Definition expression.hpp:1633
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power) const override
Definition expression.hpp:1667
Scalar value(Scalar base, Scalar power) const override
Definition expression.hpp:1641
ExpressionType type() const override
Definition expression.hpp:1646
Scalar grad_l(Scalar base, Scalar power) const override
Definition expression.hpp:1650
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power) const override
Definition expression.hpp:1674
Scalar grad_r(Scalar base, Scalar power) const override
Definition expression.hpp:1655
std::string_view name() const override
Definition expression.hpp:1648
constexpr PowExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1638
Definition expression.hpp:1733
constexpr SignExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1737
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1740
ExpressionType type() const override
Definition expression.hpp:1750
std::string_view name() const override
Definition expression.hpp:1752
Definition expression.hpp:1782
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1789
std::string_view name() const override
Definition expression.hpp:1796
ExpressionType type() const override
Definition expression.hpp:1794
constexpr SinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1786
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1803
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1798
Definition expression.hpp:1837
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1853
std::string_view name() const override
Definition expression.hpp:1851
ExpressionType type() const override
Definition expression.hpp:1849
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1858
constexpr SinhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1841
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1844
Definition expression.hpp:1892
ExpressionType type() const override
Definition expression.hpp:1904
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1908
std::string_view name() const override
Definition expression.hpp:1906
constexpr SqrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1896
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1899
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1913
Definition expression.hpp:1948
std::string_view name() const override
Definition expression.hpp:1962
constexpr TanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1952
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1964
ExpressionType type() const override
Definition expression.hpp:1960
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1971
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1955
Definition expression.hpp:2006
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:2022
ExpressionType type() const override
Definition expression.hpp:2018
std::string_view name() const override
Definition expression.hpp:2020
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:2013
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:2029
constexpr TanhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:2010
Definition expression.hpp:694
Scalar grad_l(Scalar, Scalar) const override
Definition expression.hpp:707
constexpr UnaryMinusExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:698
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:709
ExpressionType type() const override
Definition expression.hpp:703
Scalar value(Scalar lhs, Scalar) const override
Definition expression.hpp:701
std::string_view name() const override
Definition expression.hpp:705