From 4f487e3e42f165f05df8bc718bb71369b8aaf557 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 24 Jun 2024 16:08:57 +0300 Subject: [PATCH 1/4] Don't synthesize 0 of the ArraySubscriptExpr type since it may not be numerical, use getZeroInit instead. --- lib/Differentiator/BaseForwardModeVisitor.cpp | 4 ++-- test/Arrays/ArrayInputsForwardMode.C | 4 ++-- test/FirstDerivative/StructGlobalObjects.C | 2 +- test/ForwardMode/Functors.C | 4 ++-- test/ForwardMode/UserDefinedTypes.C | 17 +++++++++++++++++ test/ROOT/TFormula.C | 14 +++++++------- 6 files changed, 31 insertions(+), 14 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 34116dc39..48e1cf35c 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -894,9 +894,9 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { std::transform(std::begin(Indices), std::end(Indices), std::begin(clonedIndices), [this](const Expr* E) { return Clone(E); }); - auto cloned = BuildArraySubscript(clonedBase, clonedIndices); + Expr* cloned = BuildArraySubscript(clonedBase, clonedIndices); - auto zero = ConstantFolder::synthesizeLiteral(ExprTy, m_Context, 0); + Expr* zero = getZeroInit(ExprTy); ValueDecl* VD = nullptr; // Derived variables for member variables are also created when we are // differentiating a call operator. diff --git a/test/Arrays/ArrayInputsForwardMode.C b/test/Arrays/ArrayInputsForwardMode.C index 57559a6be..633b969aa 100644 --- a/test/Arrays/ArrayInputsForwardMode.C +++ b/test/Arrays/ArrayInputsForwardMode.C @@ -10,7 +10,7 @@ double multiply(const double *arr) { } //CHECK: double multiply_darg0_1(const double *arr) { -//CHECK-NEXT: return 0. * arr[1] + arr[0] * 1.; +//CHECK-NEXT: return 0 * arr[1] + arr[0] * 1.; //CHECK-NEXT: } double divide(const double *arr) { @@ -18,7 +18,7 @@ double divide(const double *arr) { } //CHECK: double divide_darg0_1(const double *arr) { -//CHECK-NEXT: return (0. * arr[1] - arr[0] * 1.) / (arr[1] * arr[1]); +//CHECK-NEXT: return (0 * arr[1] - arr[0] * 1.) / (arr[1] * arr[1]); //CHECK-NEXT: } double addArr(const double *arr, int n) { diff --git a/test/FirstDerivative/StructGlobalObjects.C b/test/FirstDerivative/StructGlobalObjects.C index 784a68188..681e6d132 100644 --- a/test/FirstDerivative/StructGlobalObjects.C +++ b/test/FirstDerivative/StructGlobalObjects.C @@ -37,7 +37,7 @@ double fn_array(double i, double j) { // CHECK-NEXT: double _d_j = 0; // CHECK-NEXT: double &_t0 = array.data[0]; // CHECK-NEXT: double &_t1 = array.data[1]; -// CHECK-NEXT: return 0. * i + _t0 * _d_i + 0. * j + _t1 * _d_j; +// CHECK-NEXT: return 0 * i + _t0 * _d_i + 0 * j + _t1 * _d_j; // CHECK-NEXT: } int main () { diff --git a/test/ForwardMode/Functors.C b/test/ForwardMode/Functors.C index ee15045d3..fdfcab92b 100644 --- a/test/ForwardMode/Functors.C +++ b/test/ForwardMode/Functors.C @@ -338,7 +338,7 @@ struct WidgetPointer { // CHECK-NEXT: double &_t5 = this->arr[5]; // CHECK-NEXT: double &_t6 = this->arr[5]; // CHECK-NEXT: double &_t7 = this->j; - // CHECK-NEXT: double _t8 = 0. * _t6 + _t5 * 0.; + // CHECK-NEXT: double _t8 = 0 * _t6 + _t5 * 0; // CHECK-NEXT: double _t9 = _t5 * _t6; // CHECK-NEXT: _d_j = _d_j * _t9 + _t7 * _t8; // CHECK-NEXT: _t7 *= _t9; @@ -363,7 +363,7 @@ struct WidgetPointer { // CHECK-NEXT: double &_t0 = this->arr[3]; // CHECK-NEXT: double &_t1 = this->arr[3]; // CHECK-NEXT: double &_t2 = this->i; - // CHECK-NEXT: double _t3 = 0. * _t1 + _t0 * 0.; + // CHECK-NEXT: double _t3 = 0 * _t1 + _t0 * 0; // CHECK-NEXT: double _t4 = _t0 * _t1; // CHECK-NEXT: _d_i = _d_i * _t4 + _t2 * _t3; // CHECK-NEXT: _t2 *= _t4; diff --git a/test/ForwardMode/UserDefinedTypes.C b/test/ForwardMode/UserDefinedTypes.C index 65f4cc3d6..b590a7571 100644 --- a/test/ForwardMode/UserDefinedTypes.C +++ b/test/ForwardMode/UserDefinedTypes.C @@ -950,6 +950,21 @@ double fn17(A a, B b) { // CHECK-NEXT: return _d_a.mem * _t1 + _t0 * _d_b.mem; // CHECK-NEXT: } +double fn18(double i, double j) { + A v[2] = {2, 3}; + v[0] = 9 * i; + return v[0].mem; +} + +// CHECK: double fn18_darg0(double i, double j) { +// CHECK-NEXT: double _d_i = 1; +// CHECK-NEXT: double _d_j = 0; +// CHECK-NEXT: A _d_v[2] = {0, 0}; +// CHECK-NEXT: A v[2] = {2, 3}; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = v[0].operator_equal_pushforward(9 * i, &_d_v[0], 0 * i + 9 * _d_i); +// CHECK-NEXT: return _d_v[0].mem; +// CHECK-NEXT: } + template void print(const Tensor& t) { for (int i=0; i sum_pushforward(Tensor *_d_this) { // CHECK-NEXT: double _d_res = 0; diff --git a/test/ROOT/TFormula.C b/test/ROOT/TFormula.C index d4d62174a..6d44fc921 100644 --- a/test/ROOT/TFormula.C +++ b/test/ROOT/TFormula.C @@ -57,22 +57,22 @@ void TFormula_example_grad_1(Double_t* x, Double_t* p, Double_t* _d_p); //CHECK: Double_t TFormula_example_darg1_0(Double_t *x, Double_t *p) { //CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]); //CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -1.); -//CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0.); -//CHECK-NEXT: return 0. * _t0 + x[0] * (1. + 0. + 0.) + _t1.pushforward + _t2.pushforward; +//CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0); +//CHECK-NEXT: return 0 * _t0 + x[0] * (1. + 0 + 0) + _t1.pushforward + _t2.pushforward; //CHECK-NEXT: } //CHECK: Double_t TFormula_example_darg1_1(Double_t *x, Double_t *p) { //CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]); -//CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0.); +//CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0); //CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 1.); -//CHECK-NEXT: return 0. * _t0 + x[0] * (0. + 1. + 0.) + _t1.pushforward + _t2.pushforward; +//CHECK-NEXT: return 0 * _t0 + x[0] * (0 + 1. + 0) + _t1.pushforward + _t2.pushforward; //CHECK-NEXT: } //CHECK: Double_t TFormula_example_darg1_2(Double_t *x, Double_t *p) { //CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]); -//CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0.); -//CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0.); -//CHECK-NEXT: return 0. * _t0 + x[0] * (0. + 0. + 1.) + _t1.pushforward + _t2.pushforward; +//CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0); +//CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0); +//CHECK-NEXT: return 0 * _t0 + x[0] * (0 + 0 + 1.) + _t1.pushforward + _t2.pushforward; //CHECK-NEXT: } Double_t TFormula_hess1(Double_t *x, Double_t *p) { From 4c824edf6519086f41b0b456a1045a095a1206d5 Mon Sep 17 00:00:00 2001 From: parth-07 Date: Wed, 12 Jun 2024 00:33:50 +0530 Subject: [PATCH 2/4] Add support for initializer_list in forward mode AD This commit adds primitive support for initializer_list in the forward mode AD. --- .../Differentiator/BaseForwardModeVisitor.h | 2 + include/clad/Differentiator/STLBuiltins.h | 17 +++++++++ lib/Differentiator/BaseForwardModeVisitor.cpp | 5 +++ test/FirstDerivative/Loops.C | 38 +++++++++++++++++++ 4 files changed, 62 insertions(+) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 17e6ba6a4..484d17d4b 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -112,6 +112,8 @@ class BaseForwardModeVisitor StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE); StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE); StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; }; + StmtDiff + VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE); static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 41fc05518..39e786aba 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -2,6 +2,7 @@ #define CLAD_STL_BUILTINS_H #include +#include "clad/Differentiator/BuiltinDerivatives.h" namespace clad { namespace custom_derivatives { @@ -26,6 +27,22 @@ void resize_pushforward(::std::vector* v, unsigned sz, U val, d_v->resize(sz, d_val); v->resize(sz, val); } + +template +ValueAndPushforward::iterator, + typename ::std::initializer_list::iterator> +begin_pushforward(::std::initializer_list* il, + ::std::initializer_list* d_il) { + return {il->begin(), d_il->begin()}; +} + +template +ValueAndPushforward::iterator, + typename ::std::initializer_list::iterator> +end_pushforward(const ::std::initializer_list* il, + const ::std::initializer_list* d_il) { + return {il->end(), d_il->end()}; +} } // namespace class_functions } // namespace custom_derivatives } // namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 48e1cf35c..067fd77d1 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -2159,4 +2159,9 @@ BaseForwardModeVisitor::DifferentiateStaticAssertDecl( const clang::StaticAssertDecl* SAD) { return DeclDiff(); } + +StmtDiff BaseForwardModeVisitor::VisitCXXStdInitializerListExpr( + const clang::CXXStdInitializerListExpr* ILE) { + return Visit(ILE->getSubExpr()); +} } // end namespace clad diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 152984bcd..5534291b5 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -3,7 +3,10 @@ // CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/STLBuiltins.h" +#include #include +#include "../TestUtils.h" double f1(double x, int y) { double r = 1.0; @@ -548,6 +551,38 @@ double fn17_darg0(double x); // CHECK-NEXT: return _d_x; // CHECK-NEXT: } +double fn18(double u, double v) { + auto dl = {u, v, u*v}; + double res = 0; + auto dl_end = dl.end(); + for (auto i = dl.begin(); i != dl_end; ++i) + res += *i; + return res; +} + +// CHECK: double fn18_darg0(double u, double v) { +// CHECK-NEXT: double _d_u = 1; +// CHECK-NEXT: double _d_v = 0; +// CHECK-NEXT: {{.*}}initializer_list _d_dl = {_d_u, _d_v, _d_u * v + u * _d_v}; +// CHECK-NEXT: {{.*}}initializer_list dl = {u, v, u * v}; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: {{.*}}ValueAndPushforward<{{.*}}, {{.*}}> _t0 = {{.*}}end_pushforward(&dl, &_d_dl); +// CHECK-NEXT: {{.*}}_d_dl_end = _t0.pushforward; +// CHECK-NEXT: {{.*}}dl_end = _t0.value; +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}}ValueAndPushforward<{{.*}}, {{.*}}> _t1 = {{.*}}begin_pushforward(&dl, &_d_dl); +// CHECK-NEXT: {{.*}}_d_i = _t1.pushforward; +// CHECK-NEXT: for ({{.*}}i = _t1.value; i != dl_end; ++_d_i , ++i) { +// CHECK-NEXT: _d_res += *_d_i; +// CHECK-NEXT: res += *i; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + + + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -614,4 +649,7 @@ int main() { clad::differentiate(fn17, 0); printf("Result is = %.2f\n", fn17_darg0(5)); // CHECK-EXEC: Result is = 0 + + INIT_DIFFERENTIATE(fn18, "u"); + TEST_DIFFERENTIATE(fn18, 3, 5); // CHECK-EXEC: {6.00} } From 1d56ef8997a3225d7d37cde4383de2caf54e8574 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Sat, 29 Jun 2024 18:16:16 +0200 Subject: [PATCH 3/4] Add support for const cast in fwd mode --- .../Differentiator/BaseForwardModeVisitor.h | 1 + lib/Differentiator/BaseForwardModeVisitor.cpp | 19 +++++++++++++ test/ForwardMode/Pointer.C | 28 +++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 484d17d4b..311fc5769 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -55,6 +55,7 @@ class BaseForwardModeVisitor StmtDiff VisitCallExpr(const clang::CallExpr* CE); StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); + StmtDiff VisitCXXConstCastExpr(const clang::CXXConstCastExpr* CCE); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL); StmtDiff VisitStringLiteral(const clang::StringLiteral* SL); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 067fd77d1..2e26444e4 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1600,6 +1600,25 @@ StmtDiff BaseForwardModeVisitor::VisitImplicitValueInitExpr( return StmtDiff(Clone(E), Clone(E)); } +StmtDiff +BaseForwardModeVisitor::VisitCXXConstCastExpr(const CXXConstCastExpr* CCE) { + StmtDiff subExprDiff = Visit(CCE->getSubExpr()); + Expr* castExpr = + m_Sema + .BuildCXXNamedCast(CCE->getBeginLoc(), tok::kw_const_cast, + CCE->getTypeInfoAsWritten(), subExprDiff.getExpr(), + CCE->getAngleBrackets(), CCE->getSourceRange()) + .get(); + Expr* castExprDiff = + m_Sema + .BuildCXXNamedCast(CCE->getBeginLoc(), tok::kw_const_cast, + CCE->getTypeInfoAsWritten(), + subExprDiff.getExpr_dx(), CCE->getAngleBrackets(), + CCE->getSourceRange()) + .get(); + return StmtDiff(castExpr, castExprDiff); +} + StmtDiff BaseForwardModeVisitor::VisitCStyleCastExpr(const CStyleCastExpr* CSCE) { StmtDiff subExprDiff = Visit(CSCE->getSubExpr()); diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C index cfe3598a1..317cd28a5 100644 --- a/test/ForwardMode/Pointer.C +++ b/test/ForwardMode/Pointer.C @@ -174,6 +174,29 @@ double fn7(double i) { // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +void* cling_runtime_internal_throwIfInvalidPointer(void *Sema, void *Expr, const void *Arg) { + return const_cast(Arg); +} + +double fn8(double* params) { + double arr[] = {3.0}; + return params[0]*params[0] + *(double*)(cling_runtime_internal_throwIfInvalidPointer((void*)0UL, (void*)0UL, arr)); +} + +// CHECK: clad::ValueAndPushforward cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg); + +// CHECK: double fn8_darg0_0(double *params) { +// CHECK-NEXT: double _d_arr[1] = {0.}; +// CHECK-NEXT: double arr[1] = {3.}; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = cling_runtime_internal_throwIfInvalidPointer_pushforward((void *)0UL, (void *)0UL, arr, (void *)0UL, (void *)0UL, _d_arr); +// CHECK-NEXT: return 1. * params[0] + params[0] * 1. + *(double *)_t0.pushforward; +// CHECK-NEXT: } + +// CHECK: clad::ValueAndPushforward cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) { +// CHECK-NEXT: return {const_cast(Arg), const_cast(_d_Arg)}; +// CHECK-NEXT: } + + int main() { INIT_DIFFERENTIATE(fn1, "i"); INIT_DIFFERENTIATE(fn2, "i"); @@ -190,4 +213,9 @@ int main() { TEST_DIFFERENTIATE(fn5, 3, 5); // CHECK-EXEC: {57.00} TEST_DIFFERENTIATE(fn6, 3); // CHECK-EXEC: {1.00} TEST_DIFFERENTIATE(fn7, 3); // CHECK-EXEC: {4.00} + + double params[] = {3.0}; + auto fn8_dx = clad::differentiate(fn8, "params[0]"); + double d_param = fn8_dx.execute(params); + printf("{%.2f}\n", d_param); // CHECK-EXEC: {6.00} } From 22b2590c0df1b322af81c100857dd7dc935bca4c Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Fri, 21 Jun 2024 23:01:21 +0200 Subject: [PATCH 4/4] Add support for simple lambda expressions in reverse mode This commit provides support for primitive lambda expressions with no captures in reverse mode in the same way they are currently supported in the forward mode (#937). That is, the lambda expressions are not visited yet. Instead, the lambda functions are treated as a special case of functors. Fixes: #789 --- lib/Differentiator/ReverseModeVisitor.cpp | 51 +++++++++++++-- test/Gradient/Lambdas.C | 79 +++++++++++++++++++++++ 2 files changed, 124 insertions(+), 6 deletions(-) create mode 100644 test/Gradient/Lambdas.C diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6f1df389a..6394ee9dd 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -17,6 +17,7 @@ #include "clad/Differentiator/StmtClone.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/ASTLambda.h" #include "clang/AST/Expr.h" #include "clang/AST/Stmt.h" #include "clang/AST/TemplateBase.h" @@ -1596,13 +1597,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff baseDiff; // If it has more args or f_darg0 was not found, we look for its pullback // function. + const auto* MD = dyn_cast(FD); if (!OverloadedDerivedFn) { size_t idx = 0; /// Add base derivative expression in the derived call output args list if /// `CE` is a call to an instance member function. - if (const auto* MD = dyn_cast(FD)) { - if (MD->isInstance()) { + if (MD) { + if (isLambdaCallOperator(MD)) { + QualType ptrType = m_Context.getPointerType(m_Context.getRecordType( + FD->getDeclContext()->getOuterLexicalRecordContext())); + baseDiff = + StmtDiff(Clone(dyn_cast(CE)->getArg(0)), + new (m_Context) CXXNullPtrLiteralExpr(ptrType, Loc)); + } else if (MD->isInstance()) { const Expr* baseOriginalE = nullptr; if (const auto* MCE = dyn_cast(CE)) baseOriginalE = MCE->getImplicitObjectArgument(); @@ -1700,7 +1708,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) - if (DerivedCallOutputArgs[i + isaMethod]) + if (MD && isLambdaCallOperator(MD)) { + if (const auto* paramDecl = FD->getParamDecl(i)) + pullbackRequest.DVI.push_back(paramDecl); + } else if (DerivedCallOutputArgs[i + isaMethod]) pullbackRequest.DVI.push_back(FD->getParamDecl(i)); FunctionDecl* pullbackFD = nullptr; @@ -2735,6 +2746,31 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool promoteToFnScope = !getCurrentScope()->isFunctionScope() && m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass; + + // If the DeclStmt is not empty, check the first declaration in case it is a + // lambda function. This case it is treated separately for now and we don't + // create a variable for its derivative. + bool isLambda = false; + const auto* declsBegin = DS->decls().begin(); + if (declsBegin != DS->decls().end() && isa(*declsBegin)) { + auto* VD = dyn_cast(*declsBegin); + QualType QT = VD->getType(); + if (!QT->isPointerType()) { + auto* typeDecl = QT->getAsCXXRecordDecl(); + // We should also simply copy the original lambda. The differentiation + // of lambdas is happening in the `VisitCallExpr`. For now, only the + // declarations with lambda expressions without captures are supported. + isLambda = typeDecl && typeDecl->isLambda(); + if (isLambda) { + for (auto* D : DS->decls()) + if (auto* VD = dyn_cast(D)) + decls.push_back(VD); + Stmt* DSClone = BuildDeclStmt(decls); + return StmtDiff(DSClone, nullptr); + } + } + } + // For each variable declaration v, create another declaration _d_v to // store derivatives for potential reassignments. E.g. // double y = x; @@ -2742,7 +2778,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // double _d_y = _d_x; double y = x; for (auto* D : DS->decls()) { if (auto* VD = dyn_cast(D)) { - DeclDiff VDDiff = DifferentiateVarDecl(VD); + DeclDiff VDDiff; + if (!isLambda) + VDDiff = DifferentiateVarDecl(VD); // Check if decl's name is the same as before. The name may be changed // if decl name collides with something in the derivative body. @@ -2762,8 +2800,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // double _d_y = x; // copied from original function, collides with // _d_y // } - if (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || - VD->getType() != VDDiff.getDecl()->getType()) + if (!isLambda && + (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || + VD->getType() != VDDiff.getDecl()->getType())) m_DeclReplacements[VD] = VDDiff.getDecl(); // Here, we move the declaration to the function global scope. diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C new file mode 100644 index 000000000..98b5e5536 --- /dev/null +++ b/test/Gradient/Lambdas.C @@ -0,0 +1,79 @@ +// RUN: %cladclang %s -I%S/../../include -oLambdas.out 2>&1 | %filecheck %s +// RUN: ./Lambdas.out | %filecheck_exec %s +// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oLambdas.out +// RUN: ./Lambdas.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" + +double f1(double i, double j) { + auto _f = [] (double t) { + return t*t + 1.0; + }; + return i + _f(j); +} + +// CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const; +// CHECK-NEXT: void f1_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK-NEXT: auto _f = []{{ ?}}(double t) { +// CHECK-NEXT: return t * t + 1.; +// CHECK-NEXT: }{{;?}} +// CHECK: { +// CHECK-NEXT: *_d_i += 1; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: _f.operator_call_pullback(j, 1, &_r0); +// CHECK-NEXT: *_d_j += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double f2(double i, double j) { + auto _f = [] (double t, double k) { + return t + k; + }; + double x = _f(i + j, i); + return x; +} + +// CHECK: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const; +// CHECK-NEXT: void f2_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK-NEXT: double _d_x = 0; +// CHECK-NEXT: auto _f = []{{ ?}}(double t, double k) { +// CHECK-NEXT: return t + k; +// CHECK-NEXT: }{{;?}} +// CHECK: double x = operator()(i + j, i); +// CHECK-NEXT: _d_x += 1; +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: _f.operator_call_pullback(i + j, i, _d_x, &_r0, &_r1); +// CHECK-NEXT: *_d_i += _r0; +// CHECK-NEXT: *_d_j += _r0; +// CHECK-NEXT: *_d_i += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } + + +int main() { + auto df1 = clad::gradient(f1); + double di = 0, dj = 0; + df1.execute(3, 4, &di, &dj); + printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 1.00 8.00 + + auto df2 = clad::gradient(f2); + di = 0, dj = 0; + df2.execute(3, 4, &di, &dj); + printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 2.00 1.00 +} + +// CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const { +// CHECK-NEXT: { +// CHECK-NEXT: *_d_t += _d_y * t; +// CHECK-NEXT: *_d_t += t * _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const { +// CHECK-NEXT: { +// CHECK-NEXT: *_d_t += _d_y; +// CHECK-NEXT: *_d_k += _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file