8#include <Eigen/SparseCore>
9#include <gch/small_vector.hpp>
11#include "sleipnir/autodiff/expression_graph.hpp"
12#include "sleipnir/autodiff/variable.hpp"
13#include "sleipnir/autodiff/variable_matrix.hpp"
14#include "sleipnir/util/assert.hpp"
15#include "sleipnir/util/empty.hpp"
17namespace slp::detail {
24template <
typename Scalar>
31 : m_top_list{topological_sort(
root.expr)} {
32 for (
const auto&
node : m_top_list) {
33 m_col_list.emplace_back(
node->col);
52 slp_assert(
wrt.cols() == 1);
57 if (m_top_list.empty()) {
62 m_top_list[0]->adjoint_expr = constant_ptr(Scalar(1));
68 for (
auto&
node : m_top_list) {
86 for (
int row = 0; row <
grad.rows(); ++row) {
93 for (
auto&
node : m_top_list) {
94 node->adjoint_expr =
nullptr;
110 if (m_top_list.empty()) {
115 m_top_list[0]->adjoint = Scalar(1);
118 for (
auto&
node : m_top_list | std::views::drop(1)) {
119 node->adjoint = Scalar(0);
126 for (
const auto&
node : m_top_list) {
130 if (
lhs !=
nullptr) {
131 if (
rhs !=
nullptr) {
142 for (
const auto& [col,
node] : std::views::zip(m_col_list, m_top_list)) {
152 gch::small_vector<Expression<Scalar>*> m_top_list;
155 gch::small_vector<int> m_col_list;
Definition intrusive_shared_ptr.hpp:27
Definition variable.hpp:47
Definition gradient_expression_graph.hpp:25
GradientExpressionGraph(const Variable< Scalar > &root)
Definition gradient_expression_graph.hpp:30
void update_values()
Definition gradient_expression_graph.hpp:39
VariableMatrix< Scalar > generate_tree(const VariableMatrix< Scalar > &wrt) const
Definition gradient_expression_graph.hpp:50
void append_triplets(gch::small_vector< Eigen::Triplet< Scalar > > &triplets, int row) const
Definition gradient_expression_graph.hpp:105