Skip to content

Commit

Permalink
Merge branch 'master' into rangebased
Browse files Browse the repository at this point in the history
  • Loading branch information
ovdiiuv authored Jul 2, 2024
2 parents d5d292c + 22b2590 commit 44e2aae
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 53 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class BaseForwardModeVisitor
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXConstCastExpr(const clang::CXXConstCastExpr* CCE);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL);
StmtDiff VisitStringLiteral(const clang::StringLiteral* SL);
Expand Down Expand Up @@ -113,6 +114,8 @@ class BaseForwardModeVisitor
StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE);
StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; };
StmtDiff
VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE);
static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

Expand Down
17 changes: 17 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define CLAD_STL_BUILTINS_H

#include <vector>
#include "clad/Differentiator/BuiltinDerivatives.h"

namespace clad {
namespace custom_derivatives {
Expand All @@ -26,6 +27,22 @@ void resize_pushforward(::std::vector<T>* v, unsigned sz, U val,
d_v->resize(sz, d_val);
v->resize(sz, val);
}

template <typename T>
ValueAndPushforward<typename ::std::initializer_list<T>::iterator,
typename ::std::initializer_list<T>::iterator>
begin_pushforward(::std::initializer_list<T>* il,
::std::initializer_list<T>* d_il) {
return {il->begin(), d_il->begin()};
}

template <typename T>
ValueAndPushforward<typename ::std::initializer_list<T>::iterator,
typename ::std::initializer_list<T>::iterator>
end_pushforward(const ::std::initializer_list<T>* il,
const ::std::initializer_list<T>* d_il) {
return {il->end(), d_il->end()};
}
} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad
Expand Down
28 changes: 26 additions & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,9 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
std::transform(std::begin(Indices), std::end(Indices),
std::begin(clonedIndices),
[this](const Expr* E) { return Clone(E); });
auto cloned = BuildArraySubscript(clonedBase, clonedIndices);
Expr* cloned = BuildArraySubscript(clonedBase, clonedIndices);

auto zero = ConstantFolder::synthesizeLiteral(ExprTy, m_Context, 0);
Expr* zero = getZeroInit(ExprTy);
ValueDecl* VD = nullptr;
// Derived variables for member variables are also created when we are
// differentiating a call operator.
Expand Down Expand Up @@ -1671,6 +1671,25 @@ StmtDiff BaseForwardModeVisitor::VisitImplicitValueInitExpr(
return StmtDiff(Clone(E), Clone(E));
}

StmtDiff
BaseForwardModeVisitor::VisitCXXConstCastExpr(const CXXConstCastExpr* CCE) {
StmtDiff subExprDiff = Visit(CCE->getSubExpr());
Expr* castExpr =
m_Sema
.BuildCXXNamedCast(CCE->getBeginLoc(), tok::kw_const_cast,
CCE->getTypeInfoAsWritten(), subExprDiff.getExpr(),
CCE->getAngleBrackets(), CCE->getSourceRange())
.get();
Expr* castExprDiff =
m_Sema
.BuildCXXNamedCast(CCE->getBeginLoc(), tok::kw_const_cast,
CCE->getTypeInfoAsWritten(),
subExprDiff.getExpr_dx(), CCE->getAngleBrackets(),
CCE->getSourceRange())
.get();
return StmtDiff(castExpr, castExprDiff);
}

StmtDiff
BaseForwardModeVisitor::VisitCStyleCastExpr(const CStyleCastExpr* CSCE) {
StmtDiff subExprDiff = Visit(CSCE->getSubExpr());
Expand Down Expand Up @@ -2230,4 +2249,9 @@ BaseForwardModeVisitor::DifferentiateStaticAssertDecl(
const clang::StaticAssertDecl* SAD) {
return DeclDiff<StaticAssertDecl>();
}

StmtDiff BaseForwardModeVisitor::VisitCXXStdInitializerListExpr(
const clang::CXXStdInitializerListExpr* ILE) {
return Visit(ILE->getSubExpr());
}
} // end namespace clad
51 changes: 45 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "clad/Differentiator/StmtClone.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTLambda.h"
#include "clang/AST/Expr.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/TemplateBase.h"
Expand Down Expand Up @@ -1596,13 +1597,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff baseDiff;
// If it has more args or f_darg0 was not found, we look for its pullback
// function.
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
if (!OverloadedDerivedFn) {
size_t idx = 0;

/// Add base derivative expression in the derived call output args list if
/// `CE` is a call to an instance member function.
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
if (MD) {
if (isLambdaCallOperator(MD)) {
QualType ptrType = m_Context.getPointerType(m_Context.getRecordType(
FD->getDeclContext()->getOuterLexicalRecordContext()));
baseDiff =
StmtDiff(Clone(dyn_cast<CXXOperatorCallExpr>(CE)->getArg(0)),
new (m_Context) CXXNullPtrLiteralExpr(ptrType, Loc));
} else if (MD->isInstance()) {
const Expr* baseOriginalE = nullptr;
if (const auto* MCE = dyn_cast<CXXMemberCallExpr>(CE))
baseOriginalE = MCE->getImplicitObjectArgument();
Expand Down Expand Up @@ -1700,7 +1708,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis;
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (DerivedCallOutputArgs[i + isaMethod])
if (MD && isLambdaCallOperator(MD)) {
if (const auto* paramDecl = FD->getParamDecl(i))
pullbackRequest.DVI.push_back(paramDecl);
} else if (DerivedCallOutputArgs[i + isaMethod])
pullbackRequest.DVI.push_back(FD->getParamDecl(i));

FunctionDecl* pullbackFD = nullptr;
Expand Down Expand Up @@ -2735,14 +2746,41 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool promoteToFnScope =
!getCurrentScope()->isFunctionScope() &&
m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass;

// If the DeclStmt is not empty, check the first declaration in case it is a
// lambda function. This case it is treated separately for now and we don't
// create a variable for its derivative.
bool isLambda = false;
const auto* declsBegin = DS->decls().begin();
if (declsBegin != DS->decls().end() && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
QualType QT = VD->getType();
if (!QT->isPointerType()) {
auto* typeDecl = QT->getAsCXXRecordDecl();
// We should also simply copy the original lambda. The differentiation
// of lambdas is happening in the `VisitCallExpr`. For now, only the
// declarations with lambda expressions without captures are supported.
isLambda = typeDecl && typeDecl->isLambda();
if (isLambda) {
for (auto* D : DS->decls())
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Stmt* DSClone = BuildDeclStmt(decls);
return StmtDiff(DSClone, nullptr);
}
}
}

// For each variable declaration v, create another declaration _d_v to
// store derivatives for potential reassignments. E.g.
// double y = x;
// ->
// double _d_y = _d_x; double y = x;
for (auto* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D)) {
DeclDiff<VarDecl> VDDiff = DifferentiateVarDecl(VD);
DeclDiff<VarDecl> VDDiff;
if (!isLambda)
VDDiff = DifferentiateVarDecl(VD);

// Check if decl's name is the same as before. The name may be changed
// if decl name collides with something in the derivative body.
Expand All @@ -2762,8 +2800,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// double _d_y = x; // copied from original function, collides with
// _d_y
// }
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName() ||
VD->getType() != VDDiff.getDecl()->getType())
if (!isLambda &&
(VDDiff.getDecl()->getDeclName() != VD->getDeclName() ||
VD->getType() != VDDiff.getDecl()->getType()))
m_DeclReplacements[VD] = VDDiff.getDecl();

// Here, we move the declaration to the function global scope.
Expand Down
4 changes: 2 additions & 2 deletions test/Arrays/ArrayInputsForwardMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ double multiply(const double *arr) {
}

//CHECK: double multiply_darg0_1(const double *arr) {
//CHECK-NEXT: return 0. * arr[1] + arr[0] * 1.;
//CHECK-NEXT: return 0 * arr[1] + arr[0] * 1.;
//CHECK-NEXT: }

double divide(const double *arr) {
return arr[0] / arr[1];
}

//CHECK: double divide_darg0_1(const double *arr) {
//CHECK-NEXT: return (0. * arr[1] - arr[0] * 1.) / (arr[1] * arr[1]);
//CHECK-NEXT: return (0 * arr[1] - arr[0] * 1.) / (arr[1] * arr[1]);
//CHECK-NEXT: }

double addArr(const double *arr, int n) {
Expand Down
97 changes: 64 additions & 33 deletions test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"
#include "clad/Differentiator/STLBuiltins.h"
#include <initializer_list>
#include <cmath>
#include "../TestUtils.h"

double f1(double x, int y) {
double r = 1.0;
Expand Down Expand Up @@ -548,38 +551,32 @@ double fn17_darg0(double x);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

double fn18(double x, double y){
int coefficients[3] = {4, 7, 3};
double res = 0;
for(auto i: coefficients){
if(i%2==0)
continue;
res+= x*y*i;
}
return res;
double fn18(double u, double v) {
auto dl = {u, v, u*v};
double res = 0;
auto dl_end = dl.end();
for (auto i = dl.begin(); i != dl_end; ++i)
res += *i;
return res;
}

double fn18_darg0(double x, double y);
// CHECK: double fn18_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: int _d_coefficients[3] = {0, 0, 0};
// CHECK-NEXT: int coefficients[3] = {4, 7, 3};
// CHECK: double fn18_darg0(double u, double v) {
// CHECK-NEXT: double _d_u = 1;
// CHECK-NEXT: double _d_v = 0;
// CHECK-NEXT: {{.*}}initializer_list<double> _d_dl = {_d_u, _d_v, _d_u * v + u * _d_v};
// CHECK-NEXT: {{.*}}initializer_list<double> dl = {u, v, u * v};
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: int (&_d___range1)[3] = _d_coefficients;
// CHECK-NEXT: int (&__range10)[3] = coefficients;
// CHECK-NEXT: int *_d___begin1 = _d___range1;
// CHECK-NEXT: int *__begin10 = __range10;
// CHECK-NEXT: int *__end10 = __range10 + {{3|3L}};
// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) {
// CHECK-NEXT: int _d_i = *_d___begin1;
// CHECK-NEXT: int i = *__begin10;
// CHECK-NEXT: if (i % 2 == 0)
// CHECK-NEXT: continue;
// CHECK-NEXT: double _t0 = x * y;
// CHECK-NEXT: _d_res += (_d_x * y + x * _d_y) * i + _t0 * _d_i;
// CHECK-NEXT: res += _t0 * i;
// CHECK-NEXT: {{.*}}ValueAndPushforward<{{.*}}, {{.*}}> _t0 = {{.*}}end_pushforward(&dl, &_d_dl);
// CHECK-NEXT: {{.*}}_d_dl_end = _t0.pushforward;
// CHECK-NEXT: {{.*}}dl_end = _t0.value;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}ValueAndPushforward<{{.*}}, {{.*}}> _t1 = {{.*}}begin_pushforward(&dl, &_d_dl);
// CHECK-NEXT: {{.*}}_d_i = _t1.pushforward;
// CHECK-NEXT: for ({{.*}}i = _t1.value; i != dl_end; ++_d_i , ++i) {
// CHECK-NEXT: _d_res += *_d_i;
// CHECK-NEXT: res += *i;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }
Expand Down Expand Up @@ -629,6 +626,7 @@ double fn20(double x){
return x;
}

double fn20_darg0(double x);
// CHECK: double fn20_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: int _d_a[1] = {0};
Expand All @@ -647,9 +645,38 @@ double fn20(double x){
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

ouble fn21(double x, double y){
int coefficients[3] = {4, 7, 3};
double res = 0;
for(auto i: coefficients){
if(i%2==0)
continue;
res+= x*y*i;
}
return res;
}

double fn20_darg0(double x);

double fn21_darg0(double x, double y);
// CHECK: double fn21_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: int _d_coefficients[3] = {0, 0, 0};
// CHECK-NEXT: int coefficients[3] = {4, 7, 3};
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: int (&_d___range1)[3] = _d_coefficients;
// CHECK-NEXT: int (&__range10)[3] = coefficients;
// CHECK-NEXT: int *_d___begin1 = _d___range1;
// CHECK-NEXT: int *__begin10 = __range10;
// CHECK-NEXT: int *__end10 = __range10 + {{3|3L}};
// CHECK-NEXT: for (; __begin10 != __end10; ++_d___begin1 , ++__begin10) {
// CHECK-NEXT: int _d_i = *_d___begin1;
// CHECK-NEXT: int i = *__begin10;
// CHECK-NEXT: if (i % 2 == 0)
// CHECK-NEXT: continue;
// CHECK-NEXT: double _t0 = x * y;
// CHECK-NEXT: _d_res += (_d_x * y + x * _d_y) * i + _t0 * _d_i;
// CHECK-NEXT: res += _t0 * i;

#define TEST(fn)\
auto d_##fn = clad::differentiate(fn, "i");\
Expand Down Expand Up @@ -717,13 +744,17 @@ int main() {

clad::differentiate(fn17, 0);
printf("Result is = %.2f\n", fn17_darg0(5)); // CHECK-EXEC: Result is = 0

clad::differentiate(fn18, 0);
printf("Result is = %.2f\n", fn18_darg0(5, 1)); // CHECK-EXEC: Result is = 10.00
INIT_DIFFERENTIATE(fn18, "u");
TEST_DIFFERENTIATE(fn18, 3, 5); // CHECK-EXEC: {6.00}

clad::differentiate(fn19, 0);
printf("Result is = %.2f\n", fn19_darg0(5, 2)); // CHECK-EXEC: Result is = 14.00

clad::differentiate(fn20, 0);
printf("Result is = %.2f\n", fn20_darg0(5)); // CHECK-EXEC: Result is = 6.00

clad::differentiate(fn21, 0);
printf("Result is = %.2f\n", fn21_darg0(5, 1)); // CHECK-EXEC: Result is = 10.00

}
2 changes: 1 addition & 1 deletion test/FirstDerivative/StructGlobalObjects.C
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ double fn_array(double i, double j) {
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: double &_t0 = array.data[0];
// CHECK-NEXT: double &_t1 = array.data[1];
// CHECK-NEXT: return 0. * i + _t0 * _d_i + 0. * j + _t1 * _d_j;
// CHECK-NEXT: return 0 * i + _t0 * _d_i + 0 * j + _t1 * _d_j;
// CHECK-NEXT: }

int main () {
Expand Down
4 changes: 2 additions & 2 deletions test/ForwardMode/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ struct WidgetPointer {
// CHECK-NEXT: double &_t5 = this->arr[5];
// CHECK-NEXT: double &_t6 = this->arr[5];
// CHECK-NEXT: double &_t7 = this->j;
// CHECK-NEXT: double _t8 = 0. * _t6 + _t5 * 0.;
// CHECK-NEXT: double _t8 = 0 * _t6 + _t5 * 0;
// CHECK-NEXT: double _t9 = _t5 * _t6;
// CHECK-NEXT: _d_j = _d_j * _t9 + _t7 * _t8;
// CHECK-NEXT: _t7 *= _t9;
Expand All @@ -363,7 +363,7 @@ struct WidgetPointer {
// CHECK-NEXT: double &_t0 = this->arr[3];
// CHECK-NEXT: double &_t1 = this->arr[3];
// CHECK-NEXT: double &_t2 = this->i;
// CHECK-NEXT: double _t3 = 0. * _t1 + _t0 * 0.;
// CHECK-NEXT: double _t3 = 0 * _t1 + _t0 * 0;
// CHECK-NEXT: double _t4 = _t0 * _t1;
// CHECK-NEXT: _d_i = _d_i * _t4 + _t2 * _t3;
// CHECK-NEXT: _t2 *= _t4;
Expand Down
Loading

0 comments on commit 44e2aae

Please sign in to comment.