Sleipnir C++ API
Loading...
Searching...
No Matches
adjoint_expression_graph.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6#include <utility>
7
8#include <Eigen/SparseCore>
9#include <gch/small_vector.hpp>
10
11#include "sleipnir/autodiff/expression_graph.hpp"
12#include "sleipnir/autodiff/variable.hpp"
13#include "sleipnir/autodiff/variable_matrix.hpp"
14#include "sleipnir/util/assert.hpp"
15#include "sleipnir/util/empty.hpp"
16
17namespace slp::detail {
18
25template <typename Scalar>
27 public:
34 : m_top_list{topological_sort(root.expr)} {
35 for (const auto& node : m_top_list) {
36 m_col_list.emplace_back(node->col);
37 }
38 }
39
44 void update_values() { detail::update_values(m_top_list); }
45
58 const VariableMatrix<Scalar>& wrt) const {
59 slp_assert(wrt.cols() == 1);
60
61 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
62 // for background on reverse accumulation automatic differentiation.
63
64 if (m_top_list.empty()) {
65 return VariableMatrix<Scalar>{detail::empty, wrt.rows(), 1};
66 }
67
68 // Set root node's adjoint to 1 since df/df is 1
69 m_top_list[0]->adjoint_expr = constant_ptr(Scalar(1));
70
71 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
72 // multiplied by dy/dx. If there are multiple "paths" from the root node to
73 // variable; the variable's adjoint is the sum of each path's adjoint
74 // contribution.
75 for (auto& node : m_top_list) {
76 auto& lhs = node->args[0];
77 auto& rhs = node->args[1];
78
79 if (lhs != nullptr) {
80 if (rhs != nullptr) {
81 // Binary operator
82 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
83 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
84 } else {
85 // Unary operator
86 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
87 }
88 }
89 }
90
91 // Move gradient tree to return value
92 VariableMatrix<Scalar> grad{detail::empty, wrt.rows(), 1};
93 for (int row = 0; row < grad.rows(); ++row) {
94 grad[row] = Variable{std::move(wrt[row].expr->adjoint_expr)};
95 }
96
97 // Unlink adjoints to avoid circular references between them and their
98 // parent expressions. This ensures all expressions are returned to the free
99 // list.
100 for (auto& node : m_top_list) {
101 node->adjoint_expr = nullptr;
102 }
103
104 return grad;
105 }
106
117 gch::small_vector<Eigen::Triplet<Scalar>>& triplets, int row,
118 const VariableMatrix<Scalar>& wrt) const {
119 slp_assert(wrt.cols() == 1);
120
121 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
122 // for background on reverse accumulation automatic differentiation.
123
124 // If wrt has fewer nodes than graph, zero wrt's adjoints
125 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
126 for (const auto& elem : wrt) {
127 elem.expr->adjoint = Scalar(0);
128 }
129 }
130
131 if (m_top_list.empty()) {
132 return;
133 }
134
135 // Set root node's adjoint to 1 since df/df is 1
136 m_top_list[0]->adjoint = Scalar(1);
137
138 // Zero the rest of the adjoints
139 for (auto& node : m_top_list | std::views::drop(1)) {
140 node->adjoint = Scalar(0);
141 }
142
143 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
144 // multiplied by dy/dx. If there are multiple "paths" from the root node to
145 // variable; the variable's adjoint is the sum of each path's adjoint
146 // contribution.
147 for (const auto& node : m_top_list) {
148 auto& lhs = node->args[0];
149 auto& rhs = node->args[1];
150
151 if (lhs != nullptr) {
152 if (rhs != nullptr) {
153 // Binary operator
154 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
155 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
156 } else {
157 // Unary operator
158 lhs->adjoint += node->grad_l(lhs->val, Scalar(0), node->adjoint);
159 }
160 }
161 }
162
163 // If wrt has fewer nodes than graph, iterate over wrt
164 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
165 for (int col = 0; col < wrt.rows(); ++col) {
166 const auto& node = wrt[col].expr;
167
168 // Append adjoints of wrt to sparse matrix triplets
169 if (node->adjoint != Scalar(0)) {
170 triplets.emplace_back(row, col, node->adjoint);
171 }
172 }
173 } else {
174 for (const auto& [col, node] : std::views::zip(m_col_list, m_top_list)) {
175 // Append adjoints of wrt to sparse matrix triplets
176 if (col != -1 && node->adjoint != Scalar(0)) {
177 triplets.emplace_back(row, col, node->adjoint);
178 }
179 }
180 }
181 }
182
183 private:
184 // Topological sort of graph from parent to child
185 gch::small_vector<Expression<Scalar>*> m_top_list;
186
187 // List that maps nodes to their respective column
188 gch::small_vector<int> m_col_list;
189};
190
191} // namespace slp::detail
Definition variable_matrix.hpp:35
int rows() const
Definition variable_matrix.hpp:1104
int cols() const
Definition variable_matrix.hpp:1111
Definition variable.hpp:49
Definition adjoint_expression_graph.hpp:26
AdjointExpressionGraph(const Variable< Scalar > &root)
Definition adjoint_expression_graph.hpp:33
void update_values()
Definition adjoint_expression_graph.hpp:44
void append_gradient_triplets(gch::small_vector< Eigen::Triplet< Scalar > > &triplets, int row, const VariableMatrix< Scalar > &wrt) const
Definition adjoint_expression_graph.hpp:116
VariableMatrix< Scalar > generate_gradient_tree(const VariableMatrix< Scalar > &wrt) const
Definition adjoint_expression_graph.hpp:57