From 7bf55f3a56597cbee522c586f599ce4be64c0bcb Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Thu, 12 Dec 2024 16:21:19 +0100 Subject: [PATCH] Remove static set from DiffReq --- .../clad/Differentiator/DerivativeBuilder.h | 2 ++ include/clad/Differentiator/DiffPlanner.h | 24 +++++++++++--- lib/Differentiator/DiffPlanner.cpp | 32 ++++++++++++------- lib/Differentiator/ReverseModeVisitor.cpp | 32 ++++++++++--------- 4 files changed, 59 insertions(+), 31 deletions(-) diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 5e9d54ac2..8a70fc6f3 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -13,6 +13,7 @@ #include "clang/Sema/Sema.h" #include "clad/Differentiator/DerivedFnCollector.h" #include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/DynamicGraph.h" #include #include @@ -72,6 +73,7 @@ namespace clad { class DerivativeBuilder { private: friend class VisitorBase; + friend class DiffRequest; friend class BaseForwardModeVisitor; friend class PushForwardModeVisitor; friend class VectorForwardModeVisitor; diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index a33e10f79..b025b5c8f 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -3,12 +3,13 @@ #include "clang/AST/RecursiveASTVisitor.h" #include "llvm/ADT/SmallSet.h" +#include +#include +#include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" -#include -#include namespace clang { class CallExpr; class CompilerInstance; @@ -21,6 +22,7 @@ class Type; } // namespace clang namespace clad { +class DerivativeBuilder; /// A struct containing information about request to differentiate a function. struct DiffRequest { @@ -34,12 +36,13 @@ struct DiffRequest { } m_TbrRunInfo; mutable struct ActivityRunInfo { + std::set VariedDecls; bool HasAnalysisRun = false; } m_ActivityRunInfo; public: - /// All varied declarations. - static std::set AllVariedDecls; + const DerivativeBuilder* Builder = nullptr; + // static std::set AllVariedDecls; /// Function to be differentiated. const clang::FunctionDecl* Function = nullptr; /// Name of the base function to be differentiated. Can be different from @@ -128,7 +131,8 @@ struct DiffRequest { Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && EnableVariedAnalysis == other.EnableVariedAnalysis && DVI == other.DVI && use_enzyme == other.use_enzyme && - DeclarationOnly == other.DeclarationOnly; + DeclarationOnly == other.DeclarationOnly && + getVariedDecls() == other.getVariedDecls(); } const clang::FunctionDecl* operator->() const { return Function; } @@ -145,6 +149,16 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; + + void setVariedDecls(std::set init) { + for (auto* vd : init) + this->m_ActivityRunInfo.VariedDecls.insert(vd); + } + std::set getVariedDecls() const { + return this->m_ActivityRunInfo.VariedDecls; + } + DiffRequest() {} + DiffRequest(DerivativeBuilder& builder) : Builder(&builder) {} }; using DiffInterval = std::vector; diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index f404391a3..ec5aaf6ed 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -22,8 +22,8 @@ using namespace clang; namespace clad { -std::set DiffRequest::AllVariedDecls; -static SourceLocation noLoc; +// std::set DiffRequest::AllVariedDecls; +static SourceLocation noloc; /// Returns `DeclRefExpr` node corresponding to the function, method or /// functor argument which is to be differentiated. @@ -62,7 +62,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { auto callOperatorDeclName = m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName( OverloadedOperatorKind::OO_Call); - LookupResult R(m_SemaRef, callOperatorDeclName, noLoc, + LookupResult R(m_SemaRef, callOperatorDeclName, noloc, Sema::LookupNameKind::LookupMemberName); // We do not want diagnostics that would fire because of this lookup. R.suppressDiagnostics(); @@ -149,7 +149,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { auto* newFnDRE = clad_compat::GetResult(m_SemaRef.BuildDeclRefExpr( callOperator, callOperator->getType(), - CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noLoc, &CSS)); + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noloc, &CSS)); m_FnDRE = cast(newFnDRE); } return false; @@ -198,7 +198,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { auto kernelArgIdx = numArgs - 1; auto* cudaKernelFlag = SemaRef - .ActOnCXXBoolLiteral(noLoc, + .ActOnCXXBoolLiteral(noloc, replacementFD->hasAttr() ? tok::kw_true : tok::kw_false) @@ -209,7 +209,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { // Create ref to generated FD. DeclRefExpr* DRE = - DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noLoc, replacementFD, + DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noloc, replacementFD, /*RefersToEnclosingVariableOrCapture=*/false, replacementFD->getNameInfo(), replacementFD->getType(), oldDRE->getValueKind()); @@ -225,7 +225,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { // Add the "&" operator auto* newUnOp = SemaRef - .BuildUnaryOp(nullptr, noLoc, UnaryOperatorKind::UO_AddrOf, DRE) + .BuildUnaryOp(nullptr, noloc, UnaryOperatorKind::UO_AddrOf, DRE) .get(); call->setArg(derivedFnArgIdx, newUnOp); } @@ -618,15 +618,25 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { return true; if (!m_ActivityRunInfo.HasAnalysisRun) { + if (Builder) + for (auto diffreq : this->Builder->m_DiffRequestGraph.getNodes()) + for (auto vd : diffreq.getVariedDecls()) + m_ActivityRunInfo.VariedDecls.insert(vd); + if (Args) for (const auto& dParam : DVI) - AllVariedDecls.insert(cast(dParam.param)); - VariedAnalyzer analyzer(Function->getASTContext(), AllVariedDecls); + m_ActivityRunInfo.VariedDecls.insert(cast(dParam.param)); + VariedAnalyzer analyzer(Function->getASTContext(), + m_ActivityRunInfo.VariedDecls); analyzer.Analyze(Function); m_ActivityRunInfo.HasAnalysisRun = true; + if (Builder) + this->Builder->m_DiffRequestGraph.addNode(*this); } - auto found = AllVariedDecls.find(VD); - return found != AllVariedDecls.end(); + auto found = m_ActivityRunInfo.VariedDecls.find(VD); + return found != m_ActivityRunInfo.VariedDecls.end(); + + return false; } bool DiffCollector::VisitCallExpr(CallExpr* E) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 61532f656..7b8f9966c 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -216,6 +216,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args); auto derivativeBaseName = m_DiffReq.BaseFunctionName; + // llvm::errs() << "\nBaseFunctionName: " << derivativeBaseName << "\n"; std::string gradientName = derivativeBaseName + funcPostfix(); // To be consistent with older tests, nothing is appended to 'f_grad' if // we differentiate w.r.t. all the parameters at once. @@ -1946,27 +1947,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Overloaded derivative was not found, request the CladPlugin to // derive the called function. - DiffRequest pullbackRequest{}; + DiffRequest pullbackRequest(m_Builder); pullbackRequest.Function = FD; // Mark the indexes of the global args. Necessary if the argument of the // call has a different name than the function's signature parameter. pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs; - pullbackRequest.BaseFunctionName = - clad::utils::ComputeEffectiveFnName(FD); - pullbackRequest.Mode = DiffMode::experimental_pullback; - // Silence diag outputs in nested derivation process. - pullbackRequest.VerboseDiags = false; - pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; - pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; - bool isaMethod = isa(FD); - for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) - if (MD && isLambdaCallOperator(MD)) { - if (const auto* paramDecl = FD->getParamDecl(i)) - pullbackRequest.DVI.push_back(paramDecl); - } else if (DerivedCallOutputArgs[i + isaMethod]) - pullbackRequest.DVI.push_back(FD->getParamDecl(i)); + pullbackRequest.BaseFunctionName = + clad::utils::ComputeEffectiveFnName(FD); + pullbackRequest.Mode = DiffMode::experimental_pullback; + // Silence diag outputs in nested derivation process. + pullbackRequest.VerboseDiags = false; + pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; + pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; + pullbackRequest.setVariedDecls(m_DiffReq.getVariedDecls()); + bool isaMethod = isa(FD); + for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) + if (MD && isLambdaCallOperator(MD)) { + if (const auto* paramDecl = FD->getParamDecl(i)) + pullbackRequest.DVI.push_back(paramDecl); + } else if (DerivedCallOutputArgs[i + isaMethod]) + pullbackRequest.DVI.push_back(FD->getParamDecl(i)); FunctionDecl* pullbackFD = nullptr; if (m_ExternalSource)