Sleipnir C++ API
Loading...
Searching...
No Matches
expression_graph.hpp
1// Copyright (c) Sleipnir contributors
2
3#pragma once
4
5#include <ranges>
6
7#include "sleipnir/autodiff/expression.hpp"
8#include "sleipnir/util/small_vector.hpp"
9
10namespace slp::detail {
11
19inline small_vector<Expression*> topological_sort(const ExpressionPtr& root) {
20 small_vector<Expression*> list;
21
22 // If the root type is a constant, Update() is a no-op, so there's no work
23 // to do
24 if (root == nullptr || root->type() == ExpressionType::CONSTANT) {
25 return list;
26 }
27
28 // Stack of nodes to explore
29 small_vector<Expression*> stack;
30
31 // Enumerate incoming edges for each node via depth-first search
32 stack.emplace_back(root.get());
33 while (!stack.empty()) {
34 auto node = stack.back();
35 stack.pop_back();
36
37 for (auto& arg : node->args) {
38 // If the node hasn't been explored yet, add it to the stack
39 if (arg != nullptr && ++arg->incoming_edges == 1) {
40 stack.push_back(arg.get());
41 }
42 }
43 }
44
45 // Generate topological sort of graph from parent to child.
46 //
47 // A node is only added to the stack after all its incoming edges have been
48 // traversed. Expression::incoming_edges is a decrementing counter for
49 // tracking this.
50 //
51 // https://en.wikipedia.org/wiki/Topological_sorting
52 stack.emplace_back(root.get());
53 while (!stack.empty()) {
54 auto node = stack.back();
55 stack.pop_back();
56
57 list.emplace_back(node);
58
59 for (auto& arg : node->args) {
60 // If we traversed all this node's incoming edges, add it to the stack
61 if (arg != nullptr && --arg->incoming_edges == 0) {
62 stack.push_back(arg.get());
63 }
64 }
65 }
66
67 return list;
68}
69
76inline void update_values(const small_vector<Expression*>& list) {
77 // Traverse graph from child to parent and update values
78 for (auto& node : list | std::views::reverse) {
79 auto& lhs = node->args[0];
80 auto& rhs = node->args[1];
81
82 if (lhs != nullptr) {
83 if (rhs != nullptr) {
84 node->val = node->value(lhs->val, rhs->val);
85 } else {
86 node->val = node->value(lhs->val, 0.0);
87 }
88 }
89 }
90}
91
92} // namespace slp::detail