Sleipnir C++ API
Loading...
Searching...
No Matches
expression_graph.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6
7#include <Eigen/SparseCore>
8#include <gch/small_vector.hpp>
9
10#include "sleipnir/autodiff/expression.hpp"
11
12namespace slp::detail {
13
17template <typename Scalar>
18using ExpressionGraph = gch::small_vector<Expression<Scalar>*>;
19
26template <typename Scalar>
27ExpressionGraph<Scalar> topological_sort(const ExpressionPtr<Scalar>& root) {
28 ExpressionGraph<Scalar> list;
29
30 // If the root type is constant, updates are a no-op, so return an empty list
31 if (root == nullptr || root->type() == ExpressionType::CONSTANT) {
32 return list;
33 }
34
35 // Stack of nodes to explore
36 gch::small_vector<Expression<Scalar>*> stack;
37
38 // Enumerate incoming edges for each node via depth-first search
39 //
40 // NOTE: scratch counts incoming edges, offset by -1 so -1 means no edges.
41 stack.emplace_back(root.get());
42 while (!stack.empty()) {
43 auto node = stack.back();
44 stack.pop_back();
45
46 for (auto& arg : node->args) {
47 // If the node hasn't been explored yet, add it to the stack
48 if (arg != nullptr && ++arg->scratch == 0) {
49 stack.push_back(arg.get());
50 }
51 }
52 }
53
54 // Generate topological sort of graph from parent to child.
55 //
56 // A node is only added to the stack after all its incoming edges have been
57 // traversed. Expression::scratch is a decrementing counter for tracking this.
58 //
59 // https://en.wikipedia.org/wiki/Topological_sorting
60 stack.emplace_back(root.get());
61 while (!stack.empty()) {
62 auto node = stack.back();
63 stack.pop_back();
64
65 list.emplace_back(node);
66
67 for (auto& arg : node->args) {
68 // If we traversed all this node's incoming edges, add it to the stack
69 if (arg != nullptr && --arg->scratch == -1) {
70 stack.push_back(arg.get());
71 }
72 }
73 }
74
75 return list;
76}
77
83template <typename Scalar>
84void update_values(const ExpressionGraph<Scalar>& list) {
85 // Traverse graph from child to parent and update values
86 for (auto& node : list | std::views::reverse) {
87 auto& lhs = node->args[0];
88 auto& rhs = node->args[1];
89
90 if (lhs != nullptr) {
91 node->val = node->value(lhs->val, rhs ? rhs->val : Scalar(0));
92 }
93 }
94}
95
104template <typename Scalar>
105void append_triplets(
106 const ExpressionGraph<Scalar>& top_list,
107 const gch::small_vector<std::pair<int, detail::Expression<Scalar>*>>&
108 output_list,
109 gch::small_vector<Eigen::Triplet<Scalar>>& triplets, int row) {
110 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
111 // for background on reverse accumulation automatic differentiation.
112
113 if (top_list.empty()) {
114 return;
115 }
116
117 // Set root node's adjoint to 1 since df/df is 1
118 top_list[0]->adjoint = Scalar(1);
119
120 // Zero the rest of the adjoints
121 for (auto& node : top_list | std::views::drop(1)) {
122 node->adjoint = Scalar(0);
123 }
124
125 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
126 // multiplied by dy/dx. If there are multiple "paths" from the root node to
127 // variable; the variable's adjoint is the sum of each path's adjoint
128 // contribution.
129 for (const auto& node : top_list) {
130 auto& lhs = node->args[0];
131 auto& rhs = node->args[1];
132
133 if (lhs != nullptr) {
134 if (rhs != nullptr) {
135 // Binary operator
136 lhs->adjoint += node->grad_l(lhs->val, rhs->val);
137 rhs->adjoint += node->grad_r(lhs->val, rhs->val);
138 } else {
139 // Unary operator
140 lhs->adjoint += node->grad_l(lhs->val, Scalar(0));
141 }
142 }
143 }
144
145 // Exploit the row's sparsity pattern by only appending wrt adjoints that
146 // appear in the expression graph
147 for (const auto& [col, node] : output_list) {
148 // Append adjoints of wrt to sparse matrix triplets
149 triplets.emplace_back(row, col, node->adjoint);
150 }
151}
152
153} // namespace slp::detail