Sleipnir C++ API
Loading...
Searching...
No Matches
expression.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <stdint.h>
6
7#include <algorithm>
8#include <array>
9#include <cmath>
10#include <memory>
11#include <numbers>
12#include <utility>
13
14#include <gch/small_vector.hpp>
15
16#include "sleipnir/autodiff/expression_type.hpp"
17#include "sleipnir/util/intrusive_shared_ptr.hpp"
18#include "sleipnir/util/pool.hpp"
19
20namespace slp::detail {
21
22// The global pool allocator uses a thread-local static pool resource, which
23// isn't guaranteed to be initialized properly across DLL boundaries on Windows
24#ifdef _WIN32
25inline constexpr bool USE_POOL_ALLOCATOR = false;
26#else
27inline constexpr bool USE_POOL_ALLOCATOR = true;
28#endif
29
30struct Expression;
31
32inline constexpr void inc_ref_count(Expression* expr);
33inline constexpr void dec_ref_count(Expression* expr);
34
38using ExpressionPtr = IntrusiveSharedPtr<Expression>;
39
47template <typename T, typename... Args>
48static ExpressionPtr make_expression_ptr(Args&&... args) {
49 if constexpr (USE_POOL_ALLOCATOR) {
50 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
51 std::forward<Args>(args)...);
52 } else {
53 return make_intrusive_shared<T>(std::forward<Args>(args)...);
54 }
55}
56
57template <ExpressionType T>
58struct BinaryMinusExpression;
59
60template <ExpressionType T>
61struct BinaryPlusExpression;
62
63struct ConstExpression;
64
65template <ExpressionType T>
66struct DivExpression;
67
68template <ExpressionType T>
69struct MultExpression;
70
71template <ExpressionType T>
72struct UnaryMinusExpression;
73
77struct Expression {
79 double val = 0.0;
80
82 double adjoint = 0.0;
83
85 uint32_t incoming_edges = 0;
86
88 int32_t col = -1;
89
93
95 uint32_t ref_count = 0;
96
98 std::array<ExpressionPtr, 2> args{nullptr, nullptr};
99
103 constexpr Expression() = default;
104
110 explicit constexpr Expression(double value) : val{value} {}
111
117 explicit constexpr Expression(ExpressionPtr lhs)
118 : args{std::move(lhs), nullptr} {}
119
127 : args{std::move(lhs), std::move(rhs)} {}
128
129 virtual ~Expression() = default;
130
138 constexpr bool is_constant(double constant) const {
139 return type() == ExpressionType::CONSTANT && val == constant;
140 }
141
149 const ExpressionPtr& rhs) {
150 using enum ExpressionType;
151
152 // Prune expression
153 if (lhs->is_constant(0.0)) {
154 // Return zero
155 return lhs;
156 } else if (rhs->is_constant(0.0)) {
157 // Return zero
158 return rhs;
159 } else if (lhs->is_constant(1.0)) {
160 return rhs;
161 } else if (rhs->is_constant(1.0)) {
162 return lhs;
163 }
164
165 // Evaluate constant
166 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
167 return make_expression_ptr<ConstExpression>(lhs->val * rhs->val);
168 }
169
170 // Evaluate expression type
171 if (lhs->type() == CONSTANT) {
172 if (rhs->type() == LINEAR) {
173 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
174 } else if (rhs->type() == QUADRATIC) {
175 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
176 } else {
177 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
178 }
179 } else if (rhs->type() == CONSTANT) {
180 if (lhs->type() == LINEAR) {
181 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
182 } else if (lhs->type() == QUADRATIC) {
183 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
184 } else {
185 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
186 }
187 } else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
188 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
189 } else {
190 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
191 }
192 }
193
201 const ExpressionPtr& rhs) {
202 using enum ExpressionType;
203
204 // Prune expression
205 if (lhs->is_constant(0.0)) {
206 // Return zero
207 return lhs;
208 } else if (rhs->is_constant(1.0)) {
209 return lhs;
210 }
211
212 // Evaluate constant
213 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
214 return make_expression_ptr<ConstExpression>(lhs->val / rhs->val);
215 }
216
217 // Evaluate expression type
218 if (rhs->type() == CONSTANT) {
219 if (lhs->type() == LINEAR) {
220 return make_expression_ptr<DivExpression<LINEAR>>(lhs, rhs);
221 } else if (lhs->type() == QUADRATIC) {
222 return make_expression_ptr<DivExpression<QUADRATIC>>(lhs, rhs);
223 } else {
224 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
225 }
226 } else {
227 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
228 }
229 }
230
238 const ExpressionPtr& rhs) {
239 using enum ExpressionType;
240
241 // Prune expression
242 if (lhs == nullptr || lhs->is_constant(0.0)) {
243 return rhs;
244 } else if (rhs == nullptr || rhs->is_constant(0.0)) {
245 return lhs;
246 }
247
248 // Evaluate constant
249 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
250 return make_expression_ptr<ConstExpression>(lhs->val + rhs->val);
251 }
252
253 auto type = std::max(lhs->type(), rhs->type());
254 if (type == LINEAR) {
255 return make_expression_ptr<BinaryPlusExpression<LINEAR>>(lhs, rhs);
256 } else if (type == QUADRATIC) {
257 return make_expression_ptr<BinaryPlusExpression<QUADRATIC>>(lhs, rhs);
258 } else {
259 return make_expression_ptr<BinaryPlusExpression<NONLINEAR>>(lhs, rhs);
260 }
261 }
262
270 const ExpressionPtr& rhs) {
271 return lhs = lhs + rhs;
272 }
273
281 const ExpressionPtr& rhs) {
282 using enum ExpressionType;
283
284 // Prune expression
285 if (lhs->is_constant(0.0)) {
286 if (rhs->is_constant(0.0)) {
287 // Return zero
288 return rhs;
289 } else {
290 return -rhs;
291 }
292 } else if (rhs->is_constant(0.0)) {
293 return lhs;
294 }
295
296 // Evaluate constant
297 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
298 return make_expression_ptr<ConstExpression>(lhs->val - rhs->val);
299 }
300
301 auto type = std::max(lhs->type(), rhs->type());
302 if (type == LINEAR) {
303 return make_expression_ptr<BinaryMinusExpression<LINEAR>>(lhs, rhs);
304 } else if (type == QUADRATIC) {
305 return make_expression_ptr<BinaryMinusExpression<QUADRATIC>>(lhs, rhs);
306 } else {
307 return make_expression_ptr<BinaryMinusExpression<NONLINEAR>>(lhs, rhs);
308 }
309 }
310
317 using enum ExpressionType;
318
319 // Prune expression
320 if (lhs->is_constant(0.0)) {
321 // Return zero
322 return lhs;
323 }
324
325 // Evaluate constant
326 if (lhs->type() == CONSTANT) {
327 return make_expression_ptr<ConstExpression>(-lhs->val);
328 }
329
330 if (lhs->type() == LINEAR) {
331 return make_expression_ptr<UnaryMinusExpression<LINEAR>>(lhs);
332 } else if (lhs->type() == QUADRATIC) {
333 return make_expression_ptr<UnaryMinusExpression<QUADRATIC>>(lhs);
334 } else {
335 return make_expression_ptr<UnaryMinusExpression<NONLINEAR>>(lhs);
336 }
337 }
338
344 friend ExpressionPtr operator+(const ExpressionPtr& lhs) { return lhs; }
345
355 virtual double value([[maybe_unused]] double lhs,
356 [[maybe_unused]] double rhs) const = 0;
357
364 virtual ExpressionType type() const = 0;
365
374 virtual double grad_l([[maybe_unused]] double lhs,
375 [[maybe_unused]] double rhs,
376 [[maybe_unused]] double parent_adjoint) const {
377 return 0.0;
378 }
379
388 virtual double grad_r([[maybe_unused]] double lhs,
389 [[maybe_unused]] double rhs,
390 [[maybe_unused]] double parent_adjoint) const {
391 return 0.0;
392 }
393
403 [[maybe_unused]] const ExpressionPtr& lhs,
404 [[maybe_unused]] const ExpressionPtr& rhs,
405 [[maybe_unused]] const ExpressionPtr& parent_adjoint) const {
406 return make_expression_ptr<ConstExpression>();
407 }
408
418 [[maybe_unused]] const ExpressionPtr& lhs,
419 [[maybe_unused]] const ExpressionPtr& rhs,
420 [[maybe_unused]] const ExpressionPtr& parent_adjoint) const {
421 return make_expression_ptr<ConstExpression>();
422 }
423};
424
425inline ExpressionPtr cbrt(const ExpressionPtr& x);
426inline ExpressionPtr exp(const ExpressionPtr& x);
427inline ExpressionPtr sin(const ExpressionPtr& x);
428inline ExpressionPtr sinh(const ExpressionPtr& x);
429inline ExpressionPtr sqrt(const ExpressionPtr& x);
430
436template <ExpressionType T>
445 : Expression{std::move(lhs), std::move(rhs)} {}
446
447 double value(double lhs, double rhs) const override { return lhs - rhs; }
448
449 ExpressionType type() const override { return T; }
450
451 double grad_l(double, double, double parent_adjoint) const override {
452 return parent_adjoint;
453 }
454
455 double grad_r(double, double, double parent_adjoint) const override {
456 return -parent_adjoint;
457 }
458
460 const ExpressionPtr&, const ExpressionPtr&,
461 const ExpressionPtr& parent_adjoint) const override {
462 return parent_adjoint;
463 }
464
466 const ExpressionPtr&, const ExpressionPtr&,
467 const ExpressionPtr& parent_adjoint) const override {
468 return -parent_adjoint;
469 }
470};
471
477template <ExpressionType T>
486 : Expression{std::move(lhs), std::move(rhs)} {}
487
488 double value(double lhs, double rhs) const override { return lhs + rhs; }
489
490 ExpressionType type() const override { return T; }
491
492 double grad_l(double, double, double parent_adjoint) const override {
493 return parent_adjoint;
494 }
495
496 double grad_r(double, double, double parent_adjoint) const override {
497 return parent_adjoint;
498 }
499
501 const ExpressionPtr&, const ExpressionPtr&,
502 const ExpressionPtr& parent_adjoint) const override {
503 return parent_adjoint;
504 }
505
507 const ExpressionPtr&, const ExpressionPtr&,
508 const ExpressionPtr& parent_adjoint) const override {
509 return parent_adjoint;
510 }
511};
512
522 explicit constexpr CbrtExpression(ExpressionPtr lhs)
523 : Expression{std::move(lhs)} {}
524
525 double value(double x, double) const override { return std::cbrt(x); }
526
527 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
528
529 double grad_l(double x, double, double parent_adjoint) const override {
530 double c = std::cbrt(x);
531 return parent_adjoint / (3.0 * c * c);
532 }
533
535 const ExpressionPtr& x, const ExpressionPtr&,
536 const ExpressionPtr& parent_adjoint) const override {
537 auto c = slp::detail::cbrt(x);
538 return parent_adjoint / (make_expression_ptr<ConstExpression>(3.0) * c * c);
539 }
540};
541
547inline ExpressionPtr cbrt(const ExpressionPtr& x) {
548 using enum ExpressionType;
549
550 // Evaluate constant
551 if (x->type() == CONSTANT) {
552 if (x->val == 0.0) {
553 // Return zero
554 return x;
555 } else if (x->val == -1.0 || x->val == 1.0) {
556 return x;
557 } else {
558 return make_expression_ptr<ConstExpression>(std::cbrt(x->val));
559 }
560 }
561
562 return make_expression_ptr<CbrtExpression>(x);
563}
564
572 constexpr ConstExpression() = default;
573
579 explicit constexpr ConstExpression(double value) : Expression{value} {}
580
581 double value(double, double) const override { return val; }
582
583 ExpressionType type() const override { return ExpressionType::CONSTANT; }
584};
585
593 constexpr DecisionVariableExpression() = default;
594
600 explicit constexpr DecisionVariableExpression(double value)
601 : Expression{value} {}
602
603 double value(double, double) const override { return val; }
604
605 ExpressionType type() const override { return ExpressionType::LINEAR; }
606};
607
613template <ExpressionType T>
614struct DivExpression final : Expression {
622 : Expression{std::move(lhs), std::move(rhs)} {}
623
624 double value(double lhs, double rhs) const override { return lhs / rhs; }
625
626 ExpressionType type() const override { return T; }
627
628 double grad_l(double, double rhs, double parent_adjoint) const override {
629 return parent_adjoint / rhs;
630 };
631
632 double grad_r(double lhs, double rhs, double parent_adjoint) const override {
633 return parent_adjoint * -lhs / (rhs * rhs);
634 }
635
637 const ExpressionPtr&, const ExpressionPtr& rhs,
638 const ExpressionPtr& parent_adjoint) const override {
639 return parent_adjoint / rhs;
640 }
641
643 const ExpressionPtr& lhs, const ExpressionPtr& rhs,
644 const ExpressionPtr& parent_adjoint) const override {
645 return parent_adjoint * -lhs / (rhs * rhs);
646 }
647};
648
654template <ExpressionType T>
663 : Expression{std::move(lhs), std::move(rhs)} {}
664
665 double value(double lhs, double rhs) const override { return lhs * rhs; }
666
667 ExpressionType type() const override { return T; }
668
669 double grad_l([[maybe_unused]] double lhs, double rhs,
670 double parent_adjoint) const override {
671 return parent_adjoint * rhs;
672 }
673
674 double grad_r(double lhs, [[maybe_unused]] double rhs,
675 double parent_adjoint) const override {
676 return parent_adjoint * lhs;
677 }
678
680 [[maybe_unused]] const ExpressionPtr& lhs, const ExpressionPtr& rhs,
681 const ExpressionPtr& parent_adjoint) const override {
682 return parent_adjoint * rhs;
683 }
684
686 const ExpressionPtr& lhs, [[maybe_unused]] const ExpressionPtr& rhs,
687 const ExpressionPtr& parent_adjoint) const override {
688 return parent_adjoint * lhs;
689 }
690};
691
697template <ExpressionType T>
704 explicit constexpr UnaryMinusExpression(ExpressionPtr lhs)
705 : Expression{std::move(lhs)} {}
706
707 double value(double lhs, double) const override { return -lhs; }
708
709 ExpressionType type() const override { return T; }
710
711 double grad_l(double, double, double parent_adjoint) const override {
712 return -parent_adjoint;
713 }
714
716 const ExpressionPtr&, const ExpressionPtr&,
717 const ExpressionPtr& parent_adjoint) const override {
718 return -parent_adjoint;
719 }
720};
721
727inline constexpr void inc_ref_count(Expression* expr) {
728 ++expr->ref_count;
729}
730
736inline constexpr void dec_ref_count(Expression* expr) {
737 // If a deeply nested tree is being deallocated all at once, calling the
738 // Expression destructor when expr's refcount reaches zero can cause a stack
739 // overflow. Instead, we iterate over its children to decrement their
740 // refcounts and deallocate them.
741 gch::small_vector<Expression*> stack;
742 stack.emplace_back(expr);
743
744 while (!stack.empty()) {
745 auto elem = stack.back();
746 stack.pop_back();
747
748 // Decrement the current node's refcount. If it reaches zero, deallocate the
749 // node and enqueue its children so their refcounts are decremented too.
750 if (--elem->ref_count == 0) {
751 if (elem->adjoint_expr != nullptr) {
752 stack.emplace_back(elem->adjoint_expr.get());
753 }
754 for (auto& arg : elem->args) {
755 if (arg != nullptr) {
756 stack.emplace_back(arg.get());
757 }
758 }
759
760 // Not calling the destructor here is safe because it only decrements
761 // refcounts, which was already done above.
762 if constexpr (USE_POOL_ALLOCATOR) {
763 auto alloc = global_pool_allocator<Expression>();
764 std::allocator_traits<decltype(alloc)>::deallocate(alloc, elem,
765 sizeof(Expression));
766 }
767 }
768 }
769}
770
774struct AbsExpression final : Expression {
780 explicit constexpr AbsExpression(ExpressionPtr lhs)
781 : Expression{std::move(lhs)} {}
782
783 double value(double x, double) const override { return std::abs(x); }
784
785 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
786
787 double grad_l(double x, double, double parent_adjoint) const override {
788 if (x < 0.0) {
789 return -parent_adjoint;
790 } else if (x > 0.0) {
791 return parent_adjoint;
792 } else {
793 return 0.0;
794 }
795 }
796
798 const ExpressionPtr& x, const ExpressionPtr&,
799 const ExpressionPtr& parent_adjoint) const override {
800 if (x->val < 0.0) {
801 return -parent_adjoint;
802 } else if (x->val > 0.0) {
803 return parent_adjoint;
804 } else {
805 // Return zero
806 return make_expression_ptr<ConstExpression>();
807 }
808 }
809};
810
816inline ExpressionPtr abs(const ExpressionPtr& x) {
817 using enum ExpressionType;
818
819 // Prune expression
820 if (x->is_constant(0.0)) {
821 // Return zero
822 return x;
823 }
824
825 // Evaluate constant
826 if (x->type() == CONSTANT) {
827 return make_expression_ptr<ConstExpression>(std::abs(x->val));
828 }
829
830 return make_expression_ptr<AbsExpression>(x);
831}
832
842 explicit constexpr AcosExpression(ExpressionPtr lhs)
843 : Expression{std::move(lhs)} {}
844
845 double value(double x, double) const override { return std::acos(x); }
846
847 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
848
849 double grad_l(double x, double, double parent_adjoint) const override {
850 return -parent_adjoint / std::sqrt(1.0 - x * x);
851 }
852
854 const ExpressionPtr& x, const ExpressionPtr&,
855 const ExpressionPtr& parent_adjoint) const override {
856 return -parent_adjoint /
857 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
858 }
859};
860
866inline ExpressionPtr acos(const ExpressionPtr& x) {
867 using enum ExpressionType;
868
869 // Prune expression
870 if (x->is_constant(0.0)) {
871 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
872 }
873
874 // Evaluate constant
875 if (x->type() == CONSTANT) {
876 return make_expression_ptr<ConstExpression>(std::acos(x->val));
877 }
878
879 return make_expression_ptr<AcosExpression>(x);
880}
881
891 explicit constexpr AsinExpression(ExpressionPtr lhs)
892 : Expression{std::move(lhs)} {}
893
894 double value(double x, double) const override { return std::asin(x); }
895
896 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
897
898 double grad_l(double x, double, double parent_adjoint) const override {
899 return parent_adjoint / std::sqrt(1.0 - x * x);
900 }
901
903 const ExpressionPtr& x, const ExpressionPtr&,
904 const ExpressionPtr& parent_adjoint) const override {
905 return parent_adjoint /
906 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
907 }
908};
909
915inline ExpressionPtr asin(const ExpressionPtr& x) {
916 using enum ExpressionType;
917
918 // Prune expression
919 if (x->is_constant(0.0)) {
920 // Return zero
921 return x;
922 }
923
924 // Evaluate constant
925 if (x->type() == CONSTANT) {
926 return make_expression_ptr<ConstExpression>(std::asin(x->val));
927 }
928
929 return make_expression_ptr<AsinExpression>(x);
930}
931
941 explicit constexpr AtanExpression(ExpressionPtr lhs)
942 : Expression{std::move(lhs)} {}
943
944 double value(double x, double) const override { return std::atan(x); }
945
946 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
947
948 double grad_l(double x, double, double parent_adjoint) const override {
949 return parent_adjoint / (1.0 + x * x);
950 }
951
953 const ExpressionPtr& x, const ExpressionPtr&,
954 const ExpressionPtr& parent_adjoint) const override {
955 return parent_adjoint / (make_expression_ptr<ConstExpression>(1.0) + x * x);
956 }
957};
958
964inline ExpressionPtr atan(const ExpressionPtr& x) {
965 using enum ExpressionType;
966
967 // Prune expression
968 if (x->is_constant(0.0)) {
969 // Return zero
970 return x;
971 }
972
973 // Evaluate constant
974 if (x->type() == CONSTANT) {
975 return make_expression_ptr<ConstExpression>(std::atan(x->val));
976 }
977
978 return make_expression_ptr<AtanExpression>(x);
979}
980
992 : Expression{std::move(lhs), std::move(rhs)} {}
993
994 double value(double y, double x) const override { return std::atan2(y, x); }
995
996 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
997
998 double grad_l(double y, double x, double parent_adjoint) const override {
999 return parent_adjoint * x / (y * y + x * x);
1000 }
1001
1002 double grad_r(double y, double x, double parent_adjoint) const override {
1003 return parent_adjoint * -y / (y * y + x * x);
1004 }
1005
1007 const ExpressionPtr& y, const ExpressionPtr& x,
1008 const ExpressionPtr& parent_adjoint) const override {
1009 return parent_adjoint * x / (y * y + x * x);
1010 }
1011
1013 const ExpressionPtr& y, const ExpressionPtr& x,
1014 const ExpressionPtr& parent_adjoint) const override {
1015 return parent_adjoint * -y / (y * y + x * x);
1016 }
1017};
1018
1025inline ExpressionPtr atan2(const ExpressionPtr& y, const ExpressionPtr& x) {
1026 using enum ExpressionType;
1027
1028 // Prune expression
1029 if (y->is_constant(0.0)) {
1030 // Return zero
1031 return y;
1032 } else if (x->is_constant(0.0)) {
1033 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
1034 }
1035
1036 // Evaluate constant
1037 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1038 return make_expression_ptr<ConstExpression>(std::atan2(y->val, x->val));
1039 }
1040
1041 return make_expression_ptr<Atan2Expression>(y, x);
1042}
1043
1053 explicit constexpr CosExpression(ExpressionPtr lhs)
1054 : Expression{std::move(lhs)} {}
1055
1056 double value(double x, double) const override { return std::cos(x); }
1057
1058 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1059
1060 double grad_l(double x, double, double parent_adjoint) const override {
1061 return -parent_adjoint * std::sin(x);
1062 }
1063
1065 const ExpressionPtr& x, const ExpressionPtr&,
1066 const ExpressionPtr& parent_adjoint) const override {
1067 return parent_adjoint * -slp::detail::sin(x);
1068 }
1069};
1070
1076inline ExpressionPtr cos(const ExpressionPtr& x) {
1077 using enum ExpressionType;
1078
1079 // Prune expression
1080 if (x->is_constant(0.0)) {
1081 return make_expression_ptr<ConstExpression>(1.0);
1082 }
1083
1084 // Evaluate constant
1085 if (x->type() == CONSTANT) {
1086 return make_expression_ptr<ConstExpression>(std::cos(x->val));
1087 }
1088
1089 return make_expression_ptr<CosExpression>(x);
1090}
1091
1101 explicit constexpr CoshExpression(ExpressionPtr lhs)
1102 : Expression{std::move(lhs)} {}
1103
1104 double value(double x, double) const override { return std::cosh(x); }
1105
1106 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1107
1108 double grad_l(double x, double, double parent_adjoint) const override {
1109 return parent_adjoint * std::sinh(x);
1110 }
1111
1113 const ExpressionPtr& x, const ExpressionPtr&,
1114 const ExpressionPtr& parent_adjoint) const override {
1115 return parent_adjoint * slp::detail::sinh(x);
1116 }
1117};
1118
1124inline ExpressionPtr cosh(const ExpressionPtr& x) {
1125 using enum ExpressionType;
1126
1127 // Prune expression
1128 if (x->is_constant(0.0)) {
1129 return make_expression_ptr<ConstExpression>(1.0);
1130 }
1131
1132 // Evaluate constant
1133 if (x->type() == CONSTANT) {
1134 return make_expression_ptr<ConstExpression>(std::cosh(x->val));
1135 }
1136
1137 return make_expression_ptr<CoshExpression>(x);
1138}
1139
1149 explicit constexpr ErfExpression(ExpressionPtr lhs)
1150 : Expression{std::move(lhs)} {}
1151
1152 double value(double x, double) const override { return std::erf(x); }
1153
1154 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1155
1156 double grad_l(double x, double, double parent_adjoint) const override {
1157 return parent_adjoint * 2.0 * std::numbers::inv_sqrtpi * std::exp(-x * x);
1158 }
1159
1161 const ExpressionPtr& x, const ExpressionPtr&,
1162 const ExpressionPtr& parent_adjoint) const override {
1163 return parent_adjoint *
1164 make_expression_ptr<ConstExpression>(2.0 *
1165 std::numbers::inv_sqrtpi) *
1166 slp::detail::exp(-x * x);
1167 }
1168};
1169
1175inline ExpressionPtr erf(const ExpressionPtr& x) {
1176 using enum ExpressionType;
1177
1178 // Prune expression
1179 if (x->is_constant(0.0)) {
1180 // Return zero
1181 return x;
1182 }
1183
1184 // Evaluate constant
1185 if (x->type() == CONSTANT) {
1186 return make_expression_ptr<ConstExpression>(std::erf(x->val));
1187 }
1188
1189 return make_expression_ptr<ErfExpression>(x);
1190}
1191
1201 explicit constexpr ExpExpression(ExpressionPtr lhs)
1202 : Expression{std::move(lhs)} {}
1203
1204 double value(double x, double) const override { return std::exp(x); }
1205
1206 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1207
1208 double grad_l(double x, double, double parent_adjoint) const override {
1209 return parent_adjoint * std::exp(x);
1210 }
1211
1213 const ExpressionPtr& x, const ExpressionPtr&,
1214 const ExpressionPtr& parent_adjoint) const override {
1215 return parent_adjoint * slp::detail::exp(x);
1216 }
1217};
1218
1224inline ExpressionPtr exp(const ExpressionPtr& x) {
1225 using enum ExpressionType;
1226
1227 // Prune expression
1228 if (x->is_constant(0.0)) {
1229 return make_expression_ptr<ConstExpression>(1.0);
1230 }
1231
1232 // Evaluate constant
1233 if (x->type() == CONSTANT) {
1234 return make_expression_ptr<ConstExpression>(std::exp(x->val));
1235 }
1236
1237 return make_expression_ptr<ExpExpression>(x);
1238}
1239
1240inline ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y);
1241
1253 : Expression{std::move(lhs), std::move(rhs)} {}
1254
1255 double value(double x, double y) const override { return std::hypot(x, y); }
1256
1257 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1258
1259 double grad_l(double x, double y, double parent_adjoint) const override {
1260 return parent_adjoint * x / std::hypot(x, y);
1261 }
1262
1263 double grad_r(double x, double y, double parent_adjoint) const override {
1264 return parent_adjoint * y / std::hypot(x, y);
1265 }
1266
1268 const ExpressionPtr& x, const ExpressionPtr& y,
1269 const ExpressionPtr& parent_adjoint) const override {
1270 return parent_adjoint * x / slp::detail::hypot(x, y);
1271 }
1272
1274 const ExpressionPtr& x, const ExpressionPtr& y,
1275 const ExpressionPtr& parent_adjoint) const override {
1276 return parent_adjoint * y / slp::detail::hypot(x, y);
1277 }
1278};
1279
1286inline ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y) {
1287 using enum ExpressionType;
1288
1289 // Prune expression
1290 if (x->is_constant(0.0)) {
1291 return y;
1292 } else if (y->is_constant(0.0)) {
1293 return x;
1294 }
1295
1296 // Evaluate constant
1297 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1298 return make_expression_ptr<ConstExpression>(std::hypot(x->val, y->val));
1299 }
1300
1301 return make_expression_ptr<HypotExpression>(x, y);
1302}
1303
1313 explicit constexpr LogExpression(ExpressionPtr lhs)
1314 : Expression{std::move(lhs)} {}
1315
1316 double value(double x, double) const override { return std::log(x); }
1317
1318 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1319
1320 double grad_l(double x, double, double parent_adjoint) const override {
1321 return parent_adjoint / x;
1322 }
1323
1325 const ExpressionPtr& x, const ExpressionPtr&,
1326 const ExpressionPtr& parent_adjoint) const override {
1327 return parent_adjoint / x;
1328 }
1329};
1330
1336inline ExpressionPtr log(const ExpressionPtr& x) {
1337 using enum ExpressionType;
1338
1339 // Prune expression
1340 if (x->is_constant(0.0)) {
1341 // Return zero
1342 return x;
1343 }
1344
1345 // Evaluate constant
1346 if (x->type() == CONSTANT) {
1347 return make_expression_ptr<ConstExpression>(std::log(x->val));
1348 }
1349
1350 return make_expression_ptr<LogExpression>(x);
1351}
1352
1362 explicit constexpr Log10Expression(ExpressionPtr lhs)
1363 : Expression{std::move(lhs)} {}
1364
1365 double value(double x, double) const override { return std::log10(x); }
1366
1367 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1368
1369 double grad_l(double x, double, double parent_adjoint) const override {
1370 return parent_adjoint / (std::numbers::ln10 * x);
1371 }
1372
1374 const ExpressionPtr& x, const ExpressionPtr&,
1375 const ExpressionPtr& parent_adjoint) const override {
1376 return parent_adjoint /
1377 (make_expression_ptr<ConstExpression>(std::numbers::ln10) * x);
1378 }
1379};
1380
1386inline ExpressionPtr log10(const ExpressionPtr& x) {
1387 using enum ExpressionType;
1388
1389 // Prune expression
1390 if (x->is_constant(0.0)) {
1391 // Return zero
1392 return x;
1393 }
1394
1395 // Evaluate constant
1396 if (x->type() == CONSTANT) {
1397 return make_expression_ptr<ConstExpression>(std::log10(x->val));
1398 }
1399
1400 return make_expression_ptr<Log10Expression>(x);
1401}
1402
1403inline ExpressionPtr pow(const ExpressionPtr& base, const ExpressionPtr& power);
1404
1410template <ExpressionType T>
1419 : Expression{std::move(lhs), std::move(rhs)} {}
1420
1421 double value(double base, double power) const override {
1422 return std::pow(base, power);
1423 }
1424
1425 ExpressionType type() const override { return T; }
1426
1427 double grad_l(double base, double power,
1428 double parent_adjoint) const override {
1429 return parent_adjoint * std::pow(base, power - 1) * power;
1430 }
1431
1432 double grad_r(double base, double power,
1433 double parent_adjoint) const override {
1434 // Since x * std::log(x) -> 0 as x -> 0
1435 if (base == 0.0) {
1436 return 0.0;
1437 } else {
1438 return parent_adjoint * std::pow(base, power - 1) * base * std::log(base);
1439 }
1440 }
1441
1443 const ExpressionPtr& base, const ExpressionPtr& power,
1444 const ExpressionPtr& parent_adjoint) const override {
1445 return parent_adjoint *
1446 slp::detail::pow(base,
1447 power - make_expression_ptr<ConstExpression>(1.0)) *
1448 power;
1449 }
1450
1452 const ExpressionPtr& base, const ExpressionPtr& power,
1453 const ExpressionPtr& parent_adjoint) const override {
1454 // Since x * std::log(x) -> 0 as x -> 0
1455 if (base->val == 0.0) {
1456 // Return zero
1457 return base;
1458 } else {
1459 return parent_adjoint *
1460 slp::detail::pow(
1461 base, power - make_expression_ptr<ConstExpression>(1.0)) *
1462 base * slp::detail::log(base);
1463 }
1464 }
1465};
1466
1473inline ExpressionPtr pow(const ExpressionPtr& base,
1474 const ExpressionPtr& power) {
1475 using enum ExpressionType;
1476
1477 // Prune expression
1478 if (base->is_constant(0.0)) {
1479 // Return zero
1480 return base;
1481 } else if (base->is_constant(1.0)) {
1482 // Return one
1483 return base;
1484 }
1485 if (power->is_constant(0.0)) {
1486 return make_expression_ptr<ConstExpression>(1.0);
1487 } else if (power->is_constant(1.0)) {
1488 return base;
1489 }
1490
1491 // Evaluate constant
1492 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1493 return make_expression_ptr<ConstExpression>(
1494 std::pow(base->val, power->val));
1495 }
1496
1497 if (power->is_constant(2.0)) {
1498 if (base->type() == LINEAR) {
1499 return make_expression_ptr<MultExpression<QUADRATIC>>(base, base);
1500 } else {
1501 return make_expression_ptr<MultExpression<NONLINEAR>>(base, base);
1502 }
1503 }
1504
1505 return make_expression_ptr<PowExpression<NONLINEAR>>(base, power);
1506}
1507
1517 explicit constexpr SignExpression(ExpressionPtr lhs)
1518 : Expression{std::move(lhs)} {}
1519
1520 double value(double x, double) const override {
1521 if (x < 0.0) {
1522 return -1.0;
1523 } else if (x == 0.0) {
1524 return 0.0;
1525 } else {
1526 return 1.0;
1527 }
1528 }
1529
1530 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1531
1532 double grad_l(double, double, double) const override { return 0.0; }
1533
1535 const ExpressionPtr&) const override {
1536 // Return zero
1537 return make_expression_ptr<ConstExpression>();
1538 }
1539};
1540
1546inline ExpressionPtr sign(const ExpressionPtr& x) {
1547 using enum ExpressionType;
1548
1549 // Evaluate constant
1550 if (x->type() == CONSTANT) {
1551 if (x->val < 0.0) {
1552 return make_expression_ptr<ConstExpression>(-1.0);
1553 } else if (x->val == 0.0) {
1554 // Return zero
1555 return x;
1556 } else {
1557 return make_expression_ptr<ConstExpression>(1.0);
1558 }
1559 }
1560
1561 return make_expression_ptr<SignExpression>(x);
1562}
1563
1573 explicit constexpr SinExpression(ExpressionPtr lhs)
1574 : Expression{std::move(lhs)} {}
1575
1576 double value(double x, double) const override { return std::sin(x); }
1577
1578 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1579
1580 double grad_l(double x, double, double parent_adjoint) const override {
1581 return parent_adjoint * std::cos(x);
1582 }
1583
1585 const ExpressionPtr& x, const ExpressionPtr&,
1586 const ExpressionPtr& parent_adjoint) const override {
1587 return parent_adjoint * slp::detail::cos(x);
1588 }
1589};
1590
1596inline ExpressionPtr sin(const ExpressionPtr& x) {
1597 using enum ExpressionType;
1598
1599 // Prune expression
1600 if (x->is_constant(0.0)) {
1601 // Return zero
1602 return x;
1603 }
1604
1605 // Evaluate constant
1606 if (x->type() == CONSTANT) {
1607 return make_expression_ptr<ConstExpression>(std::sin(x->val));
1608 }
1609
1610 return make_expression_ptr<SinExpression>(x);
1611}
1612
1622 explicit constexpr SinhExpression(ExpressionPtr lhs)
1623 : Expression{std::move(lhs)} {}
1624
1625 double value(double x, double) const override { return std::sinh(x); }
1626
1627 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1628
1629 double grad_l(double x, double, double parent_adjoint) const override {
1630 return parent_adjoint * std::cosh(x);
1631 }
1632
1634 const ExpressionPtr& x, const ExpressionPtr&,
1635 const ExpressionPtr& parent_adjoint) const override {
1636 return parent_adjoint * slp::detail::cosh(x);
1637 }
1638};
1639
1645inline ExpressionPtr sinh(const ExpressionPtr& x) {
1646 using enum ExpressionType;
1647
1648 // Prune expression
1649 if (x->is_constant(0.0)) {
1650 // Return zero
1651 return x;
1652 }
1653
1654 // Evaluate constant
1655 if (x->type() == CONSTANT) {
1656 return make_expression_ptr<ConstExpression>(std::sinh(x->val));
1657 }
1658
1659 return make_expression_ptr<SinhExpression>(x);
1660}
1661
1671 explicit constexpr SqrtExpression(ExpressionPtr lhs)
1672 : Expression{std::move(lhs)} {}
1673
1674 double value(double x, double) const override { return std::sqrt(x); }
1675
1676 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1677
1678 double grad_l(double x, double, double parent_adjoint) const override {
1679 return parent_adjoint / (2.0 * std::sqrt(x));
1680 }
1681
1683 const ExpressionPtr& x, const ExpressionPtr&,
1684 const ExpressionPtr& parent_adjoint) const override {
1685 return parent_adjoint /
1686 (make_expression_ptr<ConstExpression>(2.0) * slp::detail::sqrt(x));
1687 }
1688};
1689
1695inline ExpressionPtr sqrt(const ExpressionPtr& x) {
1696 using enum ExpressionType;
1697
1698 // Evaluate constant
1699 if (x->type() == CONSTANT) {
1700 if (x->val == 0.0) {
1701 // Return zero
1702 return x;
1703 } else if (x->val == 1.0) {
1704 return x;
1705 } else {
1706 return make_expression_ptr<ConstExpression>(std::sqrt(x->val));
1707 }
1708 }
1709
1710 return make_expression_ptr<SqrtExpression>(x);
1711}
1712
1722 explicit constexpr TanExpression(ExpressionPtr lhs)
1723 : Expression{std::move(lhs)} {}
1724
1725 double value(double x, double) const override { return std::tan(x); }
1726
1727 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1728
1729 double grad_l(double x, double, double parent_adjoint) const override {
1730 return parent_adjoint / (std::cos(x) * std::cos(x));
1731 }
1732
1734 const ExpressionPtr& x, const ExpressionPtr&,
1735 const ExpressionPtr& parent_adjoint) const override {
1736 return parent_adjoint / (slp::detail::cos(x) * slp::detail::cos(x));
1737 }
1738};
1739
1745inline ExpressionPtr tan(const ExpressionPtr& x) {
1746 using enum ExpressionType;
1747
1748 // Prune expression
1749 if (x->is_constant(0.0)) {
1750 // Return zero
1751 return x;
1752 }
1753
1754 // Evaluate constant
1755 if (x->type() == CONSTANT) {
1756 return make_expression_ptr<ConstExpression>(std::tan(x->val));
1757 }
1758
1759 return make_expression_ptr<TanExpression>(x);
1760}
1761
1771 explicit constexpr TanhExpression(ExpressionPtr lhs)
1772 : Expression{std::move(lhs)} {}
1773
1774 double value(double x, double) const override { return std::tanh(x); }
1775
1776 ExpressionType type() const override { return ExpressionType::NONLINEAR; }
1777
1778 double grad_l(double x, double, double parent_adjoint) const override {
1779 return parent_adjoint / (std::cosh(x) * std::cosh(x));
1780 }
1781
1783 const ExpressionPtr& x, const ExpressionPtr&,
1784 const ExpressionPtr& parent_adjoint) const override {
1785 return parent_adjoint / (slp::detail::cosh(x) * slp::detail::cosh(x));
1786 }
1787};
1788
1794inline ExpressionPtr tanh(const ExpressionPtr& x) {
1795 using enum ExpressionType;
1796
1797 // Prune expression
1798 if (x->is_constant(0.0)) {
1799 // Return zero
1800 return x;
1801 }
1802
1803 // Evaluate constant
1804 if (x->type() == CONSTANT) {
1805 return make_expression_ptr<ConstExpression>(std::tanh(x->val));
1806 }
1807
1808 return make_expression_ptr<TanhExpression>(x);
1809}
1810
1811} // namespace slp::detail
Definition expression.hpp:774
ExpressionType type() const override
Definition expression.hpp:785
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:787
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:797
double value(double x, double) const override
Definition expression.hpp:783
constexpr AbsExpression(ExpressionPtr lhs)
Definition expression.hpp:780
Definition expression.hpp:836
ExpressionType type() const override
Definition expression.hpp:847
constexpr AcosExpression(ExpressionPtr lhs)
Definition expression.hpp:842
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:849
double value(double x, double) const override
Definition expression.hpp:845
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:853
Definition expression.hpp:885
double value(double x, double) const override
Definition expression.hpp:894
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:902
ExpressionType type() const override
Definition expression.hpp:896
constexpr AsinExpression(ExpressionPtr lhs)
Definition expression.hpp:891
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:898
Definition expression.hpp:984
ExpressionPtr grad_expr_r(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1012
double value(double y, double x) const override
Definition expression.hpp:994
ExpressionType type() const override
Definition expression.hpp:996
constexpr Atan2Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:991
double grad_l(double y, double x, double parent_adjoint) const override
Definition expression.hpp:998
double grad_r(double y, double x, double parent_adjoint) const override
Definition expression.hpp:1002
ExpressionPtr grad_expr_l(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1006
Definition expression.hpp:935
constexpr AtanExpression(ExpressionPtr lhs)
Definition expression.hpp:941
double value(double x, double) const override
Definition expression.hpp:944
ExpressionType type() const override
Definition expression.hpp:946
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:952
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:948
Definition expression.hpp:437
double value(double lhs, double rhs) const override
Definition expression.hpp:447
constexpr BinaryMinusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:444
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:455
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:451
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:465
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:459
ExpressionType type() const override
Definition expression.hpp:449
Definition expression.hpp:478
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:496
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:492
double value(double lhs, double rhs) const override
Definition expression.hpp:488
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:506
constexpr BinaryPlusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:485
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:500
ExpressionType type() const override
Definition expression.hpp:490
Definition expression.hpp:516
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:529
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:534
ExpressionType type() const override
Definition expression.hpp:527
constexpr CbrtExpression(ExpressionPtr lhs)
Definition expression.hpp:522
double value(double x, double) const override
Definition expression.hpp:525
Definition expression.hpp:568
constexpr ConstExpression(double value)
Definition expression.hpp:579
constexpr ConstExpression()=default
double value(double, double) const override
Definition expression.hpp:581
ExpressionType type() const override
Definition expression.hpp:583
Definition expression.hpp:1047
ExpressionType type() const override
Definition expression.hpp:1058
double value(double x, double) const override
Definition expression.hpp:1056
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1064
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1060
constexpr CosExpression(ExpressionPtr lhs)
Definition expression.hpp:1053
Definition expression.hpp:1095
constexpr CoshExpression(ExpressionPtr lhs)
Definition expression.hpp:1101
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1108
double value(double x, double) const override
Definition expression.hpp:1104
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1112
ExpressionType type() const override
Definition expression.hpp:1106
Definition expression.hpp:589
double value(double, double) const override
Definition expression.hpp:603
constexpr DecisionVariableExpression()=default
constexpr DecisionVariableExpression(double value)
Definition expression.hpp:600
ExpressionType type() const override
Definition expression.hpp:605
Definition expression.hpp:614
double grad_l(double, double rhs, double parent_adjoint) const override
Definition expression.hpp:628
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:642
double value(double lhs, double rhs) const override
Definition expression.hpp:624
constexpr DivExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:621
ExpressionType type() const override
Definition expression.hpp:626
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:636
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:632
Definition expression.hpp:1143
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1160
double value(double x, double) const override
Definition expression.hpp:1152
constexpr ErfExpression(ExpressionPtr lhs)
Definition expression.hpp:1149
ExpressionType type() const override
Definition expression.hpp:1154
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1156
Definition expression.hpp:1195
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1208
ExpressionType type() const override
Definition expression.hpp:1206
constexpr ExpExpression(ExpressionPtr lhs)
Definition expression.hpp:1201
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1212
double value(double x, double) const override
Definition expression.hpp:1204
Definition expression.hpp:77
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:88
virtual double grad_l(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:374
constexpr bool is_constant(double constant) const
Definition expression.hpp:138
constexpr Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:126
constexpr Expression(double value)
Definition expression.hpp:110
virtual ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:417
virtual ExpressionType type() const =0
ExpressionPtr adjoint_expr
Definition expression.hpp:92
friend ExpressionPtr operator+(const ExpressionPtr &lhs)
Definition expression.hpp:344
constexpr Expression(ExpressionPtr lhs)
Definition expression.hpp:117
double adjoint
The adjoint of the expression node used during autodiff.
Definition expression.hpp:82
friend ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:148
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:85
friend ExpressionPtr operator+=(ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:269
friend ExpressionPtr operator-(const ExpressionPtr &lhs)
Definition expression.hpp:316
friend ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:280
friend ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:200
virtual double grad_r(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:388
virtual double value(double lhs, double rhs) const =0
virtual ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:402
std::array< ExpressionPtr, 2 > args
Expression arguments.
Definition expression.hpp:98
double val
The value of the expression node.
Definition expression.hpp:79
constexpr Expression()=default
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:95
friend ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:237
Definition expression.hpp:1245
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1267
double grad_r(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1263
double grad_l(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1259
ExpressionPtr grad_expr_r(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1273
double value(double x, double y) const override
Definition expression.hpp:1255
ExpressionType type() const override
Definition expression.hpp:1257
constexpr HypotExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1252
Definition expression.hpp:1356
constexpr Log10Expression(ExpressionPtr lhs)
Definition expression.hpp:1362
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1369
double value(double x, double) const override
Definition expression.hpp:1365
ExpressionType type() const override
Definition expression.hpp:1367
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1373
Definition expression.hpp:1307
ExpressionType type() const override
Definition expression.hpp:1318
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1324
constexpr LogExpression(ExpressionPtr lhs)
Definition expression.hpp:1313
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1320
double value(double x, double) const override
Definition expression.hpp:1316
Definition expression.hpp:655
ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:679
constexpr MultExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:662
ExpressionType type() const override
Definition expression.hpp:667
double grad_l(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:669
double value(double lhs, double rhs) const override
Definition expression.hpp:665
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:685
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:674
Definition expression.hpp:1411
ExpressionPtr grad_expr_l(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1442
double value(double base, double power) const override
Definition expression.hpp:1421
double grad_l(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1427
double grad_r(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1432
ExpressionType type() const override
Definition expression.hpp:1425
ExpressionPtr grad_expr_r(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1451
constexpr PowExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1418
Definition expression.hpp:1511
double grad_l(double, double, double) const override
Definition expression.hpp:1532
ExpressionType type() const override
Definition expression.hpp:1530
constexpr SignExpression(ExpressionPtr lhs)
Definition expression.hpp:1517
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) const override
Definition expression.hpp:1534
double value(double x, double) const override
Definition expression.hpp:1520
Definition expression.hpp:1567
double value(double x, double) const override
Definition expression.hpp:1576
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1580
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1584
ExpressionType type() const override
Definition expression.hpp:1578
constexpr SinExpression(ExpressionPtr lhs)
Definition expression.hpp:1573
Definition expression.hpp:1616
ExpressionType type() const override
Definition expression.hpp:1627
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1629
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1633
constexpr SinhExpression(ExpressionPtr lhs)
Definition expression.hpp:1622
double value(double x, double) const override
Definition expression.hpp:1625
Definition expression.hpp:1665
ExpressionType type() const override
Definition expression.hpp:1676
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1678
constexpr SqrtExpression(ExpressionPtr lhs)
Definition expression.hpp:1671
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1682
double value(double x, double) const override
Definition expression.hpp:1674
Definition expression.hpp:1716
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1729
double value(double x, double) const override
Definition expression.hpp:1725
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1733
ExpressionType type() const override
Definition expression.hpp:1727
constexpr TanExpression(ExpressionPtr lhs)
Definition expression.hpp:1722
Definition expression.hpp:1765
constexpr TanhExpression(ExpressionPtr lhs)
Definition expression.hpp:1771
ExpressionType type() const override
Definition expression.hpp:1776
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1778
double value(double x, double) const override
Definition expression.hpp:1774
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1782
Definition expression.hpp:698
constexpr UnaryMinusExpression(ExpressionPtr lhs)
Definition expression.hpp:704
ExpressionType type() const override
Definition expression.hpp:709
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:715
double value(double lhs, double) const override
Definition expression.hpp:707
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:711