Skip to content

Commit

Permalink
Support differentiating functions that return CXXConstructExpr.
Browse files Browse the repository at this point in the history
The goal of the changes is to support ``ReturnStmt`` with ``CXXConstructExpr``, e.g.
```
SimpleFunctions1 operator*(const SimpleFunctions1& rhs) {
  return SimpleFunctions1(rhs.x, this->y);
}
```
Currently, we have 2 problems in the way:
1) When forming the ``constructor_pullback`` call, we need to initialize a parameter of type ``SimpleFunctions1`` with a constructor. This is possible either with ``SimpleFunctions1(x, y)`` or with an ``InitListExpr`` ``{x, y}``. However, instead of doing the former, we generate a ``ParenListExpr`` ``(x, y)``. Clang can build the constructor call itself when dealing with decl inits but not here.
2) The pullback argument is not propagated to the constructor pullback; therefore, we get a zero derivative.

This PR solves 2) completely and fixes 1) for usage in ``ReturnStmt`` and as a parameter.
Originally, the PR was opened to support all tests from #984. Therefore, all of them are added.
  • Loading branch information
PetroZarytskyi committed Jan 21, 2025
1 parent e9827a2 commit e30d747
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 2 deletions.
12 changes: 10 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4151,6 +4151,10 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
// responsible for updating these args.
Expr* thisE = getZeroInit(recordPointerType);
Expr* dThisE = getZeroInit(recordPointerType);
if (!m_TrackConstructorPullbackInfo && dfdx() &&
m_DiffReq.Mode == DiffMode::experimental_pullback)
dThisE = BuildOp(UnaryOperatorKind::UO_AddrOf, dfdx(),
m_DiffReq->getLocation());

pullbackArgs.push_back(thisE);
pullbackArgs.append(primalArgs.begin(), primalArgs.end());
Expand All @@ -4170,7 +4174,6 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
if (m_TrackConstructorPullbackInfo) {
setConstructorPullbackCallInfo(llvm::cast<CallExpr>(customPullbackCall),
primalArgs.size() + 1);
m_TrackConstructorPullbackInfo = false;
}
}
// FIXME: If no compatible custom constructor pullback is found then try
Expand Down Expand Up @@ -4227,7 +4230,12 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
Expr* clonedArgsE = nullptr;

if (CE->getNumArgs() != 1) {
if (CE->isListInitialization()) {
// FIXME: We generate a InitListExpr when the constructor is called
// outside of a VarDecl init. This works out when it is later used in a
// ReturnStmt. However, to support member exprs/calls of constructors, we
// need to explicitly generate a constructor and not rely on higher level
// Sema functions.
if (CE->isListInitialization() || !m_TrackConstructorPullbackInfo) {
clonedArgsE = m_Sema.ActOnInitList(noLoc, primalArgs, noLoc).get();
} else {
if (CE->getNumArgs() == 0) {
Expand Down
138 changes: 138 additions & 0 deletions test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -525,12 +525,29 @@ public:
double x;
double y;
double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; }
double mem_fn(double i, double j) { return (x + y) * i + i * j; }
SimpleFunctions1 operator+(const SimpleFunctions1& other) const {
SimpleFunctions1 res(x + other.x, y + other.y);
return res;
}
SimpleFunctions1 operator*(const SimpleFunctions1& rhs) {
return {this->x * rhs.x, this->y * rhs.y};
}
};

namespace clad {
namespace custom_derivatives {
namespace class_functions {
void constructor_pullback(SimpleFunctions1* f, double x, double y, SimpleFunctions1* d_f, double* d_x, double* d_y) {
*d_x += d_f->x;
*d_y += d_f->y;
}
void constructor_pullback(SimpleFunctions1* f, const SimpleFunctions1& other, SimpleFunctions1* d_f, SimpleFunctions1* d_other) {
d_other->x += d_f->x;
d_other->y += d_f->y;
}
}}}

// CHECK: void operator_plus_pullback(const SimpleFunctions1 &other, SimpleFunctions1 _d_y, SimpleFunctions1 *_d_this, SimpleFunctions1 *_d_other) const;

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j);
Expand Down Expand Up @@ -562,13 +579,101 @@ double fn16(double i, double j) {
// CHECK-NEXT: {
// CHECK-NEXT: double _r2 = 0.;
// CHECK-NEXT: double _r3 = 0.;
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&obj2, 3, 5, &_d_obj2, &_r2, &_r3);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&obj1, 2, 3, &_d_obj1, &_r0, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT:}

// CHECK: void mem_fn_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j);

double fn17(double i, double j) {
SimpleFunctions1 sf(3, 5);
return sf.mem_fn(i, j);
}

// CHECK: void fn17_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 sf(3, 5);
// CHECK-NEXT: SimpleFunctions1 _d_sf(sf);
// CHECK-NEXT: clad::zero_init(_d_sf);
// CHECK-NEXT: SimpleFunctions1 _t0 = sf;
// CHECK-NEXT: {
// CHECK-NEXT: double _r2 = 0.;
// CHECK-NEXT: double _r3 = 0.;
// CHECK-NEXT: _t0.mem_fn_pullback(i, j, 1, &_d_sf, &_r2, &_r3);
// CHECK-NEXT: *_d_i += _r2;
// CHECK-NEXT: *_d_j += _r3;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&sf, 3, 5, &_d_sf, &_r0, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT:}

double fn18(double i, double j) {
SimpleFunctions1 sf(3 * i, 5 * j);
return sf.mem_fn(i, j);
}

// CHECK: void fn18_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 sf(3 * i, 5 * j);
// CHECK-NEXT: SimpleFunctions1 _d_sf(sf);
// CHECK-NEXT: clad::zero_init(_d_sf);
// CHECK-NEXT: SimpleFunctions1 _t0 = sf;
// CHECK-NEXT: {
// CHECK-NEXT: double _r2 = 0.;
// CHECK-NEXT: double _r3 = 0.;
// CHECK-NEXT: _t0.mem_fn_pullback(i, j, 1, &_d_sf, &_r2, &_r3);
// CHECK-NEXT: *_d_i += _r2;
// CHECK-NEXT: *_d_j += _r3;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&sf, 3 * i, 5 * j, &_d_sf, &_r0, &_r1);
// CHECK-NEXT: *_d_i += 3 * _r0;
// CHECK-NEXT: *_d_j += 5 * _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void operator_star_pullback(const SimpleFunctions1 &rhs, SimpleFunctions1 _d_y, SimpleFunctions1 *_d_this, SimpleFunctions1 *_d_rhs);

double fn19(double i, double j) {
SimpleFunctions1 sf1(3, 5);
SimpleFunctions1 sf2(i, j);
return (sf1 * sf2).mem_fn(i, j);
}

// CHECK: void fn19_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 sf1(3, 5);
// CHECK-NEXT: SimpleFunctions1 _d_sf1(sf1);
// CHECK-NEXT: clad::zero_init(_d_sf1);
// CHECK-NEXT: SimpleFunctions1 sf2(i, j);
// CHECK-NEXT: SimpleFunctions1 _d_sf2(sf2);
// CHECK-NEXT: clad::zero_init(_d_sf2);
// CHECK-NEXT: SimpleFunctions1 _t0 = sf1;
// CHECK-NEXT: SimpleFunctions1 _t1 = sf1.operator*(sf2);
// CHECK-NEXT: {
// CHECK-NEXT: double _r2 = 0.;
// CHECK-NEXT: double _r3 = 0.;
// CHECK-NEXT: SimpleFunctions1 _r4 = {};
// CHECK-NEXT: _t1.mem_fn_pullback(i, j, 1, &_r4, &_r2, &_r3);
// CHECK-NEXT: *_d_i += _r2;
// CHECK-NEXT: *_d_j += _r3;
// CHECK-NEXT: _t0.operator_star_pullback(sf2, _r4, &_d_sf1, &_d_sf2);
// CHECK-NEXT: }
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&sf2, i, j, &_d_sf2, &*_d_i, &*_d_j);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&sf1, 3, 5, &_d_sf1, &_r0, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: }

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand Down Expand Up @@ -645,6 +750,15 @@ int main() {

INIT_GRADIENT(fn16);
TEST_GRADIENT(fn16, /*numOfDerivativeArgs=*/2, 2, 3, &d_i, &d_j); // CHECK-EXEC: {22.00, 12.00}

INIT_GRADIENT(fn17);
TEST_GRADIENT(fn17, /*numOfDerivativeArgs=*/2, 2, 3, &d_i, &d_j); // CHECK-EXEC: {11.00, 2.00}

INIT_GRADIENT(fn18);
TEST_GRADIENT(fn18, /*numOfDerivativeArgs=*/2, 2, 3, &d_i, &d_j); // CHECK-EXEC: {30.00, 12.00}

INIT_GRADIENT(fn19);
TEST_GRADIENT(fn19, /*numOfDerivativeArgs=*/2, 2, 3, &d_i, &d_j); // CHECK-EXEC: {30.00, 12.00}
}

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
Expand Down Expand Up @@ -799,9 +913,11 @@ int main() {
// CHECK-NEXT: SimpleFunctions1 res(this->x + other.x, this->y + other.y);
// CHECK-NEXT: SimpleFunctions1 _d_res(res);
// CHECK-NEXT: clad::zero_init(_d_res);
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(nullptr, res, &_d_y, &_d_res);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&res, this->x + other.x, this->y + other.y, &_d_res, &_r0, &_r1);
// CHECK-NEXT: (*_d_this).x += _r0;
// CHECK-NEXT: (*_d_other).x += _r0;
// CHECK-NEXT: (*_d_this).y += _r1;
Expand All @@ -818,4 +934,26 @@ int main() {
// CHECK-NEXT: *_d_j += i * _d_y * j;
// CHECK-NEXT: *_d_j += i * j * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT:}

// CHECK: void mem_fn_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: (*_d_this).x += _d_y * i;
// CHECK-NEXT: (*_d_this).y += _d_y * i;
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y;
// CHECK-NEXT: *_d_i += _d_y * j;
// CHECK-NEXT: *_d_j += i * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT:}

// CHECK: void operator_star_pullback(const SimpleFunctions1 &rhs, SimpleFunctions1 _d_y, SimpleFunctions1 *_d_this, SimpleFunctions1 *_d_rhs) {
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(nullptr, this->x * rhs.x, this->y * rhs.y, &_d_y, &_r0, &_r1);
// CHECK-NEXT: (*_d_this).x += _r0 * rhs.x;
// CHECK-NEXT: (*_d_rhs).x += this->x * _r0;
// CHECK-NEXT: (*_d_this).y += _r1 * rhs.y;
// CHECK-NEXT: (*_d_rhs).y += this->y * _r1;
// CHECK-NEXT: }
// CHECK-NEXT:}

0 comments on commit e30d747

Please sign in to comment.