You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
structA {
double x;
A() = default;
A(double px) : x(px) {}
A(const A& b) : x(b.x) {}
};
namespaceclad {
namespacecustom_derivatives {
namespaceclass_functions {
voidconstructor_pullback(A* a, double x, A* d_a, double* d_x) {
*d_x += d_a->x;
}
voidconstructor_pullback(A* a, const A& b, A* d_a, A* d_b) {
d_b->x += d_a->x;
}
}}}
A operator+(const A& a, const A& b) {
A res;
res.x = a.x + b.x;
return res;
}
doublef(double x, double y) {
A t1{1};
A t2{x};
A sum = t1 + t2;
return sum.x;
}
intmain(int argc, char* argv[]) {
auto df = clad::gradient(f);
double dx, dy;
dx = 0; dy = 0;
df.execute(3, 4, &dx, &dy);
std::cout << dx << '' << dy << '\n';
}
Expected output: 1 0
Output: segmentation fault
Note that the bug happens because of the way operator+ is differentiated:
void operator_plus_pullback(const A &a, const A &b, A _d_y, A *_d_a, A *_d_b) {
A res;
A _d_res({});
clad::zero_init(_d_res);
double _t0 = res.x;
res.x = a.x + b.x;
...::constructor_pullback(nullptr, res, nullptr, &_d_res); // <--- This line corresponds to the return-stmt
{ // Instead of the last nullptr we are supposed to have _d_y
res.x = _t0;
double _r_d0 = _d_res.x;
_d_res.x = 0.;
(*_d_a).x += _r_d0;
(*_d_b).x += _r_d0;
}
}
The text was updated successfully, but these errors were encountered:
Expected output:
1 0
Output:
segmentation fault
Note that the bug happens because of the way
operator+
is differentiated:The text was updated successfully, but these errors were encountered: