Sleipnir C++ API
Loading...
Searching...
No Matches
expression.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <stdint.h>
6
7#include <algorithm>
8#include <array>
9#include <cmath>
10#include <memory>
11#include <numbers>
12#include <utility>
13
14#include <gch/small_vector.hpp>
15
16#include "sleipnir/autodiff/expression_type.hpp"
17#include "sleipnir/util/intrusive_shared_ptr.hpp"
18#include "sleipnir/util/pool.hpp"
19
20namespace slp::detail {
21
22// The global pool allocator uses a thread-local static pool resource, which
23// isn't guaranteed to be initialized properly across DLL boundaries on Windows
24#ifdef _WIN32
25inline constexpr bool USE_POOL_ALLOCATOR = false;
26#else
27inline constexpr bool USE_POOL_ALLOCATOR = true;
28#endif
29
30struct Expression;
31
32inline constexpr void inc_ref_count(Expression* expr);
33inline constexpr void dec_ref_count(Expression* expr);
34
38using ExpressionPtr = IntrusiveSharedPtr<Expression>;
39
47template <typename T, typename... Args>
48static ExpressionPtr make_expression_ptr(Args&&... args) {
49 if constexpr (USE_POOL_ALLOCATOR) {
50 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
51 std::forward<Args>(args)...);
52 } else {
53 return make_intrusive_shared<T>(std::forward<Args>(args)...);
54 }
55}
56
57template <ExpressionType T>
58struct BinaryMinusExpression;
59
60template <ExpressionType T>
61struct BinaryPlusExpression;
62
63struct ConstExpression;
64
65template <ExpressionType T>
66struct DivExpression;
67
68template <ExpressionType T>
69struct MultExpression;
70
71template <ExpressionType T>
72struct UnaryMinusExpression;
73
77struct Expression {
79 double val = 0.0;
80
82 double adjoint = 0.0;
83
85 uint32_t incoming_edges = 0;
86
88 int32_t col = -1;
89
93
95 uint32_t ref_count = 0;
96
98 std::array<ExpressionPtr, 2> args{nullptr, nullptr};
99
103 constexpr Expression() = default;
104
110 explicit constexpr Expression(double value) : val{value} {}
111
117 explicit constexpr Expression(ExpressionPtr lhs)
118 : args{std::move(lhs), nullptr} {}
119
127 : args{std::move(lhs), std::move(rhs)} {}
128
129 virtual ~Expression() = default;
130
138 constexpr bool is_constant(double constant) const {
139 return type() == ExpressionType::CONSTANT && val == constant;
140 }
141
149 const ExpressionPtr& rhs) {
150 using enum ExpressionType;
151
152 // Prune expression
153 if (lhs->is_constant(0.0)) {
154 // Return zero
155 return lhs;
156 } else if (rhs->is_constant(0.0)) {
157 // Return zero
158 return rhs;
159 } else if (lhs->is_constant(1.0)) {
160 return rhs;
161 } else if (rhs->is_constant(1.0)) {
162 return lhs;
163 }
164
165 // Evaluate constant
166 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
167 return make_expression_ptr<ConstExpression>(lhs->val * rhs->val);
168 }
169
170 // Evaluate expression type
171 if (lhs->type() == CONSTANT) {
172 if (rhs->type() == LINEAR) {
173 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
174 } else if (rhs->type() == QUADRATIC) {
175 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
176 } else {
177 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
178 }
179 } else if (rhs->type() == CONSTANT) {
180 if (lhs->type() == LINEAR) {
181 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
182 } else if (lhs->type() == QUADRATIC) {
183 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
184 } else {
185 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
186 }
187 } else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
188 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
189 } else {
190 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
191 }
192 }
193
201 const ExpressionPtr& rhs) {
202 using enum ExpressionType;
203
204 // Prune expression
205 if (lhs->is_constant(0.0)) {
206 // Return zero
207 return lhs;
208 } else if (rhs->is_constant(1.0)) {
209 return lhs;
210 }
211
212 // Evaluate constant
213 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
214 return make_expression_ptr<ConstExpression>(lhs->val / rhs->val);
215 }
216
217 // Evaluate expression type
218 if (rhs->type() == CONSTANT) {
219 if (lhs->type() == LINEAR) {
220 return make_expression_ptr<DivExpression<LINEAR>>(lhs, rhs);
221 } else if (lhs->type() == QUADRATIC) {
222 return make_expression_ptr<DivExpression<QUADRATIC>>(lhs, rhs);
223 } else {
224 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
225 }
226 } else {
227 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
228 }
229 }
230
238 const ExpressionPtr& rhs) {
239 using enum ExpressionType;
240
241 // Prune expression
242 if (lhs == nullptr || lhs->is_constant(0.0)) {
243 return rhs;
244 } else if (rhs == nullptr || rhs->is_constant(0.0)) {
245 return lhs;
246 }
247
248 // Evaluate constant
249 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
250 return make_expression_ptr<ConstExpression>(lhs->val + rhs->val);
251 }
252
253 auto type = std::max(lhs->type(), rhs->type());
254 if (type == LINEAR) {
255 return make_expression_ptr<BinaryPlusExpression<LINEAR>>(lhs, rhs);
256 } else if (type == QUADRATIC) {
257 return make_expression_ptr<BinaryPlusExpression<QUADRATIC>>(lhs, rhs);
258 } else {
259 return make_expression_ptr<BinaryPlusExpression<NONLINEAR>>(lhs, rhs);
260 }
261 }
262
270 const ExpressionPtr& rhs) {
271 return lhs = lhs + rhs;
272 }
273
281 const ExpressionPtr& rhs) {
282 using enum ExpressionType;
283
284 // Prune expression
285 if (lhs->is_constant(0.0)) {
286 if (rhs->is_constant(0.0)) {
287 // Return zero
288 return rhs;
289 } else {
290 return -rhs;
291 }
292 } else if (rhs->is_constant(0.0)) {
293 return lhs;
294 }
295
296 // Evaluate constant
297 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
298 return make_expression_ptr<ConstExpression>(lhs->val - rhs->val);
299 }
300
301 auto type = std::max(lhs->type(), rhs->type());
302 if (type == LINEAR) {
303 return make_expression_ptr<BinaryMinusExpression<LINEAR>>(lhs, rhs);
304 } else if (type == QUADRATIC) {
305 return make_expression_ptr<BinaryMinusExpression<QUADRATIC>>(lhs, rhs);
306 } else {
307 return make_expression_ptr<BinaryMinusExpression<NONLINEAR>>(lhs, rhs);
308 }
309 }
310
317 using enum ExpressionType;
318
319 // Prune expression
320 if (lhs->is_constant(0.0)) {
321 // Return zero
322 return lhs;
323 }
324
325 // Evaluate constant
326 if (lhs->type() == CONSTANT) {
327 return make_expression_ptr<ConstExpression>(-lhs->val);
328 }
329
330 if (lhs->type() == LINEAR) {
331 return make_expression_ptr<UnaryMinusExpression<LINEAR>>(lhs);
332 } else if (lhs->type() == QUADRATIC) {
333 return make_expression_ptr<UnaryMinusExpression<QUADRATIC>>(lhs);
334 } else {
335 return make_expression_ptr<UnaryMinusExpression<NONLINEAR>>(lhs);
336 }
337 }
338
344 friend ExpressionPtr operator+(const ExpressionPtr& lhs) { return lhs; }
345
355 virtual double value([[maybe_unused]] double lhs,
356 [[maybe_unused]] double rhs) const = 0;
357
364 virtual ExpressionType type() const = 0;
365
374 virtual double grad_l([[maybe_unused]] double lhs,
375 [[maybe_unused]] double rhs,
376 [[maybe_unused]] double parent_adjoint) const {
377 return 0.0;
378 }
379
388 virtual double grad_r([[maybe_unused]] double lhs,
389 [[maybe_unused]] double rhs,
390 [[maybe_unused]] double parent_adjoint) const {
391 return 0.0;
392 }
393
403 [[maybe_unused]] const ExpressionPtr& lhs,
404 [[maybe_unused]] const ExpressionPtr& rhs,
405 [[maybe_unused]] const ExpressionPtr& parent_adjoint) const {
406 return make_expression_ptr<ConstExpression>();
407 }
408
418 [[maybe_unused]] const ExpressionPtr& lhs,
419 [[maybe_unused]] const ExpressionPtr& rhs,
420 [[maybe_unused]] const ExpressionPtr& parent_adjoint) const {
421 return make_expression_ptr<ConstExpression>();
422 }
423};
424
430template <ExpressionType T>
439 : Expression{std::move(lhs), std::move(rhs)} {}
440
441 double value(double lhs, double rhs) const override { return lhs - rhs; }
442
443 ExpressionType type() const override { return T; }
444
445 double grad_l(double, double, double parent_adjoint) const override {
446 return parent_adjoint;
447 }
448
449 double grad_r(double, double, double parent_adjoint) const override {
450 return -parent_adjoint;
451 }
452
454 const ExpressionPtr&, const ExpressionPtr&,
455 const ExpressionPtr& parent_adjoint) const override {
456 return parent_adjoint;
457 }
458
460 const ExpressionPtr&, const ExpressionPtr&,
461 const ExpressionPtr& parent_adjoint) const override {
462 return -parent_adjoint;
463 }
464};
465
471template <ExpressionType T>
480 : Expression{std::move(lhs), std::move(rhs)} {}
481
482 double value(double lhs, double rhs) const override { return lhs + rhs; }
483
484 ExpressionType type() const override { return T; }
485
486 double grad_l(double, double, double parent_adjoint) const override {
487 return parent_adjoint;
488 }
489
490 double grad_r(double, double, double parent_adjoint) const override {
491 return parent_adjoint;
492 }
493
495 const ExpressionPtr&, const ExpressionPtr&,
496 const ExpressionPtr& parent_adjoint) const override {
497 return parent_adjoint;
498 }
499
501 const ExpressionPtr&, const ExpressionPtr&,
502 const ExpressionPtr& parent_adjoint) const override {
503 return parent_adjoint;
504 }
505};
506
514 constexpr ConstExpression() = default;
515
521 explicit constexpr ConstExpression(double value) : Expression{value} {}
522
523 double value(double, double) const override { return val; }
524
525 ExpressionType type() const override { return ExpressionType::CONSTANT; }
526};
527
535 constexpr DecisionVariableExpression() = default;
536
542 explicit constexpr DecisionVariableExpression(double value)
543 : Expression{value} {}
544
545 double value(double, double) const override { return val; }
546
547 ExpressionType type() const override { return ExpressionType::LINEAR; }
548};
549
555template <ExpressionType T>
556struct DivExpression final : Expression {
564 : Expression{std::move(lhs), std::move(rhs)} {}
565
566 double value(double lhs, double rhs) const override { return lhs / rhs; }
567
568 ExpressionType type() const override { return T; }
569
570 double grad_l(double, double rhs, double parent_adjoint) const override {
571 return parent_adjoint / rhs;
572 };
573
574 double grad_r(double lhs, double rhs, double parent_adjoint) const override {
575 return parent_adjoint * -lhs / (rhs * rhs);
576 }
577
579 const ExpressionPtr&, const ExpressionPtr& rhs,
580 const ExpressionPtr& parent_adjoint) const override {
581 return parent_adjoint / rhs;
582 }
583
585 const ExpressionPtr& lhs, const ExpressionPtr& rhs,
586 const ExpressionPtr& parent_adjoint) const override {
587 return parent_adjoint * -lhs / (rhs * rhs);
588 }
589};
590
596template <ExpressionType T>
605 : Expression{std::move(lhs), std::move(rhs)} {}
606
607 double value(double lhs, double rhs) const override { return lhs * rhs; }
608
609 ExpressionType type() const override { return T; }
610
611 double grad_l([[maybe_unused]] double lhs, double rhs,
612 double parent_adjoint) const override {
613 return parent_adjoint * rhs;
614 }
615
616 double grad_r(double lhs, [[maybe_unused]] double rhs,
617 double parent_adjoint) const override {
618 return parent_adjoint * lhs;
619 }
620
622 [[maybe_unused]] const ExpressionPtr& lhs, const ExpressionPtr& rhs,
623 const ExpressionPtr& parent_adjoint) const override {
624 return parent_adjoint * rhs;
625 }
626
628 const ExpressionPtr& lhs, [[maybe_unused]] const ExpressionPtr& rhs,
629 const ExpressionPtr& parent_adjoint) const override {
630 return parent_adjoint * lhs;
631 }
632};
633
639template <ExpressionType T>
646 explicit constexpr UnaryMinusExpression(ExpressionPtr lhs)
647 : Expression{std::move(lhs)} {}
648
649 double value(double lhs, double) const override { return -lhs; }
650
651 ExpressionType type() const override { return T; }
652
653 double grad_l(double, double, double parent_adjoint) const override {
654 return -parent_adjoint;
655 }
656
658 const ExpressionPtr&, const ExpressionPtr&,
659 const ExpressionPtr& parent_adjoint) const override {
660 return -parent_adjoint;
661 }
662};
663
664inline ExpressionPtr exp(const ExpressionPtr& x);
665inline ExpressionPtr sin(const ExpressionPtr& x);
666inline ExpressionPtr sinh(const ExpressionPtr& x);
667inline ExpressionPtr sqrt(const ExpressionPtr& x);
668
674inline constexpr void inc_ref_count(Expression* expr) {
675 ++expr->ref_count;
676}
677
683inline constexpr void dec_ref_count(Expression* expr) {
684 // If a deeply nested tree is being deallocated all at once, calling the
685 // Expression destructor when expr's refcount reaches zero can cause a stack
686 // overflow. Instead, we iterate over its children to decrement their
687 // refcounts and deallocate them.
688 gch::small_vector<Expression*> stack;
689 stack.emplace_back(expr);
690
691 while (!stack.empty()) {
692 auto elem = stack.back();
693 stack.pop_back();
694
695 // Decrement the current node's refcount. If it reaches zero, deallocate the
696 // node and enqueue its children so their refcounts are decremented too.
697 if (--elem->ref_count == 0) {
698 if (elem->adjoint_expr != nullptr) {
699 stack.emplace_back(elem->adjoint_expr.get());
700 }
701 for (auto& arg : elem->args) {
702 if (arg != nullptr) {
703 stack.emplace_back(arg.get());
704 }
705 }
706
707 // Not calling the destructor here is safe because it only decrements
708 // refcounts, which was already done above.
709 if constexpr (USE_POOL_ALLOCATOR) {
710 auto alloc = global_pool_allocator<Expression>();
711 std::allocator_traits<decltype(alloc)>::deallocate(alloc, elem,
712 sizeof(Expression));
713 }
714 }
715 }
716}
717
721struct AbsExpression final : Expression {
727 explicit constexpr AbsExpression(ExpressionPtr lhs)
728 : Expression{std::move(lhs)} {}
729
730 double value(double x, double) const override { return std::abs(x); }
731
732 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
733
734 double grad_l(double x, double, double parent_adjoint) const override {
735 if (x < 0.0) {
736 return -parent_adjoint;
737 } else if (x > 0.0) {
738 return parent_adjoint;
739 } else {
740 return 0.0;
741 }
742 }
743
745 const ExpressionPtr& x, const ExpressionPtr&,
746 const ExpressionPtr& parent_adjoint) const override {
747 if (x->val < 0.0) {
748 return -parent_adjoint;
749 } else if (x->val > 0.0) {
750 return parent_adjoint;
751 } else {
752 // Return zero
753 return make_expression_ptr<ConstExpression>();
754 }
755 }
756};
757
763inline ExpressionPtr abs(const ExpressionPtr& x) {
764 using enum ExpressionType;
765
766 // Prune expression
767 if (x->is_constant(0.0)) {
768 // Return zero
769 return x;
770 }
771
772 // Evaluate constant
773 if (x->type() == CONSTANT) {
774 return make_expression_ptr<ConstExpression>(std::abs(x->val));
775 }
776
777 return make_expression_ptr<AbsExpression>(x);
778}
779
789 explicit constexpr AcosExpression(ExpressionPtr lhs)
790 : Expression{std::move(lhs)} {}
791
792 double value(double x, double) const override { return std::acos(x); }
793
794 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
795
796 double grad_l(double x, double, double parent_adjoint) const override {
797 return -parent_adjoint / std::sqrt(1.0 - x * x);
798 }
799
801 const ExpressionPtr& x, const ExpressionPtr&,
802 const ExpressionPtr& parent_adjoint) const override {
803 return -parent_adjoint /
804 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
805 }
806};
807
813inline ExpressionPtr acos(const ExpressionPtr& x) {
814 using enum ExpressionType;
815
816 // Prune expression
817 if (x->is_constant(0.0)) {
818 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
819 }
820
821 // Evaluate constant
822 if (x->type() == CONSTANT) {
823 return make_expression_ptr<ConstExpression>(std::acos(x->val));
824 }
825
826 return make_expression_ptr<AcosExpression>(x);
827}
828
838 explicit constexpr AsinExpression(ExpressionPtr lhs)
839 : Expression{std::move(lhs)} {}
840
841 double value(double x, double) const override { return std::asin(x); }
842
843 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
844
845 double grad_l(double x, double, double parent_adjoint) const override {
846 return parent_adjoint / std::sqrt(1.0 - x * x);
847 }
848
850 const ExpressionPtr& x, const ExpressionPtr&,
851 const ExpressionPtr& parent_adjoint) const override {
852 return parent_adjoint /
853 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
854 }
855};
856
862inline ExpressionPtr asin(const ExpressionPtr& x) {
863 using enum ExpressionType;
864
865 // Prune expression
866 if (x->is_constant(0.0)) {
867 // Return zero
868 return x;
869 }
870
871 // Evaluate constant
872 if (x->type() == CONSTANT) {
873 return make_expression_ptr<ConstExpression>(std::asin(x->val));
874 }
875
876 return make_expression_ptr<AsinExpression>(x);
877}
878
888 explicit constexpr AtanExpression(ExpressionPtr lhs)
889 : Expression{std::move(lhs)} {}
890
891 double value(double x, double) const override { return std::atan(x); }
892
893 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
894
895 double grad_l(double x, double, double parent_adjoint) const override {
896 return parent_adjoint / (1.0 + x * x);
897 }
898
900 const ExpressionPtr& x, const ExpressionPtr&,
901 const ExpressionPtr& parent_adjoint) const override {
902 return parent_adjoint / (make_expression_ptr<ConstExpression>(1.0) + x * x);
903 }
904};
905
911inline ExpressionPtr atan(const ExpressionPtr& x) {
912 using enum ExpressionType;
913
914 // Prune expression
915 if (x->is_constant(0.0)) {
916 // Return zero
917 return x;
918 }
919
920 // Evaluate constant
921 if (x->type() == CONSTANT) {
922 return make_expression_ptr<ConstExpression>(std::atan(x->val));
923 }
924
925 return make_expression_ptr<AtanExpression>(x);
926}
927
939 : Expression{std::move(lhs), std::move(rhs)} {}
940
941 double value(double y, double x) const override { return std::atan2(y, x); }
942
943 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
944
945 double grad_l(double y, double x, double parent_adjoint) const override {
946 return parent_adjoint * x / (y * y + x * x);
947 }
948
949 double grad_r(double y, double x, double parent_adjoint) const override {
950 return parent_adjoint * -y / (y * y + x * x);
951 }
952
954 const ExpressionPtr& y, const ExpressionPtr& x,
955 const ExpressionPtr& parent_adjoint) const override {
956 return parent_adjoint * x / (y * y + x * x);
957 }
958
960 const ExpressionPtr& y, const ExpressionPtr& x,
961 const ExpressionPtr& parent_adjoint) const override {
962 return parent_adjoint * -y / (y * y + x * x);
963 }
964};
965
972inline ExpressionPtr atan2(const ExpressionPtr& y, const ExpressionPtr& x) {
973 using enum ExpressionType;
974
975 // Prune expression
976 if (y->is_constant(0.0)) {
977 // Return zero
978 return y;
979 } else if (x->is_constant(0.0)) {
980 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
981 }
982
983 // Evaluate constant
984 if (y->type() == CONSTANT && x->type() == CONSTANT) {
985 return make_expression_ptr<ConstExpression>(std::atan2(y->val, x->val));
986 }
987
988 return make_expression_ptr<Atan2Expression>(y, x);
989}
990
994struct CosExpression final : Expression {
1000 explicit constexpr CosExpression(ExpressionPtr lhs)
1001 : Expression{std::move(lhs)} {}
1002
1003 double value(double x, double) const override { return std::cos(x); }
1004
1005 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1006
1007 double grad_l(double x, double, double parent_adjoint) const override {
1008 return -parent_adjoint * std::sin(x);
1009 }
1010
1012 const ExpressionPtr& x, const ExpressionPtr&,
1013 const ExpressionPtr& parent_adjoint) const override {
1014 return parent_adjoint * -slp::detail::sin(x);
1015 }
1016};
1017
1023inline ExpressionPtr cos(const ExpressionPtr& x) {
1024 using enum ExpressionType;
1025
1026 // Prune expression
1027 if (x->is_constant(0.0)) {
1028 return make_expression_ptr<ConstExpression>(1.0);
1029 }
1030
1031 // Evaluate constant
1032 if (x->type() == CONSTANT) {
1033 return make_expression_ptr<ConstExpression>(std::cos(x->val));
1034 }
1035
1036 return make_expression_ptr<CosExpression>(x);
1037}
1038
1048 explicit constexpr CoshExpression(ExpressionPtr lhs)
1049 : Expression{std::move(lhs)} {}
1050
1051 double value(double x, double) const override { return std::cosh(x); }
1052
1053 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1054
1055 double grad_l(double x, double, double parent_adjoint) const override {
1056 return parent_adjoint * std::sinh(x);
1057 }
1058
1060 const ExpressionPtr& x, const ExpressionPtr&,
1061 const ExpressionPtr& parent_adjoint) const override {
1062 return parent_adjoint * slp::detail::sinh(x);
1063 }
1064};
1065
1071inline ExpressionPtr cosh(const ExpressionPtr& x) {
1072 using enum ExpressionType;
1073
1074 // Prune expression
1075 if (x->is_constant(0.0)) {
1076 return make_expression_ptr<ConstExpression>(1.0);
1077 }
1078
1079 // Evaluate constant
1080 if (x->type() == CONSTANT) {
1081 return make_expression_ptr<ConstExpression>(std::cosh(x->val));
1082 }
1083
1084 return make_expression_ptr<CoshExpression>(x);
1085}
1086
1096 explicit constexpr ErfExpression(ExpressionPtr lhs)
1097 : Expression{std::move(lhs)} {}
1098
1099 double value(double x, double) const override { return std::erf(x); }
1100
1101 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1102
1103 double grad_l(double x, double, double parent_adjoint) const override {
1104 return parent_adjoint * 2.0 * std::numbers::inv_sqrtpi * std::exp(-x * x);
1105 }
1106
1108 const ExpressionPtr& x, const ExpressionPtr&,
1109 const ExpressionPtr& parent_adjoint) const override {
1110 return parent_adjoint *
1111 make_expression_ptr<ConstExpression>(2.0 *
1112 std::numbers::inv_sqrtpi) *
1113 slp::detail::exp(-x * x);
1114 }
1115};
1116
1122inline ExpressionPtr erf(const ExpressionPtr& x) {
1123 using enum ExpressionType;
1124
1125 // Prune expression
1126 if (x->is_constant(0.0)) {
1127 // Return zero
1128 return x;
1129 }
1130
1131 // Evaluate constant
1132 if (x->type() == CONSTANT) {
1133 return make_expression_ptr<ConstExpression>(std::erf(x->val));
1134 }
1135
1136 return make_expression_ptr<ErfExpression>(x);
1137}
1138
1148 explicit constexpr ExpExpression(ExpressionPtr lhs)
1149 : Expression{std::move(lhs)} {}
1150
1151 double value(double x, double) const override { return std::exp(x); }
1152
1153 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1154
1155 double grad_l(double x, double, double parent_adjoint) const override {
1156 return parent_adjoint * std::exp(x);
1157 }
1158
1160 const ExpressionPtr& x, const ExpressionPtr&,
1161 const ExpressionPtr& parent_adjoint) const override {
1162 return parent_adjoint * slp::detail::exp(x);
1163 }
1164};
1165
1171inline ExpressionPtr exp(const ExpressionPtr& x) {
1172 using enum ExpressionType;
1173
1174 // Prune expression
1175 if (x->is_constant(0.0)) {
1176 return make_expression_ptr<ConstExpression>(1.0);
1177 }
1178
1179 // Evaluate constant
1180 if (x->type() == CONSTANT) {
1181 return make_expression_ptr<ConstExpression>(std::exp(x->val));
1182 }
1183
1184 return make_expression_ptr<ExpExpression>(x);
1185}
1186
1187inline ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y);
1188
1200 : Expression{std::move(lhs), std::move(rhs)} {}
1201
1202 double value(double x, double y) const override { return std::hypot(x, y); }
1203
1204 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1205
1206 double grad_l(double x, double y, double parent_adjoint) const override {
1207 return parent_adjoint * x / std::hypot(x, y);
1208 }
1209
1210 double grad_r(double x, double y, double parent_adjoint) const override {
1211 return parent_adjoint * y / std::hypot(x, y);
1212 }
1213
1215 const ExpressionPtr& x, const ExpressionPtr& y,
1216 const ExpressionPtr& parent_adjoint) const override {
1217 return parent_adjoint * x / slp::detail::hypot(x, y);
1218 }
1219
1221 const ExpressionPtr& x, const ExpressionPtr& y,
1222 const ExpressionPtr& parent_adjoint) const override {
1223 return parent_adjoint * y / slp::detail::hypot(x, y);
1224 }
1225};
1226
1233inline ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y) {
1234 using enum ExpressionType;
1235
1236 // Prune expression
1237 if (x->is_constant(0.0)) {
1238 return y;
1239 } else if (y->is_constant(0.0)) {
1240 return x;
1241 }
1242
1243 // Evaluate constant
1244 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1245 return make_expression_ptr<ConstExpression>(std::hypot(x->val, y->val));
1246 }
1247
1248 return make_expression_ptr<HypotExpression>(x, y);
1249}
1250
1260 explicit constexpr LogExpression(ExpressionPtr lhs)
1261 : Expression{std::move(lhs)} {}
1262
1263 double value(double x, double) const override { return std::log(x); }
1264
1265 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1266
1267 double grad_l(double x, double, double parent_adjoint) const override {
1268 return parent_adjoint / x;
1269 }
1270
1272 const ExpressionPtr& x, const ExpressionPtr&,
1273 const ExpressionPtr& parent_adjoint) const override {
1274 return parent_adjoint / x;
1275 }
1276};
1277
1283inline ExpressionPtr log(const ExpressionPtr& x) {
1284 using enum ExpressionType;
1285
1286 // Prune expression
1287 if (x->is_constant(0.0)) {
1288 // Return zero
1289 return x;
1290 }
1291
1292 // Evaluate constant
1293 if (x->type() == CONSTANT) {
1294 return make_expression_ptr<ConstExpression>(std::log(x->val));
1295 }
1296
1297 return make_expression_ptr<LogExpression>(x);
1298}
1299
1309 explicit constexpr Log10Expression(ExpressionPtr lhs)
1310 : Expression{std::move(lhs)} {}
1311
1312 double value(double x, double) const override { return std::log10(x); }
1313
1314 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1315
1316 double grad_l(double x, double, double parent_adjoint) const override {
1317 return parent_adjoint / (std::numbers::ln10 * x);
1318 }
1319
1321 const ExpressionPtr& x, const ExpressionPtr&,
1322 const ExpressionPtr& parent_adjoint) const override {
1323 return parent_adjoint /
1324 (make_expression_ptr<ConstExpression>(std::numbers::ln10) * x);
1325 }
1326};
1327
1333inline ExpressionPtr log10(const ExpressionPtr& x) {
1334 using enum ExpressionType;
1335
1336 // Prune expression
1337 if (x->is_constant(0.0)) {
1338 // Return zero
1339 return x;
1340 }
1341
1342 // Evaluate constant
1343 if (x->type() == CONSTANT) {
1344 return make_expression_ptr<ConstExpression>(std::log10(x->val));
1345 }
1346
1347 return make_expression_ptr<Log10Expression>(x);
1348}
1349
1350inline ExpressionPtr pow(const ExpressionPtr& base, const ExpressionPtr& power);
1351
1357template <ExpressionType T>
1366 : Expression{std::move(lhs), std::move(rhs)} {}
1367
1368 double value(double base, double power) const override {
1369 return std::pow(base, power);
1370 }
1371
1372 ExpressionType type() const override { return T; }
1373
1374 double grad_l(double base, double power,
1375 double parent_adjoint) const override {
1376 return parent_adjoint * std::pow(base, power - 1) * power;
1377 }
1378
1379 double grad_r(double base, double power,
1380 double parent_adjoint) const override {
1381 // Since x * std::log(x) -> 0 as x -> 0
1382 if (base == 0.0) {
1383 return 0.0;
1384 } else {
1385 return parent_adjoint * std::pow(base, power - 1) * base * std::log(base);
1386 }
1387 }
1388
1390 const ExpressionPtr& base, const ExpressionPtr& power,
1391 const ExpressionPtr& parent_adjoint) const override {
1392 return parent_adjoint *
1393 slp::detail::pow(base,
1394 power - make_expression_ptr<ConstExpression>(1.0)) *
1395 power;
1396 }
1397
1399 const ExpressionPtr& base, const ExpressionPtr& power,
1400 const ExpressionPtr& parent_adjoint) const override {
1401 // Since x * std::log(x) -> 0 as x -> 0
1402 if (base->val == 0.0) {
1403 // Return zero
1404 return base;
1405 } else {
1406 return parent_adjoint *
1407 slp::detail::pow(
1408 base, power - make_expression_ptr<ConstExpression>(1.0)) *
1409 base * slp::detail::log(base);
1410 }
1411 }
1412};
1413
1420inline ExpressionPtr pow(const ExpressionPtr& base,
1421 const ExpressionPtr& power) {
1422 using enum ExpressionType;
1423
1424 // Prune expression
1425 if (base->is_constant(0.0)) {
1426 // Return zero
1427 return base;
1428 } else if (base->is_constant(1.0)) {
1429 // Return one
1430 return base;
1431 }
1432 if (power->is_constant(0.0)) {
1433 return make_expression_ptr<ConstExpression>(1.0);
1434 } else if (power->is_constant(1.0)) {
1435 return base;
1436 }
1437
1438 // Evaluate constant
1439 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1440 return make_expression_ptr<ConstExpression>(
1441 std::pow(base->val, power->val));
1442 }
1443
1444 if (power->is_constant(2.0)) {
1445 if (base->type() == LINEAR) {
1446 return make_expression_ptr<MultExpression<QUADRATIC>>(base, base);
1447 } else {
1448 return make_expression_ptr<MultExpression<NONLINEAR>>(base, base);
1449 }
1450 }
1451
1452 return make_expression_ptr<PowExpression<NONLINEAR>>(base, power);
1453}
1454
1464 explicit constexpr SignExpression(ExpressionPtr lhs)
1465 : Expression{std::move(lhs)} {}
1466
1467 double value(double x, double) const override {
1468 if (x < 0.0) {
1469 return -1.0;
1470 } else if (x == 0.0) {
1471 return 0.0;
1472 } else {
1473 return 1.0;
1474 }
1475 }
1476
1477 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1478
1479 double grad_l(double, double, double) const override { return 0.0; }
1480
1482 const ExpressionPtr&) const override {
1483 // Return zero
1484 return make_expression_ptr<ConstExpression>();
1485 }
1486};
1487
1493inline ExpressionPtr sign(const ExpressionPtr& x) {
1494 using enum ExpressionType;
1495
1496 // Evaluate constant
1497 if (x->type() == CONSTANT) {
1498 if (x->val < 0.0) {
1499 return make_expression_ptr<ConstExpression>(-1.0);
1500 } else if (x->val == 0.0) {
1501 // Return zero
1502 return x;
1503 } else {
1504 return make_expression_ptr<ConstExpression>(1.0);
1505 }
1506 }
1507
1508 return make_expression_ptr<SignExpression>(x);
1509}
1510
1520 explicit constexpr SinExpression(ExpressionPtr lhs)
1521 : Expression{std::move(lhs)} {}
1522
1523 double value(double x, double) const override { return std::sin(x); }
1524
1525 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1526
1527 double grad_l(double x, double, double parent_adjoint) const override {
1528 return parent_adjoint * std::cos(x);
1529 }
1530
1532 const ExpressionPtr& x, const ExpressionPtr&,
1533 const ExpressionPtr& parent_adjoint) const override {
1534 return parent_adjoint * slp::detail::cos(x);
1535 }
1536};
1537
1543inline ExpressionPtr sin(const ExpressionPtr& x) {
1544 using enum ExpressionType;
1545
1546 // Prune expression
1547 if (x->is_constant(0.0)) {
1548 // Return zero
1549 return x;
1550 }
1551
1552 // Evaluate constant
1553 if (x->type() == CONSTANT) {
1554 return make_expression_ptr<ConstExpression>(std::sin(x->val));
1555 }
1556
1557 return make_expression_ptr<SinExpression>(x);
1558}
1559
1569 explicit constexpr SinhExpression(ExpressionPtr lhs)
1570 : Expression{std::move(lhs)} {}
1571
1572 double value(double x, double) const override { return std::sinh(x); }
1573
1574 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1575
1576 double grad_l(double x, double, double parent_adjoint) const override {
1577 return parent_adjoint * std::cosh(x);
1578 }
1579
1581 const ExpressionPtr& x, const ExpressionPtr&,
1582 const ExpressionPtr& parent_adjoint) const override {
1583 return parent_adjoint * slp::detail::cosh(x);
1584 }
1585};
1586
1592inline ExpressionPtr sinh(const ExpressionPtr& x) {
1593 using enum ExpressionType;
1594
1595 // Prune expression
1596 if (x->is_constant(0.0)) {
1597 // Return zero
1598 return x;
1599 }
1600
1601 // Evaluate constant
1602 if (x->type() == CONSTANT) {
1603 return make_expression_ptr<ConstExpression>(std::sinh(x->val));
1604 }
1605
1606 return make_expression_ptr<SinhExpression>(x);
1607}
1608
1618 explicit constexpr SqrtExpression(ExpressionPtr lhs)
1619 : Expression{std::move(lhs)} {}
1620
1621 double value(double x, double) const override { return std::sqrt(x); }
1622
1623 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1624
1625 double grad_l(double x, double, double parent_adjoint) const override {
1626 return parent_adjoint / (2.0 * std::sqrt(x));
1627 }
1628
1630 const ExpressionPtr& x, const ExpressionPtr&,
1631 const ExpressionPtr& parent_adjoint) const override {
1632 return parent_adjoint /
1633 (make_expression_ptr<ConstExpression>(2.0) * slp::detail::sqrt(x));
1634 }
1635};
1636
1642inline ExpressionPtr sqrt(const ExpressionPtr& x) {
1643 using enum ExpressionType;
1644
1645 // Evaluate constant
1646 if (x->type() == CONSTANT) {
1647 if (x->val == 0.0) {
1648 // Return zero
1649 return x;
1650 } else if (x->val == 1.0) {
1651 return x;
1652 } else {
1653 return make_expression_ptr<ConstExpression>(std::sqrt(x->val));
1654 }
1655 }
1656
1657 return make_expression_ptr<SqrtExpression>(x);
1658}
1659
1669 explicit constexpr TanExpression(ExpressionPtr lhs)
1670 : Expression{std::move(lhs)} {}
1671
1672 double value(double x, double) const override { return std::tan(x); }
1673
1674 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1675
1676 double grad_l(double x, double, double parent_adjoint) const override {
1677 return parent_adjoint / (std::cos(x) * std::cos(x));
1678 }
1679
1681 const ExpressionPtr& x, const ExpressionPtr&,
1682 const ExpressionPtr& parent_adjoint) const override {
1683 return parent_adjoint / (slp::detail::cos(x) * slp::detail::cos(x));
1684 }
1685};
1686
1692inline ExpressionPtr tan(const ExpressionPtr& x) {
1693 using enum ExpressionType;
1694
1695 // Prune expression
1696 if (x->is_constant(0.0)) {
1697 // Return zero
1698 return x;
1699 }
1700
1701 // Evaluate constant
1702 if (x->type() == CONSTANT) {
1703 return make_expression_ptr<ConstExpression>(std::tan(x->val));
1704 }
1705
1706 return make_expression_ptr<TanExpression>(x);
1707}
1708
1718 explicit constexpr TanhExpression(ExpressionPtr lhs)
1719 : Expression{std::move(lhs)} {}
1720
1721 double value(double x, double) const override { return std::tanh(x); }
1722
1723 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1724
1725 double grad_l(double x, double, double parent_adjoint) const override {
1726 return parent_adjoint / (std::cosh(x) * std::cosh(x));
1727 }
1728
1730 const ExpressionPtr& x, const ExpressionPtr&,
1731 const ExpressionPtr& parent_adjoint) const override {
1732 return parent_adjoint / (slp::detail::cosh(x) * slp::detail::cosh(x));
1733 }
1734};
1735
1741inline ExpressionPtr tanh(const ExpressionPtr& x) {
1742 using enum ExpressionType;
1743
1744 // Prune expression
1745 if (x->is_constant(0.0)) {
1746 // Return zero
1747 return x;
1748 }
1749
1750 // Evaluate constant
1751 if (x->type() == CONSTANT) {
1752 return make_expression_ptr<ConstExpression>(std::tanh(x->val));
1753 }
1754
1755 return make_expression_ptr<TanhExpression>(x);
1756}
1757
1758} // namespace slp::detail
Definition expression.hpp:721
ExpressionType type() const override
Definition expression.hpp:732
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:734
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:744
double value(double x, double) const override
Definition expression.hpp:730
constexpr AbsExpression(ExpressionPtr lhs)
Definition expression.hpp:727
Definition expression.hpp:783
ExpressionType type() const override
Definition expression.hpp:794
constexpr AcosExpression(ExpressionPtr lhs)
Definition expression.hpp:789
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:796
double value(double x, double) const override
Definition expression.hpp:792
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:800
Definition expression.hpp:832
double value(double x, double) const override
Definition expression.hpp:841
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:849
ExpressionType type() const override
Definition expression.hpp:843
constexpr AsinExpression(ExpressionPtr lhs)
Definition expression.hpp:838
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:845
Definition expression.hpp:931
ExpressionPtr grad_expr_r(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:959
double value(double y, double x) const override
Definition expression.hpp:941
ExpressionType type() const override
Definition expression.hpp:943
constexpr Atan2Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:938
double grad_l(double y, double x, double parent_adjoint) const override
Definition expression.hpp:945
double grad_r(double y, double x, double parent_adjoint) const override
Definition expression.hpp:949
ExpressionPtr grad_expr_l(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:953
Definition expression.hpp:882
constexpr AtanExpression(ExpressionPtr lhs)
Definition expression.hpp:888
double value(double x, double) const override
Definition expression.hpp:891
ExpressionType type() const override
Definition expression.hpp:893
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:899
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:895
Definition expression.hpp:431
double value(double lhs, double rhs) const override
Definition expression.hpp:441
constexpr BinaryMinusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:438
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:449
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:445
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:459
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:453
ExpressionType type() const override
Definition expression.hpp:443
Definition expression.hpp:472
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:490
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:486
double value(double lhs, double rhs) const override
Definition expression.hpp:482
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:500
constexpr BinaryPlusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:479
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:494
ExpressionType type() const override
Definition expression.hpp:484
Definition expression.hpp:510
constexpr ConstExpression(double value)
Definition expression.hpp:521
constexpr ConstExpression()=default
double value(double, double) const override
Definition expression.hpp:523
ExpressionType type() const override
Definition expression.hpp:525
Definition expression.hpp:994
ExpressionType type() const override
Definition expression.hpp:1005
double value(double x, double) const override
Definition expression.hpp:1003
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1011
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1007
constexpr CosExpression(ExpressionPtr lhs)
Definition expression.hpp:1000
Definition expression.hpp:1042
constexpr CoshExpression(ExpressionPtr lhs)
Definition expression.hpp:1048
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1055
double value(double x, double) const override
Definition expression.hpp:1051
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1059
ExpressionType type() const override
Definition expression.hpp:1053
Definition expression.hpp:531
double value(double, double) const override
Definition expression.hpp:545
constexpr DecisionVariableExpression()=default
constexpr DecisionVariableExpression(double value)
Definition expression.hpp:542
ExpressionType type() const override
Definition expression.hpp:547
Definition expression.hpp:556
double grad_l(double, double rhs, double parent_adjoint) const override
Definition expression.hpp:570
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:584
double value(double lhs, double rhs) const override
Definition expression.hpp:566
constexpr DivExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:563
ExpressionType type() const override
Definition expression.hpp:568
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:578
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:574
Definition expression.hpp:1090
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1107
double value(double x, double) const override
Definition expression.hpp:1099
constexpr ErfExpression(ExpressionPtr lhs)
Definition expression.hpp:1096
ExpressionType type() const override
Definition expression.hpp:1101
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1103
Definition expression.hpp:1142
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1155
ExpressionType type() const override
Definition expression.hpp:1153
constexpr ExpExpression(ExpressionPtr lhs)
Definition expression.hpp:1148
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1159
double value(double x, double) const override
Definition expression.hpp:1151
Definition expression.hpp:77
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:88
virtual double grad_l(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:374
constexpr bool is_constant(double constant) const
Definition expression.hpp:138
constexpr Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:126
constexpr Expression(double value)
Definition expression.hpp:110
virtual ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:417
virtual ExpressionType type() const =0
ExpressionPtr adjoint_expr
Definition expression.hpp:92
friend ExpressionPtr operator+(const ExpressionPtr &lhs)
Definition expression.hpp:344
constexpr Expression(ExpressionPtr lhs)
Definition expression.hpp:117
double adjoint
The adjoint of the expression node used during autodiff.
Definition expression.hpp:82
friend ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:148
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:85
friend ExpressionPtr operator+=(ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:269
friend ExpressionPtr operator-(const ExpressionPtr &lhs)
Definition expression.hpp:316
friend ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:280
friend ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:200
virtual double grad_r(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:388
virtual double value(double lhs, double rhs) const =0
virtual ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:402
std::array< ExpressionPtr, 2 > args
Expression arguments.
Definition expression.hpp:98
double val
The value of the expression node.
Definition expression.hpp:79
constexpr Expression()=default
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:95
friend ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:237
Definition expression.hpp:1192
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1214
double grad_r(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1210
double grad_l(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1206
ExpressionPtr grad_expr_r(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1220
double value(double x, double y) const override
Definition expression.hpp:1202
ExpressionType type() const override
Definition expression.hpp:1204
constexpr HypotExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1199
Definition expression.hpp:1303
constexpr Log10Expression(ExpressionPtr lhs)
Definition expression.hpp:1309
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1316
double value(double x, double) const override
Definition expression.hpp:1312
ExpressionType type() const override
Definition expression.hpp:1314
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1320
Definition expression.hpp:1254
ExpressionType type() const override
Definition expression.hpp:1265
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1271
constexpr LogExpression(ExpressionPtr lhs)
Definition expression.hpp:1260
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1267
double value(double x, double) const override
Definition expression.hpp:1263
Definition expression.hpp:597
ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:621
constexpr MultExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:604
ExpressionType type() const override
Definition expression.hpp:609
double grad_l(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:611
double value(double lhs, double rhs) const override
Definition expression.hpp:607
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:627
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:616
Definition expression.hpp:1358
ExpressionPtr grad_expr_l(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1389
double value(double base, double power) const override
Definition expression.hpp:1368
double grad_l(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1374
double grad_r(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1379
ExpressionType type() const override
Definition expression.hpp:1372
ExpressionPtr grad_expr_r(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1398
constexpr PowExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1365
Definition expression.hpp:1458
double grad_l(double, double, double) const override
Definition expression.hpp:1479
ExpressionType type() const override
Definition expression.hpp:1477
constexpr SignExpression(ExpressionPtr lhs)
Definition expression.hpp:1464
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) const override
Definition expression.hpp:1481
double value(double x, double) const override
Definition expression.hpp:1467
Definition expression.hpp:1514
double value(double x, double) const override
Definition expression.hpp:1523
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1527
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1531
ExpressionType type() const override
Definition expression.hpp:1525
constexpr SinExpression(ExpressionPtr lhs)
Definition expression.hpp:1520
Definition expression.hpp:1563
ExpressionType type() const override
Definition expression.hpp:1574
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1576
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1580
constexpr SinhExpression(ExpressionPtr lhs)
Definition expression.hpp:1569
double value(double x, double) const override
Definition expression.hpp:1572
Definition expression.hpp:1612
ExpressionType type() const override
Definition expression.hpp:1623
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1625
constexpr SqrtExpression(ExpressionPtr lhs)
Definition expression.hpp:1618
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1629
double value(double x, double) const override
Definition expression.hpp:1621
Definition expression.hpp:1663
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1676
double value(double x, double) const override
Definition expression.hpp:1672
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1680
ExpressionType type() const override
Definition expression.hpp:1674
constexpr TanExpression(ExpressionPtr lhs)
Definition expression.hpp:1669
Definition expression.hpp:1712
constexpr TanhExpression(ExpressionPtr lhs)
Definition expression.hpp:1718
ExpressionType type() const override
Definition expression.hpp:1723
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1725
double value(double x, double) const override
Definition expression.hpp:1721
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1729
Definition expression.hpp:640
constexpr UnaryMinusExpression(ExpressionPtr lhs)
Definition expression.hpp:646
ExpressionType type() const override
Definition expression.hpp:651
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:657
double value(double lhs, double) const override
Definition expression.hpp:649
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:653