Sleipnir C++ API
Loading...
Searching...
No Matches
hessian.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <utility>
6
7#include <Eigen/SparseCore>
8#include <gch/small_vector.hpp>
9
10#include "sleipnir/autodiff/expression_graph.hpp"
11#include "sleipnir/autodiff/variable.hpp"
12#include "sleipnir/autodiff/variable_matrix.hpp"
13#include "sleipnir/util/assert.hpp"
14#include "sleipnir/util/concepts.hpp"
15#include "sleipnir/util/symbol_exports.hpp"
16
17namespace slp {
18
28template <typename Scalar, int UpLo>
29 requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper))
30class Hessian {
31 public:
38
45 : m_variables{detail::gradient_tree(
46 detail::topological_sort(variable.expr), wrt)},
47 m_wrt{std::move(wrt)} {
48 slp_assert(wrt.cols() == 1);
49
50 for (auto& variable : m_variables) {
51 m_top_lists.emplace_back(detail::topological_sort(variable.expr));
52 }
53
54 // Initialize column each expression's adjoint occupies in the Jacobian
55 for (size_t col = 0; col < m_wrt.size(); ++col) {
56 m_wrt[col].expr->scratch = col;
57 }
58
59 // Make list of only nodes in output row, and their output columns
60 for (auto& top_list : m_top_lists) {
61 m_output_lists.emplace_back();
62 for (const auto& node : top_list) {
63 if (node->scratch != -1) {
64 m_output_lists.back().emplace_back(node->scratch, node);
65 }
66 }
67 }
68
69 // Reset col to -1
70 for (auto& node : m_wrt) {
71 node.expr->scratch = -1;
72 }
73
74 for (int row = 0; row < m_variables.rows(); ++row) {
75 if (m_variables[row].expr == nullptr) {
76 continue;
77 }
78
79 if (m_variables[row].type() == ExpressionType::LINEAR) {
80 // If the row is linear, compute its gradient once here and cache its
81 // triplets. Constant rows are ignored because their gradients have no
82 // nonzero triplets.
83 detail::append_triplets(m_top_lists[row], m_output_lists[row],
84 m_cached_triplets, row);
85 } else if (m_variables[row].type() > ExpressionType::LINEAR) {
86 // If the row is quadratic or nonlinear, add it to the list of nonlinear
87 // rows to be recomputed in value().
88 m_nonlinear_rows.emplace_back(row);
89 }
90 }
91
92 if (m_nonlinear_rows.empty()) {
93 m_H.setFromTriplets(m_cached_triplets.begin(), m_cached_triplets.end());
94 if constexpr (UpLo == Eigen::Lower) {
95 m_H = m_H.template triangularView<Eigen::Lower>();
96 }
97 }
98 }
99
107 VariableMatrix<Scalar> result{detail::empty, m_variables.rows(),
108 m_wrt.rows()};
109
110 for (int row = 0; row < m_variables.rows(); ++row) {
111 auto grad = detail::gradient_tree(m_top_lists[row], m_wrt);
112 for (int col = 0; col < m_wrt.rows(); ++col) {
113 if (grad[col].expr != nullptr) {
114 result[row, col] = std::move(grad[col]);
115 } else {
116 result[row, col] = Variable{Scalar(0)};
117 }
118 }
119 }
120
121 return result;
122 }
123
127 const Eigen::SparseMatrix<Scalar>& value() {
128 if (m_nonlinear_rows.empty()) {
129 return m_H;
130 }
131
132 for (auto& top_list : m_top_lists) {
133 detail::update_values(top_list);
134 }
135
136 // Copy the cached triplets so triplets added for the nonlinear rows are
137 // thrown away at the end of the function
138 auto triplets = m_cached_triplets;
139
140 // Compute each nonlinear row of the Hessian
141 for (int row : m_nonlinear_rows) {
142 detail::append_triplets(m_top_lists[row], m_output_lists[row], triplets,
143 row);
144 }
145
146 m_H.setFromTriplets(triplets.begin(), triplets.end());
147 if constexpr (UpLo == Eigen::Lower) {
148 m_H = m_H.template triangularView<Eigen::Lower>();
149 }
150
151 return m_H;
152 }
153
154 private:
155 VariableMatrix<Scalar> m_variables;
157
159 gch::small_vector<detail::ExpressionGraph<Scalar>> m_top_lists;
160
162 gch::small_vector<
163 gch::small_vector<std::pair<int, detail::Expression<Scalar>*>>>
164 m_output_lists;
165
166 Eigen::SparseMatrix<Scalar> m_H{m_variables.rows(), m_wrt.rows()};
167
168 // Cached triplets for gradients of linear rows
169 gch::small_vector<Eigen::Triplet<Scalar>> m_cached_triplets;
170
171 // List of row indices for nonlinear rows whose graients will be computed in
172 // value()
173 gch::small_vector<int> m_nonlinear_rows;
174};
175
176// @cond Suppress Doxygen
177extern template class EXPORT_TEMPLATE_DECLARE(SLEIPNIR_DLLEXPORT)
178Hessian<double, Eigen::Lower | Eigen::Upper>;
179// @endcond
180
181} // namespace slp
Definition hessian.hpp:30
VariableMatrix< Scalar > get() const
Definition hessian.hpp:106
const Eigen::SparseMatrix< Scalar > & value()
Definition hessian.hpp:127
Hessian(Variable< Scalar > variable, SleipnirMatrixLike< Scalar > auto wrt)
Definition hessian.hpp:44
Hessian(Variable< Scalar > variable, Variable< Scalar > wrt)
Definition hessian.hpp:36
Definition intrusive_shared_ptr.hpp:27
Definition variable_matrix.hpp:33
Definition variable.hpp:52
Definition concepts.hpp:33