diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h index f5652b295de16..3bc0bdff2bdd1 100644 --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -158,8 +158,8 @@ class ASTNodeTraverser ConstStmtVisitor::Visit(S); // Some statements have custom mechanisms for dumping their children. - if (isa(S) || isa(S) || - isa(S) || isa(S)) + if (isa(S)) return; if (Traversal == TK_IgnoreUnlessSpelledInSource && @@ -585,6 +585,12 @@ class ASTNodeTraverser void VisitTopLevelStmtDecl(const TopLevelStmtDecl *D) { Visit(D->getStmt()); } + void VisitOutlinedFunctionDecl(const OutlinedFunctionDecl *D) { + for (const ImplicitParamDecl *Parameter : D->parameters()) + Visit(Parameter); + Visit(D->getBody()); + } + void VisitCapturedDecl(const CapturedDecl *D) { Visit(D->getBody()); } void VisitOMPThreadPrivateDecl(const OMPThreadPrivateDecl *D) { @@ -815,6 +821,12 @@ class ASTNodeTraverser Visit(Node->getCapturedDecl()); } + void VisitSYCLKernelCallStmt(const SYCLKernelCallStmt *Node) { + Visit(Node->getOriginalStmt()); + if (Traversal != TK_IgnoreUnlessSpelledInSource) + Visit(Node->getOutlinedFunctionDecl()); + } + void VisitOMPExecutableDirective(const OMPExecutableDirective *Node) { for (const auto *C : Node->clauses()) Visit(C); diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h index 16fc98aa1a57f..901ec1e48ca08 100644 --- a/clang/include/clang/AST/Decl.h +++ b/clang/include/clang/AST/Decl.h @@ -4678,6 +4678,96 @@ class BlockDecl : public Decl, public DeclContext { } }; +/// Represents a partial function definition. +/// +/// An outlined function declaration contains the parameters and body of +/// a function independent of other function definition concerns such +/// as function name, type, and calling convention. Such declarations may +/// be used to hold a parameterized and transformed sequence of statements +/// used to generate a target dependent function definition without losing +/// association with the original statements. See SYCLKernelCallStmt as an +/// example. +class OutlinedFunctionDecl final + : public Decl, + public DeclContext, + private llvm::TrailingObjects { +protected: + size_t numTrailingObjects(OverloadToken) { + return NumParams; + } + +private: + /// The number of parameters to the outlined function. + unsigned NumParams; + + /// The body of the outlined function. + llvm::PointerIntPair BodyAndNothrow; + + explicit OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams); + + ImplicitParamDecl *const *getParams() const { + return getTrailingObjects(); + } + + ImplicitParamDecl **getParams() { + return getTrailingObjects(); + } + +public: + friend class ASTDeclReader; + friend class ASTDeclWriter; + friend TrailingObjects; + + static OutlinedFunctionDecl *Create(ASTContext &C, DeclContext *DC, + unsigned NumParams); + static OutlinedFunctionDecl *CreateDeserialized(ASTContext &C, + GlobalDeclID ID, + unsigned NumParams); + + Stmt *getBody() const override; + void setBody(Stmt *B); + + bool isNothrow() const; + void setNothrow(bool Nothrow = true); + + unsigned getNumParams() const { return NumParams; } + + ImplicitParamDecl *getParam(unsigned i) const { + assert(i < NumParams); + return getParams()[i]; + } + void setParam(unsigned i, ImplicitParamDecl *P) { + assert(i < NumParams); + getParams()[i] = P; + } + + // ArrayRef interface to parameters. + ArrayRef parameters() const { + return {getParams(), getNumParams()}; + } + MutableArrayRef parameters() { + return {getParams(), getNumParams()}; + } + + using param_iterator = ImplicitParamDecl *const *; + using param_range = llvm::iterator_range; + + /// Retrieve an iterator pointing to the first parameter decl. + param_iterator param_begin() const { return getParams(); } + /// Retrieve an iterator one past the last parameter decl. + param_iterator param_end() const { return getParams() + NumParams; } + + // Implement isa/cast/dyncast/etc. + static bool classof(const Decl *D) { return classofKind(D->getKind()); } + static bool classofKind(Kind K) { return K == OutlinedFunction; } + static DeclContext *castToDeclContext(const OutlinedFunctionDecl *D) { + return static_cast(const_cast(D)); + } + static OutlinedFunctionDecl *castFromDeclContext(const DeclContext *DC) { + return static_cast(const_cast(DC)); + } +}; + /// Represents the body of a CapturedStmt, and serves as its DeclContext. class CapturedDecl final : public Decl, diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index d500f4eadef75..c4a1d03f1b3d1 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -37,6 +37,7 @@ #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenACC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtSYCL.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/TemplateName.h" #include "clang/AST/Type.h" @@ -1581,6 +1582,11 @@ DEF_TRAVERSE_DECL(BlockDecl, { ShouldVisitChildren = false; }) +DEF_TRAVERSE_DECL(OutlinedFunctionDecl, { + TRY_TO(TraverseStmt(D->getBody())); + ShouldVisitChildren = false; +}) + DEF_TRAVERSE_DECL(CapturedDecl, { TRY_TO(TraverseStmt(D->getBody())); ShouldVisitChildren = false; @@ -2904,6 +2910,14 @@ DEF_TRAVERSE_STMT(SEHFinallyStmt, {}) DEF_TRAVERSE_STMT(SEHLeaveStmt, {}) DEF_TRAVERSE_STMT(CapturedStmt, { TRY_TO(TraverseDecl(S->getCapturedDecl())); }) +DEF_TRAVERSE_STMT(SYCLKernelCallStmt, { + if (getDerived().shouldVisitImplicitCode()) { + TRY_TO(TraverseStmt(S->getOriginalStmt())); + TRY_TO(TraverseDecl(S->getOutlinedFunctionDecl())); + ShouldVisitChildren = false; + } +}) + DEF_TRAVERSE_STMT(CXXOperatorCallExpr, {}) DEF_TRAVERSE_STMT(CXXRewrittenBinaryOperator, { if (!getDerived().shouldVisitImplicitCode()) { diff --git a/clang/include/clang/AST/StmtSYCL.h b/clang/include/clang/AST/StmtSYCL.h new file mode 100644 index 0000000000000..ac356cb0bf384 --- /dev/null +++ b/clang/include/clang/AST/StmtSYCL.h @@ -0,0 +1,95 @@ +//===- StmtSYCL.h - Classes for SYCL kernel calls ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines SYCL AST classes used to represent calls to SYCL kernels. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_AST_STMTSYCL_H +#define LLVM_CLANG_AST_STMTSYCL_H + +#include "clang/AST/ASTContext.h" +#include "clang/AST/Decl.h" +#include "clang/AST/Stmt.h" +#include "clang/Basic/SourceLocation.h" + +namespace clang { + +//===----------------------------------------------------------------------===// +// AST classes for SYCL kernel calls. +//===----------------------------------------------------------------------===// + +/// SYCLKernelCallStmt represents the transformation that is applied to the body +/// of a function declared with the sycl_kernel_entry_point attribute. The body +/// of such a function specifies the statements to be executed on a SYCL device +/// to invoke a SYCL kernel with a particular set of kernel arguments. The +/// SYCLKernelCallStmt associates an original statement (the compound statement +/// that is the function body) with an OutlinedFunctionDecl that holds the +/// kernel parameters and the transformed body. During code generation, the +/// OutlinedFunctionDecl is used to emit an offload kernel entry point suitable +/// for invocation from a SYCL library implementation. If executed, the +/// SYCLKernelCallStmt behaves as a no-op; no code generation is performed for +/// it. +class SYCLKernelCallStmt : public Stmt { + friend class ASTStmtReader; + friend class ASTStmtWriter; + +private: + Stmt *OriginalStmt = nullptr; + OutlinedFunctionDecl *OFDecl = nullptr; + +public: + /// Construct a SYCL kernel call statement. + SYCLKernelCallStmt(Stmt *OS, OutlinedFunctionDecl *OFD) + : Stmt(SYCLKernelCallStmtClass), OriginalStmt(OS), OFDecl(OFD) {} + + /// Construct an empty SYCL kernel call statement. + SYCLKernelCallStmt(EmptyShell Empty) + : Stmt(SYCLKernelCallStmtClass, Empty) {} + + /// Retrieve the model statement. + Stmt *getOriginalStmt() { return OriginalStmt; } + const Stmt *getOriginalStmt() const { return OriginalStmt; } + void setOriginalStmt(Stmt *S) { OriginalStmt = S; } + + /// Retrieve the outlined function declaration. + OutlinedFunctionDecl *getOutlinedFunctionDecl() { return OFDecl; } + const OutlinedFunctionDecl *getOutlinedFunctionDecl() const { return OFDecl; } + + /// Set the outlined function declaration. + void setOutlinedFunctionDecl(OutlinedFunctionDecl *OFD) { + OFDecl = OFD; + } + + SourceLocation getBeginLoc() const LLVM_READONLY { + return getOriginalStmt()->getBeginLoc(); + } + + SourceLocation getEndLoc() const LLVM_READONLY { + return getOriginalStmt()->getEndLoc(); + } + + SourceRange getSourceRange() const LLVM_READONLY { + return getOriginalStmt()->getSourceRange(); + } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == SYCLKernelCallStmtClass; + } + + child_range children() { + return child_range(&OriginalStmt, &OriginalStmt + 1); + } + + const_child_range children() const { + return const_child_range(&OriginalStmt, &OriginalStmt + 1); + } +}; + +} // end namespace clang + +#endif diff --git a/clang/include/clang/AST/StmtVisitor.h b/clang/include/clang/AST/StmtVisitor.h index 990aa2df180d4..8b7b728deaff2 100644 --- a/clang/include/clang/AST/StmtVisitor.h +++ b/clang/include/clang/AST/StmtVisitor.h @@ -22,6 +22,7 @@ #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenACC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtSYCL.h" #include "clang/Basic/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" diff --git a/clang/include/clang/Basic/DeclNodes.td b/clang/include/clang/Basic/DeclNodes.td index 48396e85c5ada..723113dc2486e 100644 --- a/clang/include/clang/Basic/DeclNodes.td +++ b/clang/include/clang/Basic/DeclNodes.td @@ -101,6 +101,7 @@ def Friend : DeclNode; def FriendTemplate : DeclNode; def StaticAssert : DeclNode; def Block : DeclNode, DeclContext; +def OutlinedFunction : DeclNode, DeclContext; def Captured : DeclNode, DeclContext; def Import : DeclNode; def OMPThreadPrivate : DeclNode; diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td index ce2c48bd3c84e..53fc77bbbcecc 100644 --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -24,6 +24,7 @@ def SwitchCase : StmtNode; def CaseStmt : StmtNode; def DefaultStmt : StmtNode; def CapturedStmt : StmtNode; +def SYCLKernelCallStmt : StmtNode; // Statements that might produce a value (for example, as the last non-null // statement in a GNU statement-expression). diff --git a/clang/include/clang/Sema/SemaSYCL.h b/clang/include/clang/Sema/SemaSYCL.h index 5bb0de40c886c..b4f607d1287bc 100644 --- a/clang/include/clang/Sema/SemaSYCL.h +++ b/clang/include/clang/Sema/SemaSYCL.h @@ -65,6 +65,7 @@ class SemaSYCL : public SemaBase { void handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL); void CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD); + StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, Stmt *Body); }; } // namespace clang diff --git a/clang/include/clang/Sema/Template.h b/clang/include/clang/Sema/Template.h index 9800f75f676aa..4206bd50b13dd 100644 --- a/clang/include/clang/Sema/Template.h +++ b/clang/include/clang/Sema/Template.h @@ -627,7 +627,10 @@ enum class TemplateSubstitutionKind : char { #define EMPTY(DERIVED, BASE) #define LIFETIMEEXTENDEDTEMPORARY(DERIVED, BASE) - // Decls which use special-case instantiation code. +// Decls which never appear inside a template. +#define OUTLINEDFUNCTION(DERIVED, BASE) + +// Decls which use special-case instantiation code. #define BLOCK(DERIVED, BASE) #define CAPTURED(DERIVED, BASE) #define IMPLICITPARAM(DERIVED, BASE) diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h index aac165130b719..87ac62551a142 100644 --- a/clang/include/clang/Serialization/ASTBitCodes.h +++ b/clang/include/clang/Serialization/ASTBitCodes.h @@ -1312,6 +1312,9 @@ enum DeclCode { /// A BlockDecl record. DECL_BLOCK, + /// A OutlinedFunctionDecl record. + DECL_OUTLINEDFUNCTION, + /// A CapturedDecl record. DECL_CAPTURED, @@ -1588,6 +1591,9 @@ enum StmtCode { /// A CapturedStmt record. STMT_CAPTURED, + /// A SYCLKernelCallStmt record. + STMT_SYCLKERNELCALL, + /// A GCC-style AsmStmt record. STMT_GCCASM, diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp index 308551c306151..eaf0748395268 100644 --- a/clang/lib/AST/ASTStructuralEquivalence.cpp +++ b/clang/lib/AST/ASTStructuralEquivalence.cpp @@ -76,6 +76,7 @@ #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenACC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtSYCL.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/TemplateName.h" #include "clang/AST/Type.h" diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp index 97e23dd1aaa92..5bce2c37bf058 100644 --- a/clang/lib/AST/Decl.cpp +++ b/clang/lib/AST/Decl.cpp @@ -5440,6 +5440,31 @@ BlockDecl *BlockDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID) { return new (C, ID) BlockDecl(nullptr, SourceLocation()); } + +OutlinedFunctionDecl::OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams) + : Decl(OutlinedFunction, DC, SourceLocation()), DeclContext(OutlinedFunction), + NumParams(NumParams), BodyAndNothrow(nullptr, false) {} + +OutlinedFunctionDecl *OutlinedFunctionDecl::Create(ASTContext &C, DeclContext *DC, + unsigned NumParams) { + return new (C, DC, additionalSizeToAlloc(NumParams)) + OutlinedFunctionDecl(DC, NumParams); +} + +OutlinedFunctionDecl *OutlinedFunctionDecl::CreateDeserialized(ASTContext &C, + GlobalDeclID ID, + unsigned NumParams) { + return new (C, ID, additionalSizeToAlloc(NumParams)) + OutlinedFunctionDecl(nullptr, NumParams); +} + +Stmt *OutlinedFunctionDecl::getBody() const { return BodyAndNothrow.getPointer(); } +void OutlinedFunctionDecl::setBody(Stmt *B) { BodyAndNothrow.setPointer(B); } + +bool OutlinedFunctionDecl::isNothrow() const { return BodyAndNothrow.getInt(); } +void OutlinedFunctionDecl::setNothrow(bool Nothrow) { BodyAndNothrow.setInt(Nothrow); } + + CapturedDecl::CapturedDecl(DeclContext *DC, unsigned NumParams) : Decl(Captured, DC, SourceLocation()), DeclContext(Captured), NumParams(NumParams), ContextParam(0), BodyAndNothrow(nullptr, false) {} diff --git a/clang/lib/AST/DeclBase.cpp b/clang/lib/AST/DeclBase.cpp index fb701f76231bc..77ca8c5c8accd 100644 --- a/clang/lib/AST/DeclBase.cpp +++ b/clang/lib/AST/DeclBase.cpp @@ -958,6 +958,7 @@ unsigned Decl::getIdentifierNamespaceForKind(Kind DeclKind) { case PragmaDetectMismatch: case Block: case Captured: + case OutlinedFunction: case TranslationUnit: case ExternCContext: case Decomposition: @@ -1237,6 +1238,8 @@ template static Decl *getNonClosureContext(T *D) { return getNonClosureContext(BD->getParent()); if (auto *CD = dyn_cast(D)) return getNonClosureContext(CD->getParent()); + if (auto *OFD = dyn_cast(D)) + return getNonClosureContext(OFD->getParent()); return nullptr; } @@ -1429,6 +1432,7 @@ DeclContext *DeclContext::getPrimaryContext() { case Decl::TopLevelStmt: case Decl::Block: case Decl::Captured: + case Decl::OutlinedFunction: case Decl::OMPDeclareReduction: case Decl::OMPDeclareMapper: case Decl::RequiresExprBody: diff --git a/clang/lib/AST/Stmt.cpp b/clang/lib/AST/Stmt.cpp index d6a351a78c7ba..685c00d0cb44f 100644 --- a/clang/lib/AST/Stmt.cpp +++ b/clang/lib/AST/Stmt.cpp @@ -25,6 +25,7 @@ #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenACC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtSYCL.h" #include "clang/AST/Type.h" #include "clang/Basic/CharInfo.h" #include "clang/Basic/LLVM.h" diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp index 52bcb5135d351..b5def6fbe525c 100644 --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -30,6 +30,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtSYCL.h" #include "clang/AST/StmtVisitor.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/Type.h" @@ -582,6 +583,10 @@ void StmtPrinter::VisitCapturedStmt(CapturedStmt *Node) { PrintStmt(Node->getCapturedDecl()->getBody()); } +void StmtPrinter::VisitSYCLKernelCallStmt(SYCLKernelCallStmt *Node) { + PrintStmt(Node->getOutlinedFunctionDecl()->getBody()); +} + void StmtPrinter::VisitObjCAtTryStmt(ObjCAtTryStmt *Node) { Indent() << "@try"; if (auto *TS = dyn_cast(Node->getTryBody())) { diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 0f1ebc68a4f76..85b59f714ba84 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -392,6 +392,10 @@ void StmtProfiler::VisitCapturedStmt(const CapturedStmt *S) { VisitStmt(S); } +void StmtProfiler::VisitSYCLKernelCallStmt(const SYCLKernelCallStmt *S) { + VisitStmt(S); +} + void StmtProfiler::VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { VisitStmt(S); } diff --git a/clang/lib/CodeGen/CGDecl.cpp b/clang/lib/CodeGen/CGDecl.cpp index 6f3ff050cb697..cda1a15e92f62 100644 --- a/clang/lib/CodeGen/CGDecl.cpp +++ b/clang/lib/CodeGen/CGDecl.cpp @@ -97,6 +97,7 @@ void CodeGenFunction::EmitDecl(const Decl &D) { case Decl::Friend: case Decl::FriendTemplate: case Decl::Block: + case Decl::OutlinedFunction: case Decl::Captured: case Decl::UsingShadow: case Decl::ConstructorUsingShadow: diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp index c87ec899798aa..9d8baa85ba592 100644 --- a/clang/lib/CodeGen/CGStmt.cpp +++ b/clang/lib/CodeGen/CGStmt.cpp @@ -114,6 +114,7 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef Attrs) { case Stmt::DefaultStmtClass: case Stmt::CaseStmtClass: case Stmt::SEHLeaveStmtClass: + case Stmt::SYCLKernelCallStmtClass: llvm_unreachable("should have emitted these statements as simple"); #define STMT(Type, Base) @@ -527,6 +528,23 @@ bool CodeGenFunction::EmitSimpleStmt(const Stmt *S, case Stmt::SEHLeaveStmtClass: EmitSEHLeaveStmt(cast(*S)); break; + case Stmt::SYCLKernelCallStmtClass: + // SYCL kernel call statements are generated as wrappers around the body + // of functions declared with the sycl_kernel_entry_point attribute. Such + // functions are used to specify how a SYCL kernel (a function object) is + // to be invoked; the SYCL kernel call statement contains a transformed + // variation of the function body and is used to generate a SYCL kernel + // caller function; a function that serves as the device side entry point + // used to execute the SYCL kernel. The sycl_kernel_entry_point attributed + // function is invoked by host code in order to trigger emission of the + // device side SYCL kernel caller function and to generate metadata needed + // by SYCL run-time library implementations; the function is otherwise + // intended to have no effect. As such, the function body is not evaluated + // as part of the invocation during host compilation (and the function + // should not be called or emitted during device compilation); the SYCL + // kernel call statement is thus handled as a null statement for the + // purpose of code generation. + break; } return true; } diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 457e1477bb2ee..83eb2ff2b3e0a 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -28,6 +28,7 @@ #include "clang/AST/ExprOpenMP.h" #include "clang/AST/StmtOpenACC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtSYCL.h" #include "clang/AST/Type.h" #include "clang/Basic/ABI.h" #include "clang/Basic/CapturedStmt.h" diff --git a/clang/lib/Sema/JumpDiagnostics.cpp b/clang/lib/Sema/JumpDiagnostics.cpp index d465599450e7f..2361c567581e3 100644 --- a/clang/lib/Sema/JumpDiagnostics.cpp +++ b/clang/lib/Sema/JumpDiagnostics.cpp @@ -18,6 +18,7 @@ #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenACC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtSYCL.h" #include "clang/Basic/SourceLocation.h" #include "clang/Sema/SemaInternal.h" #include "llvm/ADT/BitVector.h" diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 75920052c4f0c..96441a919f7e6 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -15986,7 +15986,8 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body, CheckCoroutineWrapper(FD); } - // Diagnose invalid SYCL kernel entry point function declarations. + // Diagnose invalid SYCL kernel entry point function declarations + // and build SYCLKernelCallStmts for valid ones. if (FD && !FD->isInvalidDecl() && FD->hasAttr()) { SYCLKernelEntryPointAttr *SKEPAttr = FD->getAttr(); @@ -16003,6 +16004,13 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body, << /*coroutine*/ 7; SKEPAttr->setInvalidAttr(); } + + if (Body && !FD->isTemplated() && !SKEPAttr->isInvalidAttr()) { + StmtResult SR = SYCL().BuildSYCLKernelCallStmt(FD, Body); + if (SR.isInvalid()) + return nullptr; + Body = SR.get(); + } } { diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp index 254ad05c5ba74..470d0d753b558 100644 --- a/clang/lib/Sema/SemaExceptionSpec.cpp +++ b/clang/lib/Sema/SemaExceptionSpec.cpp @@ -1427,6 +1427,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) { case Stmt::AttributedStmtClass: case Stmt::BreakStmtClass: case Stmt::CapturedStmtClass: + case Stmt::SYCLKernelCallStmtClass: case Stmt::CaseStmtClass: case Stmt::CompoundStmtClass: case Stmt::ContinueStmtClass: diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index ce53990fdcb18..7f2ccb36bfad9 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -8,8 +8,10 @@ // This implements Semantic Analysis for SYCL constructs. //===----------------------------------------------------------------------===// +#include "TreeTransform.h" #include "clang/Sema/SemaSYCL.h" #include "clang/AST/Mangle.h" +#include "clang/AST/StmtSYCL.h" #include "clang/AST/SYCLKernelInfo.h" #include "clang/AST/TypeOrdering.h" #include "clang/Basic/Diagnostic.h" @@ -362,3 +364,95 @@ void SemaSYCL::CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD) { } } } + +namespace { + +// The body of a function declared with the [[sycl_kernel_entry_point]] +// attribute is cloned and transformed to substitute references to the original +// function parameters with references to replacement variables that stand in +// for SYCL kernel parameters or local variables that reconstitute a decomposed +// SYCL kernel argument. +class OutlinedFunctionDeclBodyInstantiator + : public TreeTransform { +public: + using ParmDeclMap = llvm::DenseMap; + + OutlinedFunctionDeclBodyInstantiator(Sema &S, ParmDeclMap &M) + : TreeTransform(S), + SemaRef(S), MapRef(M) {} + + // A new set of AST nodes is always required. + bool AlwaysRebuild() { + return true; + } + + // Transform ParmVarDecl references to the supplied replacement variables. + ExprResult TransformDeclRefExpr(DeclRefExpr *DRE) { + const ParmVarDecl *PVD = dyn_cast(DRE->getDecl()); + if (PVD) { + ParmDeclMap::iterator I = MapRef.find(PVD); + if (I != MapRef.end()) { + VarDecl *VD = I->second; + assert(SemaRef.getASTContext().hasSameUnqualifiedType(PVD->getType(), + VD->getType())); + assert(!VD->getType().isMoreQualifiedThan(PVD->getType(), + SemaRef.getASTContext())); + VD->setIsUsed(); + return DeclRefExpr::Create( + SemaRef.getASTContext(), DRE->getQualifierLoc(), + DRE->getTemplateKeywordLoc(), VD, false, DRE->getNameInfo(), + DRE->getType(), DRE->getValueKind()); + } + } + return DRE; + } + +private: + Sema &SemaRef; + ParmDeclMap &MapRef; +}; + +} // unnamed namespace + +StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD, Stmt *Body) { + assert(!FD->isInvalidDecl()); + assert(!FD->isTemplated()); + assert(FD->hasPrototype()); + + const auto *SKEPAttr = FD->getAttr(); + assert(SKEPAttr && "Missing sycl_kernel_entry_point attribute"); + assert(!SKEPAttr->isInvalidAttr() && + "sycl_kernel_entry_point attribute is invalid"); + + // Ensure that the kernel name was previously registered and that the + // stored declaration matches. + const SYCLKernelInfo &SKI = + getASTContext().getSYCLKernelInfo(SKEPAttr->getKernelName()); + assert(declaresSameEntity(SKI.getKernelEntryPointDecl(), FD) && + "SYCL kernel name conflict"); + + using ParmDeclMap = OutlinedFunctionDeclBodyInstantiator::ParmDeclMap; + ParmDeclMap ParmMap; + + assert(SemaRef.CurContext == FD); + OutlinedFunctionDecl *OFD = + OutlinedFunctionDecl::Create(getASTContext(), FD, FD->getNumParams()); + unsigned i = 0; + for (ParmVarDecl *PVD : FD->parameters()) { + ImplicitParamDecl *IPD = + ImplicitParamDecl::Create(getASTContext(), OFD, SourceLocation(), + PVD->getIdentifier(), PVD->getType(), + ImplicitParamKind::Other); + OFD->setParam(i, IPD); + ParmMap[PVD] = IPD; + ++i; + } + + OutlinedFunctionDeclBodyInstantiator OFDBodyInstantiator(SemaRef, ParmMap); + Stmt *OFDBody = OFDBodyInstantiator.TransformStmt(Body).get(); + OFD->setBody(OFDBody); + OFD->setNothrow(); + Stmt *NewBody = new (getASTContext()) SYCLKernelCallStmt(Body, OFD); + + return NewBody; +} diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 4a3c739ecbeab..f1a40fbc97eed 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -29,6 +29,7 @@ #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenACC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtSYCL.h" #include "clang/Basic/DiagnosticParse.h" #include "clang/Basic/OpenMPKinds.h" #include "clang/Sema/Designator.h" @@ -17404,6 +17405,16 @@ TreeTransform::TransformCapturedStmt(CapturedStmt *S) { return getSema().ActOnCapturedRegionEnd(Body.get()); } +template +StmtResult +TreeTransform::TransformSYCLKernelCallStmt(SYCLKernelCallStmt *S) { + // SYCLKernelCallStmt nodes are inserted upon completion of a (non-template) + // function definition or instantiation of a function template specialization + // and will therefore never appear in a dependent context. + llvm_unreachable("SYCL kernel call statement cannot appear in dependent " + "context"); +} + template ExprResult TreeTransform::TransformHLSLOutArgExpr(HLSLOutArgExpr *E) { // We can transform the base expression and allow argument resolution to fill diff --git a/clang/lib/Serialization/ASTCommon.cpp b/clang/lib/Serialization/ASTCommon.cpp index ec18e84255ca8..3a62c4ea5595b 100644 --- a/clang/lib/Serialization/ASTCommon.cpp +++ b/clang/lib/Serialization/ASTCommon.cpp @@ -338,6 +338,7 @@ serialization::getDefinitiveDeclContext(const DeclContext *DC) { case Decl::CXXConversion: case Decl::ObjCMethod: case Decl::Block: + case Decl::OutlinedFunction: case Decl::Captured: // Objective C categories, category implementations, and class // implementations can only be defined in one place. @@ -439,6 +440,7 @@ bool serialization::isRedeclarableDeclKind(unsigned Kind) { case Decl::FriendTemplate: case Decl::StaticAssert: case Decl::Block: + case Decl::OutlinedFunction: case Decl::Captured: case Decl::Import: case Decl::OMPThreadPrivate: diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp index dee5169ae5723..0242dabf799c0 100644 --- a/clang/lib/Serialization/ASTReaderDecl.cpp +++ b/clang/lib/Serialization/ASTReaderDecl.cpp @@ -409,6 +409,7 @@ class ASTDeclReader : public DeclVisitor { void VisitFriendTemplateDecl(FriendTemplateDecl *D); void VisitStaticAssertDecl(StaticAssertDecl *D); void VisitBlockDecl(BlockDecl *BD); + void VisitOutlinedFunctionDecl(OutlinedFunctionDecl *D); void VisitCapturedDecl(CapturedDecl *CD); void VisitEmptyDecl(EmptyDecl *D); void VisitLifetimeExtendedTemporaryDecl(LifetimeExtendedTemporaryDecl *D); @@ -451,9 +452,9 @@ class ASTDeclReader : public DeclVisitor { void VisitOMPDeclareMapperDecl(OMPDeclareMapperDecl *D); void VisitOMPRequiresDecl(OMPRequiresDecl *D); void VisitOMPCapturedExprDecl(OMPCapturedExprDecl *D); - }; +}; - } // namespace clang +} // namespace clang namespace { @@ -1793,6 +1794,16 @@ void ASTDeclReader::VisitBlockDecl(BlockDecl *BD) { BD->setCaptures(Reader.getContext(), captures, capturesCXXThis); } +void ASTDeclReader::VisitOutlinedFunctionDecl(OutlinedFunctionDecl *D) { + // NumParams is deserialized by OutlinedFunctionDecl::CreateDeserialized(). + VisitDecl(D); + D->setNothrow(Record.readInt() != 0); + for (unsigned I = 0; I < D->NumParams; ++I) { + D->setParam(I, readDeclAs()); + } + D->setBody(cast_or_null(Record.readStmt())); +} + void ASTDeclReader::VisitCapturedDecl(CapturedDecl *CD) { VisitDecl(CD); unsigned ContextParamPos = Record.readInt(); @@ -4092,6 +4103,9 @@ Decl *ASTReader::ReadDeclRecord(GlobalDeclID ID) { case DECL_TEMPLATE_PARAM_OBJECT: D = TemplateParamObjectDecl::CreateDeserialized(Context, ID); break; + case DECL_OUTLINEDFUNCTION: + D = OutlinedFunctionDecl::CreateDeserialized(Context, ID, Record.readInt()); + break; case DECL_CAPTURED: D = CapturedDecl::CreateDeserialized(Context, ID, Record.readInt()); break; diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 4766f34e9f3a8..484a89462dc63 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -33,6 +33,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtSYCL.h" #include "clang/AST/StmtVisitor.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/Type.h" @@ -528,6 +529,12 @@ void ASTStmtReader::VisitCapturedStmt(CapturedStmt *S) { } } +void ASTStmtReader::VisitSYCLKernelCallStmt(SYCLKernelCallStmt *S) { + VisitStmt(S); + S->setOriginalStmt(Record.readSubStmt()); + S->setOutlinedFunctionDecl(readDeclAs()); +} + void ASTStmtReader::VisitExpr(Expr *E) { VisitStmt(E); CurrentUnpackingBits.emplace(Record.readInt()); @@ -3112,6 +3119,10 @@ Stmt *ASTReader::ReadStmtFromStream(ModuleFile &F) { Context, Record[ASTStmtReader::NumStmtFields]); break; + case STMT_SYCLKERNELCALL: + S = new (Context) SYCLKernelCallStmt(Empty); + break; + case EXPR_CONSTANT: S = ConstantExpr::CreateEmpty( Context, static_cast( diff --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp index f8ed155ca389d..415dc9b80b216 100644 --- a/clang/lib/Serialization/ASTWriterDecl.cpp +++ b/clang/lib/Serialization/ASTWriterDecl.cpp @@ -132,6 +132,7 @@ namespace clang { void VisitFriendTemplateDecl(FriendTemplateDecl *D); void VisitStaticAssertDecl(StaticAssertDecl *D); void VisitBlockDecl(BlockDecl *D); + void VisitOutlinedFunctionDecl(OutlinedFunctionDecl *D); void VisitCapturedDecl(CapturedDecl *D); void VisitEmptyDecl(EmptyDecl *D); void VisitLifetimeExtendedTemporaryDecl(LifetimeExtendedTemporaryDecl *D); @@ -1376,6 +1377,16 @@ void ASTDeclWriter::VisitBlockDecl(BlockDecl *D) { Code = serialization::DECL_BLOCK; } +void ASTDeclWriter::VisitOutlinedFunctionDecl(OutlinedFunctionDecl *D) { + Record.push_back(D->getNumParams()); + VisitDecl(D); + Record.push_back(D->isNothrow() ? 1 : 0); + for (unsigned I = 0; I < D->getNumParams(); ++I) + Record.AddDeclRef(D->getParam(I)); + Record.AddStmt(D->getBody()); + Code = serialization::DECL_OUTLINEDFUNCTION; +} + void ASTDeclWriter::VisitCapturedDecl(CapturedDecl *CD) { Record.push_back(CD->getNumParams()); VisitDecl(CD); diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index 7eedf7da7d3fc..651553244812f 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -609,6 +609,14 @@ void ASTStmtWriter::VisitCapturedStmt(CapturedStmt *S) { Code = serialization::STMT_CAPTURED; } +void ASTStmtWriter::VisitSYCLKernelCallStmt(SYCLKernelCallStmt *S) { + VisitStmt(S); + Record.AddStmt(S->getOriginalStmt()); + Record.AddDeclRef(S->getOutlinedFunctionDecl()); + + Code = serialization::STMT_SYCLKERNELCALL; +} + void ASTStmtWriter::VisitExpr(Expr *E) { VisitStmt(E); diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp index ff8bdcea9a220..140c77790496d 100644 --- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -1822,6 +1822,7 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred, case Stmt::OMPParallelGenericLoopDirectiveClass: case Stmt::OMPTargetParallelGenericLoopDirectiveClass: case Stmt::CapturedStmtClass: + case Stmt::SYCLKernelCallStmtClass: case Stmt::OpenACCComputeConstructClass: case Stmt::OpenACCLoopConstructClass: case Stmt::OpenACCCombinedConstructClass: diff --git a/clang/test/ASTSYCL/ast-dump-sycl-kernel-call-stmt.cpp b/clang/test/ASTSYCL/ast-dump-sycl-kernel-call-stmt.cpp new file mode 100644 index 0000000000000..27604e237adbb --- /dev/null +++ b/clang/test/ASTSYCL/ast-dump-sycl-kernel-call-stmt.cpp @@ -0,0 +1,275 @@ +// Tests without serialization: +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-unknown -fsycl-is-device \ +// RUN: -ast-dump %s \ +// RUN: | FileCheck --match-full-lines %s +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-unknown -fsycl-is-host \ +// RUN: -ast-dump %s \ +// RUN: | FileCheck --match-full-lines %s +// +// Tests with serialization: +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-unknown -fsycl-is-device \ +// RUN: -emit-pch -o %t %s +// RUN: %clang_cc1 -x c++ -std=c++17 -triple x86_64-unknown-unknown -fsycl-is-device \ +// RUN: -include-pch %t -ast-dump-all /dev/null \ +// RUN: | sed -e "s/ //" -e "s/ imported//" \ +// RUN: | FileCheck --match-full-lines %s +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-unknown -fsycl-is-host \ +// RUN: -emit-pch -o %t %s +// RUN: %clang_cc1 -x c++ -std=c++17 -triple x86_64-unknown-unknown -fsycl-is-host \ +// RUN: -include-pch %t -ast-dump-all /dev/null \ +// RUN: | sed -e "s/ //" -e "s/ imported//" \ +// RUN: | FileCheck --match-full-lines %s + +// These tests validate the AST body produced for functions declared with the +// sycl_kernel_entry_point attribute. + +// CHECK: TranslationUnitDecl {{.*}} + +// A unique kernel name type is required for each declared kernel entry point. +template struct KN; + +// A unique invocable type for use with each declared kernel entry point. +template struct K { + template + void operator()(Ts...) const {} +}; + + +[[clang::sycl_kernel_entry_point(KN<1>)]] +void skep1() { +} +// CHECK: |-FunctionDecl {{.*}} skep1 'void ()' +// CHECK-NEXT: | |-SYCLKernelCallStmt {{.*}} +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}} +// CHECK-NEXT: | | `-CompoundStmt {{.*}} +// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<1> + +template +[[clang::sycl_kernel_entry_point(KNT)]] +void skep2(KT k) { + k(); +} +template +void skep2>(K<2>); +// CHECK: |-FunctionTemplateDecl {{.*}} skep2 +// CHECK-NEXT: | |-TemplateTypeParmDecl {{.*}} KNT +// CHECK-NEXT: | |-TemplateTypeParmDecl {{.*}} KT +// CHECK-NEXT: | |-FunctionDecl {{.*}} skep2 'void (KT)' +// CHECK-NEXT: | | |-ParmVarDecl {{.*}} k 'KT' +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | `-CallExpr {{.*}} '' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'KT' lvalue ParmVar {{.*}} 'k' 'KT' +// CHECK-NEXT: | | `-SYCLKernelEntryPointAttr {{.*}} KNT + +// CHECK-NEXT: | `-FunctionDecl {{.*}} skep2 'void (K<2>)' explicit_instantiation_definition +// CHECK-NEXT: | |-TemplateArgument type 'KN<2>' +// CHECK-NEXT: | | `-RecordType {{.*}} 'KN<2>' +// CHECK-NEXT: | | `-ClassTemplateSpecialization {{.*}} 'KN' +// CHECK-NEXT: | |-TemplateArgument type 'K<2>' +// CHECK-NEXT: | | `-RecordType {{.*}} 'K<2>' +// CHECK-NEXT: | | `-ClassTemplateSpecialization {{.*}} 'K' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} k 'K<2>' +// CHECK-NEXT: | |-SYCLKernelCallStmt {{.*}} +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | | `-ImplicitCastExpr {{.*}} 'const K<2>' lvalue +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'K<2>' lvalue ParmVar {{.*}} 'k' 'K<2>' +// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}} +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'K<2>' +// CHECK-NEXT: | | `-CompoundStmt {{.*}} +// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | `-ImplicitCastExpr {{.*}} 'const K<2>' lvalue +// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'K<2>' lvalue ImplicitParam {{.*}} 'k' 'K<2>' +// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<2> + +template +[[clang::sycl_kernel_entry_point(KNT)]] +void skep3(KT k) { + k(); +} +template<> +[[clang::sycl_kernel_entry_point(KN<3>)]] +void skep3>(K<3> k) { + k(); +} +// CHECK: |-FunctionTemplateDecl {{.*}} skep3 +// CHECK-NEXT: | |-TemplateTypeParmDecl {{.*}} KNT +// CHECK-NEXT: | |-TemplateTypeParmDecl {{.*}} KT +// CHECK-NEXT: | |-FunctionDecl {{.*}} skep3 'void (KT)' +// CHECK-NEXT: | | |-ParmVarDecl {{.*}} k 'KT' +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | `-CallExpr {{.*}} '' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'KT' lvalue ParmVar {{.*}} 'k' 'KT' +// CHECK-NEXT: | | `-SYCLKernelEntryPointAttr {{.*}} KNT + +// CHECK-NEXT: | `-Function {{.*}} 'skep3' 'void (K<3>)' +// CHECK-NEXT: |-FunctionDecl {{.*}} skep3 'void (K<3>)' explicit_specialization +// CHECK-NEXT: | |-TemplateArgument type 'KN<3>' +// CHECK-NEXT: | | `-RecordType {{.*}} 'KN<3>' +// CHECK-NEXT: | | `-ClassTemplateSpecialization {{.*}} 'KN' +// CHECK-NEXT: | |-TemplateArgument type 'K<3>' +// CHECK-NEXT: | | `-RecordType {{.*}} 'K<3>' +// CHECK-NEXT: | | `-ClassTemplateSpecialization {{.*}} 'K' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} k 'K<3>' +// CHECK-NEXT: | |-SYCLKernelCallStmt {{.*}} +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | | `-ImplicitCastExpr {{.*}} 'const K<3>' lvalue +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'K<3>' lvalue ParmVar {{.*}} 'k' 'K<3>' +// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}} +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'K<3>' +// CHECK-NEXT: | | `-CompoundStmt {{.*}} +// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | `-ImplicitCastExpr {{.*}} 'const K<3>' lvalue +// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'K<3>' lvalue ImplicitParam {{.*}} 'k' 'K<3>' +// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<3> + +[[clang::sycl_kernel_entry_point(KN<4>)]] +void skep4(K<4> k, int p1, int p2) { + k(p1, p2); +} +// CHECK: |-FunctionDecl {{.*}} skep4 'void (K<4>, int, int)' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} k 'K<4>' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} p1 'int' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} p2 'int' +// CHECK-NEXT: | |-SYCLKernelCallStmt {{.*}} +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'void (*)(int, int) const' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'void (int, int) const' lvalue CXXMethod {{.*}} 'operator()' 'void (int, int) const' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'const K<4>' lvalue +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'K<4>' lvalue ParmVar {{.*}} 'k' 'K<4>' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} 'p1' 'int' +// CHECK-NEXT: | | | `-ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} 'p2' 'int' +// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}} +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'K<4>' +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used p1 'int' +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used p2 'int' +// CHECK-NEXT: | | `-CompoundStmt {{.*}} +// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)(int, int) const' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void (int, int) const' lvalue CXXMethod {{.*}} 'operator()' 'void (int, int) const' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'const K<4>' lvalue +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'K<4>' lvalue ImplicitParam {{.*}} 'k' 'K<4>' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue ImplicitParam {{.*}} 'p1' 'int' +// CHECK-NEXT: | | `-ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'int' lvalue ImplicitParam {{.*}} 'p2' 'int' +// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<4> + +[[clang::sycl_kernel_entry_point(KN<5>)]] +void skep5(int unused1, K<5> k, int unused2, int p, int unused3) { + static int slv = 0; + int lv = 4; + k(slv, 1, p, 3, lv, 5, []{ return 6; }); +} +// CHECK: |-FunctionDecl {{.*}} skep5 'void (int, K<5>, int, int, int)' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} unused1 'int' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} used k 'K<5>' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} unused2 'int' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} used p 'int' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} unused3 'int' +// CHECK-NEXT: | |-SYCLKernelCallStmt {{.*}} +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK: | | `-OutlinedFunctionDecl {{.*}} +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit unused1 'int' +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'K<5>' +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit unused2 'int' +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used p 'int' +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit unused3 'int' +// CHECK-NEXT: | | `-CompoundStmt {{.*}} +// CHECK-NEXT: | | |-DeclStmt {{.*}} +// CHECK-NEXT: | | | `-VarDecl {{.*}} used slv 'int' static cinit +// CHECK-NEXT: | | | `-IntegerLiteral {{.*}} 'int' 0 +// CHECK-NEXT: | | |-DeclStmt {{.*}} +// CHECK-NEXT: | | | `-VarDecl {{.*}} used lv 'int' cinit +// CHECK-NEXT: | | | `-IntegerLiteral {{.*}} 'int' 4 +// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)(int, int, int, int, int, int, (lambda {{.*}}) const' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void (int, int, int, int, int, int, (lambda {{.*}})) const' lvalue CXXMethod {{.*}} 'operator()' 'void (int, int, int, int, int, int, (lambda {{.*}})) const' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'const K<5>' lvalue +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'K<5>' lvalue ImplicitParam {{.*}} 'k' 'K<5>' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue Var {{.*}} 'slv' 'int' +// CHECK-NEXT: | | |-IntegerLiteral {{.*}} 'int' 1 +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue ImplicitParam {{.*}} 'p' 'int' +// CHECK-NEXT: | | |-IntegerLiteral {{.*}} 'int' 3 +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue Var {{.*}} 'lv' 'int' +// CHECK-NEXT: | | |-IntegerLiteral {{.*}} 'int' 5 +// CHECK-NEXT: | | `-LambdaExpr {{.*}} '(lambda {{.*}})' +// CHECK: | `-SYCLKernelEntryPointAttr {{.*}} KN<5> + +struct S6 { + void operator()() const; +}; +[[clang::sycl_kernel_entry_point(KN<6>)]] +void skep6(const S6 &k) { + k(); +} +// CHECK: |-FunctionDecl {{.*}} skep6 'void (const S6 &)' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} used k 'const S6 &' +// CHECK-NEXT: | |-SYCLKernelCallStmt {{.*}} +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'const S6' lvalue ParmVar {{.*}} 'k' 'const S6 &' +// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}} +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'const S6 &' +// CHECK-NEXT: | | `-CompoundStmt {{.*}} +// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'const S6' lvalue ImplicitParam {{.*}} 'k' 'const S6 &' +// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<6> + +// Parameter types are not required to be complete at the point of a +// non-defining declaration. +struct S7; +[[clang::sycl_kernel_entry_point(KN<7>)]] +void skep7(S7 k); +struct S7 { + void operator()() const; +}; +[[clang::sycl_kernel_entry_point(KN<7>)]] +void skep7(S7 k) { + k(); +} +// CHECK: |-FunctionDecl {{.*}} skep7 'void (S7)' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} k 'S7' +// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<7> +// CHECK: |-FunctionDecl {{.*}} prev {{.*}} skep7 'void (S7)' +// CHECK-NEXT: | |-ParmVarDecl {{.*}} used k 'S7' +// CHECK-NEXT: | |-SYCLKernelCallStmt {{.*}} +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | | `-ImplicitCastExpr {{.*}} 'const S7' lvalue +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'S7' lvalue ParmVar {{.*}} 'k' 'S7' +// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}} +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'S7' +// CHECK-NEXT: | | `-CompoundStmt {{.*}} +// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | `-ImplicitCastExpr {{.*}} 'const S7' lvalue +// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'S7' lvalue ImplicitParam {{.*}} 'k' 'S7' +// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<7> + + +void the_end() {} +// CHECK: `-FunctionDecl {{.*}} the_end 'void ()' diff --git a/clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp b/clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp index 0189cf0402d3a..b112e9e1db850 100644 --- a/clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp +++ b/clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp @@ -143,16 +143,14 @@ void skep6() { // CHECK: |-FunctionDecl {{.*}} skep6 'void ()' // CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<6> // CHECK-NEXT: |-FunctionDecl {{.*}} prev {{.*}} skep6 'void ()' -// CHECK-NEXT: | |-CompoundStmt {{.*}} -// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<6> +// CHECK: | `-SYCLKernelEntryPointAttr {{.*}} KN<6> // Ensure that matching attributes from the same declaration are ok. [[clang::sycl_kernel_entry_point(KN<7>), clang::sycl_kernel_entry_point(KN<7>)]] void skep7() { } // CHECK: |-FunctionDecl {{.*}} skep7 'void ()' -// CHECK-NEXT: | |-CompoundStmt {{.*}} -// CHECK-NEXT: | |-SYCLKernelEntryPointAttr {{.*}} KN<7> +// CHECK: | |-SYCLKernelEntryPointAttr {{.*}} KN<7> // CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<7> void the_end() {} diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp index e175aab4499ff..42f095fea2db2 100644 --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -7202,6 +7202,7 @@ CXCursor clang_getCursorDefinition(CXCursor C) { case Decl::TopLevelStmt: case Decl::StaticAssert: case Decl::Block: + case Decl::OutlinedFunction: case Decl::Captured: case Decl::OMPCapturedExpr: case Decl::Label: // FIXME: Is this right?? diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp index ee276d8e4e148..b9f0b089e41b0 100644 --- a/clang/tools/libclang/CXCursor.cpp +++ b/clang/tools/libclang/CXCursor.cpp @@ -375,6 +375,10 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent, K = CXCursor_UnexposedStmt; break; + case Stmt::SYCLKernelCallStmtClass: + K = CXCursor_UnexposedStmt; + break; + case Stmt::IntegerLiteralClass: K = CXCursor_IntegerLiteral; break;