diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index b98bf4a2b..c689b95b3 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -128,28 +128,6 @@ namespace clad { return result; } - /// This visit method explicitly sets `dfdx` to `nullptr` for this visit. - /// - /// This method is helpful when we need derivative of some expression but we - /// do not want `_d_expression += dfdx` statments to be (automatically) - /// added. - /// - /// FIXME: Think of a better way for handling this situation. Maybe we - /// should improve the overall dfdx design and approach. One other way of - /// designing `VisitWithExplicitNoDfDx` in a more general way is - /// to develop a function that takes an expression E and returns the - /// corresponding derivative without any side effects. The difference - /// between this function and the current `VisitWithExplicitNoDfDx` will be - /// 1) better intent through the function name 2) We will also get - /// derivatives of expressions other than `DeclRefExpr` and `MemberExpr`. - StmtDiff VisitWithExplicitNoDfDx(const clang::Stmt* stmt) { - m_Stack.push(nullptr); - auto result = - clang::ConstStmtVisitor::Visit(stmt); - m_Stack.pop(); - return result; - } - /// Get the latest block of code (i.e. place for statements output). Stmts& getCurrentBlock(direction d = direction::forward) { if (d == direction::forward) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f25509278..6d4a78eec 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2139,7 +2139,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) { valueForRevPass = utils::BuildParenExpr(m_Sema, sum); } else if (opCode == UnaryOperatorKind::UO_Real || opCode == UnaryOperatorKind::UO_Imag) { - diff = VisitWithExplicitNoDfDx(E); + diff = Visit(E); ResultRef = BuildOp(opCode, diff.getExpr_dx()); /// Create and add `__real r += dfdx()` expression. if (dfdx()) { @@ -3202,7 +3202,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) { } StmtDiff ReverseModeVisitor::VisitMemberExpr(const MemberExpr* ME) { - auto baseDiff = VisitWithExplicitNoDfDx(ME->getBase()); + auto baseDiff = Visit(ME->getBase()); auto* field = ME->getMemberDecl(); assert(!isa(field) && "CXXMethodDecl nodes not supported yet!");