14#include "sleipnir/autodiff/expression_type.hpp"
15#include "sleipnir/util/intrusive_shared_ptr.hpp"
16#include "sleipnir/util/pool.hpp"
17#include "sleipnir/util/small_vector.hpp"
19namespace slp::detail {
24inline constexpr bool USE_POOL_ALLOCATOR =
false;
26inline constexpr bool USE_POOL_ALLOCATOR =
true;
31inline constexpr void inc_ref_count(Expression* expr);
32inline constexpr void dec_ref_count(Expression* expr);
37using ExpressionPtr = IntrusiveSharedPtr<Expression>;
46template <
typename T,
typename... Args>
47static ExpressionPtr make_expression_ptr(Args&&... args) {
48 if constexpr (USE_POOL_ALLOCATOR) {
49 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
50 std::forward<Args>(args)...);
52 return make_intrusive_shared<T>(std::forward<Args>(args)...);
56template <ExpressionType T>
57struct BinaryMinusExpression;
59template <ExpressionType T>
60struct BinaryPlusExpression;
62struct ConstExpression;
64template <ExpressionType T>
67template <ExpressionType T>
70template <ExpressionType T>
71struct UnaryMinusExpression;
97 std::array<ExpressionPtr, 2>
args{
nullptr,
nullptr};
117 :
args{std::move(lhs), nullptr} {}
126 :
args{std::move(lhs), std::move(rhs)} {}
138 return type() == ExpressionType::CONSTANT &&
val == constant;
149 using enum ExpressionType;
152 if (lhs->is_constant(0.0)) {
155 }
else if (rhs->is_constant(0.0)) {
158 }
else if (lhs->is_constant(1.0)) {
160 }
else if (rhs->is_constant(1.0)) {
165 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
166 return make_expression_ptr<ConstExpression>(lhs->val * rhs->val);
170 if (lhs->type() == CONSTANT) {
171 if (rhs->type() == LINEAR) {
172 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
173 }
else if (rhs->type() == QUADRATIC) {
174 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
176 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
178 }
else if (rhs->type() == CONSTANT) {
179 if (lhs->type() == LINEAR) {
180 return make_expression_ptr<MultExpression<LINEAR>>(lhs, rhs);
181 }
else if (lhs->type() == QUADRATIC) {
182 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
184 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
186 }
else if (lhs->type() == LINEAR && rhs->type() == LINEAR) {
187 return make_expression_ptr<MultExpression<QUADRATIC>>(lhs, rhs);
189 return make_expression_ptr<MultExpression<NONLINEAR>>(lhs, rhs);
201 using enum ExpressionType;
204 if (lhs->is_constant(0.0)) {
207 }
else if (rhs->is_constant(1.0)) {
212 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
213 return make_expression_ptr<ConstExpression>(lhs->val / rhs->val);
217 if (rhs->type() == CONSTANT) {
218 if (lhs->type() == LINEAR) {
219 return make_expression_ptr<DivExpression<LINEAR>>(lhs, rhs);
220 }
else if (lhs->type() == QUADRATIC) {
221 return make_expression_ptr<DivExpression<QUADRATIC>>(lhs, rhs);
223 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
226 return make_expression_ptr<DivExpression<NONLINEAR>>(lhs, rhs);
238 using enum ExpressionType;
241 if (lhs ==
nullptr || lhs->is_constant(0.0)) {
243 }
else if (rhs ==
nullptr || rhs->is_constant(0.0)) {
248 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
249 return make_expression_ptr<ConstExpression>(lhs->val + rhs->val);
252 auto type = std::max(lhs->type(), rhs->type());
253 if (
type == LINEAR) {
254 return make_expression_ptr<BinaryPlusExpression<LINEAR>>(lhs, rhs);
255 }
else if (
type == QUADRATIC) {
256 return make_expression_ptr<BinaryPlusExpression<QUADRATIC>>(lhs, rhs);
258 return make_expression_ptr<BinaryPlusExpression<NONLINEAR>>(lhs, rhs);
270 return lhs = lhs + rhs;
281 using enum ExpressionType;
284 if (lhs->is_constant(0.0)) {
285 if (rhs->is_constant(0.0)) {
291 }
else if (rhs->is_constant(0.0)) {
296 if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) {
297 return make_expression_ptr<ConstExpression>(lhs->val - rhs->val);
300 auto type = std::max(lhs->type(), rhs->type());
301 if (
type == LINEAR) {
302 return make_expression_ptr<BinaryMinusExpression<LINEAR>>(lhs, rhs);
303 }
else if (
type == QUADRATIC) {
304 return make_expression_ptr<BinaryMinusExpression<QUADRATIC>>(lhs, rhs);
306 return make_expression_ptr<BinaryMinusExpression<NONLINEAR>>(lhs, rhs);
316 using enum ExpressionType;
319 if (lhs->is_constant(0.0)) {
325 if (lhs->type() == CONSTANT) {
326 return make_expression_ptr<ConstExpression>(-lhs->val);
329 if (lhs->type() == LINEAR) {
330 return make_expression_ptr<UnaryMinusExpression<LINEAR>>(lhs);
331 }
else if (lhs->type() == QUADRATIC) {
332 return make_expression_ptr<UnaryMinusExpression<QUADRATIC>>(lhs);
334 return make_expression_ptr<UnaryMinusExpression<NONLINEAR>>(lhs);
354 virtual double value([[maybe_unused]]
double lhs,
355 [[maybe_unused]]
double rhs)
const = 0;
363 virtual ExpressionType
type()
const = 0;
373 virtual double grad_l([[maybe_unused]]
double lhs,
374 [[maybe_unused]]
double rhs,
375 [[maybe_unused]]
double parent_adjoint)
const {
387 virtual double grad_r([[maybe_unused]]
double lhs,
388 [[maybe_unused]]
double rhs,
389 [[maybe_unused]]
double parent_adjoint)
const {
404 [[maybe_unused]]
const ExpressionPtr& parent_adjoint)
const {
405 return make_expression_ptr<ConstExpression>();
419 [[maybe_unused]]
const ExpressionPtr& parent_adjoint)
const {
420 return make_expression_ptr<ConstExpression>();
429template <ExpressionType T>
438 :
Expression{std::move(lhs), std::move(rhs)} {
442 double value(
double lhs,
double rhs)
const override {
return lhs - rhs; }
444 ExpressionType
type()
const override {
return T; }
446 double grad_l(
double,
double,
double parent_adjoint)
const override {
447 return parent_adjoint;
450 double grad_r(
double,
double,
double parent_adjoint)
const override {
451 return -parent_adjoint;
457 return parent_adjoint;
463 return -parent_adjoint;
472template <ExpressionType T>
481 :
Expression{std::move(lhs), std::move(rhs)} {
485 double value(
double lhs,
double rhs)
const override {
return lhs + rhs; }
487 ExpressionType
type()
const override {
return T; }
489 double grad_l(
double,
double,
double parent_adjoint)
const override {
490 return parent_adjoint;
493 double grad_r(
double,
double,
double parent_adjoint)
const override {
494 return parent_adjoint;
500 return parent_adjoint;
506 return parent_adjoint;
526 double value(
double,
double)
const override {
return val; }
528 ExpressionType
type()
const override {
return ExpressionType::CONSTANT; }
548 double value(
double,
double)
const override {
return val; }
550 ExpressionType
type()
const override {
return ExpressionType::LINEAR; }
558template <ExpressionType T>
567 :
Expression{std::move(lhs), std::move(rhs)} {
571 double value(
double lhs,
double rhs)
const override {
return lhs / rhs; }
573 ExpressionType
type()
const override {
return T; }
575 double grad_l(
double,
double rhs,
double parent_adjoint)
const override {
576 return parent_adjoint / rhs;
579 double grad_r(
double lhs,
double rhs,
double parent_adjoint)
const override {
580 return parent_adjoint * -lhs / (rhs * rhs);
586 return parent_adjoint / rhs;
592 return parent_adjoint * -lhs / (rhs * rhs);
601template <ExpressionType T>
610 :
Expression{std::move(lhs), std::move(rhs)} {
614 double value(
double lhs,
double rhs)
const override {
return lhs * rhs; }
616 ExpressionType
type()
const override {
return T; }
618 double grad_l([[maybe_unused]]
double lhs,
double rhs,
619 double parent_adjoint)
const override {
620 return parent_adjoint * rhs;
623 double grad_r(
double lhs, [[maybe_unused]]
double rhs,
624 double parent_adjoint)
const override {
625 return parent_adjoint * lhs;
631 return parent_adjoint * rhs;
637 return parent_adjoint * lhs;
646template <ExpressionType T>
658 double value(
double lhs,
double)
const override {
return -lhs; }
660 ExpressionType
type()
const override {
return T; }
662 double grad_l(
double,
double,
double parent_adjoint)
const override {
663 return -parent_adjoint;
669 return -parent_adjoint;
673inline ExpressionPtr exp(
const ExpressionPtr& x);
674inline ExpressionPtr sin(
const ExpressionPtr& x);
675inline ExpressionPtr sinh(
const ExpressionPtr& x);
676inline ExpressionPtr sqrt(
const ExpressionPtr& x);
683inline constexpr void inc_ref_count(Expression* expr) {
692inline constexpr void dec_ref_count(Expression* expr) {
697 small_vector<Expression*> stack;
698 stack.emplace_back(expr);
700 while (!stack.empty()) {
701 auto elem = stack.back();
706 if (--elem->ref_count == 0) {
707 if (elem->adjoint_expr !=
nullptr) {
708 stack.emplace_back(elem->adjoint_expr.get());
710 for (
auto& arg : elem->args) {
711 if (arg !=
nullptr) {
712 stack.emplace_back(arg.get());
718 if constexpr (USE_POOL_ALLOCATOR) {
719 auto alloc = global_pool_allocator<Expression>();
720 std::allocator_traits<
decltype(alloc)>::deallocate(alloc, elem,
741 double value(
double x,
double)
const override {
return std::abs(x); }
743 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
745 double grad_l(
double x,
double,
double parent_adjoint)
const override {
747 return -parent_adjoint;
748 }
else if (x > 0.0) {
749 return parent_adjoint;
759 return -parent_adjoint;
760 }
else if (x->val > 0.0) {
761 return parent_adjoint;
764 return make_expression_ptr<ConstExpression>();
774inline ExpressionPtr abs(
const ExpressionPtr& x) {
775 using enum ExpressionType;
778 if (x->is_constant(0.0)) {
784 if (x->type() == CONSTANT) {
785 return make_expression_ptr<ConstExpression>(std::abs(x->val));
788 return make_expression_ptr<AbsExpression>(x);
804 double value(
double x,
double)
const override {
return std::acos(x); }
806 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
808 double grad_l(
double x,
double,
double parent_adjoint)
const override {
809 return -parent_adjoint / std::sqrt(1.0 - x * x);
815 return -parent_adjoint /
816 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
825inline ExpressionPtr acos(
const ExpressionPtr& x) {
826 using enum ExpressionType;
829 if (x->is_constant(0.0)) {
830 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
834 if (x->type() == CONSTANT) {
835 return make_expression_ptr<ConstExpression>(std::acos(x->val));
838 return make_expression_ptr<AcosExpression>(x);
854 double value(
double x,
double)
const override {
return std::asin(x); }
856 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
858 double grad_l(
double x,
double,
double parent_adjoint)
const override {
859 return parent_adjoint / std::sqrt(1.0 - x * x);
865 return parent_adjoint /
866 slp::detail::sqrt(make_expression_ptr<ConstExpression>(1.0) - x * x);
875inline ExpressionPtr asin(
const ExpressionPtr& x) {
876 using enum ExpressionType;
879 if (x->is_constant(0.0)) {
885 if (x->type() == CONSTANT) {
886 return make_expression_ptr<ConstExpression>(std::asin(x->val));
889 return make_expression_ptr<AsinExpression>(x);
905 double value(
double x,
double)
const override {
return std::atan(x); }
907 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
909 double grad_l(
double x,
double,
double parent_adjoint)
const override {
910 return parent_adjoint / (1.0 + x * x);
916 return parent_adjoint / (make_expression_ptr<ConstExpression>(1.0) + x * x);
925inline ExpressionPtr atan(
const ExpressionPtr& x) {
926 using enum ExpressionType;
929 if (x->is_constant(0.0)) {
935 if (x->type() == CONSTANT) {
936 return make_expression_ptr<ConstExpression>(std::atan(x->val));
939 return make_expression_ptr<AtanExpression>(x);
953 :
Expression{std::move(lhs), std::move(rhs)} {
957 double value(
double y,
double x)
const override {
return std::atan2(y, x); }
959 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
961 double grad_l(
double y,
double x,
double parent_adjoint)
const override {
962 return parent_adjoint * x / (y * y + x * x);
965 double grad_r(
double y,
double x,
double parent_adjoint)
const override {
966 return parent_adjoint * -y / (y * y + x * x);
972 return parent_adjoint * x / (y * y + x * x);
978 return parent_adjoint * -y / (y * y + x * x);
988inline ExpressionPtr atan2(
const ExpressionPtr& y,
const ExpressionPtr& x) {
989 using enum ExpressionType;
992 if (y->is_constant(0.0)) {
995 }
else if (x->is_constant(0.0)) {
996 return make_expression_ptr<ConstExpression>(std::numbers::pi / 2.0);
1000 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1001 return make_expression_ptr<ConstExpression>(std::atan2(y->val, x->val));
1004 return make_expression_ptr<Atan2Expression>(y, x);
1020 double value(
double x,
double)
const override {
return std::cos(x); }
1022 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1024 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1025 return -parent_adjoint * std::sin(x);
1031 return parent_adjoint * -slp::detail::sin(x);
1040inline ExpressionPtr cos(
const ExpressionPtr& x) {
1041 using enum ExpressionType;
1044 if (x->is_constant(0.0)) {
1045 return make_expression_ptr<ConstExpression>(1.0);
1049 if (x->type() == CONSTANT) {
1050 return make_expression_ptr<ConstExpression>(std::cos(x->val));
1053 return make_expression_ptr<CosExpression>(x);
1069 double value(
double x,
double)
const override {
return std::cosh(x); }
1071 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1073 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1074 return parent_adjoint * std::sinh(x);
1080 return parent_adjoint * slp::detail::sinh(x);
1089inline ExpressionPtr cosh(
const ExpressionPtr& x) {
1090 using enum ExpressionType;
1093 if (x->is_constant(0.0)) {
1094 return make_expression_ptr<ConstExpression>(1.0);
1098 if (x->type() == CONSTANT) {
1099 return make_expression_ptr<ConstExpression>(std::cosh(x->val));
1102 return make_expression_ptr<CoshExpression>(x);
1118 double value(
double x,
double)
const override {
return std::erf(x); }
1120 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1122 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1123 return parent_adjoint * 2.0 * std::numbers::inv_sqrtpi * std::exp(-x * x);
1129 return parent_adjoint *
1130 make_expression_ptr<ConstExpression>(2.0 *
1131 std::numbers::inv_sqrtpi) *
1132 slp::detail::exp(-x * x);
1141inline ExpressionPtr erf(
const ExpressionPtr& x) {
1142 using enum ExpressionType;
1145 if (x->is_constant(0.0)) {
1151 if (x->type() == CONSTANT) {
1152 return make_expression_ptr<ConstExpression>(std::erf(x->val));
1155 return make_expression_ptr<ErfExpression>(x);
1171 double value(
double x,
double)
const override {
return std::exp(x); }
1173 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1175 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1176 return parent_adjoint * std::exp(x);
1182 return parent_adjoint * slp::detail::exp(x);
1191inline ExpressionPtr exp(
const ExpressionPtr& x) {
1192 using enum ExpressionType;
1195 if (x->is_constant(0.0)) {
1196 return make_expression_ptr<ConstExpression>(1.0);
1200 if (x->type() == CONSTANT) {
1201 return make_expression_ptr<ConstExpression>(std::exp(x->val));
1204 return make_expression_ptr<ExpExpression>(x);
1207inline ExpressionPtr hypot(
const ExpressionPtr& x,
const ExpressionPtr& y);
1220 :
Expression{std::move(lhs), std::move(rhs)} {
1224 double value(
double x,
double y)
const override {
return std::hypot(x, y); }
1226 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1228 double grad_l(
double x,
double y,
double parent_adjoint)
const override {
1229 return parent_adjoint * x / std::hypot(x, y);
1232 double grad_r(
double x,
double y,
double parent_adjoint)
const override {
1233 return parent_adjoint * y / std::hypot(x, y);
1239 return parent_adjoint * x / slp::detail::hypot(x, y);
1245 return parent_adjoint * y / slp::detail::hypot(x, y);
1255inline ExpressionPtr hypot(
const ExpressionPtr& x,
const ExpressionPtr& y) {
1256 using enum ExpressionType;
1259 if (x->is_constant(0.0)) {
1261 }
else if (y->is_constant(0.0)) {
1266 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1267 return make_expression_ptr<ConstExpression>(std::hypot(x->val, y->val));
1270 return make_expression_ptr<HypotExpression>(x, y);
1286 double value(
double x,
double)
const override {
return std::log(x); }
1288 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1290 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1291 return parent_adjoint / x;
1297 return parent_adjoint / x;
1306inline ExpressionPtr log(
const ExpressionPtr& x) {
1307 using enum ExpressionType;
1310 if (x->is_constant(0.0)) {
1316 if (x->type() == CONSTANT) {
1317 return make_expression_ptr<ConstExpression>(std::log(x->val));
1320 return make_expression_ptr<LogExpression>(x);
1336 double value(
double x,
double)
const override {
return std::log10(x); }
1338 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1340 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1341 return parent_adjoint / (std::numbers::ln10 * x);
1347 return parent_adjoint /
1348 (make_expression_ptr<ConstExpression>(std::numbers::ln10) * x);
1357inline ExpressionPtr log10(
const ExpressionPtr& x) {
1358 using enum ExpressionType;
1361 if (x->is_constant(0.0)) {
1367 if (x->type() == CONSTANT) {
1368 return make_expression_ptr<ConstExpression>(std::log10(x->val));
1371 return make_expression_ptr<Log10Expression>(x);
1374inline ExpressionPtr pow(
const ExpressionPtr& base,
const ExpressionPtr& power);
1381template <ExpressionType T>
1390 :
Expression{std::move(lhs), std::move(rhs)} {
1394 double value(
double base,
double power)
const override {
1395 return std::pow(base, power);
1398 ExpressionType
type()
const override {
return T; }
1401 double parent_adjoint)
const override {
1402 return parent_adjoint * std::pow(base, power - 1) * power;
1406 double parent_adjoint)
const override {
1411 return parent_adjoint * std::pow(base, power - 1) * base * std::log(base);
1418 return parent_adjoint *
1419 slp::detail::pow(base,
1420 power - make_expression_ptr<ConstExpression>(1.0)) *
1428 if (base->val == 0.0) {
1432 return parent_adjoint *
1434 base, power - make_expression_ptr<ConstExpression>(1.0)) *
1435 base * slp::detail::log(base);
1446inline ExpressionPtr pow(
const ExpressionPtr& base,
1447 const ExpressionPtr& power) {
1448 using enum ExpressionType;
1451 if (base->is_constant(0.0)) {
1454 }
else if (base->is_constant(1.0)) {
1458 if (power->is_constant(0.0)) {
1459 return make_expression_ptr<ConstExpression>(1.0);
1460 }
else if (power->is_constant(1.0)) {
1465 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1466 return make_expression_ptr<ConstExpression>(
1467 std::pow(base->val, power->val));
1470 if (power->is_constant(2.0)) {
1471 if (base->type() == LINEAR) {
1472 return make_expression_ptr<MultExpression<QUADRATIC>>(base, base);
1474 return make_expression_ptr<MultExpression<NONLINEAR>>(base, base);
1478 return make_expression_ptr<PowExpression<NONLINEAR>>(base, power);
1494 }
else if (
args[0]->
val == 0.0) {
1501 double value(
double x,
double)
const override {
1504 }
else if (x == 0.0) {
1511 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1513 double grad_l(
double,
double,
double)
const override {
return 0.0; }
1518 return make_expression_ptr<ConstExpression>();
1527inline ExpressionPtr sign(
const ExpressionPtr& x) {
1528 using enum ExpressionType;
1531 if (x->type() == CONSTANT) {
1533 return make_expression_ptr<ConstExpression>(-1.0);
1534 }
else if (x->val == 0.0) {
1538 return make_expression_ptr<ConstExpression>(1.0);
1542 return make_expression_ptr<SignExpression>(x);
1558 double value(
double x,
double)
const override {
return std::sin(x); }
1560 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1562 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1563 return parent_adjoint * std::cos(x);
1569 return parent_adjoint * slp::detail::cos(x);
1578inline ExpressionPtr sin(
const ExpressionPtr& x) {
1579 using enum ExpressionType;
1582 if (x->is_constant(0.0)) {
1588 if (x->type() == CONSTANT) {
1589 return make_expression_ptr<ConstExpression>(std::sin(x->val));
1592 return make_expression_ptr<SinExpression>(x);
1608 double value(
double x,
double)
const override {
return std::sinh(x); }
1610 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1612 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1613 return parent_adjoint * std::cosh(x);
1619 return parent_adjoint * slp::detail::cosh(x);
1628inline ExpressionPtr sinh(
const ExpressionPtr& x) {
1629 using enum ExpressionType;
1632 if (x->is_constant(0.0)) {
1638 if (x->type() == CONSTANT) {
1639 return make_expression_ptr<ConstExpression>(std::sinh(x->val));
1642 return make_expression_ptr<SinhExpression>(x);
1658 double value(
double x,
double)
const override {
return std::sqrt(x); }
1660 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1662 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1663 return parent_adjoint / (2.0 * std::sqrt(x));
1669 return parent_adjoint /
1670 (make_expression_ptr<ConstExpression>(2.0) * slp::detail::sqrt(x));
1679inline ExpressionPtr sqrt(
const ExpressionPtr& x) {
1680 using enum ExpressionType;
1683 if (x->type() == CONSTANT) {
1684 if (x->val == 0.0) {
1687 }
else if (x->val == 1.0) {
1690 return make_expression_ptr<ConstExpression>(std::sqrt(x->val));
1694 return make_expression_ptr<SqrtExpression>(x);
1710 double value(
double x,
double)
const override {
return std::tan(x); }
1712 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1714 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1715 return parent_adjoint / (std::cos(x) * std::cos(x));
1721 return parent_adjoint / (slp::detail::cos(x) * slp::detail::cos(x));
1730inline ExpressionPtr tan(
const ExpressionPtr& x) {
1731 using enum ExpressionType;
1734 if (x->is_constant(0.0)) {
1740 if (x->type() == CONSTANT) {
1741 return make_expression_ptr<ConstExpression>(std::tan(x->val));
1744 return make_expression_ptr<TanExpression>(x);
1760 double value(
double x,
double)
const override {
return std::tanh(x); }
1762 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1764 double grad_l(
double x,
double,
double parent_adjoint)
const override {
1765 return parent_adjoint / (std::cosh(x) * std::cosh(x));
1771 return parent_adjoint / (slp::detail::cosh(x) * slp::detail::cosh(x));
1780inline ExpressionPtr tanh(
const ExpressionPtr& x) {
1781 using enum ExpressionType;
1784 if (x->is_constant(0.0)) {
1790 if (x->type() == CONSTANT) {
1791 return make_expression_ptr<ConstExpression>(std::tanh(x->val));
1794 return make_expression_ptr<TanhExpression>(x);
Definition expression.hpp:730
ExpressionType type() const override
Definition expression.hpp:743
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:745
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:755
double value(double x, double) const override
Definition expression.hpp:741
constexpr AbsExpression(ExpressionPtr lhs)
Definition expression.hpp:736
Definition expression.hpp:794
ExpressionType type() const override
Definition expression.hpp:806
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:808
double value(double x, double) const override
Definition expression.hpp:804
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:812
AcosExpression(ExpressionPtr lhs)
Definition expression.hpp:800
Definition expression.hpp:844
double value(double x, double) const override
Definition expression.hpp:854
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:862
ExpressionType type() const override
Definition expression.hpp:856
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:858
AsinExpression(ExpressionPtr lhs)
Definition expression.hpp:850
Definition expression.hpp:945
ExpressionPtr grad_expr_r(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:975
double value(double y, double x) const override
Definition expression.hpp:957
ExpressionType type() const override
Definition expression.hpp:959
double grad_l(double y, double x, double parent_adjoint) const override
Definition expression.hpp:961
double grad_r(double y, double x, double parent_adjoint) const override
Definition expression.hpp:965
ExpressionPtr grad_expr_l(const ExpressionPtr &y, const ExpressionPtr &x, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:969
Atan2Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:952
Definition expression.hpp:895
double value(double x, double) const override
Definition expression.hpp:905
AtanExpression(ExpressionPtr lhs)
Definition expression.hpp:901
ExpressionType type() const override
Definition expression.hpp:907
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:913
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:909
Definition expression.hpp:430
double value(double lhs, double rhs) const override
Definition expression.hpp:442
constexpr BinaryMinusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:437
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:450
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:446
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:460
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:454
ExpressionType type() const override
Definition expression.hpp:444
Definition expression.hpp:473
double grad_r(double, double, double parent_adjoint) const override
Definition expression.hpp:493
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:489
double value(double lhs, double rhs) const override
Definition expression.hpp:485
ExpressionPtr grad_expr_r(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:503
constexpr BinaryPlusExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:480
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:497
ExpressionType type() const override
Definition expression.hpp:487
Definition expression.hpp:513
constexpr ConstExpression(double value)
Definition expression.hpp:524
constexpr ConstExpression()=default
double value(double, double) const override
Definition expression.hpp:526
ExpressionType type() const override
Definition expression.hpp:528
Definition expression.hpp:1010
ExpressionType type() const override
Definition expression.hpp:1022
double value(double x, double) const override
Definition expression.hpp:1020
CosExpression(ExpressionPtr lhs)
Definition expression.hpp:1016
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1028
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1024
Definition expression.hpp:1059
CoshExpression(ExpressionPtr lhs)
Definition expression.hpp:1065
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1073
double value(double x, double) const override
Definition expression.hpp:1069
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1077
ExpressionType type() const override
Definition expression.hpp:1071
Definition expression.hpp:534
double value(double, double) const override
Definition expression.hpp:548
constexpr DecisionVariableExpression()=default
constexpr DecisionVariableExpression(double value)
Definition expression.hpp:545
ExpressionType type() const override
Definition expression.hpp:550
Definition expression.hpp:559
double grad_l(double, double rhs, double parent_adjoint) const override
Definition expression.hpp:575
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:589
double value(double lhs, double rhs) const override
Definition expression.hpp:571
constexpr DivExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:566
ExpressionType type() const override
Definition expression.hpp:573
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:583
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:579
Definition expression.hpp:1108
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1126
double value(double x, double) const override
Definition expression.hpp:1118
ExpressionType type() const override
Definition expression.hpp:1120
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1122
ErfExpression(ExpressionPtr lhs)
Definition expression.hpp:1114
Definition expression.hpp:1161
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1175
ExpExpression(ExpressionPtr lhs)
Definition expression.hpp:1167
ExpressionType type() const override
Definition expression.hpp:1173
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1179
double value(double x, double) const override
Definition expression.hpp:1171
Definition expression.hpp:76
int32_t col
This expression's column in a Jacobian, or -1 otherwise.
Definition expression.hpp:87
virtual double grad_l(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:373
constexpr bool is_constant(double constant) const
Definition expression.hpp:137
constexpr Expression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:125
constexpr Expression(double value)
Definition expression.hpp:109
virtual ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:416
virtual ExpressionType type() const =0
ExpressionPtr adjoint_expr
Definition expression.hpp:91
friend ExpressionPtr operator+(const ExpressionPtr &lhs)
Definition expression.hpp:343
constexpr Expression(ExpressionPtr lhs)
Definition expression.hpp:116
double adjoint
The adjoint of the expression node used during autodiff.
Definition expression.hpp:81
friend ExpressionPtr operator*(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:147
uint32_t incoming_edges
Counts incoming edges for this node.
Definition expression.hpp:84
friend ExpressionPtr operator+=(ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:268
friend ExpressionPtr operator-(const ExpressionPtr &lhs)
Definition expression.hpp:315
friend ExpressionPtr operator-(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:279
friend ExpressionPtr operator/(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:199
virtual double grad_r(double lhs, double rhs, double parent_adjoint) const
Definition expression.hpp:387
virtual double value(double lhs, double rhs) const =0
virtual ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const
Definition expression.hpp:401
std::array< ExpressionPtr, 2 > args
Expression arguments.
Definition expression.hpp:97
double val
The value of the expression node.
Definition expression.hpp:78
constexpr Expression()=default
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:94
friend ExpressionPtr operator+(const ExpressionPtr &lhs, const ExpressionPtr &rhs)
Definition expression.hpp:236
Definition expression.hpp:1212
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1236
double grad_r(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1232
double grad_l(double x, double y, double parent_adjoint) const override
Definition expression.hpp:1228
ExpressionPtr grad_expr_r(const ExpressionPtr &x, const ExpressionPtr &y, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1242
HypotExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1219
double value(double x, double y) const override
Definition expression.hpp:1224
ExpressionType type() const override
Definition expression.hpp:1226
Definition expression.hpp:1326
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1340
double value(double x, double) const override
Definition expression.hpp:1336
ExpressionType type() const override
Definition expression.hpp:1338
Log10Expression(ExpressionPtr lhs)
Definition expression.hpp:1332
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1344
Definition expression.hpp:1276
LogExpression(ExpressionPtr lhs)
Definition expression.hpp:1282
ExpressionType type() const override
Definition expression.hpp:1288
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1294
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1290
double value(double x, double) const override
Definition expression.hpp:1286
Definition expression.hpp:602
ExpressionPtr grad_expr_l(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:628
constexpr MultExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:609
ExpressionType type() const override
Definition expression.hpp:616
double grad_l(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:618
double value(double lhs, double rhs) const override
Definition expression.hpp:614
ExpressionPtr grad_expr_r(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:634
double grad_r(double lhs, double rhs, double parent_adjoint) const override
Definition expression.hpp:623
Definition expression.hpp:1382
ExpressionPtr grad_expr_l(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1415
double value(double base, double power) const override
Definition expression.hpp:1394
double grad_l(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1400
double grad_r(double base, double power, double parent_adjoint) const override
Definition expression.hpp:1405
ExpressionType type() const override
Definition expression.hpp:1398
ExpressionPtr grad_expr_r(const ExpressionPtr &base, const ExpressionPtr &power, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1424
PowExpression(ExpressionPtr lhs, ExpressionPtr rhs)
Definition expression.hpp:1389
Definition expression.hpp:1484
double grad_l(double, double, double) const override
Definition expression.hpp:1513
ExpressionType type() const override
Definition expression.hpp:1511
constexpr SignExpression(ExpressionPtr lhs)
Definition expression.hpp:1490
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &) const override
Definition expression.hpp:1515
double value(double x, double) const override
Definition expression.hpp:1501
Definition expression.hpp:1548
double value(double x, double) const override
Definition expression.hpp:1558
SinExpression(ExpressionPtr lhs)
Definition expression.hpp:1554
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1562
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1566
ExpressionType type() const override
Definition expression.hpp:1560
Definition expression.hpp:1598
ExpressionType type() const override
Definition expression.hpp:1610
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1612
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1616
SinhExpression(ExpressionPtr lhs)
Definition expression.hpp:1604
double value(double x, double) const override
Definition expression.hpp:1608
Definition expression.hpp:1648
SqrtExpression(ExpressionPtr lhs)
Definition expression.hpp:1654
ExpressionType type() const override
Definition expression.hpp:1660
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1662
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1666
double value(double x, double) const override
Definition expression.hpp:1658
Definition expression.hpp:1700
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1714
double value(double x, double) const override
Definition expression.hpp:1710
TanExpression(ExpressionPtr lhs)
Definition expression.hpp:1706
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1718
ExpressionType type() const override
Definition expression.hpp:1712
Definition expression.hpp:1750
TanhExpression(ExpressionPtr lhs)
Definition expression.hpp:1756
ExpressionType type() const override
Definition expression.hpp:1762
double grad_l(double x, double, double parent_adjoint) const override
Definition expression.hpp:1764
double value(double x, double) const override
Definition expression.hpp:1760
ExpressionPtr grad_expr_l(const ExpressionPtr &x, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:1768
Definition expression.hpp:647
constexpr UnaryMinusExpression(ExpressionPtr lhs)
Definition expression.hpp:653
ExpressionType type() const override
Definition expression.hpp:660
ExpressionPtr grad_expr_l(const ExpressionPtr &, const ExpressionPtr &, const ExpressionPtr &parent_adjoint) const override
Definition expression.hpp:666
double value(double lhs, double) const override
Definition expression.hpp:658
double grad_l(double, double, double parent_adjoint) const override
Definition expression.hpp:662