Sleipnir C++ API
Loading...
Searching...
No Matches
expression_graph.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6
7#include <gch/small_vector.hpp>
8
9#include "sleipnir/autodiff/expression.hpp"
10
11namespace slp::detail {
12
21template <typename Scalar>
22gch::small_vector<Expression<Scalar>*> topological_sort(
23 const ExpressionPtr<Scalar>& root) {
24 gch::small_vector<Expression<Scalar>*> list;
25
26 // If the root type is constant, updates are a no-op, so return an empty list
27 if (root == nullptr || root->type() == ExpressionType::CONSTANT) {
28 return list;
29 }
30
31 // Stack of nodes to explore
32 gch::small_vector<Expression<Scalar>*> stack;
33
34 // Enumerate incoming edges for each node via depth-first search
35 stack.emplace_back(root.get());
36 while (!stack.empty()) {
37 auto node = stack.back();
38 stack.pop_back();
39
40 for (auto& arg : node->args) {
41 // If the node hasn't been explored yet, add it to the stack
42 if (arg != nullptr && ++arg->incoming_edges == 1) {
43 stack.push_back(arg.get());
44 }
45 }
46 }
47
48 // Generate topological sort of graph from parent to child.
49 //
50 // A node is only added to the stack after all its incoming edges have been
51 // traversed. Expression::incoming_edges is a decrementing counter for
52 // tracking this.
53 //
54 // https://en.wikipedia.org/wiki/Topological_sorting
55 stack.emplace_back(root.get());
56 while (!stack.empty()) {
57 auto node = stack.back();
58 stack.pop_back();
59
60 list.emplace_back(node);
61
62 for (auto& arg : node->args) {
63 // If we traversed all this node's incoming edges, add it to the stack
64 if (arg != nullptr && --arg->incoming_edges == 0) {
65 stack.push_back(arg.get());
66 }
67 }
68 }
69
70 return list;
71}
72
80template <typename Scalar>
81void update_values(const gch::small_vector<Expression<Scalar>*>& list) {
82 // Traverse graph from child to parent and update values
83 for (auto& node : list | std::views::reverse) {
84 auto& lhs = node->args[0];
85 auto& rhs = node->args[1];
86
87 if (lhs != nullptr) {
88 node->val = node->value(lhs->val, rhs ? rhs->val : Scalar(0));
89 }
90 }
91}
92
93} // namespace slp::detail