Skip to content

Commit

Permalink
Adding support for range-based for loops in the reverse mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk committed Jul 31, 2024
1 parent 9a1f751 commit 092aa5f
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ namespace clad {
virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL);
StmtDiff VisitCXXForRangeStmt(const clang::CXXForRangeStmt* FRS);
StmtDiff VisitForStmt(const clang::ForStmt* FS);
StmtDiff VisitIfStmt(const clang::IfStmt* If);
StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE);
Expand Down
125 changes: 125 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,131 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(condExpr, ResultRef);
}

StmtDiff
ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
beginBlock(direction::reverse);

LoopCounter loopCounter(*this);
const VarDecl* LoopVD = FRS->getLoopVariable();

const Stmt* RangeDecl = FRS->getRangeStmt();
const Stmt* BeginDecl = FRS->getBeginStmt();
StmtDiff VisitRange = Visit(RangeDecl);
StmtDiff VisitBegin = Visit(BeginDecl);
Expr* BeginExpr = cast<BinaryOperator>(VisitBegin.getStmt())->getLHS();

beginBlock(direction::reverse);
// Create all declarations needed.
auto* BeginDeclRef = cast<DeclRefExpr>(BeginExpr);
Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()];

auto* RangeExpr =
cast<DeclRefExpr>(cast<BinaryOperator>(VisitRange.getStmt())->getLHS());

Expr* RangeInit = Clone(FRS->getRangeInit());
Expr* AssignRange =
BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit));
Expr* AssignBegin =
BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr));
addToCurrentBlock(AssignRange);
addToCurrentBlock(AssignBegin);
const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());

Expr* EndInit = cast<BinaryOperator>(EndDecl->getInit())->getRHS();
QualType EndType = CloneType(EndDecl->getType());
std::string EndName = EndDecl->getNameAsString();
Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit);
VarDecl* EndVarDecl =
BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false);
DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl);

addToCurrentBlock(AssignEnd);
auto* AssignEndVarDecl =
cast<VarDecl>(cast<DeclStmt>(AssignEnd)->getSingleDecl());
DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl);
Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef);

beginBlock(direction::forward);
DeclDiff<VarDecl> LoopVDDiff = DifferentiateVarDecl(LoopVD);
Stmt* AdjLoopVDAddAssign =
utils::unwrapIfSingleStmt(endBlock(direction::forward));

if ((LoopVDDiff.getDecl()->getDeclName() != LoopVD->getDeclName() ||
LoopVD->getType() != LoopVDDiff.getDecl()->getType()))
m_DeclReplacements[LoopVD] = LoopVDDiff.getDecl();
llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop,
/*NewValue=*/true);

Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginDeclRef);
Expr* d_DecBegin = BuildOp(UO_PostDec, d_BeginDeclRef);
Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr);
// Add item assignment statement to the body.
const Stmt* body = FRS->getBody();
StmtDiff bodyDiff = Visit(body);

StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl()));
StmtDiff storeAdjLoop =
StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx()));

addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx()));
Expr* CounterIncrement = loopCounter.getCounterIncrement();

Expr* LoopInit = LoopVDDiff.getDecl()->getInit();
LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType()));
addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl()));
Expr* AssignLoop =
BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), LoopInit);

if (!LoopVD->getType()->isReferenceType()) {
Expr* d_LoopVD = BuildDeclRef(LoopVDDiff.getDecl_dx());
AdjLoopVDAddAssign =
BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_BeginDeclRef));
}

beginBlock(direction::forward);
addToCurrentBlock(CounterIncrement);
addToCurrentBlock(AdjLoopVDAddAssign);
addToCurrentBlock(AssignLoop);
addToCurrentBlock(storeLoop.getStmt());
addToCurrentBlock(storeAdjLoop.getStmt());
CompoundStmt* LoopVDForwardDiff = endBlock(direction::forward);
CompoundStmt* bodyForward = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), bodyDiff.getStmt(), LoopVDForwardDiff);

beginBlock(direction::forward);
addToCurrentBlock(d_DecBegin);
addToCurrentBlock(storeLoop.getStmt_dx());
addToCurrentBlock(storeAdjLoop.getStmt_dx());
CompoundStmt* LoopVDReverseDiff = endBlock(direction::forward);
CompoundStmt* bodyReverse = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), bodyDiff.getStmt_dx(), LoopVDReverseDiff);

Expr* Inc = BuildOp(BO_Comma, IncBegin, d_IncBegin);
Stmt* Forward = new (m_Context) ForStmt(
m_Context, /*Init=*/nullptr, ForwardCond, /*CondVar=*/nullptr, Inc,
bodyForward, FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc());
Expr* CounterCondition =
loopCounter.getCounterConditionResult().get().second;
Expr* CounterDecrement = loopCounter.getCounterDecrement();

Stmt* Reverse = bodyReverse;
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);

Reverse = new (m_Context)
ForStmt(m_Context, /*Init=*/nullptr, CounterCondition,
/*CondVar=*/nullptr, CounterDecrement, Reverse,
FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc());
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
endScope();

return {utils::unwrapIfSingleStmt(Forward),
utils::unwrapIfSingleStmt(Reverse)};
}

StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) {
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
Expand Down
69 changes: 69 additions & 0 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -2685,6 +2685,74 @@ double fn33(double i, double j) {
//CHECK-NEXT: }
//CHECK-NEXT:}

double fn34(double x, double y){
double r = 0;
double a[] = {y, x*y, x*x + y};
for(auto& i: a){
r+=i;
}
return r;
}

//CHECK: void fn34_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: double _d_r = 0;
//CHECK-NEXT: double _d_a[3] = {0};
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: double (*_d___range1)[3] = 0;
//CHECK-NEXT: double (*__range10)[3] = {};
//CHECK-NEXT: double *_d___begin1 = 0;
//CHECK-NEXT: double *__begin10 = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double *> _t2 = {};
//CHECK-NEXT: clad::tape<double *> _t3 = {};
//CHECK-NEXT: double r = 0;
//CHECK-NEXT: double a[3] = {y, x * y, x * x + y};
//CHECK-NEXT: _t0 = 0UL;
//CHECK-NEXT: _d___range1 = &_d_a;
//CHECK-NEXT: _d___begin1 = *_d___range1;
//CHECK-NEXT: __range10 = &a;
//CHECK-NEXT: __begin10 = *__range10;
//CHECK-NEXT: double *__end10 = *__range10 + 3L;
//CHECK-NEXT: double *_d_i = 0;
//CHECK-NEXT: double *i = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
//CHECK-NEXT: {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: _d_i = &*_d___begin1;
//CHECK-NEXT: i = &*__begin10;
//CHECK-NEXT: clad::push(_t2, i);
//CHECK-NEXT: clad::push(_t3, _d_i);
//CHECK-NEXT: }
//CHECK-NEXT: clad::push(_t1, r);
//CHECK-NEXT: r += *i;
//CHECK-NEXT: }
//CHECK-NEXT: _d_r += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: {
//CHECK-NEXT: _d___begin1--;
//CHECK-NEXT: i = clad::pop(_t2);
//CHECK-NEXT: _d_i = clad::pop(_t3);
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: r = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_r;
//CHECK-NEXT: *_d_i += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: *_d_y += _d_a[0];
//CHECK-NEXT: *_d_x += _d_a[1] * y;
//CHECK-NEXT: *_d_y += x * _d_a[1];
//CHECK-NEXT: *_d_x += _d_a[2] * x;
//CHECK-NEXT: *_d_x += x * _d_a[2];
//CHECK-NEXT: *_d_y += _d_a[2];
//CHECK-NEXT: }
//CHECK-NEXT:}



#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient(F);\
Expand Down Expand Up @@ -2769,6 +2837,7 @@ int main() {
TEST_2(fn32, 3, 5); // CHECK-EXEC: {45.00, 27.00}
TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00}

TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00}
}

//CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {
Expand Down

0 comments on commit 092aa5f

Please sign in to comment.