Sleipnir C++ API
Loading...
Searching...
No Matches
jacobian.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <utility>
6
7#include <Eigen/SparseCore>
8#include <gch/small_vector.hpp>
9
10#include "sleipnir/autodiff/adjoint_expression_graph.hpp"
11#include "sleipnir/autodiff/variable.hpp"
12#include "sleipnir/autodiff/variable_matrix.hpp"
13#include "sleipnir/util/assert.hpp"
14#include "sleipnir/util/concepts.hpp"
15#include "sleipnir/util/empty.hpp"
16#include "sleipnir/util/symbol_exports.hpp"
17
18namespace slp {
19
29template <typename Scalar>
30class Jacobian {
31 public:
41
51
61 : m_variables{std::move(variables)}, m_wrt{std::move(wrt)} {
62 slp_assert(m_variables.cols() == 1);
63 slp_assert(m_wrt.cols() == 1);
64
65 // Initialize column each expression's adjoint occupies in the Jacobian
66 for (size_t col = 0; col < m_wrt.size(); ++col) {
67 m_wrt[col].expr->col = col;
68 }
69
70 for (auto& variable : m_variables) {
71 m_graphs.emplace_back(variable);
72 }
73
74 // Reset col to -1
75 for (auto& node : m_wrt) {
76 node.expr->col = -1;
77 }
78
79 for (int row = 0; row < m_variables.rows(); ++row) {
80 if (m_variables[row].expr == nullptr) {
81 continue;
82 }
83
84 if (m_variables[row].type() == ExpressionType::LINEAR) {
85 // If the row is linear, compute its gradient once here and cache its
86 // triplets. Constant rows are ignored because their gradients have no
87 // nonzero triplets.
88 m_graphs[row].append_gradient_triplets(m_cached_triplets, row, m_wrt);
89 } else if (m_variables[row].type() > ExpressionType::LINEAR) {
90 // If the row is quadratic or nonlinear, add it to the list of nonlinear
91 // rows to be recomputed in Value().
92 m_nonlinear_rows.emplace_back(row);
93 }
94 }
95
96 if (m_nonlinear_rows.empty()) {
97 m_J.setFromTriplets(m_cached_triplets.begin(), m_cached_triplets.end());
98 }
99 }
100
110 VariableMatrix<Scalar> result{detail::empty, m_variables.rows(),
111 m_wrt.rows()};
112
113 for (int row = 0; row < m_variables.rows(); ++row) {
114 auto grad = m_graphs[row].generate_gradient_tree(m_wrt);
115 for (int col = 0; col < m_wrt.rows(); ++col) {
116 if (grad[col].expr != nullptr) {
117 result[row, col] = std::move(grad[col]);
118 } else {
119 result[row, col] = Variable{Scalar(0)};
120 }
121 }
122 }
123
124 return result;
125 }
126
132 const Eigen::SparseMatrix<Scalar>& value() {
133 if (m_nonlinear_rows.empty()) {
134 return m_J;
135 }
136
137 for (auto& graph : m_graphs) {
138 graph.update_values();
139 }
140
141 // Copy the cached triplets so triplets added for the nonlinear rows are
142 // thrown away at the end of the function
143 auto triplets = m_cached_triplets;
144
145 // Compute each nonlinear row of the Jacobian
146 for (int row : m_nonlinear_rows) {
147 m_graphs[row].append_gradient_triplets(triplets, row, m_wrt);
148 }
149
150 m_J.setFromTriplets(triplets.begin(), triplets.end());
151
152 return m_J;
153 }
154
155 private:
156 VariableMatrix<Scalar> m_variables;
158
159 gch::small_vector<detail::AdjointExpressionGraph<Scalar>> m_graphs;
160
161 Eigen::SparseMatrix<Scalar> m_J{m_variables.rows(), m_wrt.rows()};
162
163 // Cached triplets for gradients of linear rows
164 gch::small_vector<Eigen::Triplet<Scalar>> m_cached_triplets;
165
166 // List of row indices for nonlinear rows whose graients will be computed in
167 // Value()
168 gch::small_vector<int> m_nonlinear_rows;
169};
170
171extern template class EXPORT_TEMPLATE_DECLARE(
172 SLEIPNIR_DLLEXPORT) Jacobian<double>;
173
174} // namespace slp
Definition intrusive_shared_ptr.hpp:29
Definition jacobian.hpp:30
Jacobian(Variable< Scalar > variable, Variable< Scalar > wrt)
Definition jacobian.hpp:38
Jacobian(Variable< Scalar > variable, SleipnirMatrixLike< Scalar > auto wrt)
Definition jacobian.hpp:49
const Eigen::SparseMatrix< Scalar > & value()
Definition jacobian.hpp:132
VariableMatrix< Scalar > get() const
Definition jacobian.hpp:109
Jacobian(VariableMatrix< Scalar > variables, SleipnirMatrixLike< Scalar > auto wrt)
Definition jacobian.hpp:59
Definition variable_matrix.hpp:35
Definition variable.hpp:49
Definition concepts.hpp:33