Skip to content

Commit

Permalink
Remove m_DerivativeForForwSweep from StmtDiff.
Browse files Browse the repository at this point in the history
Currently, ``m_DerivativeForForwSweep`` is used to store adjoint to the reference to a given expression. e.g.
1) for ``x``,  it stores ``_d_x``
2) for ``x[i]``, ``_d_x[i]``
etc.

We only use ``m_DerivativeForForwSweep`` to initialize reference-type variables. However, to initialize pointers, we use ``getExpr_dx()``. Pointers and references are equivalent and should be initialized the same way. When differentiating expressions, we initialize ``m_DerivativeForForwSweep`` either with the same value as ``Expr_dx`` or with ``nullptr``.
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Jan 23, 2025
1 parent d4d29d8 commit 715f6d2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 46 deletions.
13 changes: 1 addition & 12 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ namespace clad {
class StmtDiff {
private:
std::array<clang::Stmt*, 2> data;
clang::Stmt* m_DerivativeForForwSweep;
clang::Stmt* m_ValueForRevSweep;

public:
StmtDiff(clang::Stmt* orig = nullptr, clang::Stmt* diff = nullptr,
clang::Stmt* forwSweepDiff = nullptr,
clang::Stmt* valueForRevSweep = nullptr)
: m_DerivativeForForwSweep(forwSweepDiff),
m_ValueForRevSweep(valueForRevSweep) {
: m_ValueForRevSweep(valueForRevSweep) {
data[1] = orig;
data[0] = diff;
}
Expand All @@ -58,8 +55,6 @@ namespace clad {
// Stmt_dx goes first!
std::array<clang::Stmt*, 2>& getBothStmts() { return data; }

clang::Stmt* getForwSweepStmt_dx() { return m_DerivativeForForwSweep; }

clang::Expr* getRevSweepAsExpr() {
return llvm::cast_or_null<clang::Expr>(getRevSweepStmt());
}
Expand All @@ -71,12 +66,6 @@ namespace clad {
return data[1];
return m_ValueForRevSweep;
}

clang::Expr* getForwSweepExpr_dx() {
return llvm::cast_or_null<clang::Expr>(m_DerivativeForForwSweep);
}

void setForwSweepStmt_dx(clang::Stmt* S) { m_DerivativeForForwSweep = S; }
};

template <typename T> class DeclDiff {
Expand Down
60 changes: 26 additions & 34 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,6 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
endScope();
return StmtDiff(utils::unwrapIfSingleStmt(ForwardBlock),
utils::unwrapIfSingleStmt(ReverseBlock),
/*forwSweepDiff=*/nullptr,
/*valueForRevSweep=*/condDiffStored);
}

Expand Down Expand Up @@ -1151,7 +1150,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
StmtDiff ReverseModeVisitor::VisitParenExpr(const ParenExpr* PE) {
StmtDiff subStmtDiff = Visit(PE->getSubExpr(), dfdx());
return StmtDiff(BuildParens(subStmtDiff.getExpr()),
BuildParens(subStmtDiff.getExpr_dx()), nullptr,
BuildParens(subStmtDiff.getExpr_dx()),
BuildParens(subStmtDiff.getRevSweepAsExpr()));
}

Expand Down Expand Up @@ -1229,13 +1228,11 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
StmtDiff BaseDiff = Visit(Base);
llvm::SmallVector<Expr*, 4> clonedIndices(Indices.size());
llvm::SmallVector<Expr*, 4> reverseIndices(Indices.size());
llvm::SmallVector<Expr*, 4> forwSweepDerivativeIndices(Indices.size());
for (std::size_t i = 0; i < Indices.size(); i++) {
// FIXME: Remove redundant indices vectors.
StmtDiff IdxDiff = Visit(Indices[i]);
clonedIndices[i] = Clone(IdxDiff.getExpr());
reverseIndices[i] = Clone(IdxDiff.getExpr());
forwSweepDerivativeIndices[i] = IdxDiff.getExpr();
reverseIndices[i] = IdxDiff.getExpr();
}
auto* cloned = BuildArraySubscript(BaseDiff.getExpr(), clonedIndices);
auto* valueForRevSweep =
Expand All @@ -1244,11 +1241,8 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
if (!target)
return cloned;
Expr* result = nullptr;
Expr* forwSweepDerivative = nullptr;
// Create the target[idx] expression.
result = BuildArraySubscript(target, reverseIndices);
forwSweepDerivative =
BuildArraySubscript(target, forwSweepDerivativeIndices);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (shouldUseCudaAtomicOps(target)) {
Expand All @@ -1263,7 +1257,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
}
if (m_ExternalSource)
m_ExternalSource->ActAfterProcessingArraySubscriptExpr(valueForRevSweep);
return StmtDiff(cloned, result, forwSweepDerivative, valueForRevSweep);
return StmtDiff(cloned, result, valueForRevSweep);
}

StmtDiff ReverseModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
Expand Down Expand Up @@ -1312,7 +1306,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
}
}
}
return StmtDiff(clonedDRE, it->second, it->second);
return StmtDiff(clonedDRE, it->second);
}

return StmtDiff(clonedDRE);
Expand Down Expand Up @@ -1954,7 +1948,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
auto* resAdjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, resAdjoint, resAdjoint);
return StmtDiff(resValue, resAdjoint);
}
if (utils::isNonConstReferenceType(returnType) ||
returnType->isPointerType()) {
Expand Down Expand Up @@ -1987,7 +1981,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
auto* resAdjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, resAdjoint, resAdjoint);
return StmtDiff(resValue, resAdjoint);
} // Recreate the original call expression.

if (isMethodOperatorCall) {
Expand Down Expand Up @@ -2177,7 +2171,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
}
}
}
return {cloneE, derivedE, derivedE};
return {cloneE, derivedE};
} else {
if (opCode != UO_LNot)
// We should only output warnings on visiting boolean conditions
Expand All @@ -2190,7 +2184,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
ResultRef = diff.getExpr_dx();
}
Expr* op = BuildOp(opCode, diff.getExpr());
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
return StmtDiff(op, ResultRef, valueForRevPass);
}

StmtDiff
Expand Down Expand Up @@ -2515,7 +2509,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR);
if (opCode == BO_Sub)
derivedR = BuildParens(derivedR);
return StmtDiff(op, BuildOp(opCode, derivedL, derivedR), nullptr,
return StmtDiff(op, BuildOp(opCode, derivedL, derivedR),
valueForRevPass);
}
if (opCode == BO_Assign || opCode == BO_AddAssign ||
Expand All @@ -2531,7 +2525,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
addToCurrentBlock(memsetCall, direction::forward);
}
}
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
return StmtDiff(op, ResultRef, valueForRevPass);
}

QualType ReverseModeVisitor::CloneType(QualType T) {
Expand Down Expand Up @@ -2638,14 +2632,14 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {

if (isRefType) {
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
if (!initDiff.getStmt_dx()) {
VDDerivedType = ComputeAdjointType(VDType.getNonReferenceType());
isRefType = false;
}
if (promoteToFnScope || !isRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
VDDerivedInit = initDiff.getExpr_dx();
}

if (isConstructInit) {
Expand All @@ -2654,8 +2648,8 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
m_TrackConstructorPullbackInfo = false;
constructorPullbackInfo = getConstructorPullbackCallInfo();
resetConstructorPullbackCallInfo();
if (initDiff.getForwSweepExpr_dx()) {
VDDerivedInit = initDiff.getForwSweepExpr_dx();
if (initDiff.getExpr_dx()) {
VDDerivedInit = initDiff.getExpr_dx();
emptyInitListInit = false;
}
// ListInit style combined with `_t0.value`/`_t0.adjoint` inits will be
Expand Down Expand Up @@ -2764,8 +2758,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
if (isRefType && promoteToFnScope) {
Expr* assignDerivativeE =
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getForwSweepExpr_dx()));
BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr_dx()));
addToCurrentBlock(assignDerivativeE);
if (isInsideLoop) {
StmtDiff pushPop = StoreAndRestore(derivedVDE);
Expand Down Expand Up @@ -3210,7 +3203,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
BuildOp(BinaryOperatorKind::BO_AddAssign, derivedME, dfdx());
addToCurrentBlock(addAssign, direction::reverse);
}
return {clonedME, derivedME, derivedME};
return {clonedME, derivedME};
}

StmtDiff
Expand Down Expand Up @@ -3446,22 +3439,21 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
// easier debugging.
Expr* PH = ConstantFolder::synthesizeLiteral(E->getType(), m_Context,
/*val=*/~0U);
return DelayedStoreResult{
*this,
StmtDiff{PH, /*diff=*/nullptr, /*forwSweepDiff=*/nullptr, PH},
/*Declaration=*/nullptr,
/*isInsideLoop=*/false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true,
/*pPlaceholder=*/PH};
return DelayedStoreResult{*this,
StmtDiff{PH, /*diff=*/nullptr, PH},
/*Declaration=*/nullptr,
/*isInsideLoop=*/false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true,
/*pPlaceholder=*/PH};
}
if (isInsideLoop) {
Expr* dummy = E;
auto CladTape = MakeCladTapeFor(dummy);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return DelayedStoreResult{*this,
StmtDiff{Push, nullptr, nullptr, Pop},
StmtDiff{Push, nullptr, Pop},
/*Declaration=*/nullptr,
/*isInsideLoop=*/true,
/*isFnScope=*/false,
Expand All @@ -3476,7 +3468,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
addToBlock(BuildDeclStmt(VD), m_Globals);
// Return reference to the declaration instead of original expression.
return DelayedStoreResult{*this,
StmtDiff{Ref, nullptr, nullptr, Ref},
StmtDiff{Ref, nullptr, Ref},
/*Declaration=*/VD,
/*isInsideLoop=*/false,
/*isFnScope=*/isFnScope,
Expand Down Expand Up @@ -4226,7 +4218,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
Expr* adjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return {val, nullptr, adjoint};
return {val, adjoint};
}

Expr* clonedArgsE = nullptr;
Expand Down

0 comments on commit 715f6d2

Please sign in to comment.