7#include <gch/small_vector.hpp>
9#include "sleipnir/autodiff/expression.hpp"
11namespace slp::detail {
20inline gch::small_vector<Expression*> topological_sort(
21 const ExpressionPtr& root) {
22 gch::small_vector<Expression*> list;
25 if (root ==
nullptr || root->type() == ExpressionType::CONSTANT) {
30 gch::small_vector<Expression*> stack;
33 stack.emplace_back(root.get());
34 while (!stack.empty()) {
35 auto node = stack.back();
38 for (
auto& arg : node->args) {
40 if (arg !=
nullptr && ++arg->incoming_edges == 1) {
41 stack.push_back(arg.get());
53 stack.emplace_back(root.get());
54 while (!stack.empty()) {
55 auto node = stack.back();
58 list.emplace_back(node);
60 for (
auto& arg : node->args) {
62 if (arg !=
nullptr && --arg->incoming_edges == 0) {
63 stack.push_back(arg.get());
77inline void update_values(
const gch::small_vector<Expression*>& list) {
79 for (
auto& node : list | std::views::reverse) {
80 auto& lhs = node->args[0];
81 auto& rhs = node->args[1];
85 node->val = node->value(lhs->val, rhs->val);
87 node->val = node->value(lhs->val, 0.0);