46 while (!
stack.empty()) {
56 if (
arg->duplications == 0) {
66 while (!
stack.empty()) {
72 m_adjointList.emplace_back(
node);
73 if (
node->args[0] !=
nullptr) {
76 m_valueList.emplace_back(
node);
87 if (
arg->duplications == 0) {
102 for (
auto&
node : m_valueList | std::views::reverse) {
106 if (
lhs !=
nullptr) {
107 if (
rhs !=
nullptr) {
126 if (m_adjointList.size() > 0) {
134 for (
auto&
node : m_adjointList) {
138 if (
lhs !=
nullptr) {
141 if (
rhs !=
nullptr) {
149 for (
int row = 0; row <
grad.Rows(); ++row) {
156 for (
auto&
node : m_adjointList) {
158 if (
arg !=
nullptr) {
159 arg->adjointExpr =
nullptr;
162 node->adjointExpr =
nullptr;
177 m_adjointList[0]->adjoint = 1.0;
178 for (
auto&
node : m_adjointList | std::views::drop(1)) {
186 for (
size_t col = 0;
col < m_adjointList.size(); ++
col) {
187 auto&
node = m_adjointList[
col];
191 if (
lhs !=
nullptr) {
192 if (
rhs !=
nullptr) {
199 node->GradientValueLhs(
lhs->value, 0.0,
node->adjoint);
204 int row = m_rowList[
col];
Definition VariableMatrix.hpp:28
static constexpr empty_t empty
Definition VariableMatrix.hpp:31
Definition Variable.hpp:33
Definition ExpressionGraph.hpp:19
void ComputeAdjoints(function_ref< void(int row, double adjoint)> func)
Definition ExpressionGraph.hpp:175
VariableMatrix GenerateGradientTree(const VariableMatrix &wrt) const
Definition ExpressionGraph.hpp:121
ExpressionGraph(Variable &root)
Definition ExpressionGraph.hpp:26
void Update()
Definition ExpressionGraph.hpp:99
Definition FunctionRef.hpp:17
Definition small_vector.hpp:3616
::value &&MoveInsertable constexpr reference emplace_back(Args &&... args)
Definition small_vector.hpp:4071
Definition Expression.hpp:18
IntrusiveSharedPtr< T > AllocateIntrusiveShared(Alloc alloc, Args &&... args)
Definition IntrusiveSharedPtr.hpp:275
@ kConstant
The expression is a constant.