Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use move semantics to push objects on tapes #1249

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
/// Remove the last value from the tape, return it.
template <typename T>
CUDA_HOST_DEVICE T pop(tape<T>& to) {
T val = to.back();
T val = std::move(to.back());
to.pop_back();
return val;
}
Expand Down
6 changes: 4 additions & 2 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ namespace clad {
clang::Expr* GlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t",
bool force = false);
StmtDiff StoreAndRestore(clang::Expr* E, llvm::StringRef prefix = "_t");
StmtDiff StoreAndRestore(clang::Expr* E, llvm::StringRef prefix = "_t",
bool moveToTape = false);

//// A type returned by DelayedGlobalStoreAndRef
/// .Result is a reference to the created (yet uninitialized) global
Expand Down Expand Up @@ -314,7 +315,8 @@ namespace clad {
/// \returns A struct containg necessary call expressions for the built
/// tape
CladTapeResult MakeCladTapeFor(clang::Expr* E,
llvm::StringRef prefix = "_t");
llvm::StringRef prefix = "_t",
clang::QualType type = {});

/// A function to get the multi-argument "central_difference"
/// call expression for the given arguments.
Expand Down
21 changes: 16 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/SaveAndRestore.h"
#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Support/raw_ostream.h>

#include <algorithm>
Expand Down Expand Up @@ -91,11 +92,14 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
}

ReverseModeVisitor::CladTapeResult
ReverseModeVisitor::MakeCladTapeFor(Expr* E, llvm::StringRef prefix) {
ReverseModeVisitor::MakeCladTapeFor(Expr* E, llvm::StringRef prefix,
clang::QualType type) {
assert(E && "must be provided");
E = E->IgnoreImplicit();
if (type.isNull())
type = E->getType();
QualType TapeType =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: misleading indentation: statement is indented too deeply [readability-misleading-indentation]

    QualType TapeType =
    ^
Additional context

lib/Differentiator/ReverseModeVisitor.cpp:97: did you mean this line to be inside this 'if'

    if (type.isNull())
    ^

Copy link
Collaborator Author

@PetroZarytskyi PetroZarytskyi Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vgvassilev Is it broken or am I not seeing something?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is the indentation of the entire function.

GetCladTapeOfType(getNonConstType(E->getType(), m_Context, m_Sema));
GetCladTapeOfType(getNonConstType(type, m_Context, m_Sema));
LookupResult& Push = GetCladTapePush();
LookupResult& Pop = GetCladTapePop();
Expr* TapeRef =
Expand Down Expand Up @@ -2986,7 +2990,8 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
auto* declRef = BuildDeclRef(decl);
auto* assignment = BuildOp(BO_Assign, declRef, decl->getInit());
if (isInsideLoop) {
auto pushPop = StoreAndRestore(declRef);
auto pushPop = StoreAndRestore(declRef, /*prefix=*/"_t",
/*moveToTape=*/true);
if (pushPop.getExpr() != declRef)
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
assignment = BuildOp(BO_Comma, pushPop.getExpr(), assignment);
Expand Down Expand Up @@ -3291,12 +3296,18 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
}

StmtDiff ReverseModeVisitor::StoreAndRestore(clang::Expr* E,
llvm::StringRef prefix) {
llvm::StringRef prefix,
bool moveToTape) {
assert(E && "must be provided");
auto Type = getNonConstType(E->getType(), m_Context, m_Sema);

if (isInsideLoop) {
auto CladTape = MakeCladTapeFor(Clone(E), prefix);
Expr* clone = Clone(E);
if (moveToTape && E->getType()->isRecordType()) {
llvm::SmallVector<Expr*, 1> args = {clone};
clone = GetFunctionCall("move", "std", args);
}
auto CladTape = MakeCladTapeFor(clone, prefix, Type);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
auto* popAssign = BuildOp(BinaryOperatorKind::BO_Assign, Clone(E), Pop);
Expand Down
4 changes: 2 additions & 2 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ double func6(double seed) {
// CHECK-NEXT: break;
// CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, arr) , arr = {seed, seed * i, seed + i};
//CHECK-NEXT: clad::push(_t1, std::move(arr)) , arr = {seed, seed * i, seed + i};
//CHECK-NEXT: clad::push(_t2, sum);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
Expand Down Expand Up @@ -377,7 +377,7 @@ double func7(double *params) {
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, paramsPrime) , paramsPrime = {params[0]};
// CHECK-NEXT: clad::push(_t1, std::move(paramsPrime)) , paramsPrime = {params[0]};
// CHECK-NEXT: clad::push(_t2, out);
// CHECK-NEXT: out = out + inv_square(paramsPrime);
// CHECK-NEXT: }
Expand Down
4 changes: 2 additions & 2 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -1773,7 +1773,7 @@ double fn21(double x) {
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, arr) , arr = {1, x, 2};
// CHECK-NEXT: clad::push(_t1, std::move(arr)) , arr = {1, x, 2};
// CHECK-NEXT: clad::push(_t2, res);
// CHECK-NEXT: res += arr[0] + arr[1];
// CHECK-NEXT: }
Expand Down Expand Up @@ -1825,7 +1825,7 @@ double fn22(double param) {
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, arr) , arr = {1.};
// CHECK-NEXT: clad::push(_t1, std::move(arr)) , arr = {1.};
// CHECK-NEXT: clad::push(_t2, out);
// CHECK-NEXT: clad::push(_t3, arr[0]);
// CHECK-NEXT: out += clad::back(_t3) * param;
Expand Down
2 changes: 1 addition & 1 deletion test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ int main() {
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, ls) , ls = {u, v}, alloc;
// CHECK-NEXT: clad::push(_t1, std::move(ls)) , ls = {u, v}, alloc;
// CHECK-NEXT: _d_ls = ls;
// CHECK-NEXT: clad::zero_init(_d_ls);
// CHECK-NEXT: clad::push(_t2, ls);
Expand Down
Loading