Sleipnir C++ API
Loading...
Searching...
No Matches
expression.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <stdint.h>
6
7#include <algorithm>
8#include <array>
9#include <cmath>
10#include <memory>
11#include <numbers>
12#include <utility>
13
14#include "sleipnir/autodiff/expression_type.hpp"
15#include "sleipnir/util/intrusive_shared_ptr.hpp"
16#include "sleipnir/util/pool.hpp"
17#include "sleipnir/util/small_vector.hpp"
18
19namespace slp::detail {
20
21// The global pool allocator uses a thread-local static pool resource, which
22// isn't guaranteed to be initialized properly across DLL boundaries on Windows
23#ifdef _WIN32
24inline constexpr bool USE_POOL_ALLOCATOR = false;
25#else
26inline constexpr bool USE_POOL_ALLOCATOR = true;
27#endif
28
29struct Expression;
30
31inline constexpr void inc_ref_count(Expression* expr);
32inline constexpr void dec_ref_count(Expression* expr);
33
37using ExpressionPtr = IntrusiveSharedPtr<Expression>;
38
46template <typename T, typename... Args>
47static ExpressionPtr make_expression_ptr(Args&&... args) {
48 if constexpr (USE_POOL_ALLOCATOR) {
49 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
50 std::forward<Args>(args)...);
51 } else {
52 return make_intrusive_shared<T>(std::forward<Args>(args)...);
53 }
54}
55
56template <ExpressionType T>
57struct BinaryMinusExpression;
58
59template <ExpressionType T>
60struct BinaryPlusExpression;
61
62struct ConstExpression;
63
64template <ExpressionType T>
65struct DivExpression;
66
67template <ExpressionType T>
68struct MultExpression;
69
70template <ExpressionType T>
71struct UnaryMinusExpression;
72
76struct Expression {
78 double val = 0.0;
79
81 double adjoint = 0.0;
82
84 uint32_t incoming_edges = 0;
85
87 int32_t col = -1;
88
92
94 uint32_t ref_count = 0;
95
97 std::array<ExpressionPtr, 2> args{nullptr, nullptr};
98
102 constexpr Expression() = default;
103
109 explicit constexpr Expression(double value) : val{value} {}
110
116 explicit constexpr Expression(ExpressionPtr lhs)
117 : args{std::move(lhs), nullptr} {}
118
126 : args{std::move(lhs), std::move(rhs)} {}
127
128 virtual ~Expression() = default;
129
137 constexpr bool is_constant(double constant) const {
138 return type() == ExpressionType::CONSTANT && val == constant;
139 }
140
148 const ExpressionPtr& rhs) {
149 using enum ExpressionType;
150
151 // Prune expression
152 if (lhs->is_constant(0.0)) {
153 // Return zero
154 return lhs;
155 } else if (rhs->is_constant(0.0)) {
156 // Return zero
157 return rhs;
158 } else if (lhs->is_constant(1.0)) {
159 return rhs;
160 } else if (rhs->is_constant(1.0)) {
161 return lhs;
162 }
163
164 // Evaluate constant
165 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
166 return make_expression_ptr<ConstExpression>(lhs->val * rhs->val);
167 }
168
169 // Evaluate expression type
170 if (lhs->type() == CONSTANT) {
171 if (rhs->type() == LINEAR) {
172 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
173 } else if (rhs->type() == QUADRATIC) {
174 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
175 } else {
176 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
177 }
178 } else if (rhs->type() == CONSTANT) {
179 if (lhs->type() == LINEAR) {
180 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
181 } else if (lhs->type() == QUADRATIC) {
182 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
183 } else {
184 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
185 }
186 } else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
187 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
188 } else {
189 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
190 }
191 }
192
200 const ExpressionPtr& rhs) {
201 using enum ExpressionType;
202
203 // Prune expression
204 if (lhs->is_constant(0.0)) {
205 // Return zero
206 return lhs;
207 } else if (rhs->is_constant(1.0)) {
208 return lhs;
209 }
210
211 // Evaluate constant
212 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
213 return make_expression_ptr<ConstExpression>(lhs->val / rhs->val);
214 }
215
216 // Evaluate expression type
217 if (rhs->type() == CONSTANT) {
218 if (lhs->type() == LINEAR) {
219 return make_expression_ptr<DivExpression<LINEAR>>(lhs, rhs);
220 } else if (lhs->type() == QUADRATIC) {
221 return make_expression_ptr<DivExpression<QUADRATIC>>(lhs, rhs);
222 } else {
223 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
224 }
225 } else {
226 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
227 }
228 }
229
237 const ExpressionPtr& rhs) {
238 using enum ExpressionType;
239
240 // Prune expression
241 if (lhs == nullptr || lhs->is_constant(0.0)) {
242 return rhs;
243 } else if (rhs == nullptr || rhs->is_constant(0.0)) {
244 return lhs;
245 }
246
247 // Evaluate constant
248 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
249 return make_expression_ptr<ConstExpression>(lhs->val + rhs->val);
250 }
251
252 auto type = std::max(lhs->type(), rhs->type());
253 if (type == LINEAR) {
254 return make_expression_ptr<BinaryPlusExpression<LINEAR>>(lhs, rhs);
255 } else if (type == QUADRATIC) {
256 return make_expression_ptr<BinaryPlusExpression<QUADRATIC>>(lhs, rhs);
257 } else {
258 return make_expression_ptr<BinaryPlusExpression<NONLINEAR>>(lhs, rhs);
259 }
260 }
261
269 const ExpressionPtr& rhs) {
270 return lhs = lhs + rhs;
271 }
272
280 const ExpressionPtr& rhs) {
281 using enum ExpressionType;
282
283 // Prune expression
284 if (lhs->is_constant(0.0)) {
285 if (rhs->is_constant(0.0)) {
286 // Return zero
287 return rhs;
288 } else {
289 return -rhs;
290 }
291 } else if (rhs->is_constant(0.0)) {
292 return lhs;
293 }
294
295 // Evaluate constant
296 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
297 return make_expression_ptr<ConstExpression>(lhs->val - rhs->val);
298 }
299
300 auto type = std::max(lhs->type(), rhs->type());
301 if (type == LINEAR) {
302 return make_expression_ptr<BinaryMinusExpression<LINEAR>>(lhs, rhs);
303 } else if (type == QUADRATIC) {
304 return make_expression_ptr<BinaryMinusExpression<QUADRATIC>>(lhs, rhs);
305 } else {
306 return make_expression_ptr<BinaryMinusExpression<NONLINEAR>>(lhs, rhs);
307 }
308 }
309
316 using enum ExpressionType;
317
318 // Prune expression
319 if (lhs->is_constant(0.0)) {
320 // Return zero
321 return lhs;
322 }
323
324 // Evaluate constant
325 if (lhs->type() == CONSTANT) {
326 return make_expression_ptr<ConstExpression>(-lhs->val);
327 }
328
329 if (lhs->type() == LINEAR) {
330 return make_expression_ptr<UnaryMinusExpression<LINEAR>>(lhs);
331 } else if (lhs->type() == QUADRATIC) {
332 return make_expression_ptr<UnaryMinusExpression<QUADRATIC>>(lhs);
333 } else {
334 return make_expression_ptr<UnaryMinusExpression<NONLINEAR>>(lhs);
335 }
336 }
337
343 friend ExpressionPtr operator+(const ExpressionPtr& lhs) { return lhs; }
344
354 virtual double value([[maybe_unused]] double lhs,
355 [[maybe_unused]] double rhs) const = 0;
356
363 virtual ExpressionType type() const = 0;
364
373 virtual double grad_l([[maybe_unused]] double lhs,
374 [[maybe_unused]] double rhs,
375 [[maybe_unused]] double parent_adjoint) const {
376 return 0.0;
377 }
378
387 virtual double grad_r([[maybe_unused]] double lhs,
388 [[maybe_unused]] double rhs,
389 [[maybe_unused]] double parent_adjoint) const {
390 return 0.0;
391 }
392
402 [[maybe_unused]] const ExpressionPtr& lhs,
403 [[maybe_unused]] const ExpressionPtr& rhs,
404 [[maybe_unused]] const ExpressionPtr& parent_adjoint) const {
405 return make_expression_ptr<ConstExpression>();
406 }
407
417 [[maybe_unused]] const ExpressionPtr& lhs,
418 [[maybe_unused]] const ExpressionPtr& rhs,
419 [[maybe_unused]] const ExpressionPtr& parent_adjoint) const {
420 return make_expression_ptr<ConstExpression>();
421 }
422};
423
429template <ExpressionType T>
438 : Expression{std::move(lhs), std::move(rhs)} {
439 val = args[0]->val - args[1]->val;
440 }
441
442 double value(double lhs, double rhs) const override { return lhs - rhs; }
443
444 ExpressionType type() const override { return T; }
445
446 double grad_l(double, double, double parent_adjoint) const override {
447 return parent_adjoint;
448 }
449
450 double grad_r(double, double, double parent_adjoint) const override {
451 return -parent_adjoint;
452 }
453
455 const ExpressionPtr&, const ExpressionPtr&,
456 const ExpressionPtr& parent_adjoint) const override {
457 return parent_adjoint;
458 }
459
461 const ExpressionPtr&, const ExpressionPtr&,
462 const ExpressionPtr& parent_adjoint) const override {
463 return -parent_adjoint;
464 }
465};
466
472template <ExpressionType T>
481 : Expression{std::move(lhs), std::move(rhs)} {
482 val = args[0]->val + args[1]->val;
483 }
484
485 double value(double lhs, double rhs) const override { return lhs + rhs; }
486
487 ExpressionType type() const override { return T; }
488
489 double grad_l(double, double, double parent_adjoint) const override {
490 return parent_adjoint;
491 }
492
493 double grad_r(double, double, double parent_adjoint) const override {
494 return parent_adjoint;
495 }
496
498 const ExpressionPtr&, const ExpressionPtr&,
499 const ExpressionPtr& parent_adjoint) const override {
500 return parent_adjoint;
501 }
502
504 const ExpressionPtr&, const ExpressionPtr&,
505 const ExpressionPtr& parent_adjoint) const override {
506 return parent_adjoint;
507 }
508};
509
517 constexpr ConstExpression() = default;
518
524 explicit constexpr ConstExpression(double value) : Expression{value} {}
525
526 double value(double, double) const override { return val; }
527
528 ExpressionType type() const override { return ExpressionType::CONSTANT; }
529};
530
538 constexpr DecisionVariableExpression() = default;
539
545 explicit constexpr DecisionVariableExpression(double value)
546 : Expression{value} {}
547
548 double value(double, double) const override { return val; }
549
550 ExpressionType type() const override { return ExpressionType::LINEAR; }
551};
552
558template <ExpressionType T>
559struct DivExpression final : Expression {
567 : Expression{std::move(lhs), std::move(rhs)} {
568 val = args[0]->val / args[1]->val;
569 }
570
571 double value(double lhs, double rhs) const override { return lhs / rhs; }
572
573 ExpressionType type() const override { return T; }
574
575 double grad_l(double, double rhs, double parent_adjoint) const override {
576 return parent_adjoint / rhs;
577 };
578
579 double grad_r(double lhs, double rhs, double parent_adjoint) const override {
580 return parent_adjoint * -lhs / (rhs * rhs);
581 }
582
584 const ExpressionPtr&, const ExpressionPtr& rhs,
585 const ExpressionPtr& parent_adjoint) const override {
586 return parent_adjoint / rhs;
587 }
588
590 const ExpressionPtr& lhs, const ExpressionPtr& rhs,
591 const ExpressionPtr& parent_adjoint) const override {
592 return parent_adjoint * -lhs / (rhs * rhs);
593 }
594};
595
601template <ExpressionType T>
610 : Expression{std::move(lhs), std::move(rhs)} {
611 val = args[0]->val * args[1]->val;
612 }
613
614 double value(double lhs, double rhs) const override { return lhs * rhs; }
615
616 ExpressionType type() const override { return T; }
617
618 double grad_l([[maybe_unused]] double lhs, double rhs,
619 double parent_adjoint) const override {
620 return parent_adjoint * rhs;
621 }
622
623 double grad_r(double lhs, [[maybe_unused]] double rhs,
624 double parent_adjoint) const override {
625 return parent_adjoint * lhs;
626 }
627
629 [[maybe_unused]] const ExpressionPtr& lhs, const ExpressionPtr& rhs,
630 const ExpressionPtr& parent_adjoint) const override {
631 return parent_adjoint * rhs;
632 }
633
635 const ExpressionPtr& lhs, [[maybe_unused]] const ExpressionPtr& rhs,
636 const ExpressionPtr& parent_adjoint) const override {
637 return parent_adjoint * lhs;
638 }
639};
640
646template <ExpressionType T>
653 explicit constexpr UnaryMinusExpression(ExpressionPtr lhs)
654 : Expression{std::move(lhs)} {
655 val = -args[0]->val;
656 }
657
658 double value(double lhs, double) const override { return -lhs; }
659
660 ExpressionType type() const override { return T; }
661
662 double grad_l(double, double, double parent_adjoint) const override {
663 return -parent_adjoint;
664 }
665
667 const ExpressionPtr&, const ExpressionPtr&,
668 const ExpressionPtr& parent_adjoint) const override {
669 return -parent_adjoint;
670 }
671};
672
673inline ExpressionPtr exp(const ExpressionPtr& x);
674inline ExpressionPtr sin(const ExpressionPtr& x);
675inline ExpressionPtr sinh(const ExpressionPtr& x);
676inline ExpressionPtr sqrt(const ExpressionPtr& x);
677
683inline constexpr void inc_ref_count(Expression* expr) {
684 ++expr->ref_count;
685}
686
692inline constexpr void dec_ref_count(Expression* expr) {
693 // If a deeply nested tree is being deallocated all at once, calling the
694 // Expression destructor when expr's refcount reaches zero can cause a stack
695 // overflow. Instead, we iterate over its children to decrement their
696 // refcounts and deallocate them.
697 small_vector<Expression*> stack;
698 stack.emplace_back(expr);
699
700 while (!stack.empty()) {
701 auto elem = stack.back();
702 stack.pop_back();
703
704 // Decrement the current node's refcount. If it reaches zero, deallocate the
705 // node and enqueue its children so their refcounts are decremented too.
706 if (--elem->ref_count == 0) {
707 if (elem->adjoint_expr != nullptr) {
708 stack.emplace_back(elem->adjoint_expr.get());
709 }
710 for (auto& arg : elem->args) {
711 if (arg != nullptr) {
712 stack.emplace_back(arg.get());
713 }
714 }
715
716 // Not calling the destructor here is safe because it only decrements
717 // refcounts, which was already done above.
718 if constexpr (USE_POOL_ALLOCATOR) {
719 auto alloc = global_pool_allocator<Expression>();
720 std::allocator_traits<decltype(alloc)>::deallocate(alloc, elem,
721 sizeof(Expression));
722 }
723 }
724 }
725}
726
730struct AbsExpression final : Expression {
736 explicit constexpr AbsExpression(ExpressionPtr lhs)
737 : Expression{std::move(lhs)} {
738 val = args[0]->val < 0 ? -args[0]->val : args[0]->val;
739 }
740
741 double value(double x, double) const override { return std::abs(x); }
742
743 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
744
745 double grad_l(double x, double, double parent_adjoint) const override {
746 if (x < 0.0) {
747 return -parent_adjoint;
748 } else if (x > 0.0) {
749 return parent_adjoint;
750 } else {
751 return 0.0;
752 }
753 }
754
756 const ExpressionPtr& x, const ExpressionPtr&,
757 const ExpressionPtr& parent_adjoint) const override {
758 if (x->val < 0.0) {
759 return -parent_adjoint;
760 } else if (x->val > 0.0) {
761 return parent_adjoint;
762 } else {
763 // Return zero
764 return make_expression_ptr<ConstExpression>();
765 }
766 }
767};
768
774inline ExpressionPtr abs(const ExpressionPtr& x) {
775 using enum ExpressionType;
776
777 // Prune expression
778 if (x->is_constant(0.0)) {
779 // Return zero
780 return x;
781 }
782
783 // Evaluate constant
784 if (x->type() == CONSTANT) {
785 return make_expression_ptr<ConstExpression>(std::abs(x->val));
786 }
787
788 return make_expression_ptr<AbsExpression>(x);
789}
790
800 explicit AcosExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
801 val = std::acos(args[0]->val);
802 }
803
804 double value(double x, double) const override { return std::acos(x); }
805
806 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
807
808 double grad_l(double x, double, double parent_adjoint) const override {
809 return -parent_adjoint / std::sqrt(1.0 - x * x);
810 }
811
813 const ExpressionPtr& x, const ExpressionPtr&,
814 const ExpressionPtr& parent_adjoint) const override {
815 return -parent_adjoint /
816 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
817 }
818};
819
825inline ExpressionPtr acos(const ExpressionPtr& x) {
826 using enum ExpressionType;
827
828 // Prune expression
829 if (x->is_constant(0.0)) {
830 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
831 }
832
833 // Evaluate constant
834 if (x->type() == CONSTANT) {
835 return make_expression_ptr<ConstExpression>(std::acos(x->val));
836 }
837
838 return make_expression_ptr<AcosExpression>(x);
839}
840
850 explicit AsinExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
851 val = std::asin(args[0]->val);
852 }
853
854 double value(double x, double) const override { return std::asin(x); }
855
856 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
857
858 double grad_l(double x, double, double parent_adjoint) const override {
859 return parent_adjoint / std::sqrt(1.0 - x * x);
860 }
861
863 const ExpressionPtr& x, const ExpressionPtr&,
864 const ExpressionPtr& parent_adjoint) const override {
865 return parent_adjoint /
866 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
867 }
868};
869
875inline ExpressionPtr asin(const ExpressionPtr& x) {
876 using enum ExpressionType;
877
878 // Prune expression
879 if (x->is_constant(0.0)) {
880 // Return zero
881 return x;
882 }
883
884 // Evaluate constant
885 if (x->type() == CONSTANT) {
886 return make_expression_ptr<ConstExpression>(std::asin(x->val));
887 }
888
889 return make_expression_ptr<AsinExpression>(x);
890}
891
901 explicit AtanExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
902 val = std::atan(args[0]->val);
903 }
904
905 double value(double x, double) const override { return std::atan(x); }
906
907 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
908
909 double grad_l(double x, double, double parent_adjoint) const override {
910 return parent_adjoint / (1.0 + x * x);
911 }
912
914 const ExpressionPtr& x, const ExpressionPtr&,
915 const ExpressionPtr& parent_adjoint) const override {
916 return parent_adjoint / (make_expression_ptr<ConstExpression>(1.0) + x * x);
917 }
918};
919
925inline ExpressionPtr atan(const ExpressionPtr& x) {
926 using enum ExpressionType;
927
928 // Prune expression
929 if (x->is_constant(0.0)) {
930 // Return zero
931 return x;
932 }
933
934 // Evaluate constant
935 if (x->type() == CONSTANT) {
936 return make_expression_ptr<ConstExpression>(std::atan(x->val));
937 }
938
939 return make_expression_ptr<AtanExpression>(x);
940}
941
953 : Expression{std::move(lhs), std::move(rhs)} {
954 val = std::atan2(args[0]->val, args[1]->val);
955 }
956
957 double value(double y, double x) const override { return std::atan2(y, x); }
958
959 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
960
961 double grad_l(double y, double x, double parent_adjoint) const override {
962 return parent_adjoint * x / (y * y + x * x);
963 }
964
965 double grad_r(double y, double x, double parent_adjoint) const override {
966 return parent_adjoint * -y / (y * y + x * x);
967 }
968
970 const ExpressionPtr& y, const ExpressionPtr& x,
971 const ExpressionPtr& parent_adjoint) const override {
972 return parent_adjoint * x / (y * y + x * x);
973 }
974
976 const ExpressionPtr& y, const ExpressionPtr& x,
977 const ExpressionPtr& parent_adjoint) const override {
978 return parent_adjoint * -y / (y * y + x * x);
979 }
980};
981
988inline ExpressionPtr atan2(const ExpressionPtr& y, const ExpressionPtr& x) {
989 using enum ExpressionType;
990
991 // Prune expression
992 if (y->is_constant(0.0)) {
993 // Return zero
994 return y;
995 } else if (x->is_constant(0.0)) {
996 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
997 }
998
999 // Evaluate constant
1000 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1001 return make_expression_ptr<ConstExpression>(std::atan2(y->val, x->val));
1002 }
1003
1004 return make_expression_ptr<Atan2Expression>(y, x);
1005}
1006
1016 explicit CosExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1017 val = std::cos(args[0]->val);
1018 }
1019
1020 double value(double x, double) const override { return std::cos(x); }
1021
1022 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1023
1024 double grad_l(double x, double, double parent_adjoint) const override {
1025 return -parent_adjoint * std::sin(x);
1026 }
1027
1029 const ExpressionPtr& x, const ExpressionPtr&,
1030 const ExpressionPtr& parent_adjoint) const override {
1031 return parent_adjoint * -slp::detail::sin(x);
1032 }
1033};
1034
1040inline ExpressionPtr cos(const ExpressionPtr& x) {
1041 using enum ExpressionType;
1042
1043 // Prune expression
1044 if (x->is_constant(0.0)) {
1045 return make_expression_ptr<ConstExpression>(1.0);
1046 }
1047
1048 // Evaluate constant
1049 if (x->type() == CONSTANT) {
1050 return make_expression_ptr<ConstExpression>(std::cos(x->val));
1051 }
1052
1053 return make_expression_ptr<CosExpression>(x);
1054}
1055
1065 explicit CoshExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1066 val = std::cosh(args[0]->val);
1067 }
1068
1069 double value(double x, double) const override { return std::cosh(x); }
1070
1071 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1072
1073 double grad_l(double x, double, double parent_adjoint) const override {
1074 return parent_adjoint * std::sinh(x);
1075 }
1076
1078 const ExpressionPtr& x, const ExpressionPtr&,
1079 const ExpressionPtr& parent_adjoint) const override {
1080 return parent_adjoint * slp::detail::sinh(x);
1081 }
1082};
1083
1089inline ExpressionPtr cosh(const ExpressionPtr& x) {
1090 using enum ExpressionType;
1091
1092 // Prune expression
1093 if (x->is_constant(0.0)) {
1094 return make_expression_ptr<ConstExpression>(1.0);
1095 }
1096
1097 // Evaluate constant
1098 if (x->type() == CONSTANT) {
1099 return make_expression_ptr<ConstExpression>(std::cosh(x->val));
1100 }
1101
1102 return make_expression_ptr<CoshExpression>(x);
1103}
1104
1114 explicit ErfExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1115 val = std::erf(args[0]->val);
1116 }
1117
1118 double value(double x, double) const override { return std::erf(x); }
1119
1120 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1121
1122 double grad_l(double x, double, double parent_adjoint) const override {
1123 return parent_adjoint * 2.0 * std::numbers::inv_sqrtpi * std::exp(-x * x);
1124 }
1125
1127 const ExpressionPtr& x, const ExpressionPtr&,
1128 const ExpressionPtr& parent_adjoint) const override {
1129 return parent_adjoint *
1130 make_expression_ptr<ConstExpression>(2.0 *
1131 std::numbers::inv_sqrtpi) *
1132 slp::detail::exp(-x * x);
1133 }
1134};
1135
1141inline ExpressionPtr erf(const ExpressionPtr& x) {
1142 using enum ExpressionType;
1143
1144 // Prune expression
1145 if (x->is_constant(0.0)) {
1146 // Return zero
1147 return x;
1148 }
1149
1150 // Evaluate constant
1151 if (x->type() == CONSTANT) {
1152 return make_expression_ptr<ConstExpression>(std::erf(x->val));
1153 }
1154
1155 return make_expression_ptr<ErfExpression>(x);
1156}
1157
1167 explicit ExpExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1168 val = std::exp(args[0]->val);
1169 }
1170
1171 double value(double x, double) const override { return std::exp(x); }
1172
1173 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1174
1175 double grad_l(double x, double, double parent_adjoint) const override {
1176 return parent_adjoint * std::exp(x);
1177 }
1178
1180 const ExpressionPtr& x, const ExpressionPtr&,
1181 const ExpressionPtr& parent_adjoint) const override {
1182 return parent_adjoint * slp::detail::exp(x);
1183 }
1184};
1185
1191inline ExpressionPtr exp(const ExpressionPtr& x) {
1192 using enum ExpressionType;
1193
1194 // Prune expression
1195 if (x->is_constant(0.0)) {
1196 return make_expression_ptr<ConstExpression>(1.0);
1197 }
1198
1199 // Evaluate constant
1200 if (x->type() == CONSTANT) {
1201 return make_expression_ptr<ConstExpression>(std::exp(x->val));
1202 }
1203
1204 return make_expression_ptr<ExpExpression>(x);
1205}
1206
1207inline ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y);
1208
1220 : Expression{std::move(lhs), std::move(rhs)} {
1221 val = std::hypot(args[0]->val, args[1]->val);
1222 }
1223
1224 double value(double x, double y) const override { return std::hypot(x, y); }
1225
1226 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1227
1228 double grad_l(double x, double y, double parent_adjoint) const override {
1229 return parent_adjoint * x / std::hypot(x, y);
1230 }
1231
1232 double grad_r(double x, double y, double parent_adjoint) const override {
1233 return parent_adjoint * y / std::hypot(x, y);
1234 }
1235
1237 const ExpressionPtr& x, const ExpressionPtr& y,
1238 const ExpressionPtr& parent_adjoint) const override {
1239 return parent_adjoint * x / slp::detail::hypot(x, y);
1240 }
1241
1243 const ExpressionPtr& x, const ExpressionPtr& y,
1244 const ExpressionPtr& parent_adjoint) const override {
1245 return parent_adjoint * y / slp::detail::hypot(x, y);
1246 }
1247};
1248
1255inline ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y) {
1256 using enum ExpressionType;
1257
1258 // Prune expression
1259 if (x->is_constant(0.0)) {
1260 return y;
1261 } else if (y->is_constant(0.0)) {
1262 return x;
1263 }
1264
1265 // Evaluate constant
1266 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1267 return make_expression_ptr<ConstExpression>(std::hypot(x->val, y->val));
1268 }
1269
1270 return make_expression_ptr<HypotExpression>(x, y);
1271}
1272
1282 explicit LogExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1283 val = std::log(args[0]->val);
1284 }
1285
1286 double value(double x, double) const override { return std::log(x); }
1287
1288 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1289
1290 double grad_l(double x, double, double parent_adjoint) const override {
1291 return parent_adjoint / x;
1292 }
1293
1295 const ExpressionPtr& x, const ExpressionPtr&,
1296 const ExpressionPtr& parent_adjoint) const override {
1297 return parent_adjoint / x;
1298 }
1299};
1300
1306inline ExpressionPtr log(const ExpressionPtr& x) {
1307 using enum ExpressionType;
1308
1309 // Prune expression
1310 if (x->is_constant(0.0)) {
1311 // Return zero
1312 return x;
1313 }
1314
1315 // Evaluate constant
1316 if (x->type() == CONSTANT) {
1317 return make_expression_ptr<ConstExpression>(std::log(x->val));
1318 }
1319
1320 return make_expression_ptr<LogExpression>(x);
1321}
1322
1332 explicit Log10Expression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1333 val = std::log10(args[0]->val);
1334 }
1335
1336 double value(double x, double) const override { return std::log10(x); }
1337
1338 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1339
1340 double grad_l(double x, double, double parent_adjoint) const override {
1341 return parent_adjoint / (std::numbers::ln10 * x);
1342 }
1343
1345 const ExpressionPtr& x, const ExpressionPtr&,
1346 const ExpressionPtr& parent_adjoint) const override {
1347 return parent_adjoint /
1348 (make_expression_ptr<ConstExpression>(std::numbers::ln10) * x);
1349 }
1350};
1351
1357inline ExpressionPtr log10(const ExpressionPtr& x) {
1358 using enum ExpressionType;
1359
1360 // Prune expression
1361 if (x->is_constant(0.0)) {
1362 // Return zero
1363 return x;
1364 }
1365
1366 // Evaluate constant
1367 if (x->type() == CONSTANT) {
1368 return make_expression_ptr<ConstExpression>(std::log10(x->val));
1369 }
1370
1371 return make_expression_ptr<Log10Expression>(x);
1372}
1373
1374inline ExpressionPtr pow(const ExpressionPtr& base, const ExpressionPtr& power);
1375
1381template <ExpressionType T>
1390 : Expression{std::move(lhs), std::move(rhs)} {
1391 val = std::pow(args[0]->val, args[1]->val);
1392 }
1393
1394 double value(double base, double power) const override {
1395 return std::pow(base, power);
1396 }
1397
1398 ExpressionType type() const override { return T; }
1399
1400 double grad_l(double base, double power,
1401 double parent_adjoint) const override {
1402 return parent_adjoint * std::pow(base, power - 1) * power;
1403 }
1404
1405 double grad_r(double base, double power,
1406 double parent_adjoint) const override {
1407 // Since x * std::log(x) -> 0 as x -> 0
1408 if (base == 0.0) {
1409 return 0.0;
1410 } else {
1411 return parent_adjoint * std::pow(base, power - 1) * base * std::log(base);
1412 }
1413 }
1414
1416 const ExpressionPtr& base, const ExpressionPtr& power,
1417 const ExpressionPtr& parent_adjoint) const override {
1418 return parent_adjoint *
1419 slp::detail::pow(base,
1420 power - make_expression_ptr<ConstExpression>(1.0)) *
1421 power;
1422 }
1423
1425 const ExpressionPtr& base, const ExpressionPtr& power,
1426 const ExpressionPtr& parent_adjoint) const override {
1427 // Since x * std::log(x) -> 0 as x -> 0
1428 if (base->val == 0.0) {
1429 // Return zero
1430 return base;
1431 } else {
1432 return parent_adjoint *
1433 slp::detail::pow(
1434 base, power - make_expression_ptr<ConstExpression>(1.0)) *
1435 base * slp::detail::log(base);
1436 }
1437 }
1438};
1439
1446inline ExpressionPtr pow(const ExpressionPtr& base,
1447 const ExpressionPtr& power) {
1448 using enum ExpressionType;
1449
1450 // Prune expression
1451 if (base->is_constant(0.0)) {
1452 // Return zero
1453 return base;
1454 } else if (base->is_constant(1.0)) {
1455 // Return one
1456 return base;
1457 }
1458 if (power->is_constant(0.0)) {
1459 return make_expression_ptr<ConstExpression>(1.0);
1460 } else if (power->is_constant(1.0)) {
1461 return base;
1462 }
1463
1464 // Evaluate constant
1465 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1466 return make_expression_ptr<ConstExpression>(
1467 std::pow(base->val, power->val));
1468 }
1469
1470 if (power->is_constant(2.0)) {
1471 if (base->type() == LINEAR) {
1472 return make_expression_ptr<MultExpression<QUADRATIC>>(base, base);
1473 } else {
1474 return make_expression_ptr<MultExpression<NONLINEAR>>(base, base);
1475 }
1476 }
1477
1478 return make_expression_ptr<PowExpression<NONLINEAR>>(base, power);
1479}
1480
1490 explicit constexpr SignExpression(ExpressionPtr lhs)
1491 : Expression{std::move(lhs)} {
1492 if (args[0]->val < 0.0) {
1493 val = -1.0;
1494 } else if (args[0]->val == 0.0) {
1495 val = 0.0;
1496 } else {
1497 val = 1.0;
1498 }
1499 }
1500
1501 double value(double x, double) const override {
1502 if (x < 0.0) {
1503 return -1.0;
1504 } else if (x == 0.0) {
1505 return 0.0;
1506 } else {
1507 return 1.0;
1508 }
1509 }
1510
1511 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1512
1513 double grad_l(double, double, double) const override { return 0.0; }
1514
1516 const ExpressionPtr&) const override {
1517 // Return zero
1518 return make_expression_ptr<ConstExpression>();
1519 }
1520};
1521
1527inline ExpressionPtr sign(const ExpressionPtr& x) {
1528 using enum ExpressionType;
1529
1530 // Evaluate constant
1531 if (x->type() == CONSTANT) {
1532 if (x->val < 0.0) {
1533 return make_expression_ptr<ConstExpression>(-1.0);
1534 } else if (x->val == 0.0) {
1535 // Return zero
1536 return x;
1537 } else {
1538 return make_expression_ptr<ConstExpression>(1.0);
1539 }
1540 }
1541
1542 return make_expression_ptr<SignExpression>(x);
1543}
1544
1554 explicit SinExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1555 val = std::sin(args[0]->val);
1556 }
1557
1558 double value(double x, double) const override { return std::sin(x); }
1559
1560 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1561
1562 double grad_l(double x, double, double parent_adjoint) const override {
1563 return parent_adjoint * std::cos(x);
1564 }
1565
1567 const ExpressionPtr& x, const ExpressionPtr&,
1568 const ExpressionPtr& parent_adjoint) const override {
1569 return parent_adjoint * slp::detail::cos(x);
1570 }
1571};
1572
1578inline ExpressionPtr sin(const ExpressionPtr& x) {
1579 using enum ExpressionType;
1580
1581 // Prune expression
1582 if (x->is_constant(0.0)) {
1583 // Return zero
1584 return x;
1585 }
1586
1587 // Evaluate constant
1588 if (x->type() == CONSTANT) {
1589 return make_expression_ptr<ConstExpression>(std::sin(x->val));
1590 }
1591
1592 return make_expression_ptr<SinExpression>(x);
1593}
1594
1604 explicit SinhExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1605 val = std::sinh(args[0]->val);
1606 }
1607
1608 double value(double x, double) const override { return std::sinh(x); }
1609
1610 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1611
1612 double grad_l(double x, double, double parent_adjoint) const override {
1613 return parent_adjoint * std::cosh(x);
1614 }
1615
1617 const ExpressionPtr& x, const ExpressionPtr&,
1618 const ExpressionPtr& parent_adjoint) const override {
1619 return parent_adjoint * slp::detail::cosh(x);
1620 }
1621};
1622
1628inline ExpressionPtr sinh(const ExpressionPtr& x) {
1629 using enum ExpressionType;
1630
1631 // Prune expression
1632 if (x->is_constant(0.0)) {
1633 // Return zero
1634 return x;
1635 }
1636
1637 // Evaluate constant
1638 if (x->type() == CONSTANT) {
1639 return make_expression_ptr<ConstExpression>(std::sinh(x->val));
1640 }
1641
1642 return make_expression_ptr<SinhExpression>(x);
1643}
1644
1654 explicit SqrtExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1655 val = std::sqrt(args[0]->val);
1656 }
1657
1658 double value(double x, double) const override { return std::sqrt(x); }
1659
1660 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1661
1662 double grad_l(double x, double, double parent_adjoint) const override {
1663 return parent_adjoint / (2.0 * std::sqrt(x));
1664 }
1665
1667 const ExpressionPtr& x, const ExpressionPtr&,
1668 const ExpressionPtr& parent_adjoint) const override {
1669 return parent_adjoint /
1670 (make_expression_ptr<ConstExpression>(2.0) * slp::detail::sqrt(x));
1671 }
1672};
1673
1679inline ExpressionPtr sqrt(const ExpressionPtr& x) {
1680 using enum ExpressionType;
1681
1682 // Evaluate constant
1683 if (x->type() == CONSTANT) {
1684 if (x->val == 0.0) {
1685 // Return zero
1686 return x;
1687 } else if (x->val == 1.0) {
1688 return x;
1689 } else {
1690 return make_expression_ptr<ConstExpression>(std::sqrt(x->val));
1691 }
1692 }
1693
1694 return make_expression_ptr<SqrtExpression>(x);
1695}
1696
1706 explicit TanExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1707 val = std::tan(args[0]->val);
1708 }
1709
1710 double value(double x, double) const override { return std::tan(x); }
1711
1712 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1713
1714 double grad_l(double x, double, double parent_adjoint) const override {
1715 return parent_adjoint / (std::cos(x) * std::cos(x));
1716 }
1717
1719 const ExpressionPtr& x, const ExpressionPtr&,
1720 const ExpressionPtr& parent_adjoint) const override {
1721 return parent_adjoint / (slp::detail::cos(x) * slp::detail::cos(x));
1722 }
1723};
1724
1730inline ExpressionPtr tan(const ExpressionPtr& x) {
1731 using enum ExpressionType;
1732
1733 // Prune expression
1734 if (x->is_constant(0.0)) {
1735 // Return zero
1736 return x;
1737 }
1738
1739 // Evaluate constant
1740 if (x->type() == CONSTANT) {
1741 return make_expression_ptr<ConstExpression>(std::tan(x->val));
1742 }
1743
1744 return make_expression_ptr<TanExpression>(x);
1745}
1746
1756 explicit TanhExpression(ExpressionPtr lhs) : Expression{std::move(lhs)} {
1757 val = std::tanh(args[0]->val);
1758 }
1759
1760 double value(double x, double) const override { return std::tanh(x); }
1761
1762 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1763
1764 double grad_l(double x, double, double parent_adjoint) const override {
1765 return parent_adjoint / (std::cosh(x) * std::cosh(x));
1766 }
1767
1769 const ExpressionPtr& x, const ExpressionPtr&,
1770 const ExpressionPtr& parent_adjoint) const override {
1771 return parent_adjoint / (slp::detail::cosh(x) * slp::detail::cosh(x));
1772 }
1773};
1774
1780inline ExpressionPtr tanh(const ExpressionPtr& x) {
1781 using enum ExpressionType;
1782
1783 // Prune expression
1784 if (x->is_constant(0.0)) {
1785 // Return zero
1786 return x;
1787 }
1788
1789 // Evaluate constant
1790 if (x->type() == CONSTANT) {
1791 return make_expression_ptr<ConstExpression>(std::tanh(x->val));
1792 }
1793
1794 return make_expression_ptr<TanhExpression>(x);
1795}
1796
1797} // namespace slp::detail
Definition expression.hpp:730
ExpressionType type() const override
Definition expression.hpp:743
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:745
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:755
double value(double x, double) const override
Definition expression.hpp:741
constexpr AbsExpression(ExpressionPtr lhs)
Definition expression.hpp:736
Definition expression.hpp:794
ExpressionType type() const override
Definition expression.hpp:806
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:808
double value(double x, double) const override
Definition expression.hpp:804
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:812
AcosExpression(ExpressionPtr lhs)
Definition expression.hpp:800
Definition expression.hpp:844
double value(double x, double) const override
Definition expression.hpp:854
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:862
ExpressionType type() const override
Definition expression.hpp:856
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:858
AsinExpression(ExpressionPtr lhs)
Definition expression.hpp:850
Definition expression.hpp:945
ExpressionPtr grad_expr_r(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:975
double value(double y, double x) const override
Definition expression.hpp:957
ExpressionType type() const override
Definition expression.hpp:959
double grad_l(double y, double x, double parent_adjoint) const override
Definition expression.hpp:961
double grad_r(double y, double x, double parent_adjoint) const override
Definition expression.hpp:965
ExpressionPtr grad_expr_l(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:969
Atan2Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:952
Definition expression.hpp:895
double value(double x, double) const override
Definition expression.hpp:905
AtanExpression(ExpressionPtr lhs)
Definition expression.hpp:901
ExpressionType type() const override
Definition expression.hpp:907
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:913
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:909
Definition expression.hpp:430
double value(double lhs, double rhs) const override
Definition expression.hpp:442
constexpr BinaryMinusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:437
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:450
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:446
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:460
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:454
ExpressionType type() const override
Definition expression.hpp:444
Definition expression.hpp:473
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:493
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:489
double value(double lhs, double rhs) const override
Definition expression.hpp:485
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:503
constexpr BinaryPlusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:480
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:497
ExpressionType type() const override
Definition expression.hpp:487
Definition expression.hpp:513
constexpr ConstExpression(double value)
Definition expression.hpp:524
constexpr ConstExpression()=default
double value(double, double) const override
Definition expression.hpp:526
ExpressionType type() const override
Definition expression.hpp:528
Definition expression.hpp:1010
ExpressionType type() const override
Definition expression.hpp:1022
double value(double x, double) const override
Definition expression.hpp:1020
CosExpression(ExpressionPtr lhs)
Definition expression.hpp:1016
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1028
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1024
Definition expression.hpp:1059
CoshExpression(ExpressionPtr lhs)
Definition expression.hpp:1065
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1073
double value(double x, double) const override
Definition expression.hpp:1069
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1077
ExpressionType type() const override
Definition expression.hpp:1071
Definition expression.hpp:534
double value(double, double) const override
Definition expression.hpp:548
constexpr DecisionVariableExpression()=default
constexpr DecisionVariableExpression(double value)
Definition expression.hpp:545
ExpressionType type() const override
Definition expression.hpp:550
Definition expression.hpp:559
double grad_l(double, double rhs, double parent_adjoint) const override
Definition expression.hpp:575
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:589
double value(double lhs, double rhs) const override
Definition expression.hpp:571
constexpr DivExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:566
ExpressionType type() const override
Definition expression.hpp:573
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:583
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:579
Definition expression.hpp:1108
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1126
double value(double x, double) const override
Definition expression.hpp:1118
ExpressionType type() const override
Definition expression.hpp:1120
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1122
ErfExpression(ExpressionPtr lhs)
Definition expression.hpp:1114
Definition expression.hpp:1161
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1175
ExpExpression(ExpressionPtr lhs)
Definition expression.hpp:1167
ExpressionType type() const override
Definition expression.hpp:1173
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1179
double value(double x, double) const override
Definition expression.hpp:1171
Definition expression.hpp:76
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:87
virtual double grad_l(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:373
constexpr bool is_constant(double constant) const
Definition expression.hpp:137
constexpr Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:125
constexpr Expression(double value)
Definition expression.hpp:109
virtual ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:416
virtual ExpressionType type() const =0
ExpressionPtr adjoint_expr
Definition expression.hpp:91
friend ExpressionPtr operator+(const ExpressionPtr &lhs)
Definition expression.hpp:343
constexpr Expression(ExpressionPtr lhs)
Definition expression.hpp:116
double adjoint
The adjoint of the expression node used during autodiff.
Definition expression.hpp:81
friend ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:147
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:84
friend ExpressionPtr operator+=(ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:268
friend ExpressionPtr operator-(const ExpressionPtr &lhs)
Definition expression.hpp:315
friend ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:279
friend ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:199
virtual double grad_r(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:387
virtual double value(double lhs, double rhs) const =0
virtual ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:401
std::array< ExpressionPtr, 2 > args
Expression arguments.
Definition expression.hpp:97
double val
The value of the expression node.
Definition expression.hpp:78
constexpr Expression()=default
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:94
friend ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:236
Definition expression.hpp:1212
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1236
double grad_r(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1232
double grad_l(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1228
ExpressionPtr grad_expr_r(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1242
HypotExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1219
double value(double x, double y) const override
Definition expression.hpp:1224
ExpressionType type() const override
Definition expression.hpp:1226
Definition expression.hpp:1326
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1340
double value(double x, double) const override
Definition expression.hpp:1336
ExpressionType type() const override
Definition expression.hpp:1338
Log10Expression(ExpressionPtr lhs)
Definition expression.hpp:1332
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1344
Definition expression.hpp:1276
LogExpression(ExpressionPtr lhs)
Definition expression.hpp:1282
ExpressionType type() const override
Definition expression.hpp:1288
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1294
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1290
double value(double x, double) const override
Definition expression.hpp:1286
Definition expression.hpp:602
ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:628
constexpr MultExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:609
ExpressionType type() const override
Definition expression.hpp:616
double grad_l(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:618
double value(double lhs, double rhs) const override
Definition expression.hpp:614
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:634
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:623
Definition expression.hpp:1382
ExpressionPtr grad_expr_l(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1415
double value(double base, double power) const override
Definition expression.hpp:1394
double grad_l(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1400
double grad_r(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1405
ExpressionType type() const override
Definition expression.hpp:1398
ExpressionPtr grad_expr_r(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1424
PowExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1389
Definition expression.hpp:1484
double grad_l(double, double, double) const override
Definition expression.hpp:1513
ExpressionType type() const override
Definition expression.hpp:1511
constexpr SignExpression(ExpressionPtr lhs)
Definition expression.hpp:1490
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) const override
Definition expression.hpp:1515
double value(double x, double) const override
Definition expression.hpp:1501
Definition expression.hpp:1548
double value(double x, double) const override
Definition expression.hpp:1558
SinExpression(ExpressionPtr lhs)
Definition expression.hpp:1554
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1562
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1566
ExpressionType type() const override
Definition expression.hpp:1560
Definition expression.hpp:1598
ExpressionType type() const override
Definition expression.hpp:1610
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1612
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1616
SinhExpression(ExpressionPtr lhs)
Definition expression.hpp:1604
double value(double x, double) const override
Definition expression.hpp:1608
Definition expression.hpp:1648
SqrtExpression(ExpressionPtr lhs)
Definition expression.hpp:1654
ExpressionType type() const override
Definition expression.hpp:1660
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1662
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1666
double value(double x, double) const override
Definition expression.hpp:1658
Definition expression.hpp:1700
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1714
double value(double x, double) const override
Definition expression.hpp:1710
TanExpression(ExpressionPtr lhs)
Definition expression.hpp:1706
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1718
ExpressionType type() const override
Definition expression.hpp:1712
Definition expression.hpp:1750
TanhExpression(ExpressionPtr lhs)
Definition expression.hpp:1756
ExpressionType type() const override
Definition expression.hpp:1762
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1764
double value(double x, double) const override
Definition expression.hpp:1760
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1768
Definition expression.hpp:647
constexpr UnaryMinusExpression(ExpressionPtr lhs)
Definition expression.hpp:653
ExpressionType type() const override
Definition expression.hpp:660
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:666
double value(double lhs, double) const override
Definition expression.hpp:658
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:662