14#include <gch/small_vector.hpp>
16#include "sleipnir/autodiff/expression_type.hpp"
17#include "sleipnir/util/intrusive_shared_ptr.hpp"
18#include "sleipnir/util/pool.hpp"
20namespace slp::detail {
25inline constexpr bool USE_POOL_ALLOCATOR =
false;
27inline constexpr bool USE_POOL_ALLOCATOR =
true;
32inline constexpr void inc_ref_count(Expression* expr);
33inline constexpr void dec_ref_count(Expression* expr);
38using ExpressionPtr = IntrusiveSharedPtr<Expression>;
47template <
typename T,
typename... Args>
48static ExpressionPtr make_expression_ptr(Args&&... args) {
49 if constexpr (USE_POOL_ALLOCATOR) {
50 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
51 std::forward<Args>(args)...);
53 return make_intrusive_shared<T>(std::forward<Args>(args)...);
57template <ExpressionType T>
58struct BinaryMinusExpression;
60template <ExpressionType T>
61struct BinaryPlusExpression;
63struct ConstExpression;
65template <ExpressionType T>
68template <ExpressionType T>
71template <ExpressionType T>
72struct UnaryMinusExpression;
98 std::array<ExpressionPtr, 2>
args{
nullptr,
nullptr};
118 :
args{std::move(lhs), nullptr} {}
127 :
args{std::move(lhs), std::move(rhs)} {}
139 return type() == ExpressionType::CONSTANT &&
val == constant;
150 using enum ExpressionType;
153 if (lhs->is_constant(0.0)) {
156 }
else if (rhs->is_constant(0.0)) {
159 }
else if (lhs->is_constant(1.0)) {
161 }
else if (rhs->is_constant(1.0)) {
166 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
167 return make_expression_ptr<ConstExpression>(lhs->val * rhs->val);
171 if (lhs->type() == CONSTANT) {
172 if (rhs->type() == LINEAR) {
173 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
174 }
else if (rhs->type() == QUADRATIC) {
175 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
177 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
179 }
else if (rhs->type() == CONSTANT) {
180 if (lhs->type() == LINEAR) {
181 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
182 }
else if (lhs->type() == QUADRATIC) {
183 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
185 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
187 }
else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
188 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
190 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
202 using enum ExpressionType;
205 if (lhs->is_constant(0.0)) {
208 }
else if (rhs->is_constant(1.0)) {
213 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
214 return make_expression_ptr<ConstExpression>(lhs->val / rhs->val);
218 if (rhs->type() == CONSTANT) {
219 if (lhs->type() == LINEAR) {
220 return make_expression_ptr<DivExpression<LINEAR>>(lhs, rhs);
221 }
else if (lhs->type() == QUADRATIC) {
222 return make_expression_ptr<DivExpression<QUADRATIC>>(lhs, rhs);
224 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
227 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
239 using enum ExpressionType;
242 if (lhs ==
nullptr || lhs->is_constant(0.0)) {
244 }
else if (rhs ==
nullptr || rhs->is_constant(0.0)) {
249 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
250 return make_expression_ptr<ConstExpression>(lhs->val + rhs->val);
253 auto type = std::max(lhs->type(), rhs->type());
254 if (
type == LINEAR) {
255 return make_expression_ptr<BinaryPlusExpression<LINEAR>>(lhs, rhs);
256 }
else if (
type == QUADRATIC) {
257 return make_expression_ptr<BinaryPlusExpression<QUADRATIC>>(lhs, rhs);
259 return make_expression_ptr<BinaryPlusExpression<NONLINEAR>>(lhs, rhs);
271 return lhs = lhs + rhs;
282 using enum ExpressionType;
285 if (lhs->is_constant(0.0)) {
286 if (rhs->is_constant(0.0)) {
292 }
else if (rhs->is_constant(0.0)) {
297 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
298 return make_expression_ptr<ConstExpression>(lhs->val - rhs->val);
301 auto type = std::max(lhs->type(), rhs->type());
302 if (
type == LINEAR) {
303 return make_expression_ptr<BinaryMinusExpression<LINEAR>>(lhs, rhs);
304 }
else if (
type == QUADRATIC) {
305 return make_expression_ptr<BinaryMinusExpression<QUADRATIC>>(lhs, rhs);
307 return make_expression_ptr<BinaryMinusExpression<NONLINEAR>>(lhs, rhs);
317 using enum ExpressionType;
320 if (lhs->is_constant(0.0)) {
326 if (lhs->type() == CONSTANT) {
327 return make_expression_ptr<ConstExpression>(-lhs->val);
330 if (lhs->type() == LINEAR) {
331 return make_expression_ptr<UnaryMinusExpression<LINEAR>>(lhs);
332 }
else if (lhs->type() == QUADRATIC) {
333 return make_expression_ptr<UnaryMinusExpression<QUADRATIC>>(lhs);
335 return make_expression_ptr<UnaryMinusExpression<NONLINEAR>>(lhs);
355 virtual double value([[maybe_unused]]
double lhs,
356 [[maybe_unused]]
double rhs)
const = 0;
364 virtual ExpressionType
type()
const = 0;
374 virtual double grad_l([[maybe_unused]]
double lhs,
375 [[maybe_unused]]
double rhs,
376 [[maybe_unused]]
double parent_adjoint)
const {
388 virtual double grad_r([[maybe_unused]]
double lhs,
389 [[maybe_unused]]
double rhs,
390 [[maybe_unused]]
double parent_adjoint)
const {
405 [[maybe_unused]]
const ExpressionPtr& parent_adjoint)
const {
406 return make_expression_ptr<ConstExpression>();
420 [[maybe_unused]]
const ExpressionPtr& parent_adjoint)
const {
421 return make_expression_ptr<ConstExpression>();
430template <ExpressionType T>
439 :
Expression{std::move(lhs), std::move(rhs)} {}
441 double value(
double lhs,
double rhs)
const override {
return lhs - rhs; }
443 ExpressionType
type()
const override {
return T; }
445 double grad_l(
double,
double,
double parent_adjoint)
const override {
446 return parent_adjoint;
449 double grad_r(
double,
double,
double parent_adjoint)
const override {
450 return -parent_adjoint;
456 return parent_adjoint;
462 return -parent_adjoint;
471template <ExpressionType T>
480 :
Expression{std::move(lhs), std::move(rhs)} {}
482 double value(
double lhs,
double rhs)
const override {
return lhs + rhs; }
484 ExpressionType
type()
const override {
return T; }
486 double grad_l(
double,
double,
double parent_adjoint)
const override {
487 return parent_adjoint;
490 double grad_r(
double,
double,
double parent_adjoint)
const override {
491 return parent_adjoint;
497 return parent_adjoint;
503 return parent_adjoint;
523 double value(
double,
double)
const override {
return val; }
525 ExpressionType
type()
const override {
return ExpressionType::CONSTANT; }
545 double value(
double,
double)
const override {
return val; }
547 ExpressionType
type()
const override {
return ExpressionType::LINEAR; }
555template <ExpressionType T>
564 :
Expression{std::move(lhs), std::move(rhs)} {}
566 double value(
double lhs,
double rhs)
const override {
return lhs / rhs; }
568 ExpressionType
type()
const override {
return T; }
570 double grad_l(
double,
double rhs,
double parent_adjoint)
const override {
571 return parent_adjoint / rhs;
574 double grad_r(
double lhs,
double rhs,
double parent_adjoint)
const override {
575 return parent_adjoint * -lhs / (rhs * rhs);
581 return parent_adjoint / rhs;
587 return parent_adjoint * -lhs / (rhs * rhs);
596template <ExpressionType T>
605 :
Expression{std::move(lhs), std::move(rhs)} {}
607 double value(
double lhs,
double rhs)
const override {
return lhs * rhs; }
609 ExpressionType
type()
const override {
return T; }
611 double grad_l([[maybe_unused]]
double lhs,
double rhs,
612 double parent_adjoint)
const override {
613 return parent_adjoint * rhs;
616 double grad_r(
double lhs, [[maybe_unused]]
double rhs,
617 double parent_adjoint)
const override {
618 return parent_adjoint * lhs;
624 return parent_adjoint * rhs;
630 return parent_adjoint * lhs;
639template <ExpressionType T>
649 double value(
double lhs,
double)
const override {
return -lhs; }
651 ExpressionType
type()
const override {
return T; }
653 double grad_l(
double,
double,
double parent_adjoint)
const override {
654 return -parent_adjoint;
660 return -parent_adjoint;
664inline ExpressionPtr exp(
const ExpressionPtr& x);
665inline ExpressionPtr sin(
const ExpressionPtr& x);
666inline ExpressionPtr sinh(
const ExpressionPtr& x);
667inline ExpressionPtr sqrt(
const ExpressionPtr& x);
674inline constexpr void inc_ref_count(Expression* expr) {
683inline constexpr void dec_ref_count(Expression* expr) {
688 gch::small_vector<Expression*> stack;
689 stack.emplace_back(expr);
691 while (!stack.empty()) {
692 auto elem = stack.back();
697 if (--elem->ref_count == 0) {
698 if (elem->adjoint_expr !=
nullptr) {
699 stack.emplace_back(elem->adjoint_expr.get());
701 for (
auto& arg : elem->args) {
702 if (arg !=
nullptr) {
703 stack.emplace_back(arg.get());
709 if constexpr (USE_POOL_ALLOCATOR) {
710 auto alloc = global_pool_allocator<Expression>();
711 std::allocator_traits<
decltype(alloc)>::deallocate(alloc, elem,
730 double value(
double x,
double)
const override {
return std::abs(x); }
732 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
734 double grad_l(
double x,
double,
double parent_adjoint)
const override {
736 return -parent_adjoint;
737 }
else if (x > 0.0) {
738 return parent_adjoint;
748 return -parent_adjoint;
749 }
else if (x->val > 0.0) {
750 return parent_adjoint;
753 return make_expression_ptr<ConstExpression>();
763inline ExpressionPtr abs(
const ExpressionPtr& x) {
764 using enum ExpressionType;
767 if (x->is_constant(0.0)) {
773 if (x->type() == CONSTANT) {
774 return make_expression_ptr<ConstExpression>(std::abs(x->val));
777 return make_expression_ptr<AbsExpression>(x);
792 double value(
double x,
double)
const override {
return std::acos(x); }
794 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
796 double grad_l(
double x,
double,
double parent_adjoint)
const override {
797 return -parent_adjoint / std::sqrt(1.0 - x * x);
803 return -parent_adjoint /
804 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
813inline ExpressionPtr acos(
const ExpressionPtr& x) {
814 using enum ExpressionType;
817 if (x->is_constant(0.0)) {
818 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
822 if (x->type() == CONSTANT) {
823 return make_expression_ptr<ConstExpression>(std::acos(x->val));
826 return make_expression_ptr<AcosExpression>(x);
841 double value(
double x,
double)
const override {
return std::asin(x); }
843 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
845 double grad_l(
double x,
double,
double parent_adjoint)
const override {
846 return parent_adjoint / std::sqrt(1.0 - x * x);
852 return parent_adjoint /
853 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
862inline ExpressionPtr asin(
const ExpressionPtr& x) {
863 using enum ExpressionType;
866 if (x->is_constant(0.0)) {
872 if (x->type() == CONSTANT) {
873 return make_expression_ptr<ConstExpression>(std::asin(x->val));
876 return make_expression_ptr<AsinExpression>(x);
891 double value(
double x,
double)
const override {
return std::atan(x); }
893 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
895 double grad_l(
double x,
double,
double parent_adjoint)
const override {
896 return parent_adjoint / (1.0 + x * x);
902 return parent_adjoint / (make_expression_ptr<ConstExpression>(1.0) + x * x);
911inline ExpressionPtr atan(
const ExpressionPtr& x) {
912 using enum ExpressionType;
915 if (x->is_constant(0.0)) {
921 if (x->type() == CONSTANT) {
922 return make_expression_ptr<ConstExpression>(std::atan(x->val));
925 return make_expression_ptr<AtanExpression>(x);
939 :
Expression{std::move(lhs), std::move(rhs)} {}
941 double value(
double y,
double x)
const override {
return std::atan2(y, x); }
943 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
945 double grad_l(
double y,
double x,
double parent_adjoint)
const override {
946 return parent_adjoint * x / (y * y + x * x);
949 double grad_r(
double y,
double x,
double parent_adjoint)
const override {
950 return parent_adjoint * -y / (y * y + x * x);
956 return parent_adjoint * x / (y * y + x * x);
962 return parent_adjoint * -y / (y * y + x * x);
972inline ExpressionPtr atan2(
const ExpressionPtr& y,
const ExpressionPtr& x) {
973 using enum ExpressionType;
976 if (y->is_constant(0.0)) {
979 }
else if (x->is_constant(0.0)) {
980 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
984 if (y->type() == CONSTANT && x->type() == CONSTANT) {
985 return make_expression_ptr<ConstExpression>(std::atan2(y->val, x->val));
988 return make_expression_ptr<Atan2Expression>(y, x);
1003 double value(
double x,
double)
const override {
return std::cos(x); }
1005 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1007 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1008 return -parent_adjoint * std::sin(x);
1014 return parent_adjoint * -slp::detail::sin(x);
1023inline ExpressionPtr cos(
const ExpressionPtr& x) {
1024 using enum ExpressionType;
1027 if (x->is_constant(0.0)) {
1028 return make_expression_ptr<ConstExpression>(1.0);
1032 if (x->type() == CONSTANT) {
1033 return make_expression_ptr<ConstExpression>(std::cos(x->val));
1036 return make_expression_ptr<CosExpression>(x);
1051 double value(
double x,
double)
const override {
return std::cosh(x); }
1053 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1055 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1056 return parent_adjoint * std::sinh(x);
1062 return parent_adjoint * slp::detail::sinh(x);
1071inline ExpressionPtr cosh(
const ExpressionPtr& x) {
1072 using enum ExpressionType;
1075 if (x->is_constant(0.0)) {
1076 return make_expression_ptr<ConstExpression>(1.0);
1080 if (x->type() == CONSTANT) {
1081 return make_expression_ptr<ConstExpression>(std::cosh(x->val));
1084 return make_expression_ptr<CoshExpression>(x);
1099 double value(
double x,
double)
const override {
return std::erf(x); }
1101 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1103 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1104 return parent_adjoint * 2.0 * std::numbers::inv_sqrtpi * std::exp(-x * x);
1110 return parent_adjoint *
1111 make_expression_ptr<ConstExpression>(2.0 *
1112 std::numbers::inv_sqrtpi) *
1113 slp::detail::exp(-x * x);
1122inline ExpressionPtr erf(
const ExpressionPtr& x) {
1123 using enum ExpressionType;
1126 if (x->is_constant(0.0)) {
1132 if (x->type() == CONSTANT) {
1133 return make_expression_ptr<ConstExpression>(std::erf(x->val));
1136 return make_expression_ptr<ErfExpression>(x);
1151 double value(
double x,
double)
const override {
return std::exp(x); }
1153 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1155 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1156 return parent_adjoint * std::exp(x);
1162 return parent_adjoint * slp::detail::exp(x);
1171inline ExpressionPtr exp(
const ExpressionPtr& x) {
1172 using enum ExpressionType;
1175 if (x->is_constant(0.0)) {
1176 return make_expression_ptr<ConstExpression>(1.0);
1180 if (x->type() == CONSTANT) {
1181 return make_expression_ptr<ConstExpression>(std::exp(x->val));
1184 return make_expression_ptr<ExpExpression>(x);
1187inline ExpressionPtr hypot(
const ExpressionPtr& x,
const ExpressionPtr& y);
1200 :
Expression{std::move(lhs), std::move(rhs)} {}
1202 double value(
double x,
double y)
const override {
return std::hypot(x, y); }
1204 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1206 double grad_l(
double x,
double y,
double parent_adjoint)
const override {
1207 return parent_adjoint * x / std::hypot(x, y);
1210 double grad_r(
double x,
double y,
double parent_adjoint)
const override {
1211 return parent_adjoint * y / std::hypot(x, y);
1217 return parent_adjoint * x / slp::detail::hypot(x, y);
1223 return parent_adjoint * y / slp::detail::hypot(x, y);
1233inline ExpressionPtr hypot(
const ExpressionPtr& x,
const ExpressionPtr& y) {
1234 using enum ExpressionType;
1237 if (x->is_constant(0.0)) {
1239 }
else if (y->is_constant(0.0)) {
1244 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1245 return make_expression_ptr<ConstExpression>(std::hypot(x->val, y->val));
1248 return make_expression_ptr<HypotExpression>(x, y);
1263 double value(
double x,
double)
const override {
return std::log(x); }
1265 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1267 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1268 return parent_adjoint / x;
1274 return parent_adjoint / x;
1283inline ExpressionPtr log(
const ExpressionPtr& x) {
1284 using enum ExpressionType;
1287 if (x->is_constant(0.0)) {
1293 if (x->type() == CONSTANT) {
1294 return make_expression_ptr<ConstExpression>(std::log(x->val));
1297 return make_expression_ptr<LogExpression>(x);
1312 double value(
double x,
double)
const override {
return std::log10(x); }
1314 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1316 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1317 return parent_adjoint / (std::numbers::ln10 * x);
1323 return parent_adjoint /
1324 (make_expression_ptr<ConstExpression>(std::numbers::ln10) * x);
1333inline ExpressionPtr log10(
const ExpressionPtr& x) {
1334 using enum ExpressionType;
1337 if (x->is_constant(0.0)) {
1343 if (x->type() == CONSTANT) {
1344 return make_expression_ptr<ConstExpression>(std::log10(x->val));
1347 return make_expression_ptr<Log10Expression>(x);
1350inline ExpressionPtr pow(
const ExpressionPtr& base,
const ExpressionPtr& power);
1357template <ExpressionType T>
1366 :
Expression{std::move(lhs), std::move(rhs)} {}
1368 double value(
double base,
double power)
const override {
1369 return std::pow(base, power);
1372 ExpressionType
type()
const override {
return T; }
1375 double parent_adjoint)
const override {
1376 return parent_adjoint * std::pow(base, power - 1) * power;
1380 double parent_adjoint)
const override {
1385 return parent_adjoint * std::pow(base, power - 1) * base * std::log(base);
1392 return parent_adjoint *
1393 slp::detail::pow(base,
1394 power - make_expression_ptr<ConstExpression>(1.0)) *
1402 if (base->val == 0.0) {
1406 return parent_adjoint *
1408 base, power - make_expression_ptr<ConstExpression>(1.0)) *
1409 base * slp::detail::log(base);
1420inline ExpressionPtr pow(
const ExpressionPtr& base,
1421 const ExpressionPtr& power) {
1422 using enum ExpressionType;
1425 if (base->is_constant(0.0)) {
1428 }
else if (base->is_constant(1.0)) {
1432 if (power->is_constant(0.0)) {
1433 return make_expression_ptr<ConstExpression>(1.0);
1434 }
else if (power->is_constant(1.0)) {
1439 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1440 return make_expression_ptr<ConstExpression>(
1441 std::pow(base->val, power->val));
1444 if (power->is_constant(2.0)) {
1445 if (base->type() == LINEAR) {
1446 return make_expression_ptr<MultExpression<QUADRATIC>>(base, base);
1448 return make_expression_ptr<MultExpression<NONLINEAR>>(base, base);
1452 return make_expression_ptr<PowExpression<NONLINEAR>>(base, power);
1467 double value(
double x,
double)
const override {
1470 }
else if (x == 0.0) {
1477 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1479 double grad_l(
double,
double,
double)
const override {
return 0.0; }
1484 return make_expression_ptr<ConstExpression>();
1493inline ExpressionPtr sign(
const ExpressionPtr& x) {
1494 using enum ExpressionType;
1497 if (x->type() == CONSTANT) {
1499 return make_expression_ptr<ConstExpression>(-1.0);
1500 }
else if (x->val == 0.0) {
1504 return make_expression_ptr<ConstExpression>(1.0);
1508 return make_expression_ptr<SignExpression>(x);
1523 double value(
double x,
double)
const override {
return std::sin(x); }
1525 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1527 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1528 return parent_adjoint * std::cos(x);
1534 return parent_adjoint * slp::detail::cos(x);
1543inline ExpressionPtr sin(
const ExpressionPtr& x) {
1544 using enum ExpressionType;
1547 if (x->is_constant(0.0)) {
1553 if (x->type() == CONSTANT) {
1554 return make_expression_ptr<ConstExpression>(std::sin(x->val));
1557 return make_expression_ptr<SinExpression>(x);
1572 double value(
double x,
double)
const override {
return std::sinh(x); }
1574 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1576 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1577 return parent_adjoint * std::cosh(x);
1583 return parent_adjoint * slp::detail::cosh(x);
1592inline ExpressionPtr sinh(
const ExpressionPtr& x) {
1593 using enum ExpressionType;
1596 if (x->is_constant(0.0)) {
1602 if (x->type() == CONSTANT) {
1603 return make_expression_ptr<ConstExpression>(std::sinh(x->val));
1606 return make_expression_ptr<SinhExpression>(x);
1621 double value(
double x,
double)
const override {
return std::sqrt(x); }
1623 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1625 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1626 return parent_adjoint / (2.0 * std::sqrt(x));
1632 return parent_adjoint /
1633 (make_expression_ptr<ConstExpression>(2.0) * slp::detail::sqrt(x));
1642inline ExpressionPtr sqrt(
const ExpressionPtr& x) {
1643 using enum ExpressionType;
1646 if (x->type() == CONSTANT) {
1647 if (x->val == 0.0) {
1650 }
else if (x->val == 1.0) {
1653 return make_expression_ptr<ConstExpression>(std::sqrt(x->val));
1657 return make_expression_ptr<SqrtExpression>(x);
1672 double value(
double x,
double)
const override {
return std::tan(x); }
1674 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1676 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1677 return parent_adjoint / (std::cos(x) * std::cos(x));
1683 return parent_adjoint / (slp::detail::cos(x) * slp::detail::cos(x));
1692inline ExpressionPtr tan(
const ExpressionPtr& x) {
1693 using enum ExpressionType;
1696 if (x->is_constant(0.0)) {
1702 if (x->type() == CONSTANT) {
1703 return make_expression_ptr<ConstExpression>(std::tan(x->val));
1706 return make_expression_ptr<TanExpression>(x);
1721 double value(
double x,
double)
const override {
return std::tanh(x); }
1723 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1725 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1726 return parent_adjoint / (std::cosh(x) * std::cosh(x));
1732 return parent_adjoint / (slp::detail::cosh(x) * slp::detail::cosh(x));
1741inline ExpressionPtr tanh(
const ExpressionPtr& x) {
1742 using enum ExpressionType;
1745 if (x->is_constant(0.0)) {
1751 if (x->type() == CONSTANT) {
1752 return make_expression_ptr<ConstExpression>(std::tanh(x->val));
1755 return make_expression_ptr<TanhExpression>(x);
Definition expression.hpp:721
ExpressionType type() const override
Definition expression.hpp:732
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:734
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:744
double value(double x, double) const override
Definition expression.hpp:730
constexpr AbsExpression(ExpressionPtr lhs)
Definition expression.hpp:727
Definition expression.hpp:783
ExpressionType type() const override
Definition expression.hpp:794
constexpr AcosExpression(ExpressionPtr lhs)
Definition expression.hpp:789
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:796
double value(double x, double) const override
Definition expression.hpp:792
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:800
Definition expression.hpp:832
double value(double x, double) const override
Definition expression.hpp:841
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:849
ExpressionType type() const override
Definition expression.hpp:843
constexpr AsinExpression(ExpressionPtr lhs)
Definition expression.hpp:838
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:845
Definition expression.hpp:931
ExpressionPtr grad_expr_r(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:959
double value(double y, double x) const override
Definition expression.hpp:941
ExpressionType type() const override
Definition expression.hpp:943
constexpr Atan2Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:938
double grad_l(double y, double x, double parent_adjoint) const override
Definition expression.hpp:945
double grad_r(double y, double x, double parent_adjoint) const override
Definition expression.hpp:949
ExpressionPtr grad_expr_l(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:953
Definition expression.hpp:882
constexpr AtanExpression(ExpressionPtr lhs)
Definition expression.hpp:888
double value(double x, double) const override
Definition expression.hpp:891
ExpressionType type() const override
Definition expression.hpp:893
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:899
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:895
Definition expression.hpp:431
double value(double lhs, double rhs) const override
Definition expression.hpp:441
constexpr BinaryMinusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:438
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:449
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:445
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:459
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:453
ExpressionType type() const override
Definition expression.hpp:443
Definition expression.hpp:472
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:490
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:486
double value(double lhs, double rhs) const override
Definition expression.hpp:482
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:500
constexpr BinaryPlusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:479
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:494
ExpressionType type() const override
Definition expression.hpp:484
Definition expression.hpp:510
constexpr ConstExpression(double value)
Definition expression.hpp:521
constexpr ConstExpression()=default
double value(double, double) const override
Definition expression.hpp:523
ExpressionType type() const override
Definition expression.hpp:525
Definition expression.hpp:994
ExpressionType type() const override
Definition expression.hpp:1005
double value(double x, double) const override
Definition expression.hpp:1003
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1011
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1007
constexpr CosExpression(ExpressionPtr lhs)
Definition expression.hpp:1000
Definition expression.hpp:1042
constexpr CoshExpression(ExpressionPtr lhs)
Definition expression.hpp:1048
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1055
double value(double x, double) const override
Definition expression.hpp:1051
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1059
ExpressionType type() const override
Definition expression.hpp:1053
Definition expression.hpp:531
double value(double, double) const override
Definition expression.hpp:545
constexpr DecisionVariableExpression()=default
constexpr DecisionVariableExpression(double value)
Definition expression.hpp:542
ExpressionType type() const override
Definition expression.hpp:547
Definition expression.hpp:556
double grad_l(double, double rhs, double parent_adjoint) const override
Definition expression.hpp:570
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:584
double value(double lhs, double rhs) const override
Definition expression.hpp:566
constexpr DivExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:563
ExpressionType type() const override
Definition expression.hpp:568
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:578
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:574
Definition expression.hpp:1090
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1107
double value(double x, double) const override
Definition expression.hpp:1099
constexpr ErfExpression(ExpressionPtr lhs)
Definition expression.hpp:1096
ExpressionType type() const override
Definition expression.hpp:1101
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1103
Definition expression.hpp:1142
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1155
ExpressionType type() const override
Definition expression.hpp:1153
constexpr ExpExpression(ExpressionPtr lhs)
Definition expression.hpp:1148
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1159
double value(double x, double) const override
Definition expression.hpp:1151
Definition expression.hpp:77
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:88
virtual double grad_l(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:374
constexpr bool is_constant(double constant) const
Definition expression.hpp:138
constexpr Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:126
constexpr Expression(double value)
Definition expression.hpp:110
virtual ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:417
virtual ExpressionType type() const =0
ExpressionPtr adjoint_expr
Definition expression.hpp:92
friend ExpressionPtr operator+(const ExpressionPtr &lhs)
Definition expression.hpp:344
constexpr Expression(ExpressionPtr lhs)
Definition expression.hpp:117
double adjoint
The adjoint of the expression node used during autodiff.
Definition expression.hpp:82
friend ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:148
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:85
friend ExpressionPtr operator+=(ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:269
friend ExpressionPtr operator-(const ExpressionPtr &lhs)
Definition expression.hpp:316
friend ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:280
friend ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:200
virtual double grad_r(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:388
virtual double value(double lhs, double rhs) const =0
virtual ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:402
std::array< ExpressionPtr, 2 > args
Expression arguments.
Definition expression.hpp:98
double val
The value of the expression node.
Definition expression.hpp:79
constexpr Expression()=default
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:95
friend ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:237
Definition expression.hpp:1192
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1214
double grad_r(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1210
double grad_l(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1206
ExpressionPtr grad_expr_r(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1220
double value(double x, double y) const override
Definition expression.hpp:1202
ExpressionType type() const override
Definition expression.hpp:1204
constexpr HypotExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1199
Definition expression.hpp:1303
constexpr Log10Expression(ExpressionPtr lhs)
Definition expression.hpp:1309
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1316
double value(double x, double) const override
Definition expression.hpp:1312
ExpressionType type() const override
Definition expression.hpp:1314
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1320
Definition expression.hpp:1254
ExpressionType type() const override
Definition expression.hpp:1265
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1271
constexpr LogExpression(ExpressionPtr lhs)
Definition expression.hpp:1260
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1267
double value(double x, double) const override
Definition expression.hpp:1263
Definition expression.hpp:597
ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:621
constexpr MultExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:604
ExpressionType type() const override
Definition expression.hpp:609
double grad_l(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:611
double value(double lhs, double rhs) const override
Definition expression.hpp:607
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:627
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:616
Definition expression.hpp:1358
ExpressionPtr grad_expr_l(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1389
double value(double base, double power) const override
Definition expression.hpp:1368
double grad_l(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1374
double grad_r(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1379
ExpressionType type() const override
Definition expression.hpp:1372
ExpressionPtr grad_expr_r(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1398
constexpr PowExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1365
Definition expression.hpp:1458
double grad_l(double, double, double) const override
Definition expression.hpp:1479
ExpressionType type() const override
Definition expression.hpp:1477
constexpr SignExpression(ExpressionPtr lhs)
Definition expression.hpp:1464
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) const override
Definition expression.hpp:1481
double value(double x, double) const override
Definition expression.hpp:1467
Definition expression.hpp:1514
double value(double x, double) const override
Definition expression.hpp:1523
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1527
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1531
ExpressionType type() const override
Definition expression.hpp:1525
constexpr SinExpression(ExpressionPtr lhs)
Definition expression.hpp:1520
Definition expression.hpp:1563
ExpressionType type() const override
Definition expression.hpp:1574
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1576
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1580
constexpr SinhExpression(ExpressionPtr lhs)
Definition expression.hpp:1569
double value(double x, double) const override
Definition expression.hpp:1572
Definition expression.hpp:1612
ExpressionType type() const override
Definition expression.hpp:1623
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1625
constexpr SqrtExpression(ExpressionPtr lhs)
Definition expression.hpp:1618
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1629
double value(double x, double) const override
Definition expression.hpp:1621
Definition expression.hpp:1663
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1676
double value(double x, double) const override
Definition expression.hpp:1672
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1680
ExpressionType type() const override
Definition expression.hpp:1674
constexpr TanExpression(ExpressionPtr lhs)
Definition expression.hpp:1669
Definition expression.hpp:1712
constexpr TanhExpression(ExpressionPtr lhs)
Definition expression.hpp:1718
ExpressionType type() const override
Definition expression.hpp:1723
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1725
double value(double x, double) const override
Definition expression.hpp:1721
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1729
Definition expression.hpp:640
constexpr UnaryMinusExpression(ExpressionPtr lhs)
Definition expression.hpp:646
ExpressionType type() const override
Definition expression.hpp:651
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:657
double value(double lhs, double) const override
Definition expression.hpp:649
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:653