Skip to content

Commit

Permalink
Redesign of the rangebased for loops body
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk authored and Max Andriychuk committed Aug 9, 2024
1 parent 1b81084 commit a27a6fb
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 41 deletions.
3 changes: 2 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ namespace clad {
StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS);
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD,
bool AddToBlock = true);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
Expand Down
100 changes: 60 additions & 40 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,75 +976,85 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff
ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) {
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
beginBlock(direction::reverse);

LoopCounter loopCounter(*this);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
auto* activeBreakContHandler = PushBreakContStmtHandler();
activeBreakContHandler->BeginCFSwitchStmtScope();
const VarDecl* LoopVD = FRS->getLoopVariable();

const Stmt* RangeDecl = FRS->getRangeStmt();
const Stmt* BeginDecl = FRS->getBeginStmt();
StmtDiff VisitRange = Visit(RangeDecl);
StmtDiff VisitBegin = Visit(BeginDecl);
Expr* BeginExpr = cast<BinaryOperator>(VisitBegin.getStmt())->getLHS();
llvm::SaveAndRestore<bool> SaveIsInside(isInsideLoop,
/*NewValue=*/false);

const auto* RangeDecl = cast<VarDecl>(FRS->getRangeStmt()->getSingleDecl());
const auto* BeginDecl = cast<VarDecl>(FRS->getBeginStmt()->getSingleDecl());

DeclDiff<VarDecl> VisitRange = DifferentiateVarDecl(RangeDecl, false);
DeclDiff<VarDecl> VisitBegin = DifferentiateVarDecl(BeginDecl, false);

beginBlock(direction::reverse);
// Create all declarations needed.
auto* BeginDeclRef = cast<DeclRefExpr>(BeginExpr);
DeclRefExpr* BeginDeclRef = BuildDeclRef(VisitBegin.getDecl());
Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()];

auto* RangeExpr =
cast<DeclRefExpr>(cast<BinaryOperator>(VisitRange.getStmt())->getLHS());
DeclRefExpr* RangeDeclRef = BuildDeclRef(VisitRange.getDecl());
Expr* d_RangeDeclRef = m_Variables[RangeDeclRef->getDecl()];

Expr* RangeInit = Clone(FRS->getRangeInit());
Expr* d_RangeInitDeclRef =
m_Variables[cast<DeclRefExpr>(RangeInit)->getDecl()];
VisitRange.getDecl_dx()->setInit(BuildOp(UO_AddrOf, d_RangeInitDeclRef));
Expr* AssignAdjBegin = BuildOp(BO_Assign, d_BeginDeclRef, d_RangeDeclRef);
Expr* AssignRange =
BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit));
Expr* AssignBegin =
BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr));
BuildOp(BO_Assign, RangeDeclRef, BuildOp(UO_AddrOf, RangeInit));

addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl()));
addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl_dx()));
addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl()));
addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl_dx()));
addToCurrentBlock(AssignAdjBegin);
addToCurrentBlock(AssignRange);
addToCurrentBlock(AssignBegin);
const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());

Expr* EndInit = cast<BinaryOperator>(EndDecl->getInit())->getRHS();
const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());
QualType EndType = CloneType(EndDecl->getType());
std::string EndName = EndDecl->getNameAsString();
Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit);
Expr* EndInit = Visit(EndDecl->getInit()).getExpr();
VarDecl* EndVarDecl =
BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false);
DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl);

addToCurrentBlock(AssignEnd);
auto* AssignEndVarDecl =
cast<VarDecl>(cast<DeclStmt>(AssignEnd)->getSingleDecl());
DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl);
BuildGlobalVarDecl(EndType, EndName, EndInit, /*DirectInit=*/false);
addToCurrentBlock(BuildDeclStmt(EndVarDecl));
DeclRefExpr* EndExpr = BuildDeclRef(EndVarDecl);
Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef);

beginBlock(direction::forward);
DeclDiff<VarDecl> LoopVDDiff = DifferentiateVarDecl(LoopVD);
Stmt* AdjLoopVDAddAssign =
utils::unwrapIfSingleStmt(endBlock(direction::forward));

if ((LoopVDDiff.getDecl()->getDeclName() != LoopVD->getDeclName() ||
LoopVD->getType() != LoopVDDiff.getDecl()->getType()))
m_DeclReplacements[LoopVD] = LoopVDDiff.getDecl();
llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop,
/*NewValue=*/true);

Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginDeclRef);
Expr* d_DecBegin = BuildOp(UO_PostDec, d_BeginDeclRef);
Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr);
// Add item assignment statement to the body.
const Stmt* body = FRS->getBody();
StmtDiff bodyDiff = Visit(body);
StmtDiff bodyDiff =
DifferentiateLoopBody(body, loopCounter, nullptr, nullptr,
/*isForLoop=*/true);

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();

StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl()));
StmtDiff storeAdjLoop =
StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx()));

addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx()));
Expr* CounterIncrement = loopCounter.getCounterIncrement();

// Expr* CounterIncrement = loopCounter.getCounterIncrement();
Expr* LoopInit = LoopVDDiff.getDecl()->getInit();
LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType()));
addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl()));
Expand All @@ -1058,7 +1068,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

beginBlock(direction::forward);
addToCurrentBlock(CounterIncrement);
// addToCurrentBlock(CounterIncrement);
addToCurrentBlock(AdjLoopVDAddAssign);
addToCurrentBlock(AssignLoop);
addToCurrentBlock(storeLoop.getStmt());
Expand Down Expand Up @@ -2685,10 +2695,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}

DeclDiff<VarDecl>
ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
DeclDiff<VarDecl> ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD,
bool AddToBlock) {
StmtDiff initDiff;
Expr* VDDerivedInit = nullptr;

// Local declarations are promoted to the function global scope. This
// procedure is done to make declarations visible in the reverse sweep.
// The reverse_mode_forward_pass mode does not have a reverse pass so
Expand Down Expand Up @@ -2863,7 +2874,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
getZeroInit(VDDerivedType));
else
assignToZero = GetCladZeroInit(declRef);
addToCurrentBlock(assignToZero, direction::reverse);
if (AddToBlock)
addToCurrentBlock(assignToZero, direction::reverse);
}
}

Expand All @@ -2879,10 +2891,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getForwSweepExpr_dx()));
addToCurrentBlock(assignDerivativeE);
if (AddToBlock)
addToCurrentBlock(assignDerivativeE);
if (isInsideLoop) {
StmtDiff pushPop = StoreAndRestore(derivedVDE);
addToCurrentBlock(pushPop.getStmt(), direction::forward);
if (AddToBlock)
addToCurrentBlock(pushPop.getStmt(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getStmt_dx());
}
derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE);
Expand All @@ -2908,10 +2922,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (promoteToFnScope) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
if (AddToBlock)
addToCurrentBlock(assignDerivativeE, direction::forward);
if (isInsideLoop) {
auto tape = MakeCladTapeFor(derivedVDE);
addToCurrentBlock(tape.Push);
if (AddToBlock)
addToCurrentBlock(tape.Push);
auto* reverseSweepDerivativePointerE =
BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop);
m_LoopBlock.back().push_back(
Expand All @@ -2926,6 +2942,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (derivedVDE)
m_Variables.emplace(VDClone, derivedVDE);

if ((VD->getDeclName() != VDClone->getDeclName() ||
VD->getType() != VDClone->getType()))
m_DeclReplacements[VD] = VDClone;

return DeclDiff<VarDecl>(VDClone, VDDerived);
}

Expand Down

0 comments on commit a27a6fb

Please sign in to comment.