Sleipnir C++ API
Loading...
Searching...
No Matches
gradient_expression_graph.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6#include <utility>
7
8#include <Eigen/SparseCore>
9#include <gch/small_vector.hpp>
10
11#include "sleipnir/autodiff/expression_graph.hpp"
12#include "sleipnir/autodiff/variable.hpp"
13#include "sleipnir/autodiff/variable_matrix.hpp"
14#include "sleipnir/util/assert.hpp"
15#include "sleipnir/util/empty.hpp"
16
17namespace slp::detail {
18
24template <typename Scalar>
26 public:
31 : m_top_list{topological_sort(root.expr)} {
32 for (const auto& node : m_top_list) {
33 m_col_list.emplace_back(node->col);
34 }
35 }
36
39 void update_values() { detail::update_values(m_top_list); }
40
51 const VariableMatrix<Scalar>& wrt) const {
52 slp_assert(wrt.cols() == 1);
53
54 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
55 // for background on reverse accumulation automatic differentiation.
56
57 if (m_top_list.empty()) {
58 return VariableMatrix<Scalar>{detail::empty, wrt.rows(), 1};
59 }
60
61 // Set root node's adjoint to 1 since df/df is 1
62 m_top_list[0]->adjoint_expr = constant_ptr(Scalar(1));
63
64 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
65 // multiplied by dy/dx. If there are multiple "paths" from the root node to
66 // variable; the variable's adjoint is the sum of each path's adjoint
67 // contribution.
68 for (auto& node : m_top_list) {
69 auto& lhs = node->args[0];
70 auto& rhs = node->args[1];
71
72 if (lhs != nullptr) {
73 if (rhs != nullptr) {
74 // Binary operator
75 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
76 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
77 } else {
78 // Unary operator
79 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
80 }
81 }
82 }
83
84 // Move gradient tree to return value
85 VariableMatrix<Scalar> grad{detail::empty, wrt.rows(), 1};
86 for (int row = 0; row < grad.rows(); ++row) {
87 grad[row] = Variable{std::move(wrt[row].expr->adjoint_expr)};
88 }
89
90 // Unlink adjoints to avoid circular references between them and their
91 // parent expressions. This ensures all expressions are returned to the free
92 // list.
93 for (auto& node : m_top_list) {
94 node->adjoint_expr = nullptr;
95 }
96
97 return grad;
98 }
99
107 void append_triplets(gch::small_vector<Eigen::Triplet<Scalar>>& triplets,
108 int row, const VariableMatrix<Scalar>& wrt) const {
109 slp_assert(wrt.cols() == 1);
110
111 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
112 // for background on reverse accumulation automatic differentiation.
113
114 // If wrt has fewer nodes than graph, zero wrt's adjoints
115 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
116 for (const auto& elem : wrt) {
117 elem.expr->adjoint = Scalar(0);
118 }
119 }
120
121 if (m_top_list.empty()) {
122 return;
123 }
124
125 // Set root node's adjoint to 1 since df/df is 1
126 m_top_list[0]->adjoint = Scalar(1);
127
128 // Zero the rest of the adjoints
129 for (auto& node : m_top_list | std::views::drop(1)) {
130 node->adjoint = Scalar(0);
131 }
132
133 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
134 // multiplied by dy/dx. If there are multiple "paths" from the root node to
135 // variable; the variable's adjoint is the sum of each path's adjoint
136 // contribution.
137 for (const auto& node : m_top_list) {
138 auto& lhs = node->args[0];
139 auto& rhs = node->args[1];
140
141 if (lhs != nullptr) {
142 if (rhs != nullptr) {
143 // Binary operator
144 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
145 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
146 } else {
147 // Unary operator
148 lhs->adjoint += node->grad_l(lhs->val, Scalar(0), node->adjoint);
149 }
150 }
151 }
152
153 // If wrt has fewer nodes than graph, iterate over wrt
154 if (static_cast<size_t>(wrt.rows()) < m_top_list.size()) {
155 for (int col = 0; col < wrt.rows(); ++col) {
156 const auto& node = wrt[col].expr;
157
158 // Append adjoints of wrt to sparse matrix triplets
159 if (node->adjoint != Scalar(0)) {
160 triplets.emplace_back(row, col, node->adjoint);
161 }
162 }
163 } else {
164 for (const auto& [col, node] : std::views::zip(m_col_list, m_top_list)) {
165 // Append adjoints of wrt to sparse matrix triplets
166 if (col != -1 && node->adjoint != Scalar(0)) {
167 triplets.emplace_back(row, col, node->adjoint);
168 }
169 }
170 }
171 }
172
173 private:
175 gch::small_vector<Expression<Scalar>*> m_top_list;
176
178 gch::small_vector<int> m_col_list;
179};
180
181} // namespace slp::detail
Definition intrusive_shared_ptr.hpp:27
Definition variable.hpp:47
Definition gradient_expression_graph.hpp:25
GradientExpressionGraph(const Variable< Scalar > &root)
Definition gradient_expression_graph.hpp:30
void append_triplets(gch::small_vector< Eigen::Triplet< Scalar > > &triplets, int row, const VariableMatrix< Scalar > &wrt) const
Definition gradient_expression_graph.hpp:107
void update_values()
Definition gradient_expression_graph.hpp:39
VariableMatrix< Scalar > generate_tree(const VariableMatrix< Scalar > &wrt) const
Definition gradient_expression_graph.hpp:50