56 if (m_top_list.empty()) {
61 m_top_list[0]->adjoint_expr = make_expression_ptr<ConstExpression>(1.0);
67 for (
auto& node : m_top_list) {
68 auto& lhs = node->args[0];
69 auto& rhs = node->args[1];
72 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
74 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
81 for (
int row = 0; row < grad.rows(); ++row) {
82 grad[row] =
Variable{std::move(wrt[row].expr->adjoint_expr)};
88 for (
auto& node : m_top_list) {
89 node->adjoint_expr =
nullptr;
105 gch::small_vector<Eigen::Triplet<double>>& triplets,
int row,
111 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
112 for (
const auto& elem : wrt) {
113 elem.expr->adjoint = 0.0;
117 if (m_top_list.empty()) {
122 m_top_list[0]->adjoint = 1.0;
125 for (
auto& node : m_top_list | std::views::drop(1)) {
133 for (
const auto& node : m_top_list) {
134 auto& lhs = node->args[0];
135 auto& rhs = node->args[1];
137 if (lhs !=
nullptr) {
138 if (rhs !=
nullptr) {
139 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
140 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
142 lhs->adjoint += node->grad_l(lhs->val, 0.0, node->adjoint);
148 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
149 for (
int col = 0; col < wrt.
rows(); ++col) {
150 const auto& node = wrt[col].expr;
153 if (node->adjoint != 0.0) {
154 triplets.emplace_back(row, col, node->adjoint);
158 for (
const auto& [col, node] : std::views::zip(m_col_list, m_top_list)) {
160 if (col != -1 && node->adjoint != 0.0) {
161 triplets.emplace_back(row, col, node->adjoint);