Skip to content

Commit

Permalink
Introduce SetDeclInit to avoid setting init and init style manually.
Browse files Browse the repository at this point in the history
Sometimes we initialize declarations using ``Decl::setInit``, which is a low-level function used by clang internally. This commit introduces a wrapper of ``Sema::AddInitializerToDecl`` called ``SetDeclInit``.
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Feb 4, 2025
1 parent 1a0a9bc commit c373885
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 40 deletions.
9 changes: 8 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ namespace clad {
/// \returns Expression with correct Unary Operator placement.
clang::Expr* ResolveUnaryMinus(clang::Expr* E, clang::SourceLocation OpLoc);
clang::Expr* BuildParens(clang::Expr* E);
/// Sets Init as the initializer of the declaration VD and compute its
/// initialization kind.
///\param[in] VD - variable declaration
///\param[in] Init - can a nullptr, then only initialization kind is computed.
///\param[in] DirectInit - tells whether the initialization is direct.
void SetDeclInit(clang::VarDecl* VD, clang::Expr* Init = nullptr,
bool DirectInit = false);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand Down Expand Up @@ -477,7 +484,7 @@ namespace clad {
clang::Sema::AA_Casting);
assert(!ICAR.isInvalid() && "Invalid implicit conversion!");
// Assign the resulting expression to the variable declaration
VD->setInit(ICAR.get());
SetDeclInit(VD, ICAR.get());
}

/// Build a call to member function through Base expr and using the function
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
if (condVarResult.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx()));
auto condInit = condVarClone->getInit();
condVarClone->setInit(nullptr);
SetDeclInit(condVarClone);
cond = BuildOp(BO_Assign, BuildDeclRef(condVarClone), condInit);
addToCurrentBlock(BuildDeclStmt(condVarClone));
}
Expand Down Expand Up @@ -1696,7 +1696,7 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {
if (condVarRes.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarRes.getDecl_dx()));
auto* condInit = condVarClone->getInit();
condVarClone->setInit(nullptr);
SetDeclInit(condVarClone);
cond = BuildOp(BO_Assign, BuildDeclRef(condVarClone), condInit);
addToCurrentBlock(BuildDeclStmt(condVarClone));
}
Expand Down
3 changes: 1 addition & 2 deletions lib/Differentiator/JacobianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ DerivativeAndOverload JacobianModeVisitor::DeriveJacobian() {
addToCurrentBlock(paramAssignment);
} else {
auto* paramDecl = cast<VarDecl>(cast<DeclRefExpr>(paramDiff)->getDecl());
m_Sema.AddInitializerToDecl(paramDecl, dVectorParam, true);
paramDecl->setInitStyle(VarDecl::InitializationStyle::CInit);
SetDeclInit(paramDecl, dVectorParam);
}
}

Expand Down
43 changes: 17 additions & 26 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,16 +494,15 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
// Prepare the statements that assign the gradients to
// non array/pointer type parameters of the original function
if (!enzymeRealParams.empty()) {
auto* gradDeclStmt =
BuildVarDecl(QT, "grad", enzymeCall, /*DirectInit=*/true);
addToCurrentBlock(BuildDeclStmt(gradDeclStmt), direction::forward);
VarDecl* gradVD = BuildVarDecl(QT, "grad", enzymeCall);
addToCurrentBlock(BuildDeclStmt(gradVD), direction::forward);

Check warning on line 498 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L497-L498

Added lines #L497 - L498 were not covered by tests

for (unsigned i = 0; i < enzymeRealParams.size(); i++) {
auto* LHSExpr =
BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsDerived[i]));

auto* ME = utils::BuildMemberExpr(m_Sema, getCurrentScope(),
BuildDeclRef(gradDeclStmt), "d_arr");
BuildDeclRef(gradVD), "d_arr");

Check warning on line 505 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L505

Added line #L505 was not covered by tests

Expr* gradIndex = dyn_cast<Expr>(
IntegerLiteral::Create(m_Context, llvm::APSInt(std::to_string(i)),
Expand Down Expand Up @@ -832,7 +831,8 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx()));

Expr* loopInit = LoopVDDiff.getDecl()->getInit();
LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType()));
SetDeclInit(LoopVDDiff.getDecl(),
getZeroInit(LoopVDDiff.getDecl()->getType()));
addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl()));
Expr* assignLoop =
BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), loopInit);
Expand Down Expand Up @@ -2810,8 +2810,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
copyExpr, range, range)
.get();
}
m_Sema.AddInitializerToDecl(VDDerived, copyExpr, /*DirectInit=*/true);
VDDerived->setInitStyle(VarDecl::InitializationStyle::CallInit);
SetDeclInit(VDDerived, copyExpr, /*DirectInit=*/true);
}

if (isPointerType && derivedVDE) {
Expand All @@ -2830,8 +2829,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE);
}
} else {
m_Sema.AddInitializerToDecl(VDDerived, initDiff.getExpr_dx(), true);
VDDerived->setInitStyle(VarDecl::InitializationStyle::CInit);
SetDeclInit(VDDerived, initDiff.getExpr_dx());
}
}

Expand Down Expand Up @@ -2998,15 +2996,11 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
assignment = BuildOp(BO_Comma, pushPop.getExpr(), assignment);
}
inits.push_back(assignment);
if (const auto* AT = dyn_cast<ArrayType>(VD->getType())) {
m_Sema.AddInitializerToDecl(
decl, Clone(getArraySizeExpr(AT, m_Context, *this)), true);
decl->setInitStyle(VarDecl::InitializationStyle::CallInit);
} else {
m_Sema.AddInitializerToDecl(decl, getZeroInit(VD->getType()),
/*DirectInit=*/true);
decl->setInitStyle(VarDecl::InitializationStyle::CInit);
}
if (const auto* AT = dyn_cast<ArrayType>(VD->getType()))
SetDeclInit(decl, Clone(getArraySizeExpr(AT, m_Context, *this)),
/*DirectInit=*/true);
else
SetDeclInit(decl, getZeroInit(VD->getType()));
}
}

Expand Down Expand Up @@ -3109,8 +3103,8 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
auto* declRef = BuildDeclRef(vDecl);
auto* assignment = BuildOp(BO_Assign, declRef, init);
addToCurrentBlock(assignment, direction::forward);
m_Sema.AddInitializerToDecl(vDecl, getZeroInit(vDecl->getType()),
/*DirectInit=*/true);
SetDeclInit(vDecl, getZeroInit(vDecl->getType()),
/*DirectInit=*/true);
}
// Adjoints are initialized with copy-constructors only as a part of
// the strategy to maintain the structure of the original variable.
Expand Down Expand Up @@ -3283,8 +3277,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
addToCurrentBlock(decl, direction::forward);
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
SetDeclInit(VD, E);
} else {
addToBlock(decl, m_Globals);
Expr* Set = BuildOp(BO_Assign, Ref, E);
Expand Down Expand Up @@ -3326,8 +3319,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
Store = decl;
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
SetDeclInit(VD, E);
} else {
addToBlock(decl, m_Globals);
Store = BuildOp(BO_Assign, Ref, Clone(E));
Expand Down Expand Up @@ -3401,8 +3393,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
unsigned lastArg = Push->getNumArgs() - 1;
Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New).get());
} else if (isFnScope) {
V.m_Sema.AddInitializerToDecl(Declaration, New, true);
Declaration->setInitStyle(VarDecl::InitializationStyle::CInit);
V.SetDeclInit(Declaration, New);
V.addToCurrentBlock(V.BuildDeclStmt(Declaration), direction::forward);
} else {
V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New),
Expand Down
32 changes: 23 additions & 9 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,28 @@ namespace clad {
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
delete oldScope;
}

void VisitorBase::SetDeclInit(VarDecl* VD, Expr* Init, bool DirectInit) {
if (!Init) {
// Clang sets inits only once. Therefore, ActOnUninitializedDecl does
// not reset the init and we have to do it manually.
VD->setInit(nullptr);
m_Sema.ActOnUninitializedDecl(VD);
return;
}

// Clang sets inits only once. Therefore, AddInitializerToDecl does
// not reset the declaration style to default and we have to do it manually.
VarDecl::InitializationStyle defaultStyle{};
VD->setInitStyle(defaultStyle);

// Clang expects direct inits to be wrapped either in InitListExpr or
// ParenListExpr.
if (DirectInit && !isa<InitListExpr>(Init) && !isa<ParenListExpr>(Init))
Init = m_Sema.ActOnParenListExpr(noLoc, noLoc, Init).get();
m_Sema.AddInitializerToDecl(VD, Init, DirectInit);
}

VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI) {
Expand All @@ -122,15 +144,7 @@ namespace clad {
m_Context, m_Sema.CurContext, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None);

if (Init) {
// Clang expects direct inits to be wrapped either in InitListExpr or
// ParenListExpr.
if (DirectInit && !isa<InitListExpr>(Init) && !isa<ParenListExpr>(Init))
Init = m_Sema.ActOnParenListExpr(noLoc, noLoc, Init).get();
m_Sema.AddInitializerToDecl(VD, Init, DirectInit);
} else {
m_Sema.ActOnUninitializedDecl(VD);
}
SetDeclInit(VD, Init, DirectInit);
m_Sema.FinalizeDeclaration(VD);
// Add the identifier to the scope and IdResolver
m_Sema.PushOnScopeChains(VD, Scope, /*AddToContext*/ false);
Expand Down

0 comments on commit c373885

Please sign in to comment.