Sleipnir C++ API
Loading...
Searching...
No Matches
hessian.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <utility>
6
7#include <Eigen/SparseCore>
8#include <gch/small_vector.hpp>
9
10#include "sleipnir/autodiff/gradient_expression_graph.hpp"
11#include "sleipnir/autodiff/variable.hpp"
12#include "sleipnir/autodiff/variable_matrix.hpp"
13#include "sleipnir/util/assert.hpp"
14#include "sleipnir/util/concepts.hpp"
15
16namespace slp {
17
27template <typename Scalar, int UpLo>
28 requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper))
29class Hessian {
30 public:
37
44 : m_variables{
45 detail::GradientExpressionGraph<Scalar>{variable}.generate_tree(
46 wrt)},
47 m_wrt{wrt} {
48 slp_assert(m_wrt.cols() == 1);
49
50 // Initialize column each expression's adjoint occupies in the Jacobian
51 for (size_t col = 0; col < m_wrt.size(); ++col) {
52 m_wrt[col].expr->col = col;
53 }
54
55 for (auto& variable : m_variables) {
56 m_graphs.emplace_back(variable);
57 }
58
59 // Reset col to -1
60 for (auto& node : m_wrt) {
61 node.expr->col = -1;
62 }
63
64 for (int row = 0; row < m_variables.rows(); ++row) {
65 if (m_variables[row].expr == nullptr) {
66 continue;
67 }
68
69 if (m_variables[row].type() == ExpressionType::LINEAR) {
70 // If the row is linear, compute its gradient once here and cache its
71 // triplets. Constant rows are ignored because their gradients have no
72 // nonzero triplets.
73 m_graphs[row].append_triplets(m_cached_triplets, row, m_wrt);
74 } else if (m_variables[row].type() > ExpressionType::LINEAR) {
75 // If the row is quadratic or nonlinear, add it to the list of nonlinear
76 // rows to be recomputed in value().
77 m_nonlinear_rows.emplace_back(row);
78 }
79 }
80
81 if (m_nonlinear_rows.empty()) {
82 m_H.setFromTriplets(m_cached_triplets.begin(), m_cached_triplets.end());
83 if constexpr (UpLo == Eigen::Lower) {
84 m_H = m_H.template triangularView<Eigen::Lower>();
85 }
86 }
87 }
88
96 VariableMatrix<Scalar> result{detail::empty, m_variables.rows(),
97 m_wrt.rows()};
98
99 for (int row = 0; row < m_variables.rows(); ++row) {
100 auto grad = m_graphs[row].generate_tree(m_wrt);
101 for (int col = 0; col < m_wrt.rows(); ++col) {
102 if (grad[col].expr != nullptr) {
103 result[row, col] = std::move(grad[col]);
104 } else {
105 result[row, col] = Variable{Scalar(0)};
106 }
107 }
108 }
109
110 return result;
111 }
112
116 const Eigen::SparseMatrix<Scalar>& value() {
117 if (m_nonlinear_rows.empty()) {
118 return m_H;
119 }
120
121 for (auto& graph : m_graphs) {
122 graph.update_values();
123 }
124
125 // Copy the cached triplets so triplets added for the nonlinear rows are
126 // thrown away at the end of the function
127 auto triplets = m_cached_triplets;
128
129 // Compute each nonlinear row of the Hessian
130 for (int row : m_nonlinear_rows) {
131 m_graphs[row].append_triplets(triplets, row, m_wrt);
132 }
133
134 m_H.setFromTriplets(triplets.begin(), triplets.end());
135 if constexpr (UpLo == Eigen::Lower) {
136 m_H = m_H.template triangularView<Eigen::Lower>();
137 }
138
139 return m_H;
140 }
141
142 private:
143 VariableMatrix<Scalar> m_variables;
145
146 gch::small_vector<detail::GradientExpressionGraph<Scalar>> m_graphs;
147
148 Eigen::SparseMatrix<Scalar> m_H{m_variables.rows(), m_wrt.rows()};
149
150 // Cached triplets for gradients of linear rows
151 gch::small_vector<Eigen::Triplet<Scalar>> m_cached_triplets;
152
153 // List of row indices for nonlinear rows whose graients will be computed in
154 // value()
155 gch::small_vector<int> m_nonlinear_rows;
156};
157
158// @cond Suppress Doxygen
159extern template class EXPORT_TEMPLATE_DECLARE(SLEIPNIR_DLLEXPORT)
160Hessian<double, Eigen::Lower | Eigen::Upper>;
161// @endcond
162
163} // namespace slp
Definition hessian.hpp:29
VariableMatrix< Scalar > get() const
Definition hessian.hpp:95
const Eigen::SparseMatrix< Scalar > & value()
Definition hessian.hpp:116
Hessian(Variable< Scalar > variable, SleipnirMatrixLike< Scalar > auto wrt)
Definition hessian.hpp:43
Hessian(Variable< Scalar > variable, Variable< Scalar > wrt)
Definition hessian.hpp:35
Definition intrusive_shared_ptr.hpp:27
Definition variable_matrix.hpp:33
Definition variable.hpp:47
Definition concepts.hpp:33