8#include <Eigen/SparseCore>
9#include <gch/small_vector.hpp>
11#include "sleipnir/autodiff/expression_graph.hpp"
12#include "sleipnir/autodiff/variable.hpp"
13#include "sleipnir/autodiff/variable_matrix.hpp"
14#include "sleipnir/util/assert.hpp"
15#include "sleipnir/util/empty.hpp"
17namespace slp::detail {
24template <
typename Scalar>
31 : m_top_list{topological_sort(
root.expr)} {
32 for (
const auto&
node : m_top_list) {
33 m_col_list.emplace_back(
node->col);
52 slp_assert(
wrt.cols() == 1);
57 if (m_top_list.empty()) {
62 m_top_list[0]->adjoint_expr = constant_ptr(Scalar(1));
68 for (
auto&
node : m_top_list) {
86 for (
int row = 0; row <
grad.rows(); ++row) {
93 for (
auto&
node : m_top_list) {
94 node->adjoint_expr =
nullptr;
109 slp_assert(
wrt.cols() == 1);
115 if (
static_cast<size_t>(
wrt.rows()) < m_top_list.size()) {
117 elem.expr->adjoint = Scalar(0);
121 if (m_top_list.empty()) {
126 m_top_list[0]->adjoint = Scalar(1);
129 for (
auto&
node : m_top_list | std::views::drop(1)) {
130 node->adjoint = Scalar(0);
137 for (
const auto&
node : m_top_list) {
141 if (
lhs !=
nullptr) {
142 if (
rhs !=
nullptr) {
154 if (
static_cast<size_t>(
wrt.rows()) < m_top_list.size()) {
155 for (
int col = 0; col <
wrt.rows(); ++col) {
156 const auto&
node =
wrt[col].expr;
159 if (
node->adjoint != Scalar(0)) {
164 for (
const auto& [col,
node] : std::views::zip(m_col_list, m_top_list)) {
166 if (col != -1 &&
node->adjoint != Scalar(0)) {
175 gch::small_vector<Expression<Scalar>*> m_top_list;
178 gch::small_vector<int> m_col_list;
Definition intrusive_shared_ptr.hpp:27
Definition variable.hpp:47
Definition gradient_expression_graph.hpp:25
GradientExpressionGraph(const Variable< Scalar > &root)
Definition gradient_expression_graph.hpp:30
void append_triplets(gch::small_vector< Eigen::Triplet< Scalar > > &triplets, int row, const VariableMatrix< Scalar > &wrt) const
Definition gradient_expression_graph.hpp:107
void update_values()
Definition gradient_expression_graph.hpp:39
VariableMatrix< Scalar > generate_tree(const VariableMatrix< Scalar > &wrt) const
Definition gradient_expression_graph.hpp:50