Sleipnir C++ API
Loading...
Searching...
No Matches
adjoint_expression_graph.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6#include <utility>
7
8#include <Eigen/SparseCore>
9
10#include "sleipnir/autodiff/expression_graph.hpp"
11#include "sleipnir/autodiff/variable.hpp"
12#include "sleipnir/autodiff/variable_matrix.hpp"
13#include "sleipnir/util/small_vector.hpp"
14
15namespace slp::detail {
16
22 public:
28 explicit AdjointExpressionGraph(const Variable& root)
29 : m_top_list{topological_sort(root.expr)} {
30 for (const auto& node : m_top_list) {
31 m_col_list.emplace_back(node->col);
32 }
33 }
34
39 void update_values() { detail::update_values(m_top_list); }
40
53 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
54 // for background on reverse accumulation automatic differentiation.
55
56 if (m_top_list.empty()) {
58 }
59
60 // Set root node's adjoint to 1 since df/df is 1
61 m_top_list[0]->adjoint_expr = make_expression_ptr<ConstExpression>(1.0);
62
63 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
64 // multiplied by dy/dx. If there are multiple "paths" from the root node to
65 // variable; the variable's adjoint is the sum of each path's adjoint
66 // contribution.
67 for (auto& node : m_top_list) {
68 auto& lhs = node->args[0];
69 auto& rhs = node->args[1];
70
71 if (lhs != nullptr) {
72 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
73 if (rhs != nullptr) {
74 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
75 }
76 }
77 }
78
79 // Move gradient tree to return value
81 for (int row = 0; row < grad.rows(); ++row) {
82 grad[row] = Variable{std::move(wrt[row].expr->adjoint_expr)};
83 }
84
85 // Unlink adjoints to avoid circular references between them and their
86 // parent expressions. This ensures all expressions are returned to the free
87 // list.
88 for (auto& node : m_top_list) {
89 node->adjoint_expr = nullptr;
90 }
91
92 return grad;
93 }
94
104 void append_adjoint_triplets(small_vector<Eigen::Triplet<double>>& triplets,
105 int row, const VariableMatrix& wrt) const {
106 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
107 // for background on reverse accumulation automatic differentiation.
108
109 // If wrt has fewer nodes than graph, zero wrt's adjoints
110 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
111 for (const auto& elem : wrt) {
112 elem.expr->adjoint = 0.0;
113 }
114 }
115
116 if (m_top_list.empty()) {
117 return;
118 }
119
120 // Set root node's adjoint to 1 since df/df is 1
121 m_top_list[0]->adjoint = 1.0;
122
123 // Zero the rest of the adjoints
124 for (auto& node : m_top_list | std::views::drop(1)) {
125 node->adjoint = 0.0;
126 }
127
128 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
129 // multiplied by dy/dx. If there are multiple "paths" from the root node to
130 // variable; the variable's adjoint is the sum of each path's adjoint
131 // contribution.
132 for (const auto& node : m_top_list) {
133 auto& lhs = node->args[0];
134 auto& rhs = node->args[1];
135
136 if (lhs != nullptr) {
137 if (rhs != nullptr) {
138 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
139 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
140 } else {
141 lhs->adjoint += node->grad_l(lhs->val, 0.0, node->adjoint);
142 }
143 }
144 }
145
146 // If wrt has fewer nodes than graph, iterate over wrt
147 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
148 for (int col = 0; col < wrt.rows(); ++col) {
149 const auto& node = wrt[col].expr;
150
151 // Append adjoints of wrt to sparse matrix triplets
152 if (node->adjoint != 0.0) {
153 triplets.emplace_back(row, col, node->adjoint);
154 }
155 }
156 } else {
157 for (const auto& [col, node] : std::views::zip(m_col_list, m_top_list)) {
158 // Append adjoints of wrt to sparse matrix triplets
159 if (col != -1 && node->adjoint != 0.0) {
160 triplets.emplace_back(row, col, node->adjoint);
161 }
162 }
163 }
164 }
165
166 private:
167 // Topological sort of graph from parent to child
168 small_vector<Expression*> m_top_list;
169
170 // List that maps nodes to their respective column
171 small_vector<int> m_col_list;
172};
173
174} // namespace slp::detail
Definition variable_matrix.hpp:29
int rows() const
Definition variable_matrix.hpp:907
static constexpr empty_t empty
Definition variable_matrix.hpp:39
Definition variable.hpp:41
Definition adjoint_expression_graph.hpp:21
VariableMatrix generate_gradient_tree(const VariableMatrix &wrt) const
Definition adjoint_expression_graph.hpp:52
void update_values()
Definition adjoint_expression_graph.hpp:39
void append_adjoint_triplets(small_vector< Eigen::Triplet< double > > &triplets, int row, const VariableMatrix &wrt) const
Definition adjoint_expression_graph.hpp:104
AdjointExpressionGraph(const Variable &root)
Definition adjoint_expression_graph.hpp:28