Sleipnir C++ API
Loading...
Searching...
No Matches
gradient_expression_graph.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6#include <utility>
7
8#include <Eigen/SparseCore>
9#include <gch/small_vector.hpp>
10
11#include "sleipnir/autodiff/expression_graph.hpp"
12#include "sleipnir/autodiff/variable.hpp"
13#include "sleipnir/autodiff/variable_matrix.hpp"
14#include "sleipnir/util/assert.hpp"
15#include "sleipnir/util/empty.hpp"
16
17namespace slp::detail {
18
24template <typename Scalar>
26 public:
31 : m_top_list{topological_sort(root.expr)} {
32 for (const auto& node : m_top_list) {
33 m_col_list.emplace_back(node->col);
34 }
35 }
36
39 void update_values() { detail::update_values(m_top_list); }
40
51 const VariableMatrix<Scalar>& wrt) const {
52 slp_assert(wrt.cols() == 1);
53
54 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
55 // for background on reverse accumulation automatic differentiation.
56
57 if (m_top_list.empty()) {
58 return VariableMatrix<Scalar>{detail::empty, wrt.rows(), 1};
59 }
60
61 // Set root node's adjoint to 1 since df/df is 1
62 m_top_list[0]->adjoint_expr = constant_ptr(Scalar(1));
63
64 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
65 // multiplied by dy/dx. If there are multiple "paths" from the root node to
66 // variable; the variable's adjoint is the sum of each path's adjoint
67 // contribution.
68 for (auto& node : m_top_list) {
69 auto& lhs = node->args[0];
70 auto& rhs = node->args[1];
71
72 if (lhs != nullptr) {
73 if (rhs != nullptr) {
74 // Binary operator
75 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
76 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
77 } else {
78 // Unary operator
79 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
80 }
81 }
82 }
83
84 // Move gradient tree to return value
85 VariableMatrix<Scalar> grad{detail::empty, wrt.rows(), 1};
86 for (int row = 0; row < grad.rows(); ++row) {
87 grad[row] = Variable{std::move(wrt[row].expr->adjoint_expr)};
88 }
89
90 // Unlink adjoints to avoid circular references between them and their
91 // parent expressions. This ensures all expressions are returned to the free
92 // list.
93 for (auto& node : m_top_list) {
94 node->adjoint_expr = nullptr;
95 }
96
97 return grad;
98 }
99
105 void append_triplets(gch::small_vector<Eigen::Triplet<Scalar>>& triplets,
106 int row) const {
107 // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation
108 // for background on reverse accumulation automatic differentiation.
109
110 if (m_top_list.empty()) {
111 return;
112 }
113
114 // Set root node's adjoint to 1 since df/df is 1
115 m_top_list[0]->adjoint = Scalar(1);
116
117 // Zero the rest of the adjoints
118 for (auto& node : m_top_list | std::views::drop(1)) {
119 node->adjoint = Scalar(0);
120 }
121
122 // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y
123 // multiplied by dy/dx. If there are multiple "paths" from the root node to
124 // variable; the variable's adjoint is the sum of each path's adjoint
125 // contribution.
126 for (const auto& node : m_top_list) {
127 auto& lhs = node->args[0];
128 auto& rhs = node->args[1];
129
130 if (lhs != nullptr) {
131 if (rhs != nullptr) {
132 // Binary operator
133 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
134 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
135 } else {
136 // Unary operator
137 lhs->adjoint += node->grad_l(lhs->val, Scalar(0), node->adjoint);
138 }
139 }
140 }
141
142 for (const auto& [col, node] : std::views::zip(m_col_list, m_top_list)) {
143 // Append adjoints of wrt to sparse matrix triplets
144 if (col != -1) {
145 triplets.emplace_back(row, col, node->adjoint);
146 }
147 }
148 }
149
150 private:
152 gch::small_vector<Expression<Scalar>*> m_top_list;
153
155 gch::small_vector<int> m_col_list;
156};
157
158} // namespace slp::detail
Definition intrusive_shared_ptr.hpp:27
Definition variable.hpp:47
Definition gradient_expression_graph.hpp:25
GradientExpressionGraph(const Variable< Scalar > &root)
Definition gradient_expression_graph.hpp:30
void update_values()
Definition gradient_expression_graph.hpp:39
VariableMatrix< Scalar > generate_tree(const VariableMatrix< Scalar > &wrt) const
Definition gradient_expression_graph.hpp:50
void append_triplets(gch::small_vector< Eigen::Triplet< Scalar > > &triplets, int row) const
Definition gradient_expression_graph.hpp:105