7#include <Eigen/SparseCore>
8#include <gch/small_vector.hpp>
10#include "sleipnir/autodiff/expression.hpp"
12namespace slp::detail {
17template <
typename Scalar>
18using ExpressionGraph = gch::small_vector<Expression<Scalar>*>;
26template <
typename Scalar>
27ExpressionGraph<Scalar> topological_sort(
const ExpressionPtr<Scalar>& root) {
28 ExpressionGraph<Scalar> list;
31 if (root ==
nullptr || root->type() == ExpressionType::CONSTANT) {
36 gch::small_vector<Expression<Scalar>*> stack;
41 stack.emplace_back(root.get());
42 while (!stack.empty()) {
43 auto node = stack.back();
46 for (
auto& arg : node->args) {
48 if (arg !=
nullptr && ++arg->scratch == 0) {
49 stack.push_back(arg.get());
60 stack.emplace_back(root.get());
61 while (!stack.empty()) {
62 auto node = stack.back();
65 list.emplace_back(node);
67 for (
auto& arg : node->args) {
69 if (arg !=
nullptr && --arg->scratch == -1) {
70 stack.push_back(arg.get());
83template <
typename Scalar>
84void update_values(
const ExpressionGraph<Scalar>& list) {
86 for (
auto& node : list | std::views::reverse) {
87 auto& lhs = node->args[0];
88 auto& rhs = node->args[1];
91 node->val = node->value(lhs->val, rhs ? rhs->val : Scalar(0));
104template <
typename Scalar>
106 const ExpressionGraph<Scalar>& top_list,
107 const gch::small_vector<std::pair<
int, detail::Expression<Scalar>*>>&
109 gch::small_vector<Eigen::Triplet<Scalar>>& triplets,
int row) {
113 if (top_list.empty()) {
118 top_list[0]->adjoint = Scalar(1);
121 for (
auto& node : top_list | std::views::drop(1)) {
122 node->adjoint = Scalar(0);
129 for (
const auto& node : top_list) {
130 auto& lhs = node->args[0];
131 auto& rhs = node->args[1];
133 if (lhs !=
nullptr) {
134 if (rhs !=
nullptr) {
136 lhs->adjoint += node->grad_l(lhs->val, rhs->val);
137 rhs->adjoint += node->grad_r(lhs->val, rhs->val);
140 lhs->adjoint += node->grad_l(lhs->val, Scalar(0));
147 for (
const auto& [col, node] : output_list) {
149 triplets.emplace_back(row, col, node->adjoint);