Sleipnir C++ API
Loading...
Searching...
No Matches
jacobian.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <utility>
6
7#include <Eigen/SparseCore>
8
9#include "sleipnir/autodiff/adjoint_expression_graph.hpp"
10#include "sleipnir/autodiff/variable.hpp"
11#include "sleipnir/autodiff/variable_matrix.hpp"
12#include "sleipnir/util/concepts.hpp"
13#include "sleipnir/util/small_vector.hpp"
14#include "sleipnir/util/symbol_exports.hpp"
15
16namespace slp {
17
25class SLEIPNIR_DLLEXPORT Jacobian {
26 public:
33 Jacobian(Variable variable, Variable wrt) noexcept
34 : Jacobian{VariableMatrix{std::move(variable)},
35 VariableMatrix{std::move(wrt)}} {}
36
44 Jacobian(VariableMatrix variables, SleipnirMatrixLike auto wrt) noexcept
45 : m_variables{std::move(variables)}, m_wrt{std::move(wrt)} {
46 // Initialize column each expression's adjoint occupies in the Jacobian
47 for (size_t col = 0; col < m_wrt.size(); ++col) {
48 m_wrt[col].expr->col = col;
49 }
50
51 for (auto& variable : m_variables) {
52 m_graphs.emplace_back(variable);
53 }
54
55 // Reset col to -1
56 for (auto& node : m_wrt) {
57 node.expr->col = -1;
58 }
59
60 for (int row = 0; row < m_variables.rows(); ++row) {
61 if (m_variables[row].expr == nullptr) {
62 continue;
63 }
64
65 if (m_variables[row].type() == ExpressionType::LINEAR) {
66 // If the row is linear, compute its gradient once here and cache its
67 // triplets. Constant rows are ignored because their gradients have no
68 // nonzero triplets.
69 m_graphs[row].append_adjoint_triplets(m_cached_triplets, row, m_wrt);
70 } else if (m_variables[row].type() > ExpressionType::LINEAR) {
71 // If the row is quadratic or nonlinear, add it to the list of nonlinear
72 // rows to be recomputed in Value().
73 m_nonlinear_rows.emplace_back(row);
74 }
75 }
76
77 if (m_nonlinear_rows.empty()) {
78 m_J.setFromTriplets(m_cached_triplets.begin(), m_cached_triplets.end());
79 }
80 }
81
91 VariableMatrix result{VariableMatrix::empty, m_variables.rows(),
92 m_wrt.rows()};
93
94 for (int row = 0; row < m_variables.rows(); ++row) {
95 auto grad = m_graphs[row].generate_gradient_tree(m_wrt);
96 for (int col = 0; col < m_wrt.rows(); ++col) {
97 if (grad[col].expr != nullptr) {
98 result[row, col] = std::move(grad[col]);
99 } else {
100 result[row, col] = Variable{0.0};
101 }
102 }
103 }
104
105 return result;
106 }
107
113 const Eigen::SparseMatrix<double>& value() {
114 if (m_nonlinear_rows.empty()) {
115 return m_J;
116 }
117
118 for (auto& graph : m_graphs) {
119 graph.update_values();
120 }
121
122 // Copy the cached triplets so triplets added for the nonlinear rows are
123 // thrown away at the end of the function
124 auto triplets = m_cached_triplets;
125
126 // Compute each nonlinear row of the Jacobian
127 for (int row : m_nonlinear_rows) {
128 m_graphs[row].append_adjoint_triplets(triplets, row, m_wrt);
129 }
130
131 if (!triplets.empty()) {
132 m_J.setFromTriplets(triplets.begin(), triplets.end());
133 } else {
134 // setFromTriplets() is a no-op on empty triplets, so explicitly zero out
135 // the storage
136 m_J.setZero();
137 }
138
139 return m_J;
140 }
141
142 private:
143 VariableMatrix m_variables;
144 VariableMatrix m_wrt;
145
146 small_vector<detail::AdjointExpressionGraph> m_graphs;
147
148 Eigen::SparseMatrix<double> m_J{m_variables.rows(), m_wrt.rows()};
149
150 // Cached triplets for gradients of linear rows
151 small_vector<Eigen::Triplet<double>> m_cached_triplets;
152
153 // List of row indices for nonlinear rows whose graients will be computed in
154 // Value()
155 small_vector<int> m_nonlinear_rows;
156};
157
158} // namespace slp
Definition jacobian.hpp:25
Jacobian(Variable variable, Variable wrt) noexcept
Definition jacobian.hpp:33
VariableMatrix get() const
Definition jacobian.hpp:90
const Eigen::SparseMatrix< double > & value()
Definition jacobian.hpp:113
Jacobian(VariableMatrix variables, SleipnirMatrixLike auto wrt) noexcept
Definition jacobian.hpp:44
Definition variable_matrix.hpp:29
int rows() const
Definition variable_matrix.hpp:951
Definition variable.hpp:40
Definition concepts.hpp:30