15#include <gch/small_vector.hpp>
17#include "sleipnir/autodiff/expression_type.hpp"
18#include "sleipnir/util/intrusive_shared_ptr.hpp"
19#include "sleipnir/util/pool.hpp"
21namespace slp::detail {
26inline constexpr bool USE_POOL_ALLOCATOR =
false;
28inline constexpr bool USE_POOL_ALLOCATOR =
true;
31template <
typename Scalar>
34template <
typename Scalar>
35constexpr void inc_ref_count(Expression<Scalar>* expr);
36template <
typename Scalar>
37constexpr void dec_ref_count(Expression<Scalar>* expr);
42template <
typename Scalar>
43using ExpressionPtr = IntrusiveSharedPtr<Expression<Scalar>>;
50template <
typename T,
typename... Args>
51static ExpressionPtr<typename T::Scalar> make_expression_ptr(Args&&... args) {
52 if constexpr (USE_POOL_ALLOCATOR) {
53 return allocate_intrusive_shared<T>(global_pool_allocator<T>(),
54 std::forward<Args>(args)...);
56 return make_intrusive_shared<T>(std::forward<Args>(args)...);
60template <
typename Scalar, ExpressionType T>
61struct BinaryMinusExpression;
63template <
typename Scalar, ExpressionType T>
64struct BinaryPlusExpression;
66template <
typename Scalar>
67struct ConstantExpression;
69template <
typename Scalar, ExpressionType T>
72template <
typename Scalar, ExpressionType T>
75template <
typename Scalar, ExpressionType T>
76struct UnaryMinusExpression;
82template <
typename Scalar>
83ExpressionPtr<Scalar> constant_ptr(Scalar value);
88template <
typename Scalar_>
104 std::array<ExpressionPtr<Scalar>, 2>
args{
nullptr,
nullptr};
148 return type() == ExpressionType::CONSTANT &&
val == constant;
157 using enum ExpressionType;
163 }
else if (
rhs->is_constant(
Scalar(0))) {
166 }
else if (
lhs->is_constant(
Scalar(1))) {
169 }
else if (
rhs->is_constant(
Scalar(1))) {
175 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
176 return constant_ptr(
lhs->val *
rhs->val);
180 if (
lhs->type() == CONSTANT) {
181 if (
rhs->type() == LINEAR) {
183 }
else if (
rhs->type() == QUADRATIC) {
188 }
else if (
rhs->type() == CONSTANT) {
189 if (
lhs->type() == LINEAR) {
191 }
else if (
lhs->type() == QUADRATIC) {
196 }
else if (
lhs->type() == LINEAR &&
rhs->type() == LINEAR) {
209 using enum ExpressionType;
215 }
else if (
rhs->is_constant(
Scalar(1))) {
221 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
222 return constant_ptr(
lhs->val /
rhs->val);
226 if (
rhs->type() == CONSTANT) {
227 if (
lhs->type() == LINEAR) {
229 }
else if (
lhs->type() == QUADRATIC) {
245 using enum ExpressionType;
252 }
else if (
rhs ==
nullptr ||
rhs->is_constant(
Scalar(0))) {
258 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
259 return constant_ptr(
lhs->val +
rhs->val);
262 auto type = std::max(
lhs->type(),
rhs->type());
263 if (
type == LINEAR) {
266 }
else if (
type == QUADRATIC) {
290 using enum ExpressionType;
301 }
else if (
rhs->is_constant(
Scalar(0))) {
307 if (
lhs->type() == CONSTANT &&
rhs->type() == CONSTANT) {
308 return constant_ptr(
lhs->val -
rhs->val);
311 auto type = std::max(
lhs->type(),
rhs->type());
312 if (
type == LINEAR) {
315 }
else if (
type == QUADRATIC) {
328 using enum ExpressionType;
337 if (
lhs->type() == CONSTANT) {
338 return constant_ptr(-
lhs->val);
341 if (
lhs->type() == LINEAR) {
343 }
else if (
lhs->type() == QUADRATIC) {
371 virtual ExpressionType
type()
const = 0;
376 virtual std::string_view
name()
const = 0;
406 return constant_ptr(
Scalar(0));
417 return constant_ptr(
Scalar(0));
421template <
typename Scalar>
426template <
typename Scalar>
427ExpressionPtr<Scalar> cbrt(
const ExpressionPtr<Scalar>& x);
428template <
typename Scalar>
429ExpressionPtr<Scalar> exp(
const ExpressionPtr<Scalar>& x);
430template <
typename Scalar>
431ExpressionPtr<Scalar> sin(
const ExpressionPtr<Scalar>& x);
432template <
typename Scalar>
433ExpressionPtr<Scalar> sinh(
const ExpressionPtr<Scalar>& x);
434template <
typename Scalar>
435ExpressionPtr<Scalar> sqrt(
const ExpressionPtr<Scalar>& x);
441template <
typename Scalar, ExpressionType T>
453 ExpressionType
type()
const override {
return T; }
455 std::string_view
name()
const override {
return "binary minus"; }
478template <
typename Scalar, ExpressionType T>
490 ExpressionType
type()
const override {
return T; }
492 std::string_view
name()
const override {
return "binary plus"; }
514template <
typename Scalar>
527 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
529 std::string_view
name()
const override {
return "cbrt"; }
550template <
typename Scalar>
552 using enum ExpressionType;
556 if (x->type() == CONSTANT) {
557 if (x->val == Scalar(0)) {
560 }
else if (x->val == Scalar(-1) || x->val == Scalar(1)) {
563 return constant_ptr(cbrt(x->val));
567 return make_expression_ptr<CbrtExpression<Scalar>>(x);
573template <
typename Scalar>
583 ExpressionType
type()
const override {
return ExpressionType::CONSTANT; }
585 std::string_view
name()
const override {
return "constant"; }
591template <
typename Scalar>
604 ExpressionType
type()
const override {
return ExpressionType::LINEAR; }
606 std::string_view
name()
const override {
return "decision variable"; }
613template <
typename Scalar, ExpressionType T>
624 ExpressionType
type()
const override {
return T; }
626 std::string_view
name()
const override {
return "division"; }
653template <
typename Scalar, ExpressionType T>
664 ExpressionType
type()
const override {
return T; }
666 std::string_view
name()
const override {
return "multiplication"; }
693template <
typename Scalar, ExpressionType T>
703 ExpressionType
type()
const override {
return T; }
705 std::string_view
name()
const override {
return "unary minus"; }
720template <
typename Scalar>
729template <
typename Scalar>
730constexpr void dec_ref_count(Expression<Scalar>* expr) {
735 gch::small_vector<Expression<Scalar>*> stack;
736 stack.emplace_back(expr);
738 while (!stack.empty()) {
739 auto elem = stack.back();
744 if (--elem->ref_count == 0) {
745 if (elem->adjoint_expr !=
nullptr) {
746 stack.emplace_back(elem->adjoint_expr.get());
748 for (
auto& arg : elem->args) {
749 if (arg !=
nullptr) {
750 stack.emplace_back(arg.get());
756 if constexpr (USE_POOL_ALLOCATOR) {
757 auto alloc = global_pool_allocator<Expression<Scalar>>();
758 std::allocator_traits<
decltype(alloc)>::deallocate(
759 alloc, elem,
sizeof(Expression<Scalar>));
761 operator delete(elem);
770template <
typename Scalar>
783 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
785 std::string_view
name()
const override {
return "abs"; }
790 }
else if (x >
Scalar(0)) {
802 }
else if (x->val >
Scalar(0)) {
805 return constant_ptr(
Scalar(0));
814template <
typename Scalar>
816 using enum ExpressionType;
820 if (x->is_constant(Scalar(0))) {
826 if (x->type() == CONSTANT) {
827 return constant_ptr(abs(x->val));
830 return make_expression_ptr<AbsExpression<Scalar>>(x);
836template <
typename Scalar>
849 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
851 std::string_view
name()
const override {
return "acos"; }
869template <
typename Scalar>
871 using enum ExpressionType;
875 if (x->is_constant(Scalar(0))) {
876 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
880 if (x->type() == CONSTANT) {
881 return constant_ptr(acos(x->val));
884 return make_expression_ptr<AcosExpression<Scalar>>(x);
890template <
typename Scalar>
903 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
905 std::string_view
name()
const override {
return "asin"; }
923template <
typename Scalar>
925 using enum ExpressionType;
929 if (x->is_constant(Scalar(0))) {
935 if (x->type() == CONSTANT) {
936 return constant_ptr(asin(x->val));
939 return make_expression_ptr<AsinExpression<Scalar>>(x);
945template <
typename Scalar>
958 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
960 std::string_view
name()
const override {
return "atan"; }
977template <
typename Scalar>
979 using enum ExpressionType;
983 if (x->is_constant(Scalar(0))) {
989 if (x->type() == CONSTANT) {
990 return constant_ptr(atan(x->val));
993 return make_expression_ptr<AtanExpression<Scalar>>(x);
999template <
typename Scalar>
1014 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1016 std::string_view
name()
const override {
return "atan2"; }
1019 return this->
adjoint * x / (y * y + x * x);
1023 return this->
adjoint * -y / (y * y + x * x);
1044template <
typename Scalar>
1047 using enum ExpressionType;
1051 if (y->is_constant(Scalar(0))) {
1054 }
else if (x->is_constant(Scalar(0))) {
1055 return constant_ptr(Scalar(std::numbers::pi) / Scalar(2));
1059 if (y->type() == CONSTANT && x->type() == CONSTANT) {
1060 return constant_ptr(atan2(y->val, x->val));
1063 return make_expression_ptr<Atan2Expression<Scalar>>(y, x);
1069template <
typename Scalar>
1082 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1084 std::string_view
name()
const override {
return "cos"; }
1088 return this->
adjoint * -sin(x);
1102template <
typename Scalar>
1104 using enum ExpressionType;
1108 if (x->is_constant(Scalar(0))) {
1109 return constant_ptr(Scalar(1));
1113 if (x->type() == CONSTANT) {
1114 return constant_ptr(cos(x->val));
1117 return make_expression_ptr<CosExpression<Scalar>>(x);
1123template <
typename Scalar>
1136 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1138 std::string_view
name()
const override {
return "cosh"; }
1142 return this->
adjoint * sinh(x);
1156template <
typename Scalar>
1158 using enum ExpressionType;
1162 if (x->is_constant(Scalar(0))) {
1163 return constant_ptr(Scalar(1));
1167 if (x->type() == CONSTANT) {
1168 return constant_ptr(cosh(x->val));
1171 return make_expression_ptr<CoshExpression<Scalar>>(x);
1177template <
typename Scalar>
1190 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1192 std::string_view
name()
const override {
return "erf"; }
1196 return this->
adjoint *
Scalar(2.0 * std::numbers::inv_sqrtpi) * exp(-x * x);
1203 constant_ptr(
Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
1211template <
typename Scalar>
1213 using enum ExpressionType;
1217 if (x->is_constant(Scalar(0))) {
1223 if (x->type() == CONSTANT) {
1224 return constant_ptr(erf(x->val));
1227 return make_expression_ptr<ErfExpression<Scalar>>(x);
1233template <
typename Scalar>
1246 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1248 std::string_view
name()
const override {
return "exp"; }
1252 return this->
adjoint * exp(x);
1266template <
typename Scalar>
1268 using enum ExpressionType;
1272 if (x->is_constant(Scalar(0))) {
1273 return constant_ptr(Scalar(1));
1277 if (x->type() == CONSTANT) {
1278 return constant_ptr(exp(x->val));
1281 return make_expression_ptr<ExpExpression<Scalar>>(x);
1284template <
typename Scalar>
1285ExpressionPtr<Scalar> hypot(
const ExpressionPtr<Scalar>& x,
1286 const ExpressionPtr<Scalar>& y);
1291template <
typename Scalar>
1306 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1308 std::string_view
name()
const override {
return "hypot"; }
1312 return this->
adjoint * x / hypot(x, y);
1317 return this->
adjoint * y / hypot(x, y);
1338template <
typename Scalar>
1341 using enum ExpressionType;
1345 if (x->is_constant(Scalar(0))) {
1347 }
else if (y->is_constant(Scalar(0))) {
1352 if (x->type() == CONSTANT && y->type() == CONSTANT) {
1353 return constant_ptr(hypot(x->val, y->val));
1356 return make_expression_ptr<HypotExpression<Scalar>>(x, y);
1362template <
typename Scalar>
1375 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1377 std::string_view
name()
const override {
return "log"; }
1392template <
typename Scalar>
1394 using enum ExpressionType;
1398 if (x->is_constant(Scalar(0))) {
1404 if (x->type() == CONSTANT) {
1405 return constant_ptr(log(x->val));
1408 return make_expression_ptr<LogExpression<Scalar>>(x);
1414template <
typename Scalar>
1427 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1429 std::string_view
name()
const override {
return "log10"; }
1446template <
typename Scalar>
1448 using enum ExpressionType;
1452 if (x->is_constant(Scalar(0))) {
1458 if (x->type() == CONSTANT) {
1459 return constant_ptr(log10(x->val));
1462 return make_expression_ptr<Log10Expression<Scalar>>(x);
1470template <
typename Scalar>
1484 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1486 std::string_view
name()
const override {
return "max"; }
1507 if (
a->val >=
b->val) {
1510 return constant_ptr(
Scalar(0));
1517 if (
b->val >
a->val) {
1520 return constant_ptr(
Scalar(0));
1530template <
typename Scalar>
1533 using enum ExpressionType;
1537 if (
a->type() == CONSTANT &&
b->type() == CONSTANT) {
1538 return constant_ptr(max(
a->val,
b->val));
1541 return make_expression_ptr<MaxExpression<Scalar>>(a, b);
1549template <
typename Scalar>
1563 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1565 std::string_view
name()
const override {
return "min"; }
1587 if (
a->val <=
b->val) {
1590 return constant_ptr(
Scalar(0));
1597 if (
b->val <
a->val) {
1600 return constant_ptr(
Scalar(0));
1610template <
typename Scalar>
1613 using enum ExpressionType;
1617 if (
a->type() == CONSTANT &&
b->type() == CONSTANT) {
1618 return constant_ptr(min(
a->val,
b->val));
1621 return make_expression_ptr<MinExpression<Scalar>>(a, b);
1624template <
typename Scalar>
1625ExpressionPtr<Scalar> pow(
const ExpressionPtr<Scalar>& base,
1626 const ExpressionPtr<Scalar>& power);
1632template <
typename Scalar, ExpressionType T>
1646 ExpressionType
type()
const override {
return T; }
1648 std::string_view
name()
const override {
return "pow"; }
1692template <
typename Scalar>
1695 using enum ExpressionType;
1699 if (
base->is_constant(Scalar(0))) {
1702 }
else if (base->is_constant(Scalar(1))) {
1706 if (power->is_constant(Scalar(0))) {
1707 return constant_ptr(Scalar(1));
1708 }
else if (power->is_constant(Scalar(1))) {
1714 if (base->type() == CONSTANT && power->type() == CONSTANT) {
1715 return constant_ptr(pow(base->val, power->val));
1718 if (power->is_constant(Scalar(2))) {
1719 if (base->type() == LINEAR) {
1720 return make_expression_ptr<MultExpression<Scalar, QUADRATIC>>(base, base);
1722 return make_expression_ptr<MultExpression<Scalar, NONLINEAR>>(base, base);
1726 return make_expression_ptr<PowExpression<Scalar, NONLINEAR>>(base, power);
1732template <
typename Scalar>
1743 }
else if (x ==
Scalar(0)) {
1750 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1752 std::string_view
name()
const override {
return "sign"; }
1759template <
typename Scalar>
1761 using enum ExpressionType;
1764 if (x->type() == CONSTANT) {
1765 if (x->val < Scalar(0)) {
1766 return constant_ptr(Scalar(-1));
1767 }
else if (x->val == Scalar(0)) {
1771 return constant_ptr(Scalar(1));
1775 return make_expression_ptr<SignExpression<Scalar>>(x);
1781template <
typename Scalar>
1794 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1796 std::string_view
name()
const override {
return "sin"; }
1800 return this->
adjoint * cos(x);
1814template <
typename Scalar>
1816 using enum ExpressionType;
1820 if (x->is_constant(Scalar(0))) {
1826 if (x->type() == CONSTANT) {
1827 return constant_ptr(sin(x->val));
1830 return make_expression_ptr<SinExpression<Scalar>>(x);
1836template <
typename Scalar>
1849 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1851 std::string_view
name()
const override {
return "sinh"; }
1855 return this->
adjoint * cosh(x);
1869template <
typename Scalar>
1871 using enum ExpressionType;
1875 if (x->is_constant(Scalar(0))) {
1881 if (x->type() == CONSTANT) {
1882 return constant_ptr(sinh(x->val));
1885 return make_expression_ptr<SinhExpression<Scalar>>(x);
1891template <
typename Scalar>
1904 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1906 std::string_view
name()
const override {
return "sqrt"; }
1924template <
typename Scalar>
1926 using enum ExpressionType;
1930 if (x->type() == CONSTANT) {
1931 if (x->val == Scalar(0)) {
1934 }
else if (x->val == Scalar(1)) {
1937 return constant_ptr(sqrt(x->val));
1941 return make_expression_ptr<SqrtExpression<Scalar>>(x);
1947template <
typename Scalar>
1960 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
1962 std::string_view
name()
const override {
return "tan"; }
1983template <
typename Scalar>
1985 using enum ExpressionType;
1989 if (x->is_constant(Scalar(0))) {
1995 if (x->type() == CONSTANT) {
1996 return constant_ptr(tan(x->val));
1999 return make_expression_ptr<TanExpression<Scalar>>(x);
2005template <
typename Scalar>
2018 ExpressionType
type()
const override {
return ExpressionType::NONLINEAR; }
2020 std::string_view
name()
const override {
return "tanh"; }
2041template <
typename Scalar>
2043 using enum ExpressionType;
2047 if (x->is_constant(Scalar(0))) {
2053 if (x->type() == CONSTANT) {
2054 return constant_ptr(tanh(x->val));
2057 return make_expression_ptr<TanhExpression<Scalar>>(x);
Definition intrusive_shared_ptr.hpp:27
Definition expression.hpp:771
std::string_view name() const override
Definition expression.hpp:785
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:797
constexpr AbsExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:775
ExpressionType type() const override
Definition expression.hpp:783
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:787
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:778
Definition expression.hpp:837
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:844
constexpr AcosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:841
ExpressionType type() const override
Definition expression.hpp:849
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:853
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:858
std::string_view name() const override
Definition expression.hpp:851
Definition expression.hpp:891
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:898
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:907
std::string_view name() const override
Definition expression.hpp:905
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:912
ExpressionType type() const override
Definition expression.hpp:903
constexpr AsinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:895
Definition expression.hpp:1000
std::string_view name() const override
Definition expression.hpp:1016
Scalar value(Scalar y, Scalar x) const override
Definition expression.hpp:1009
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x) const override
Definition expression.hpp:1026
constexpr Atan2Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1005
Scalar grad_l(Scalar y, Scalar x) const override
Definition expression.hpp:1018
Scalar grad_r(Scalar y, Scalar x) const override
Definition expression.hpp:1022
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &y, const ExpressionPtr< Scalar > &x) const override
Definition expression.hpp:1032
ExpressionType type() const override
Definition expression.hpp:1014
Definition expression.hpp:946
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:966
std::string_view name() const override
Definition expression.hpp:960
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:962
constexpr AtanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:950
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:953
ExpressionType type() const override
Definition expression.hpp:958
Definition expression.hpp:442
constexpr BinaryMinusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:447
std::string_view name() const override
Definition expression.hpp:455
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:467
Scalar grad_r(Scalar, Scalar) const override
Definition expression.hpp:459
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:461
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:451
ExpressionType type() const override
Definition expression.hpp:453
Scalar grad_l(Scalar, Scalar) const override
Definition expression.hpp:457
Definition expression.hpp:479
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:488
Scalar grad_r(Scalar, Scalar) const override
Definition expression.hpp:496
Scalar grad_l(Scalar, Scalar) const override
Definition expression.hpp:494
constexpr BinaryPlusExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:484
ExpressionType type() const override
Definition expression.hpp:490
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:504
std::string_view name() const override
Definition expression.hpp:492
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:498
Definition expression.hpp:515
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:538
std::string_view name() const override
Definition expression.hpp:529
constexpr CbrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:519
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:531
ExpressionType type() const override
Definition expression.hpp:527
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:522
Definition expression.hpp:574
ExpressionType type() const override
Definition expression.hpp:583
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:581
std::string_view name() const override
Definition expression.hpp:585
constexpr ConstantExpression(Scalar value)
Definition expression.hpp:578
Definition expression.hpp:1070
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1077
ExpressionType type() const override
Definition expression.hpp:1082
constexpr CosExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1074
std::string_view name() const override
Definition expression.hpp:1084
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1086
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1091
Definition expression.hpp:1124
ExpressionType type() const override
Definition expression.hpp:1136
std::string_view name() const override
Definition expression.hpp:1138
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1131
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1145
constexpr CoshExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1128
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1140
Definition expression.hpp:592
constexpr DecisionVariableExpression()=default
Constructs a decision variable expression with a value of zero.
std::string_view name() const override
Definition expression.hpp:606
Scalar value(Scalar, Scalar) const override
Definition expression.hpp:602
constexpr DecisionVariableExpression(Scalar value)
Definition expression.hpp:599
ExpressionType type() const override
Definition expression.hpp:604
Definition expression.hpp:614
constexpr DivExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:619
ExpressionType type() const override
Definition expression.hpp:624
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const override
Definition expression.hpp:642
std::string_view name() const override
Definition expression.hpp:626
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:622
Scalar grad_l(Scalar, Scalar rhs) const override
Definition expression.hpp:628
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &rhs) const override
Definition expression.hpp:636
Scalar grad_r(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:632
Definition expression.hpp:1178
std::string_view name() const override
Definition expression.hpp:1192
constexpr ErfExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1182
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1194
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1199
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1185
ExpressionType type() const override
Definition expression.hpp:1190
Definition expression.hpp:1234
std::string_view name() const override
Definition expression.hpp:1248
ExpressionType type() const override
Definition expression.hpp:1246
constexpr ExpExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1238
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1250
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1255
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1241
Definition expression.hpp:89
Scalar val
The value of the expression node.
Definition expression.hpp:94
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:327
virtual Scalar grad_r(Scalar lhs, Scalar rhs) const
Definition expression.hpp:393
std::array< ExpressionPtr< Scalar >, 2 > args
Expression arguments.
Definition expression.hpp:104
uint32_t ref_count
Reference count for intrusive shared pointer.
Definition expression.hpp:118
constexpr Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:131
Scalar adjoint
The adjoint of the expression node, used during autodiff.
Definition expression.hpp:97
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:243
int32_t scratch
Definition expression.hpp:115
virtual Scalar grad_l(Scalar lhs, Scalar rhs) const
Definition expression.hpp:383
constexpr bool is_constant(Scalar constant) const
Definition expression.hpp:147
constexpr Expression()=default
Constructs a constant expression with a value of zero.
constexpr Expression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:138
virtual ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const
Definition expression.hpp:403
ExpressionPtr< Scalar > adjoint_expr
Definition expression.hpp:101
Scalar_ Scalar
Scalar type alias.
Definition expression.hpp:91
virtual ExpressionType type() const =0
friend ExpressionPtr< Scalar > operator/(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:207
friend ExpressionPtr< Scalar > operator-(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:288
friend ExpressionPtr< Scalar > operator*(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:155
virtual Scalar value(Scalar lhs, Scalar rhs) const =0
virtual std::string_view name() const =0
virtual ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const
Definition expression.hpp:414
friend ExpressionPtr< Scalar > operator+=(ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs)
Definition expression.hpp:279
constexpr Expression(Scalar value)
Definition expression.hpp:126
friend ExpressionPtr< Scalar > operator+(const ExpressionPtr< Scalar > &lhs)
Definition expression.hpp:353
Definition expression.hpp:1292
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y) const override
Definition expression.hpp:1320
constexpr HypotExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1297
ExpressionType type() const override
Definition expression.hpp:1306
std::string_view name() const override
Definition expression.hpp:1308
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &y) const override
Definition expression.hpp:1326
Scalar value(Scalar x, Scalar y) const override
Definition expression.hpp:1301
Scalar grad_r(Scalar x, Scalar y) const override
Definition expression.hpp:1315
Scalar grad_l(Scalar x, Scalar y) const override
Definition expression.hpp:1310
Definition expression.hpp:1415
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1422
std::string_view name() const override
Definition expression.hpp:1429
constexpr Log10Expression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1419
ExpressionType type() const override
Definition expression.hpp:1427
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1435
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1431
Definition expression.hpp:1363
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1381
ExpressionType type() const override
Definition expression.hpp:1375
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1379
std::string_view name() const override
Definition expression.hpp:1377
constexpr LogExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1367
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1370
Definition expression.hpp:1471
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b) const override
Definition expression.hpp:1504
Scalar value(Scalar a, Scalar b) const override
Definition expression.hpp:1479
Scalar grad_r(Scalar a, Scalar b) const override
Definition expression.hpp:1496
constexpr MaxExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1476
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b) const override
Definition expression.hpp:1514
Scalar grad_l(Scalar a, Scalar b) const override
Definition expression.hpp:1488
ExpressionType type() const override
Definition expression.hpp:1484
std::string_view name() const override
Definition expression.hpp:1486
Definition expression.hpp:1550
std::string_view name() const override
Definition expression.hpp:1565
Scalar value(Scalar a, Scalar b) const override
Definition expression.hpp:1558
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b) const override
Definition expression.hpp:1584
ExpressionType type() const override
Definition expression.hpp:1563
constexpr MinExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1555
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &a, const ExpressionPtr< Scalar > &b) const override
Definition expression.hpp:1594
Scalar grad_l(Scalar a, Scalar b) const override
Definition expression.hpp:1567
Scalar grad_r(Scalar a, Scalar b) const override
Definition expression.hpp:1575
Definition expression.hpp:654
Scalar grad_l(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:668
Scalar value(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:662
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const override
Definition expression.hpp:676
constexpr MultExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:659
ExpressionType type() const override
Definition expression.hpp:664
std::string_view name() const override
Definition expression.hpp:666
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &lhs, const ExpressionPtr< Scalar > &rhs) const override
Definition expression.hpp:682
Scalar grad_r(Scalar lhs, Scalar rhs) const override
Definition expression.hpp:672
Definition expression.hpp:1633
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power) const override
Definition expression.hpp:1667
Scalar value(Scalar base, Scalar power) const override
Definition expression.hpp:1641
ExpressionType type() const override
Definition expression.hpp:1646
Scalar grad_l(Scalar base, Scalar power) const override
Definition expression.hpp:1650
ExpressionPtr< Scalar > grad_expr_r(const ExpressionPtr< Scalar > &base, const ExpressionPtr< Scalar > &power) const override
Definition expression.hpp:1674
Scalar grad_r(Scalar base, Scalar power) const override
Definition expression.hpp:1655
std::string_view name() const override
Definition expression.hpp:1648
constexpr PowExpression(ExpressionPtr< Scalar > lhs, ExpressionPtr< Scalar > rhs)
Definition expression.hpp:1638
Definition expression.hpp:1733
constexpr SignExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1737
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1740
ExpressionType type() const override
Definition expression.hpp:1750
std::string_view name() const override
Definition expression.hpp:1752
Definition expression.hpp:1782
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1789
std::string_view name() const override
Definition expression.hpp:1796
ExpressionType type() const override
Definition expression.hpp:1794
constexpr SinExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1786
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1803
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1798
Definition expression.hpp:1837
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1853
std::string_view name() const override
Definition expression.hpp:1851
ExpressionType type() const override
Definition expression.hpp:1849
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1858
constexpr SinhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1841
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1844
Definition expression.hpp:1892
ExpressionType type() const override
Definition expression.hpp:1904
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1908
std::string_view name() const override
Definition expression.hpp:1906
constexpr SqrtExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1896
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1899
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1913
Definition expression.hpp:1948
std::string_view name() const override
Definition expression.hpp:1962
constexpr TanExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:1952
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:1964
ExpressionType type() const override
Definition expression.hpp:1960
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:1971
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:1955
Definition expression.hpp:2006
Scalar grad_l(Scalar x, Scalar) const override
Definition expression.hpp:2022
ExpressionType type() const override
Definition expression.hpp:2018
std::string_view name() const override
Definition expression.hpp:2020
Scalar value(Scalar x, Scalar) const override
Definition expression.hpp:2013
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &x, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:2029
constexpr TanhExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:2010
Definition expression.hpp:694
Scalar grad_l(Scalar, Scalar) const override
Definition expression.hpp:707
constexpr UnaryMinusExpression(ExpressionPtr< Scalar > lhs)
Definition expression.hpp:698
ExpressionPtr< Scalar > grad_expr_l(const ExpressionPtr< Scalar > &, const ExpressionPtr< Scalar > &) const override
Definition expression.hpp:709
ExpressionType type() const override
Definition expression.hpp:703
Scalar value(Scalar lhs, Scalar) const override
Definition expression.hpp:701
std::string_view name() const override
Definition expression.hpp:705