Skip to content

Commit

Permalink
Introduce SetDeclInit to avoid setting init and init style manually.
Browse files Browse the repository at this point in the history
Sometimes we initialize declarations using ``Decl::setInit``, which is a low-level function used by clang internally. This commit introduces a wrapper of ``Sema::AddInitializerToDecl`` called ``SetDeclInit``.
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Feb 7, 2025
1 parent d84cf58 commit e6d200d
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 40 deletions.
11 changes: 10 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@ namespace clad {
/// \returns Expression with correct Unary Operator placement.
clang::Expr* ResolveUnaryMinus(clang::Expr* E, clang::SourceLocation OpLoc);
clang::Expr* BuildParens(clang::Expr* E);
/// Sets Init as the initializer of the declaration VD and compute its
/// initialization kind.
///\param[in] VD - variable declaration
///\param[in] Init - can be nullptr, then only initialization kind is
/// computed.
///\param[in] DirectInit - tells whether the initialization is
/// direct.
void SetDeclInit(clang::VarDecl* VD, clang::Expr* Init = nullptr,
bool DirectInit = false);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand Down Expand Up @@ -478,7 +487,7 @@ namespace clad {
clang::Sema::AA_Casting);
assert(!ICAR.isInvalid() && "Invalid implicit conversion!");
// Assign the resulting expression to the variable declaration
VD->setInit(ICAR.get());
SetDeclInit(VD, ICAR.get());
}

/// Build a call to member function through Base expr and using the function
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
if (condVarResult.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx()));
auto condInit = condVarClone->getInit();
condVarClone->setInit(nullptr);
SetDeclInit(condVarClone);
cond = BuildOp(BO_Assign, BuildDeclRef(condVarClone), condInit);
addToCurrentBlock(BuildDeclStmt(condVarClone));
}
Expand Down Expand Up @@ -1696,7 +1696,7 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {
if (condVarRes.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarRes.getDecl_dx()));
auto* condInit = condVarClone->getInit();
condVarClone->setInit(nullptr);
SetDeclInit(condVarClone);
cond = BuildOp(BO_Assign, BuildDeclRef(condVarClone), condInit);
addToCurrentBlock(BuildDeclStmt(condVarClone));
}
Expand Down
3 changes: 1 addition & 2 deletions lib/Differentiator/JacobianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ DerivativeAndOverload JacobianModeVisitor::DeriveJacobian() {
addToCurrentBlock(paramAssignment);
} else {
auto* paramDecl = cast<VarDecl>(cast<DeclRefExpr>(paramDiff)->getDecl());
m_Sema.AddInitializerToDecl(paramDecl, dVectorParam, true);
paramDecl->setInitStyle(VarDecl::InitializationStyle::CInit);
SetDeclInit(paramDecl, dVectorParam);
}
}

Expand Down
43 changes: 17 additions & 26 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,16 +494,15 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
// Prepare the statements that assign the gradients to
// non array/pointer type parameters of the original function
if (!enzymeRealParams.empty()) {
auto* gradDeclStmt =
BuildVarDecl(QT, "grad", enzymeCall, /*DirectInit=*/true);
addToCurrentBlock(BuildDeclStmt(gradDeclStmt), direction::forward);
VarDecl* gradVD = BuildVarDecl(QT, "grad", enzymeCall);
addToCurrentBlock(BuildDeclStmt(gradVD), direction::forward);

for (unsigned i = 0; i < enzymeRealParams.size(); i++) {
auto* LHSExpr =
BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsDerived[i]));

auto* ME = utils::BuildMemberExpr(m_Sema, getCurrentScope(),
BuildDeclRef(gradDeclStmt), "d_arr");
BuildDeclRef(gradVD), "d_arr");

Expr* gradIndex = dyn_cast<Expr>(
IntegerLiteral::Create(m_Context, llvm::APSInt(std::to_string(i)),
Expand Down Expand Up @@ -832,7 +831,8 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx()));

Expr* loopInit = LoopVDDiff.getDecl()->getInit();
LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType()));
SetDeclInit(LoopVDDiff.getDecl(),
getZeroInit(LoopVDDiff.getDecl()->getType()));
addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl()));
Expr* assignLoop =
BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), loopInit);
Expand Down Expand Up @@ -2810,8 +2810,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
copyExpr, range, range)
.get();
}
m_Sema.AddInitializerToDecl(VDDerived, copyExpr, /*DirectInit=*/true);
VDDerived->setInitStyle(VarDecl::InitializationStyle::CallInit);
SetDeclInit(VDDerived, copyExpr, /*DirectInit=*/true);
}

if (isPointerType && derivedVDE) {
Expand All @@ -2830,8 +2829,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE);
}
} else {
m_Sema.AddInitializerToDecl(VDDerived, initDiff.getExpr_dx(), true);
VDDerived->setInitStyle(VarDecl::InitializationStyle::CInit);
SetDeclInit(VDDerived, initDiff.getExpr_dx());
}
}

Expand Down Expand Up @@ -2998,15 +2996,11 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
assignment = BuildOp(BO_Comma, pushPop.getExpr(), assignment);
}
inits.push_back(assignment);
if (const auto* AT = dyn_cast<ArrayType>(VD->getType())) {
m_Sema.AddInitializerToDecl(
decl, Clone(getArraySizeExpr(AT, m_Context, *this)), true);
decl->setInitStyle(VarDecl::InitializationStyle::CallInit);
} else {
m_Sema.AddInitializerToDecl(decl, getZeroInit(VD->getType()),
/*DirectInit=*/true);
decl->setInitStyle(VarDecl::InitializationStyle::CInit);
}
if (const auto* AT = dyn_cast<ArrayType>(VD->getType()))
SetDeclInit(decl, Clone(getArraySizeExpr(AT, m_Context, *this)),
/*DirectInit=*/true);
else
SetDeclInit(decl, getZeroInit(VD->getType()));
}
}

Expand Down Expand Up @@ -3109,8 +3103,8 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
auto* declRef = BuildDeclRef(vDecl);
auto* assignment = BuildOp(BO_Assign, declRef, init);
addToCurrentBlock(assignment, direction::forward);
m_Sema.AddInitializerToDecl(vDecl, getZeroInit(vDecl->getType()),
/*DirectInit=*/true);
SetDeclInit(vDecl, getZeroInit(vDecl->getType()),
/*DirectInit=*/true);
}
// Adjoints are initialized with copy-constructors only as a part of
// the strategy to maintain the structure of the original variable.
Expand Down Expand Up @@ -3283,8 +3277,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
addToCurrentBlock(decl, direction::forward);
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
SetDeclInit(VD, E);
} else {
addToBlock(decl, m_Globals);
Expr* Set = BuildOp(BO_Assign, Ref, E);
Expand Down Expand Up @@ -3326,8 +3319,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
Store = decl;
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
SetDeclInit(VD, E);
} else {
addToBlock(decl, m_Globals);
Store = BuildOp(BO_Assign, Ref, Clone(E));
Expand Down Expand Up @@ -3401,8 +3393,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
unsigned lastArg = Push->getNumArgs() - 1;
Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New).get());
} else if (isFnScope) {
V.m_Sema.AddInitializerToDecl(Declaration, New, true);
Declaration->setInitStyle(VarDecl::InitializationStyle::CInit);
V.SetDeclInit(Declaration, New);
V.addToCurrentBlock(V.BuildDeclStmt(Declaration), direction::forward);
} else {
V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New),
Expand Down
32 changes: 23 additions & 9 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,28 @@ namespace clad {
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
delete oldScope;
}

void VisitorBase::SetDeclInit(VarDecl* VD, Expr* Init, bool DirectInit) {
if (!Init) {
// Clang sets inits only once. Therefore, ActOnUninitializedDecl does
// not reset the init and we have to do it manually.
VD->setInit(nullptr);
m_Sema.ActOnUninitializedDecl(VD);
return;
}

// Clang sets inits only once. Therefore, AddInitializerToDecl does
// not reset the declaration style to default and we have to do it manually.
VarDecl::InitializationStyle defaultStyle{};
VD->setInitStyle(defaultStyle);

// Clang expects direct inits to be wrapped either in InitListExpr or
// ParenListExpr.
if (DirectInit && !isa<InitListExpr>(Init) && !isa<ParenListExpr>(Init))
Init = m_Sema.ActOnParenListExpr(noLoc, noLoc, Init).get();
m_Sema.AddInitializerToDecl(VD, Init, DirectInit);
}

VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI) {
Expand All @@ -122,15 +144,7 @@ namespace clad {
m_Context, m_Sema.CurContext, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None);

if (Init) {
// Clang expects direct inits to be wrapped either in InitListExpr or
// ParenListExpr.
if (DirectInit && !isa<InitListExpr>(Init) && !isa<ParenListExpr>(Init))
Init = m_Sema.ActOnParenListExpr(noLoc, noLoc, Init).get();
m_Sema.AddInitializerToDecl(VD, Init, DirectInit);
} else {
m_Sema.ActOnUninitializedDecl(VD);
}
SetDeclInit(VD, Init, DirectInit);
m_Sema.FinalizeDeclaration(VD);
// Add the identifier to the scope and IdResolver
m_Sema.PushOnScopeChains(VD, Scope, /*AddToContext*/ false);
Expand Down

0 comments on commit e6d200d

Please sign in to comment.