Sleipnir C++ API
Loading...
Searching...
No Matches
ExpressionGraph.hpp
Go to the documentation of this file.
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6#include <utility>
7
12
13namespace sleipnir::detail {
14
20 public:
27 // If the root type is a constant, Update() is a no-op, so there's no work
28 // to do
29 if (root.expr == nullptr || root.Type() == ExpressionType::kConstant) {
30 return;
31 }
32
33 // Breadth-first search (BFS) is used as opposed to a depth-first search
34 // (DFS) to avoid counting duplicate nodes multiple times. A list of nodes
35 // ordered from parent to child with no duplicates is generated.
36 //
37 // https://en.wikipedia.org/wiki/Breadth-first_search
38
39 // BFS list sorted from parent to child.
41
42 stack.emplace_back(root.expr.Get());
43
44 // Initialize the number of instances of each node in the tree
45 // (Expression::duplications)
46 while (!stack.empty()) {
47 auto node = stack.back();
48 stack.pop_back();
49
50 for (auto& arg : node->args) {
51 // Only continue if the node is not a constant and hasn't already been
52 // explored.
53 if (arg != nullptr && arg->Type() != ExpressionType::kConstant) {
54 // If this is the first instance of the node encountered (it hasn't
55 // been explored yet), add it to stack so it's recursed upon
56 if (arg->duplications == 0) {
57 stack.push_back(arg.Get());
58 }
59 ++arg->duplications;
60 }
61 }
62 }
63
64 stack.emplace_back(root.expr.Get());
65
66 while (!stack.empty()) {
67 auto node = stack.back();
68 stack.pop_back();
69
70 // BFS lists sorted from parent to child.
71 m_rowList.emplace_back(node->row);
72 m_adjointList.emplace_back(node);
73 if (node->args[0] != nullptr) {
74 // Constants (expressions with no arguments) are skipped because they
75 // don't need to be updated
76 m_valueList.emplace_back(node);
77 }
78
79 for (auto& arg : node->args) {
80 // Only add node if it's not a constant and doesn't already exist in the
81 // tape.
82 if (arg != nullptr && arg->Type() != ExpressionType::kConstant) {
83 // Once the number of node visitations equals the number of
84 // duplications (the counter hits zero), add it to the stack. Note
85 // that this means the node is only enqueued once.
86 --arg->duplications;
87 if (arg->duplications == 0) {
88 stack.push_back(arg.Get());
89 }
90 }
91 }
92 }
93 }
94
99 void Update() {
100 // Traverse the BFS list backward from child to parent and update the value
101 // of each node.
102 for (auto& node : m_valueList | std::views::reverse) {
103 auto& lhs = node->args[0];
104 auto& rhs = node->args[1];
105
106 if (lhs != nullptr) {
107 if (rhs != nullptr) {
108 node->value = node->Value(lhs->value, rhs->value);
109 } else {
110 node->value = node->Value(lhs->value, 0.0);
111 }
112 }
113 }
114 }
115
122 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
123 // for background on reverse accumulation automatic differentiation.
124
125 // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1.
126 if (m_adjointList.size() > 0) {
127 m_adjointList[0]->adjointExpr = MakeExpressionPtr<ConstExpression>(1.0);
128 }
129
130 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
131 // multiplied by dy/dx. If there are multiple "paths" from the root node to
132 // variable; the variable's adjoint is the sum of each path's adjoint
133 // contribution.
134 for (auto& node : m_adjointList) {
135 auto& lhs = node->args[0];
136 auto& rhs = node->args[1];
137
138 if (lhs != nullptr) {
139 lhs->adjointExpr =
140 lhs->adjointExpr + node->GradientLhs(lhs, rhs, node->adjointExpr);
141 if (rhs != nullptr) {
142 rhs->adjointExpr =
143 rhs->adjointExpr + node->GradientRhs(lhs, rhs, node->adjointExpr);
144 }
145 }
146 }
147
149 for (int row = 0; row < grad.Rows(); ++row) {
150 grad(row) = Variable{std::move(wrt(row).expr->adjointExpr)};
151 }
152
153 // Unlink adjoints to avoid circular references between them and their
154 // parent expressions. This ensures all expressions are returned to the free
155 // list.
156 for (auto& node : m_adjointList) {
157 for (auto& arg : node->args) {
158 if (arg != nullptr) {
159 arg->adjointExpr = nullptr;
160 }
161 }
162 node->adjointExpr = nullptr;
163 }
164
165 return grad;
166 }
167
175 void ComputeAdjoints(function_ref<void(int row, double adjoint)> func) {
176 // Zero adjoints. The root node's adjoint is 1.0 as df/df is always 1.
177 m_adjointList[0]->adjoint = 1.0;
178 for (auto& node : m_adjointList | std::views::drop(1)) {
179 node->adjoint = 0.0;
180 }
181
182 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
183 // multiplied by dy/dx. If there are multiple "paths" from the root node to
184 // variable; the variable's adjoint is the sum of each path's adjoint
185 // contribution.
186 for (size_t col = 0; col < m_adjointList.size(); ++col) {
187 auto& node = m_adjointList[col];
188 auto& lhs = node->args[0];
189 auto& rhs = node->args[1];
190
191 if (lhs != nullptr) {
192 if (rhs != nullptr) {
193 lhs->adjoint +=
194 node->GradientValueLhs(lhs->value, rhs->value, node->adjoint);
195 rhs->adjoint +=
196 node->GradientValueRhs(lhs->value, rhs->value, node->adjoint);
197 } else {
198 lhs->adjoint +=
199 node->GradientValueLhs(lhs->value, 0.0, node->adjoint);
200 }
201 }
202
203 // If variable is a leaf node, assign its adjoint to the gradient.
204 int row = m_rowList[col];
205 if (row != -1) {
206 func(row, node->adjoint);
207 }
208 }
209 }
210
211 private:
212 // List that maps nodes to their respective row.
213 small_vector<int> m_rowList;
214
215 // List for updating adjoints
216 small_vector<Expression*> m_adjointList;
217
218 // List for updating values
219 small_vector<Expression*> m_valueList;
220};
221
222} // namespace sleipnir::detail
Definition VariableMatrix.hpp:28
static constexpr empty_t empty
Definition VariableMatrix.hpp:31
Definition Variable.hpp:33
Definition ExpressionGraph.hpp:19
void ComputeAdjoints(function_ref< void(int row, double adjoint)> func)
Definition ExpressionGraph.hpp:175
VariableMatrix GenerateGradientTree(const VariableMatrix &wrt) const
Definition ExpressionGraph.hpp:121
ExpressionGraph(Variable &root)
Definition ExpressionGraph.hpp:26
void Update()
Definition ExpressionGraph.hpp:99
Definition FunctionRef.hpp:17
Definition small_vector.hpp:3616
::value &&MoveInsertable constexpr reference emplace_back(Args &&... args)
Definition small_vector.hpp:4071
Definition Expression.hpp:18
IntrusiveSharedPtr< T > AllocateIntrusiveShared(Alloc alloc, Args &&... args)
Definition IntrusiveSharedPtr.hpp:275
@ kConstant
The expression is a constant.