Sleipnir C++ API
Loading...
Searching...
No Matches
jacobian.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <utility>
6
7#include <Eigen/SparseCore>
8#include <gch/small_vector.hpp>
9
10#include "sleipnir/autodiff/adjoint_expression_graph.hpp"
11#include "sleipnir/autodiff/variable.hpp"
12#include "sleipnir/autodiff/variable_matrix.hpp"
13#include "sleipnir/util/assert.hpp"
14#include "sleipnir/util/concepts.hpp"
15#include "sleipnir/util/symbol_exports.hpp"
16
17namespace slp {
18
26class SLEIPNIR_DLLEXPORT Jacobian {
27 public:
34 Jacobian(Variable variable, Variable wrt)
35 : Jacobian{VariableMatrix{std::move(variable)},
36 VariableMatrix{std::move(wrt)}} {}
37
46 : Jacobian{VariableMatrix{std::move(variable)}, std::move(wrt)} {}
47
56 : m_variables{std::move(variables)}, m_wrt{std::move(wrt)} {
57 slp_assert(m_variables.cols() == 1);
58 slp_assert(m_wrt.cols() == 1);
59
60 // Initialize column each expression's adjoint occupies in the Jacobian
61 for (size_t col = 0; col < m_wrt.size(); ++col) {
62 m_wrt[col].expr->col = col;
63 }
64
65 for (auto& variable : m_variables) {
66 m_graphs.emplace_back(variable);
67 }
68
69 // Reset col to -1
70 for (auto& node : m_wrt) {
71 node.expr->col = -1;
72 }
73
74 for (int row = 0; row < m_variables.rows(); ++row) {
75 if (m_variables[row].expr == nullptr) {
76 continue;
77 }
78
79 if (m_variables[row].type() == ExpressionType::LINEAR) {
80 // If the row is linear, compute its gradient once here and cache its
81 // triplets. Constant rows are ignored because their gradients have no
82 // nonzero triplets.
83 m_graphs[row].append_gradient_triplets(m_cached_triplets, row, m_wrt);
84 } else if (m_variables[row].type() > ExpressionType::LINEAR) {
85 // If the row is quadratic or nonlinear, add it to the list of nonlinear
86 // rows to be recomputed in Value().
87 m_nonlinear_rows.emplace_back(row);
88 }
89 }
90
91 if (m_nonlinear_rows.empty()) {
92 m_J.setFromTriplets(m_cached_triplets.begin(), m_cached_triplets.end());
93 }
94 }
95
105 VariableMatrix result{VariableMatrix::empty, m_variables.rows(),
106 m_wrt.rows()};
107
108 for (int row = 0; row < m_variables.rows(); ++row) {
109 auto grad = m_graphs[row].generate_gradient_tree(m_wrt);
110 for (int col = 0; col < m_wrt.rows(); ++col) {
111 if (grad[col].expr != nullptr) {
112 result[row, col] = std::move(grad[col]);
113 } else {
114 result[row, col] = Variable{0.0};
115 }
116 }
117 }
118
119 return result;
120 }
121
127 const Eigen::SparseMatrix<double>& value() {
128 if (m_nonlinear_rows.empty()) {
129 return m_J;
130 }
131
132 for (auto& graph : m_graphs) {
133 graph.update_values();
134 }
135
136 // Copy the cached triplets so triplets added for the nonlinear rows are
137 // thrown away at the end of the function
138 auto triplets = m_cached_triplets;
139
140 // Compute each nonlinear row of the Jacobian
141 for (int row : m_nonlinear_rows) {
142 m_graphs[row].append_gradient_triplets(triplets, row, m_wrt);
143 }
144
145 m_J.setFromTriplets(triplets.begin(), triplets.end());
146
147 return m_J;
148 }
149
150 private:
151 VariableMatrix m_variables;
152 VariableMatrix m_wrt;
153
154 gch::small_vector<detail::AdjointExpressionGraph> m_graphs;
155
156 Eigen::SparseMatrix<double> m_J{m_variables.rows(), m_wrt.rows()};
157
158 // Cached triplets for gradients of linear rows
159 gch::small_vector<Eigen::Triplet<double>> m_cached_triplets;
160
161 // List of row indices for nonlinear rows whose graients will be computed in
162 // Value()
163 gch::small_vector<int> m_nonlinear_rows;
164};
165
166} // namespace slp
Definition jacobian.hpp:26
Jacobian(Variable variable, SleipnirMatrixLike auto wrt)
Definition jacobian.hpp:45
VariableMatrix get() const
Definition jacobian.hpp:104
Jacobian(VariableMatrix variables, SleipnirMatrixLike auto wrt)
Definition jacobian.hpp:55
Jacobian(Variable variable, Variable wrt)
Definition jacobian.hpp:34
const Eigen::SparseMatrix< double > & value()
Definition jacobian.hpp:127
Definition variable_matrix.hpp:29
int rows() const
Definition variable_matrix.hpp:949
Definition variable.hpp:40
Definition concepts.hpp:30