From 092aa5fcbd36d0f05e5645a92c5996f00f1457f3 Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Wed, 24 Jul 2024 22:08:02 +0200 Subject: [PATCH] Adding support for range-based for loops in the reverse mode. Fixes:#723 --- .../clad/Differentiator/ReverseModeVisitor.h | 1 + lib/Differentiator/ReverseModeVisitor.cpp | 125 ++++++++++++++++++ test/Gradient/Loops.C | 69 ++++++++++ 3 files changed, 195 insertions(+) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index ced679921..a161f1f58 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -381,6 +381,7 @@ namespace clad { virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL); + StmtDiff VisitCXXForRangeStmt(const clang::CXXForRangeStmt* FRS); StmtDiff VisitForStmt(const clang::ForStmt* FS); StmtDiff VisitIfStmt(const clang::IfStmt* If); StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index cd629eea1..38a7fbe39 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -976,6 +976,131 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(condExpr, ResultRef); } + StmtDiff + ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) { + beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | + Scope::ContinueScope); + beginBlock(direction::reverse); + + LoopCounter loopCounter(*this); + const VarDecl* LoopVD = FRS->getLoopVariable(); + + const Stmt* RangeDecl = FRS->getRangeStmt(); + const Stmt* BeginDecl = FRS->getBeginStmt(); + StmtDiff VisitRange = Visit(RangeDecl); + StmtDiff VisitBegin = Visit(BeginDecl); + Expr* BeginExpr = cast(VisitBegin.getStmt())->getLHS(); + + beginBlock(direction::reverse); + // Create all declarations needed. + auto* BeginDeclRef = cast(BeginExpr); + Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()]; + + auto* RangeExpr = + cast(cast(VisitRange.getStmt())->getLHS()); + + Expr* RangeInit = Clone(FRS->getRangeInit()); + Expr* AssignRange = + BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit)); + Expr* AssignBegin = + BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr)); + addToCurrentBlock(AssignRange); + addToCurrentBlock(AssignBegin); + const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); + + Expr* EndInit = cast(EndDecl->getInit())->getRHS(); + QualType EndType = CloneType(EndDecl->getType()); + std::string EndName = EndDecl->getNameAsString(); + Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit); + VarDecl* EndVarDecl = + BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false); + DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl); + + addToCurrentBlock(AssignEnd); + auto* AssignEndVarDecl = + cast(cast(AssignEnd)->getSingleDecl()); + DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl); + Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef); + + beginBlock(direction::forward); + DeclDiff LoopVDDiff = DifferentiateVarDecl(LoopVD); + Stmt* AdjLoopVDAddAssign = + utils::unwrapIfSingleStmt(endBlock(direction::forward)); + + if ((LoopVDDiff.getDecl()->getDeclName() != LoopVD->getDeclName() || + LoopVD->getType() != LoopVDDiff.getDecl()->getType())) + m_DeclReplacements[LoopVD] = LoopVDDiff.getDecl(); + llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop, + /*NewValue=*/true); + + Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginDeclRef); + Expr* d_DecBegin = BuildOp(UO_PostDec, d_BeginDeclRef); + Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr); + // Add item assignment statement to the body. + const Stmt* body = FRS->getBody(); + StmtDiff bodyDiff = Visit(body); + + StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl())); + StmtDiff storeAdjLoop = + StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx())); + + addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx())); + Expr* CounterIncrement = loopCounter.getCounterIncrement(); + + Expr* LoopInit = LoopVDDiff.getDecl()->getInit(); + LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType())); + addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl())); + Expr* AssignLoop = + BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), LoopInit); + + if (!LoopVD->getType()->isReferenceType()) { + Expr* d_LoopVD = BuildDeclRef(LoopVDDiff.getDecl_dx()); + AdjLoopVDAddAssign = + BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_BeginDeclRef)); + } + + beginBlock(direction::forward); + addToCurrentBlock(CounterIncrement); + addToCurrentBlock(AdjLoopVDAddAssign); + addToCurrentBlock(AssignLoop); + addToCurrentBlock(storeLoop.getStmt()); + addToCurrentBlock(storeAdjLoop.getStmt()); + CompoundStmt* LoopVDForwardDiff = endBlock(direction::forward); + CompoundStmt* bodyForward = utils::PrependAndCreateCompoundStmt( + m_Sema.getASTContext(), bodyDiff.getStmt(), LoopVDForwardDiff); + + beginBlock(direction::forward); + addToCurrentBlock(d_DecBegin); + addToCurrentBlock(storeLoop.getStmt_dx()); + addToCurrentBlock(storeAdjLoop.getStmt_dx()); + CompoundStmt* LoopVDReverseDiff = endBlock(direction::forward); + CompoundStmt* bodyReverse = utils::PrependAndCreateCompoundStmt( + m_Sema.getASTContext(), bodyDiff.getStmt_dx(), LoopVDReverseDiff); + + Expr* Inc = BuildOp(BO_Comma, IncBegin, d_IncBegin); + Stmt* Forward = new (m_Context) ForStmt( + m_Context, /*Init=*/nullptr, ForwardCond, /*CondVar=*/nullptr, Inc, + bodyForward, FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc()); + Expr* CounterCondition = + loopCounter.getCounterConditionResult().get().second; + Expr* CounterDecrement = loopCounter.getCounterDecrement(); + + Stmt* Reverse = bodyReverse; + addToCurrentBlock(Reverse, direction::reverse); + Reverse = endBlock(direction::reverse); + + Reverse = new (m_Context) + ForStmt(m_Context, /*Init=*/nullptr, CounterCondition, + /*CondVar=*/nullptr, CounterDecrement, Reverse, + FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc()); + addToCurrentBlock(Reverse, direction::reverse); + Reverse = endBlock(direction::reverse); + endScope(); + + return {utils::unwrapIfSingleStmt(Forward), + utils::unwrapIfSingleStmt(Reverse)}; + } + StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) { beginBlock(direction::reverse); LoopCounter loopCounter(*this); diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 7159d98e3..b6d78c8a5 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2685,6 +2685,74 @@ double fn33(double i, double j) { //CHECK-NEXT: } //CHECK-NEXT:} +double fn34(double x, double y){ + double r = 0; + double a[] = {y, x*y, x*x + y}; + for(auto& i: a){ + r+=i; + } + return r; +} + +//CHECK: void fn34_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: double _d_r = 0; +//CHECK-NEXT: double _d_a[3] = {0}; +//CHECK-NEXT: unsigned {{int|long}} _t0; +//CHECK-NEXT: double (*_d___range1)[3] = 0; +//CHECK-NEXT: double (*__range10)[3] = {}; +//CHECK-NEXT: double *_d___begin1 = 0; +//CHECK-NEXT: double *__begin10 = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: double r = 0; +//CHECK-NEXT: double a[3] = {y, x * y, x * x + y}; +//CHECK-NEXT: _t0 = 0UL; +//CHECK-NEXT: _d___range1 = &_d_a; +//CHECK-NEXT: _d___begin1 = *_d___range1; +//CHECK-NEXT: __range10 = &a; +//CHECK-NEXT: __begin10 = *__range10; +//CHECK-NEXT: double *__end10 = *__range10 + 3L; +//CHECK-NEXT: double *_d_i = 0; +//CHECK-NEXT: double *i = 0; +//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +//CHECK-NEXT: { +//CHECK-NEXT: _t0++; +//CHECK-NEXT: _d_i = &*_d___begin1; +//CHECK-NEXT: i = &*__begin10; +//CHECK-NEXT: clad::push(_t2, i); +//CHECK-NEXT: clad::push(_t3, _d_i); +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t1, r); +//CHECK-NEXT: r += *i; +//CHECK-NEXT: } +//CHECK-NEXT: _d_r += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin1--; +//CHECK-NEXT: i = clad::pop(_t2); +//CHECK-NEXT: _d_i = clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: r = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_r; +//CHECK-NEXT: *_d_i += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_y += _d_a[0]; +//CHECK-NEXT: *_d_x += _d_a[1] * y; +//CHECK-NEXT: *_d_y += x * _d_a[1]; +//CHECK-NEXT: *_d_x += _d_a[2] * x; +//CHECK-NEXT: *_d_x += x * _d_a[2]; +//CHECK-NEXT: *_d_y += _d_a[2]; +//CHECK-NEXT: } +//CHECK-NEXT:} + + + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -2769,6 +2837,7 @@ int main() { TEST_2(fn32, 3, 5); // CHECK-EXEC: {45.00, 27.00} TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00} + TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {