7#include "sleipnir/autodiff/expression.hpp"
8#include "sleipnir/util/small_vector.hpp"
10namespace slp::detail {
19inline small_vector<Expression*> topological_sort(
const ExpressionPtr& root) {
20 small_vector<Expression*> list;
24 if (root ==
nullptr || root->type() == ExpressionType::CONSTANT) {
29 small_vector<Expression*> stack;
32 stack.emplace_back(root.get());
33 while (!stack.empty()) {
34 auto node = stack.back();
37 for (
auto& arg : node->args) {
39 if (arg !=
nullptr && ++arg->incoming_edges == 1) {
40 stack.push_back(arg.get());
52 stack.emplace_back(root.get());
53 while (!stack.empty()) {
54 auto node = stack.back();
57 list.emplace_back(node);
59 for (
auto& arg : node->args) {
61 if (arg !=
nullptr && --arg->incoming_edges == 0) {
62 stack.push_back(arg.get());
76inline void update_values(
const small_vector<Expression*>& list) {
78 for (
auto& node : list | std::views::reverse) {
79 auto& lhs = node->args[0];
80 auto& rhs = node->args[1];
84 node->val = node->value(lhs->val, rhs->val);
86 node->val = node->value(lhs->val, 0.0);