59 slp_assert(wrt.
cols() == 1);
64 if (m_top_list.empty()) {
69 m_top_list[0]->adjoint_expr = constant_ptr(Scalar(1));
75 for (
auto& node : m_top_list) {
76 auto& lhs = node->args[0];
77 auto& rhs = node->args[1];
82 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
83 rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
86 lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
93 for (
int row = 0; row < grad.rows(); ++row) {
94 grad[row] =
Variable{std::move(wrt[row].expr->adjoint_expr)};
100 for (
auto& node : m_top_list) {
101 node->adjoint_expr =
nullptr;
117 gch::small_vector<Eigen::Triplet<Scalar>>& triplets,
int row,
119 slp_assert(wrt.
cols() == 1);
125 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
126 for (
const auto& elem : wrt) {
127 elem.expr->adjoint = Scalar(0);
131 if (m_top_list.empty()) {
136 m_top_list[0]->adjoint = Scalar(1);
139 for (
auto& node : m_top_list | std::views::drop(1)) {
140 node->adjoint = Scalar(0);
147 for (
const auto& node : m_top_list) {
148 auto& lhs = node->args[0];
149 auto& rhs = node->args[1];
151 if (lhs !=
nullptr) {
152 if (rhs !=
nullptr) {
154 lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
155 rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
158 lhs->adjoint += node->grad_l(lhs->val, Scalar(0), node->adjoint);
164 if (
static_cast<size_t>(wrt.
rows()) < m_top_list.size()) {
165 for (
int col = 0; col < wrt.
rows(); ++col) {
166 const auto& node = wrt[col].expr;
169 if (node->adjoint != Scalar(0)) {
170 triplets.emplace_back(row, col, node->adjoint);
174 for (
const auto& [col, node] : std::views::zip(m_col_list, m_top_list)) {
176 if (col != -1 && node->adjoint != Scalar(0)) {
177 triplets.emplace_back(row, col, node->adjoint);