Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove VisitWithExplicitNoDfDx and replace it with plain Visit. #1226

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,28 +128,6 @@ namespace clad {
return result;
}

/// This visit method explicitly sets `dfdx` to `nullptr` for this visit.
///
/// This method is helpful when we need derivative of some expression but we
/// do not want `_d_expression += dfdx` statments to be (automatically)
/// added.
///
/// FIXME: Think of a better way for handling this situation. Maybe we
/// should improve the overall dfdx design and approach. One other way of
/// designing `VisitWithExplicitNoDfDx` in a more general way is
/// to develop a function that takes an expression E and returns the
/// corresponding derivative without any side effects. The difference
/// between this function and the current `VisitWithExplicitNoDfDx` will be
/// 1) better intent through the function name 2) We will also get
/// derivatives of expressions other than `DeclRefExpr` and `MemberExpr`.
StmtDiff VisitWithExplicitNoDfDx(const clang::Stmt* stmt) {
m_Stack.push(nullptr);
auto result =
clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>::Visit(stmt);
m_Stack.pop();
return result;
}

/// Get the latest block of code (i.e. place for statements output).
Stmts& getCurrentBlock(direction d = direction::forward) {
if (d == direction::forward)
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2139,7 +2139,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
valueForRevPass = utils::BuildParenExpr(m_Sema, sum);
} else if (opCode == UnaryOperatorKind::UO_Real ||
opCode == UnaryOperatorKind::UO_Imag) {
diff = VisitWithExplicitNoDfDx(E);
diff = Visit(E);
ResultRef = BuildOp(opCode, diff.getExpr_dx());
/// Create and add `__real r += dfdx()` expression.
if (dfdx()) {
Expand Down Expand Up @@ -3202,7 +3202,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
}

StmtDiff ReverseModeVisitor::VisitMemberExpr(const MemberExpr* ME) {
auto baseDiff = VisitWithExplicitNoDfDx(ME->getBase());
auto baseDiff = Visit(ME->getBase());
auto* field = ME->getMemberDecl();
assert(!isa<CXXMethodDecl>(field) &&
"CXXMethodDecl nodes not supported yet!");
Expand Down
Loading