7#include <gch/small_vector.hpp>
9#include "sleipnir/autodiff/expression.hpp"
11namespace slp::detail {
21template <
typename Scalar>
22gch::small_vector<Expression<Scalar>*> topological_sort(
23 const ExpressionPtr<Scalar>& root) {
24 gch::small_vector<Expression<Scalar>*> list;
27 if (root ==
nullptr || root->type() == ExpressionType::CONSTANT) {
32 gch::small_vector<Expression<Scalar>*> stack;
35 stack.emplace_back(root.get());
36 while (!stack.empty()) {
37 auto node = stack.back();
40 for (
auto& arg : node->args) {
42 if (arg !=
nullptr && ++arg->incoming_edges == 1) {
43 stack.push_back(arg.get());
55 stack.emplace_back(root.get());
56 while (!stack.empty()) {
57 auto node = stack.back();
60 list.emplace_back(node);
62 for (
auto& arg : node->args) {
64 if (arg !=
nullptr && --arg->incoming_edges == 0) {
65 stack.push_back(arg.get());
80template <
typename Scalar>
81void update_values(
const gch::small_vector<Expression<Scalar>*>& list) {
83 for (
auto& node : list | std::views::reverse) {
84 auto& lhs = node->args[0];
85 auto& rhs = node->args[1];
88 node->val = node->value(lhs->val, rhs ? rhs->val : Scalar(0));