14#include <gch/small_vector.hpp>
16#include "sleipnir/autodiff/expression_type.hpp"
17#include "sleipnir/util/intrusive_shared_ptr.hpp"
18#include "sleipnir/util/pool.hpp"
20namespace slp::detail {
25inline constexpr bool USE_POOL_ALLOCATOR =
false;
27inline constexpr bool USE_POOL_ALLOCATOR =
true;
32inline constexpr void inc_ref_count(Expression* expr);
33inline constexpr void dec_ref_count(Expression* expr);
38using ExpressionPtr = IntrusiveSharedPtr<Expression>;
47template <
typename T,
typename... Args>
48static ExpressionPtr make_expression_ptr(Args&&... args) {
49 if constexpr (USE_POOL_ALLOCATOR) {
50 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
51 std::forward<Args>(args)...);
53 return make_intrusive_shared<T>(std::forward<Args>(args)...);
57template <ExpressionType T>
58struct BinaryMinusExpression;
60template <ExpressionType T>
61struct BinaryPlusExpression;
63struct ConstExpression;
65template <ExpressionType T>
68template <ExpressionType T>
71template <ExpressionType T>
72struct UnaryMinusExpression;
98 std::array<ExpressionPtr, 2>
args{
nullptr,
nullptr};
118 :
args{std::move(lhs), nullptr} {}
127 :
args{std::move(lhs), std::move(rhs)} {}
139 return type() == ExpressionType::CONSTANT &&
val == constant;
150 using enum ExpressionType;
153 if (lhs->is_constant(0.0)) {
156 }
else if (rhs->is_constant(0.0)) {
159 }
else if (lhs->is_constant(1.0)) {
161 }
else if (rhs->is_constant(1.0)) {
166 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
167 return make_expression_ptr<ConstExpression>(lhs->val * rhs->val);
171 if (lhs->type() == CONSTANT) {
172 if (rhs->type() == LINEAR) {
173 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
174 }
else if (rhs->type() == QUADRATIC) {
175 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
177 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
179 }
else if (rhs->type() == CONSTANT) {
180 if (lhs->type() == LINEAR) {
181 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
182 }
else if (lhs->type() == QUADRATIC) {
183 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
185 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
187 }
else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
188 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
190 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
202 using enum ExpressionType;
205 if (lhs->is_constant(0.0)) {
208 }
else if (rhs->is_constant(1.0)) {
213 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
214 return make_expression_ptr<ConstExpression>(lhs->val / rhs->val);
218 if (rhs->type() == CONSTANT) {
219 if (lhs->type() == LINEAR) {
220 return make_expression_ptr<DivExpression<LINEAR>>(lhs, rhs);
221 }
else if (lhs->type() == QUADRATIC) {
222 return make_expression_ptr<DivExpression<QUADRATIC>>(lhs, rhs);
224 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
227 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
239 using enum ExpressionType;
242 if (lhs ==
nullptr || lhs->is_constant(0.0)) {
244 }
else if (rhs ==
nullptr || rhs->is_constant(0.0)) {
249 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
250 return make_expression_ptr<ConstExpression>(lhs->val + rhs->val);
253 auto type = std::max(lhs->type(), rhs->type());
254 if (
type == LINEAR) {
255 return make_expression_ptr<BinaryPlusExpression<LINEAR>>(lhs, rhs);
256 }
else if (
type == QUADRATIC) {
257 return make_expression_ptr<BinaryPlusExpression<QUADRATIC>>(lhs, rhs);
259 return make_expression_ptr<BinaryPlusExpression<NONLINEAR>>(lhs, rhs);
271 return lhs = lhs + rhs;
282 using enum ExpressionType;
285 if (lhs->is_constant(0.0)) {
286 if (rhs->is_constant(0.0)) {
292 }
else if (rhs->is_constant(0.0)) {
297 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
298 return make_expression_ptr<ConstExpression>(lhs->val - rhs->val);
301 auto type = std::max(lhs->type(), rhs->type());
302 if (
type == LINEAR) {
303 return make_expression_ptr<BinaryMinusExpression<LINEAR>>(lhs, rhs);
304 }
else if (
type == QUADRATIC) {
305 return make_expression_ptr<BinaryMinusExpression<QUADRATIC>>(lhs, rhs);
307 return make_expression_ptr<BinaryMinusExpression<NONLINEAR>>(lhs, rhs);
317 using enum ExpressionType;
320 if (lhs->is_constant(0.0)) {
326 if (lhs->type() == CONSTANT) {
327 return make_expression_ptr<ConstExpression>(-lhs->val);
330 if (lhs->type() == LINEAR) {
331 return make_expression_ptr<UnaryMinusExpression<LINEAR>>(lhs);
332 }
else if (lhs->type() == QUADRATIC) {
333 return make_expression_ptr<UnaryMinusExpression<QUADRATIC>>(lhs);
335 return make_expression_ptr<UnaryMinusExpression<NONLINEAR>>(lhs);
355 virtual double value([[maybe_unused]]
double lhs,
356 [[maybe_unused]]
double rhs)
const = 0;
364 virtual ExpressionType
type()
const = 0;
374 virtual double grad_l([[maybe_unused]]
double lhs,
375 [[maybe_unused]]
double rhs,
376 [[maybe_unused]]
double parent_adjoint)
const {
388 virtual double grad_r([[maybe_unused]]
double lhs,
389 [[maybe_unused]]
double rhs,
390 [[maybe_unused]]
double parent_adjoint)
const {
405 [[maybe_unused]]
const ExpressionPtr& parent_adjoint)
const {
406 return make_expression_ptr<ConstExpression>();
420 [[maybe_unused]]
const ExpressionPtr& parent_adjoint)
const {
421 return make_expression_ptr<ConstExpression>();
425inline ExpressionPtr cbrt(
const ExpressionPtr& x);
426inline ExpressionPtr exp(
const ExpressionPtr& x);
427inline ExpressionPtr sin(
const ExpressionPtr& x);
428inline ExpressionPtr sinh(
const ExpressionPtr& x);
429inline ExpressionPtr sqrt(
const ExpressionPtr& x);
436template <ExpressionType T>
445 :
Expression{std::move(lhs), std::move(rhs)} {}
447 double value(
double lhs,
double rhs)
const override {
return lhs - rhs; }
449 ExpressionType
type()
const override {
return T; }
451 double grad_l(
double,
double,
double parent_adjoint)
const override {
452 return parent_adjoint;
455 double grad_r(
double,
double,
double parent_adjoint)
const override {
456 return -parent_adjoint;
462 return parent_adjoint;
468 return -parent_adjoint;
477template <ExpressionType T>
486 :
Expression{std::move(lhs), std::move(rhs)} {}
488 double value(
double lhs,
double rhs)
const override {
return lhs + rhs; }
490 ExpressionType
type()
const override {
return T; }
492 double grad_l(
double,
double,
double parent_adjoint)
const override {
493 return parent_adjoint;
496 double grad_r(
double,
double,
double parent_adjoint)
const override {
497 return parent_adjoint;
503 return parent_adjoint;
509 return parent_adjoint;
525 double value(
double x,
double)
const override {
return std::cbrt(x); }
527 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
529 double grad_l(
double x,
double,
double parent_adjoint)
const override {
530 double c = std::cbrt(x);
531 return parent_adjoint / (3.0 * c * c);
537 auto c = slp::detail::cbrt(x);
538 return parent_adjoint / (make_expression_ptr<ConstExpression>(3.0) * c * c);
547inline ExpressionPtr cbrt(
const ExpressionPtr& x) {
548 using enum ExpressionType;
551 if (x->type() == CONSTANT) {
555 }
else if (x->val == -1.0 || x->val == 1.0) {
558 return make_expression_ptr<ConstExpression>(std::cbrt(x->val));
562 return make_expression_ptr<CbrtExpression>(x);
581 double value(
double,
double)
const override {
return val; }
583 ExpressionType
type()
const override {
return ExpressionType::CONSTANT; }
603 double value(
double,
double)
const override {
return val; }
605 ExpressionType
type()
const override {
return ExpressionType::LINEAR; }
613template <ExpressionType T>
622 :
Expression{std::move(lhs), std::move(rhs)} {}
624 double value(
double lhs,
double rhs)
const override {
return lhs / rhs; }
626 ExpressionType
type()
const override {
return T; }
628 double grad_l(
double,
double rhs,
double parent_adjoint)
const override {
629 return parent_adjoint / rhs;
632 double grad_r(
double lhs,
double rhs,
double parent_adjoint)
const override {
633 return parent_adjoint * -lhs / (rhs * rhs);
639 return parent_adjoint / rhs;
645 return parent_adjoint * -lhs / (rhs * rhs);
654template <ExpressionType T>
663 :
Expression{std::move(lhs), std::move(rhs)} {}
665 double value(
double lhs,
double rhs)
const override {
return lhs * rhs; }
667 ExpressionType
type()
const override {
return T; }
669 double grad_l([[maybe_unused]]
double lhs,
double rhs,
670 double parent_adjoint)
const override {
671 return parent_adjoint * rhs;
674 double grad_r(
double lhs, [[maybe_unused]]
double rhs,
675 double parent_adjoint)
const override {
676 return parent_adjoint * lhs;
682 return parent_adjoint * rhs;
688 return parent_adjoint * lhs;
697template <ExpressionType T>
707 double value(
double lhs,
double)
const override {
return -lhs; }
709 ExpressionType
type()
const override {
return T; }
711 double grad_l(
double,
double,
double parent_adjoint)
const override {
712 return -parent_adjoint;
718 return -parent_adjoint;
727inline constexpr void inc_ref_count(Expression* expr) {
736inline constexpr void dec_ref_count(Expression* expr) {
741 gch::small_vector<Expression*> stack;
742 stack.emplace_back(expr);
744 while (!stack.empty()) {
745 auto elem = stack.back();
750 if (--elem->ref_count == 0) {
751 if (elem->adjoint_expr !=
nullptr) {
752 stack.emplace_back(elem->adjoint_expr.get());
754 for (
auto& arg : elem->args) {
755 if (arg !=
nullptr) {
756 stack.emplace_back(arg.get());
762 if constexpr (USE_POOL_ALLOCATOR) {
763 auto alloc = global_pool_allocator<Expression>();
764 std::allocator_traits<
decltype(alloc)>::deallocate(alloc, elem,
783 double value(
double x,
double)
const override {
return std::abs(x); }
785 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
787 double grad_l(
double x,
double,
double parent_adjoint)
const override {
789 return -parent_adjoint;
790 }
else if (x > 0.0) {
791 return parent_adjoint;
801 return -parent_adjoint;
802 }
else if (x->val > 0.0) {
803 return parent_adjoint;
806 return make_expression_ptr<ConstExpression>();
816inline ExpressionPtr abs(
const ExpressionPtr& x) {
817 using enum ExpressionType;
820 if (x->is_constant(0.0)) {
826 if (x->type() == CONSTANT) {
827 return make_expression_ptr<ConstExpression>(std::abs(x->val));
830 return make_expression_ptr<AbsExpression>(x);
845 double value(
double x,
double)
const override {
return std::acos(x); }
847 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
849 double grad_l(
double x,
double,
double parent_adjoint)
const override {
850 return -parent_adjoint / std::sqrt(1.0 - x * x);
856 return -parent_adjoint /
857 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
866inline ExpressionPtr acos(
const ExpressionPtr& x) {
867 using enum ExpressionType;
870 if (x->is_constant(0.0)) {
871 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
875 if (x->type() == CONSTANT) {
876 return make_expression_ptr<ConstExpression>(std::acos(x->val));
879 return make_expression_ptr<AcosExpression>(x);
894 double value(
double x,
double)
const override {
return std::asin(x); }
896 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
898 double grad_l(
double x,
double,
double parent_adjoint)
const override {
899 return parent_adjoint / std::sqrt(1.0 - x * x);
905 return parent_adjoint /
906 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
915inline ExpressionPtr asin(
const ExpressionPtr& x) {
916 using enum ExpressionType;
919 if (x->is_constant(0.0)) {
925 if (x->type() == CONSTANT) {
926 return make_expression_ptr<ConstExpression>(std::asin(x->val));
929 return make_expression_ptr<AsinExpression>(x);
944 double value(
double x,
double)
const override {
return std::atan(x); }
946 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
948 double grad_l(
double x,
double,
double parent_adjoint)
const override {
949 return parent_adjoint / (1.0 + x * x);
955 return parent_adjoint / (make_expression_ptr<ConstExpression>(1.0) + x * x);
964inline ExpressionPtr atan(
const ExpressionPtr& x) {
965 using enum ExpressionType;
968 if (x->is_constant(0.0)) {
974 if (x->type() == CONSTANT) {
975 return make_expression_ptr<ConstExpression>(std::atan(x->val));
978 return make_expression_ptr<AtanExpression>(x);
992 :
Expression{std::move(lhs), std::move(rhs)} {}
994 double value(
double y,
double x)
const override {
return std::atan2(y, x); }
996 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
998 double grad_l(
double y,
double x,
double parent_adjoint)
const override {
999 return parent_adjoint * x / (y * y + x * x);
1002 double grad_r(
double y,
double x,
double parent_adjoint)
const override {
1003 return parent_adjoint * -y / (y * y + x * x);
1009 return parent_adjoint * x / (y * y + x * x);
1015 return parent_adjoint * -y / (y * y + x * x);
1025inline ExpressionPtr atan2(
const ExpressionPtr& y,
const ExpressionPtr& x) {
1026 using enum ExpressionType;
1029 if (y->is_constant(0.0)) {
1032 }
else if (x->is_constant(0.0)) {
1033 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
1037 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1038 return make_expression_ptr<ConstExpression>(std::atan2(y->val, x->val));
1041 return make_expression_ptr<Atan2Expression>(y, x);
1056 double value(
double x,
double)
const override {
return std::cos(x); }
1058 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1060 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1061 return -parent_adjoint * std::sin(x);
1067 return parent_adjoint * -slp::detail::sin(x);
1076inline ExpressionPtr cos(
const ExpressionPtr& x) {
1077 using enum ExpressionType;
1080 if (x->is_constant(0.0)) {
1081 return make_expression_ptr<ConstExpression>(1.0);
1085 if (x->type() == CONSTANT) {
1086 return make_expression_ptr<ConstExpression>(std::cos(x->val));
1089 return make_expression_ptr<CosExpression>(x);
1104 double value(
double x,
double)
const override {
return std::cosh(x); }
1106 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1108 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1109 return parent_adjoint * std::sinh(x);
1115 return parent_adjoint * slp::detail::sinh(x);
1124inline ExpressionPtr cosh(
const ExpressionPtr& x) {
1125 using enum ExpressionType;
1128 if (x->is_constant(0.0)) {
1129 return make_expression_ptr<ConstExpression>(1.0);
1133 if (x->type() == CONSTANT) {
1134 return make_expression_ptr<ConstExpression>(std::cosh(x->val));
1137 return make_expression_ptr<CoshExpression>(x);
1152 double value(
double x,
double)
const override {
return std::erf(x); }
1154 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1156 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1157 return parent_adjoint * 2.0 * std::numbers::inv_sqrtpi * std::exp(-x * x);
1163 return parent_adjoint *
1164 make_expression_ptr<ConstExpression>(2.0 *
1165 std::numbers::inv_sqrtpi) *
1166 slp::detail::exp(-x * x);
1175inline ExpressionPtr erf(
const ExpressionPtr& x) {
1176 using enum ExpressionType;
1179 if (x->is_constant(0.0)) {
1185 if (x->type() == CONSTANT) {
1186 return make_expression_ptr<ConstExpression>(std::erf(x->val));
1189 return make_expression_ptr<ErfExpression>(x);
1204 double value(
double x,
double)
const override {
return std::exp(x); }
1206 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1208 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1209 return parent_adjoint * std::exp(x);
1215 return parent_adjoint * slp::detail::exp(x);
1224inline ExpressionPtr exp(
const ExpressionPtr& x) {
1225 using enum ExpressionType;
1228 if (x->is_constant(0.0)) {
1229 return make_expression_ptr<ConstExpression>(1.0);
1233 if (x->type() == CONSTANT) {
1234 return make_expression_ptr<ConstExpression>(std::exp(x->val));
1237 return make_expression_ptr<ExpExpression>(x);
1240inline ExpressionPtr hypot(
const ExpressionPtr& x,
const ExpressionPtr& y);
1253 :
Expression{std::move(lhs), std::move(rhs)} {}
1255 double value(
double x,
double y)
const override {
return std::hypot(x, y); }
1257 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1259 double grad_l(
double x,
double y,
double parent_adjoint)
const override {
1260 return parent_adjoint * x / std::hypot(x, y);
1263 double grad_r(
double x,
double y,
double parent_adjoint)
const override {
1264 return parent_adjoint * y / std::hypot(x, y);
1270 return parent_adjoint * x / slp::detail::hypot(x, y);
1276 return parent_adjoint * y / slp::detail::hypot(x, y);
1286inline ExpressionPtr hypot(
const ExpressionPtr& x,
const ExpressionPtr& y) {
1287 using enum ExpressionType;
1290 if (x->is_constant(0.0)) {
1292 }
else if (y->is_constant(0.0)) {
1297 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1298 return make_expression_ptr<ConstExpression>(std::hypot(x->val, y->val));
1301 return make_expression_ptr<HypotExpression>(x, y);
1316 double value(
double x,
double)
const override {
return std::log(x); }
1318 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1320 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1321 return parent_adjoint / x;
1327 return parent_adjoint / x;
1336inline ExpressionPtr log(
const ExpressionPtr& x) {
1337 using enum ExpressionType;
1340 if (x->is_constant(0.0)) {
1346 if (x->type() == CONSTANT) {
1347 return make_expression_ptr<ConstExpression>(std::log(x->val));
1350 return make_expression_ptr<LogExpression>(x);
1365 double value(
double x,
double)
const override {
return std::log10(x); }
1367 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1369 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1370 return parent_adjoint / (std::numbers::ln10 * x);
1376 return parent_adjoint /
1377 (make_expression_ptr<ConstExpression>(std::numbers::ln10) * x);
1386inline ExpressionPtr log10(
const ExpressionPtr& x) {
1387 using enum ExpressionType;
1390 if (x->is_constant(0.0)) {
1396 if (x->type() == CONSTANT) {
1397 return make_expression_ptr<ConstExpression>(std::log10(x->val));
1400 return make_expression_ptr<Log10Expression>(x);
1403inline ExpressionPtr pow(
const ExpressionPtr& base,
const ExpressionPtr& power);
1410template <ExpressionType T>
1419 :
Expression{std::move(lhs), std::move(rhs)} {}
1421 double value(
double base,
double power)
const override {
1422 return std::pow(base, power);
1425 ExpressionType
type()
const override {
return T; }
1428 double parent_adjoint)
const override {
1429 return parent_adjoint * std::pow(base, power - 1) * power;
1433 double parent_adjoint)
const override {
1438 return parent_adjoint * std::pow(base, power - 1) * base * std::log(base);
1445 return parent_adjoint *
1446 slp::detail::pow(base,
1447 power - make_expression_ptr<ConstExpression>(1.0)) *
1455 if (base->val == 0.0) {
1459 return parent_adjoint *
1461 base, power - make_expression_ptr<ConstExpression>(1.0)) *
1462 base * slp::detail::log(base);
1473inline ExpressionPtr pow(
const ExpressionPtr& base,
1474 const ExpressionPtr& power) {
1475 using enum ExpressionType;
1478 if (base->is_constant(0.0)) {
1481 }
else if (base->is_constant(1.0)) {
1485 if (power->is_constant(0.0)) {
1486 return make_expression_ptr<ConstExpression>(1.0);
1487 }
else if (power->is_constant(1.0)) {
1492 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1493 return make_expression_ptr<ConstExpression>(
1494 std::pow(base->val, power->val));
1497 if (power->is_constant(2.0)) {
1498 if (base->type() == LINEAR) {
1499 return make_expression_ptr<MultExpression<QUADRATIC>>(base, base);
1501 return make_expression_ptr<MultExpression<NONLINEAR>>(base, base);
1505 return make_expression_ptr<PowExpression<NONLINEAR>>(base, power);
1520 double value(
double x,
double)
const override {
1523 }
else if (x == 0.0) {
1530 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1532 double grad_l(
double,
double,
double)
const override {
return 0.0; }
1537 return make_expression_ptr<ConstExpression>();
1546inline ExpressionPtr sign(
const ExpressionPtr& x) {
1547 using enum ExpressionType;
1550 if (x->type() == CONSTANT) {
1552 return make_expression_ptr<ConstExpression>(-1.0);
1553 }
else if (x->val == 0.0) {
1557 return make_expression_ptr<ConstExpression>(1.0);
1561 return make_expression_ptr<SignExpression>(x);
1576 double value(
double x,
double)
const override {
return std::sin(x); }
1578 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1580 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1581 return parent_adjoint * std::cos(x);
1587 return parent_adjoint * slp::detail::cos(x);
1596inline ExpressionPtr sin(
const ExpressionPtr& x) {
1597 using enum ExpressionType;
1600 if (x->is_constant(0.0)) {
1606 if (x->type() == CONSTANT) {
1607 return make_expression_ptr<ConstExpression>(std::sin(x->val));
1610 return make_expression_ptr<SinExpression>(x);
1625 double value(
double x,
double)
const override {
return std::sinh(x); }
1627 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1629 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1630 return parent_adjoint * std::cosh(x);
1636 return parent_adjoint * slp::detail::cosh(x);
1645inline ExpressionPtr sinh(
const ExpressionPtr& x) {
1646 using enum ExpressionType;
1649 if (x->is_constant(0.0)) {
1655 if (x->type() == CONSTANT) {
1656 return make_expression_ptr<ConstExpression>(std::sinh(x->val));
1659 return make_expression_ptr<SinhExpression>(x);
1674 double value(
double x,
double)
const override {
return std::sqrt(x); }
1676 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1678 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1679 return parent_adjoint / (2.0 * std::sqrt(x));
1685 return parent_adjoint /
1686 (make_expression_ptr<ConstExpression>(2.0) * slp::detail::sqrt(x));
1695inline ExpressionPtr sqrt(
const ExpressionPtr& x) {
1696 using enum ExpressionType;
1699 if (x->type() == CONSTANT) {
1700 if (x->val == 0.0) {
1703 }
else if (x->val == 1.0) {
1706 return make_expression_ptr<ConstExpression>(std::sqrt(x->val));
1710 return make_expression_ptr<SqrtExpression>(x);
1725 double value(
double x,
double)
const override {
return std::tan(x); }
1727 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1729 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1730 return parent_adjoint / (std::cos(x) * std::cos(x));
1736 return parent_adjoint / (slp::detail::cos(x) * slp::detail::cos(x));
1745inline ExpressionPtr tan(
const ExpressionPtr& x) {
1746 using enum ExpressionType;
1749 if (x->is_constant(0.0)) {
1755 if (x->type() == CONSTANT) {
1756 return make_expression_ptr<ConstExpression>(std::tan(x->val));
1759 return make_expression_ptr<TanExpression>(x);
1774 double value(
double x,
double)
const override {
return std::tanh(x); }
1776 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1778 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1779 return parent_adjoint / (std::cosh(x) * std::cosh(x));
1785 return parent_adjoint / (slp::detail::cosh(x) * slp::detail::cosh(x));
1794inline ExpressionPtr tanh(
const ExpressionPtr& x) {
1795 using enum ExpressionType;
1798 if (x->is_constant(0.0)) {
1804 if (x->type() == CONSTANT) {
1805 return make_expression_ptr<ConstExpression>(std::tanh(x->val));
1808 return make_expression_ptr<TanhExpression>(x);
Definition expression.hpp:774
ExpressionType type() const override
Definition expression.hpp:785
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:787
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:797
double value(double x, double) const override
Definition expression.hpp:783
constexpr AbsExpression(ExpressionPtr lhs)
Definition expression.hpp:780
Definition expression.hpp:836
ExpressionType type() const override
Definition expression.hpp:847
constexpr AcosExpression(ExpressionPtr lhs)
Definition expression.hpp:842
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:849
double value(double x, double) const override
Definition expression.hpp:845
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:853
Definition expression.hpp:885
double value(double x, double) const override
Definition expression.hpp:894
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:902
ExpressionType type() const override
Definition expression.hpp:896
constexpr AsinExpression(ExpressionPtr lhs)
Definition expression.hpp:891
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:898
Definition expression.hpp:984
ExpressionPtr grad_expr_r(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1012
double value(double y, double x) const override
Definition expression.hpp:994
ExpressionType type() const override
Definition expression.hpp:996
constexpr Atan2Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:991
double grad_l(double y, double x, double parent_adjoint) const override
Definition expression.hpp:998
double grad_r(double y, double x, double parent_adjoint) const override
Definition expression.hpp:1002
ExpressionPtr grad_expr_l(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1006
Definition expression.hpp:935
constexpr AtanExpression(ExpressionPtr lhs)
Definition expression.hpp:941
double value(double x, double) const override
Definition expression.hpp:944
ExpressionType type() const override
Definition expression.hpp:946
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:952
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:948
Definition expression.hpp:437
double value(double lhs, double rhs) const override
Definition expression.hpp:447
constexpr BinaryMinusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:444
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:455
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:451
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:465
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:459
ExpressionType type() const override
Definition expression.hpp:449
Definition expression.hpp:478
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:496
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:492
double value(double lhs, double rhs) const override
Definition expression.hpp:488
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:506
constexpr BinaryPlusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:485
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:500
ExpressionType type() const override
Definition expression.hpp:490
Definition expression.hpp:516
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:529
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:534
ExpressionType type() const override
Definition expression.hpp:527
constexpr CbrtExpression(ExpressionPtr lhs)
Definition expression.hpp:522
double value(double x, double) const override
Definition expression.hpp:525
Definition expression.hpp:568
constexpr ConstExpression(double value)
Definition expression.hpp:579
constexpr ConstExpression()=default
double value(double, double) const override
Definition expression.hpp:581
ExpressionType type() const override
Definition expression.hpp:583
Definition expression.hpp:1047
ExpressionType type() const override
Definition expression.hpp:1058
double value(double x, double) const override
Definition expression.hpp:1056
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1064
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1060
constexpr CosExpression(ExpressionPtr lhs)
Definition expression.hpp:1053
Definition expression.hpp:1095
constexpr CoshExpression(ExpressionPtr lhs)
Definition expression.hpp:1101
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1108
double value(double x, double) const override
Definition expression.hpp:1104
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1112
ExpressionType type() const override
Definition expression.hpp:1106
Definition expression.hpp:589
double value(double, double) const override
Definition expression.hpp:603
constexpr DecisionVariableExpression()=default
constexpr DecisionVariableExpression(double value)
Definition expression.hpp:600
ExpressionType type() const override
Definition expression.hpp:605
Definition expression.hpp:614
double grad_l(double, double rhs, double parent_adjoint) const override
Definition expression.hpp:628
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:642
double value(double lhs, double rhs) const override
Definition expression.hpp:624
constexpr DivExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:621
ExpressionType type() const override
Definition expression.hpp:626
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:636
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:632
Definition expression.hpp:1143
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1160
double value(double x, double) const override
Definition expression.hpp:1152
constexpr ErfExpression(ExpressionPtr lhs)
Definition expression.hpp:1149
ExpressionType type() const override
Definition expression.hpp:1154
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1156
Definition expression.hpp:1195
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1208
ExpressionType type() const override
Definition expression.hpp:1206
constexpr ExpExpression(ExpressionPtr lhs)
Definition expression.hpp:1201
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1212
double value(double x, double) const override
Definition expression.hpp:1204
Definition expression.hpp:77
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:88
virtual double grad_l(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:374
constexpr bool is_constant(double constant) const
Definition expression.hpp:138
constexpr Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:126
constexpr Expression(double value)
Definition expression.hpp:110
virtual ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:417
virtual ExpressionType type() const =0
ExpressionPtr adjoint_expr
Definition expression.hpp:92
friend ExpressionPtr operator+(const ExpressionPtr &lhs)
Definition expression.hpp:344
constexpr Expression(ExpressionPtr lhs)
Definition expression.hpp:117
double adjoint
The adjoint of the expression node used during autodiff.
Definition expression.hpp:82
friend ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:148
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:85
friend ExpressionPtr operator+=(ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:269
friend ExpressionPtr operator-(const ExpressionPtr &lhs)
Definition expression.hpp:316
friend ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:280
friend ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:200
virtual double grad_r(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:388
virtual double value(double lhs, double rhs) const =0
virtual ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:402
std::array< ExpressionPtr, 2 > args
Expression arguments.
Definition expression.hpp:98
double val
The value of the expression node.
Definition expression.hpp:79
constexpr Expression()=default
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:95
friend ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:237
Definition expression.hpp:1245
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1267
double grad_r(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1263
double grad_l(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1259
ExpressionPtr grad_expr_r(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1273
double value(double x, double y) const override
Definition expression.hpp:1255
ExpressionType type() const override
Definition expression.hpp:1257
constexpr HypotExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1252
Definition expression.hpp:1356
constexpr Log10Expression(ExpressionPtr lhs)
Definition expression.hpp:1362
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1369
double value(double x, double) const override
Definition expression.hpp:1365
ExpressionType type() const override
Definition expression.hpp:1367
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1373
Definition expression.hpp:1307
ExpressionType type() const override
Definition expression.hpp:1318
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1324
constexpr LogExpression(ExpressionPtr lhs)
Definition expression.hpp:1313
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1320
double value(double x, double) const override
Definition expression.hpp:1316
Definition expression.hpp:655
ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:679
constexpr MultExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:662
ExpressionType type() const override
Definition expression.hpp:667
double grad_l(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:669
double value(double lhs, double rhs) const override
Definition expression.hpp:665
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:685
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:674
Definition expression.hpp:1411
ExpressionPtr grad_expr_l(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1442
double value(double base, double power) const override
Definition expression.hpp:1421
double grad_l(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1427
double grad_r(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1432
ExpressionType type() const override
Definition expression.hpp:1425
ExpressionPtr grad_expr_r(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1451
constexpr PowExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1418
Definition expression.hpp:1511
double grad_l(double, double, double) const override
Definition expression.hpp:1532
ExpressionType type() const override
Definition expression.hpp:1530
constexpr SignExpression(ExpressionPtr lhs)
Definition expression.hpp:1517
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) const override
Definition expression.hpp:1534
double value(double x, double) const override
Definition expression.hpp:1520
Definition expression.hpp:1567
double value(double x, double) const override
Definition expression.hpp:1576
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1580
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1584
ExpressionType type() const override
Definition expression.hpp:1578
constexpr SinExpression(ExpressionPtr lhs)
Definition expression.hpp:1573
Definition expression.hpp:1616
ExpressionType type() const override
Definition expression.hpp:1627
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1629
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1633
constexpr SinhExpression(ExpressionPtr lhs)
Definition expression.hpp:1622
double value(double x, double) const override
Definition expression.hpp:1625
Definition expression.hpp:1665
ExpressionType type() const override
Definition expression.hpp:1676
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1678
constexpr SqrtExpression(ExpressionPtr lhs)
Definition expression.hpp:1671
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1682
double value(double x, double) const override
Definition expression.hpp:1674
Definition expression.hpp:1716
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1729
double value(double x, double) const override
Definition expression.hpp:1725
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1733
ExpressionType type() const override
Definition expression.hpp:1727
constexpr TanExpression(ExpressionPtr lhs)
Definition expression.hpp:1722
Definition expression.hpp:1765
constexpr TanhExpression(ExpressionPtr lhs)
Definition expression.hpp:1771
ExpressionType type() const override
Definition expression.hpp:1776
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1778
double value(double x, double) const override
Definition expression.hpp:1774
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1782
Definition expression.hpp:698
constexpr UnaryMinusExpression(ExpressionPtr lhs)
Definition expression.hpp:704
ExpressionType type() const override
Definition expression.hpp:709
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:715
double value(double lhs, double) const override
Definition expression.hpp:707
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:711