Skip to content

Commit

Permalink
[SYCL] AST support for SYCL kernel entry point functions.
Browse files Browse the repository at this point in the history
A SYCL kernel entry point function is a non-member function or a static member
function declared with the `sycl_kernel_entry_point` attribute. Such functions
define a pattern for an offload kernel entry point function to be generated to
enable execution of a SYCL kernel on a device. A SYCL library implementation
orchestrates the invocation of these functions with corresponding SYCL kernel
arguments in response to calls to SYCL kernel invocation functions specified
by the SYCL 2020 specification.

The offload kernel entry point function (sometimes referred to as the SYCL
kernel caller function) is generated from the SYCL kernel entry point function
by a transformation of the function parameters followed by a transformation of
the function body to replace references to the original parameters with
references to the transformed ones. Exactly how parameters are transformed will
be explained in a future change that implements non-trivial transformations.
For now, it suffices to state that a given parameter of the SYCL kernel entry
point function may be transformed to multiple parameters of the offload kernel
entry point as needed to satisfy offload kernel argument passing requirements.
Parameters that are decomposed in this way are reconstituted as local variables
in the body of the generated offload kernel entry point function.

For example, given the following SYCL kernel entry point function definition:

  template<typename KernelNameType, typename KernelType>
  [[clang::sycl_kernel_entry_point(KernelNameType)]]
  void sycl_kernel_entry_point(KernelType kernel) {
    kernel();
  }

and the following call:

  struct Kernel {
    int dm1;
    int dm2;
    void operator()() const;
  };
  Kernel k;
  sycl_kernel_entry_point<class kernel_name>(k);

the corresponding offload kernel entry point function that is generated might
look as follows (assuming `Kernel` is a type that requires decomposition):

  void offload_kernel_entry_point_for_kernel_name(int dm1, int dm2) {
    Kernel kernel{dm1, dm2};
    kernel();
  }

Other details of the generated offload kernel entry point function, such as
its name and callng convention, are implementation details that need not be
reflected in the AST and may differ across target devices. For that reason,
only the transformation described above is represented in the AST; other
details will be filled in during code generation.

These transformations are represented using new AST nodes introduced with this
change. `OutlinedFunctionDecl` holds a sequence of `ImplicitParamDecl` nodes
and a sequence of statement nodes that correspond to the transformed
parameters and function body. `SYCLKernelCallStmt` wraps the original function
body and associates it with an `OutlinedFunctionDecl` instance. For the example
above, the AST generated for the `sycl_kernel_entry_point<kernel_name>`
specialization would look as follows:

  FunctionDecl 'sycl_kernel_entry_point<kernel_name>(Kernel)'
    TemplateArgument type 'kernel_name'
    TemplateArgument type 'Kernel'
    ParmVarDecl kernel 'Kernel'
    SYCLKernelCallStmt
      CompoundStmt
        <original statements>
      OutlinedFunctionDecl
        ImplicitParamDecl 'dm1' 'int'
        ImplicitParamDecl 'dm2' 'int'
        CompoundStmt
          VarDecl 'kernel' 'Kernel'
            <initialization of 'kernel' with 'dm1' and 'dm2'>
          <transformed statements with redirected references of 'kernel'>

Any ODR-use of the SYCL kernel entry point function will (with future changes)
suffice for the offload kernel entry point to be emitted. An actual call to
the SYCL kernel entry point function will result in a call to the function.
However, evaluation of a `SYCLKernelCallStmt` statement is a no-op, so such
calls will have no effect other than to trigger emission of the offload kernel
entry point.
  • Loading branch information
tahonermann committed Jan 9, 2025
1 parent 03eb786 commit d021c2b
Show file tree
Hide file tree
Showing 34 changed files with 734 additions and 10 deletions.
16 changes: 14 additions & 2 deletions clang/include/clang/AST/ASTNodeTraverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class ASTNodeTraverser
ConstStmtVisitor<Derived>::Visit(S);

// Some statements have custom mechanisms for dumping their children.
if (isa<DeclStmt>(S) || isa<GenericSelectionExpr>(S) ||
isa<RequiresExpr>(S) || isa<OpenACCWaitConstruct>(S))
if (isa<DeclStmt, GenericSelectionExpr, RequiresExpr,
OpenACCWaitConstruct, SYCLKernelCallStmt>(S))
return;

if (Traversal == TK_IgnoreUnlessSpelledInSource &&
Expand Down Expand Up @@ -585,6 +585,12 @@ class ASTNodeTraverser

void VisitTopLevelStmtDecl(const TopLevelStmtDecl *D) { Visit(D->getStmt()); }

void VisitOutlinedFunctionDecl(const OutlinedFunctionDecl *D) {
for (const ImplicitParamDecl *Parameter : D->parameters())
Visit(Parameter);
Visit(D->getBody());
}

void VisitCapturedDecl(const CapturedDecl *D) { Visit(D->getBody()); }

void VisitOMPThreadPrivateDecl(const OMPThreadPrivateDecl *D) {
Expand Down Expand Up @@ -815,6 +821,12 @@ class ASTNodeTraverser
Visit(Node->getCapturedDecl());
}

void VisitSYCLKernelCallStmt(const SYCLKernelCallStmt *Node) {
Visit(Node->getOriginalStmt());
if (Traversal != TK_IgnoreUnlessSpelledInSource)
Visit(Node->getOutlinedFunctionDecl());
}

void VisitOMPExecutableDirective(const OMPExecutableDirective *Node) {
for (const auto *C : Node->clauses())
Visit(C);
Expand Down
90 changes: 90 additions & 0 deletions clang/include/clang/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4678,6 +4678,96 @@ class BlockDecl : public Decl, public DeclContext {
}
};

/// Represents a partial function definition.
///
/// An outlined function declaration contains the parameters and body of
/// a function independent of other function definition concerns such
/// as function name, type, and calling convention. Such declarations may
/// be used to hold a parameterized and transformed sequence of statements
/// used to generate a target dependent function definition without losing
/// association with the original statements. See SYCLKernelCallStmt as an
/// example.
class OutlinedFunctionDecl final
: public Decl,
public DeclContext,
private llvm::TrailingObjects<OutlinedFunctionDecl, ImplicitParamDecl *> {
protected:
size_t numTrailingObjects(OverloadToken<ImplicitParamDecl>) {
return NumParams;
}

private:
/// The number of parameters to the outlined function.
unsigned NumParams;

/// The body of the outlined function.
llvm::PointerIntPair<Stmt *, 1, bool> BodyAndNothrow;

explicit OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams);

ImplicitParamDecl *const *getParams() const {
return getTrailingObjects<ImplicitParamDecl *>();
}

ImplicitParamDecl **getParams() {
return getTrailingObjects<ImplicitParamDecl *>();
}

public:
friend class ASTDeclReader;
friend class ASTDeclWriter;
friend TrailingObjects;

static OutlinedFunctionDecl *Create(ASTContext &C, DeclContext *DC,
unsigned NumParams);
static OutlinedFunctionDecl *CreateDeserialized(ASTContext &C,
GlobalDeclID ID,
unsigned NumParams);

Stmt *getBody() const override;
void setBody(Stmt *B);

bool isNothrow() const;
void setNothrow(bool Nothrow = true);

unsigned getNumParams() const { return NumParams; }

ImplicitParamDecl *getParam(unsigned i) const {
assert(i < NumParams);
return getParams()[i];
}
void setParam(unsigned i, ImplicitParamDecl *P) {
assert(i < NumParams);
getParams()[i] = P;
}

// ArrayRef interface to parameters.
ArrayRef<ImplicitParamDecl *> parameters() const {
return {getParams(), getNumParams()};
}
MutableArrayRef<ImplicitParamDecl *> parameters() {
return {getParams(), getNumParams()};
}

using param_iterator = ImplicitParamDecl *const *;
using param_range = llvm::iterator_range<param_iterator>;

/// Retrieve an iterator pointing to the first parameter decl.
param_iterator param_begin() const { return getParams(); }
/// Retrieve an iterator one past the last parameter decl.
param_iterator param_end() const { return getParams() + NumParams; }

// Implement isa/cast/dyncast/etc.
static bool classof(const Decl *D) { return classofKind(D->getKind()); }
static bool classofKind(Kind K) { return K == OutlinedFunction; }
static DeclContext *castToDeclContext(const OutlinedFunctionDecl *D) {
return static_cast<DeclContext *>(const_cast<OutlinedFunctionDecl *>(D));
}
static OutlinedFunctionDecl *castFromDeclContext(const DeclContext *DC) {
return static_cast<OutlinedFunctionDecl *>(const_cast<DeclContext *>(DC));
}
};

/// Represents the body of a CapturedStmt, and serves as its DeclContext.
class CapturedDecl final
: public Decl,
Expand Down
14 changes: 14 additions & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/AST/StmtSYCL.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TemplateName.h"
#include "clang/AST/Type.h"
Expand Down Expand Up @@ -1581,6 +1582,11 @@ DEF_TRAVERSE_DECL(BlockDecl, {
ShouldVisitChildren = false;
})

DEF_TRAVERSE_DECL(OutlinedFunctionDecl, {
TRY_TO(TraverseStmt(D->getBody()));
ShouldVisitChildren = false;
})

DEF_TRAVERSE_DECL(CapturedDecl, {
TRY_TO(TraverseStmt(D->getBody()));
ShouldVisitChildren = false;
Expand Down Expand Up @@ -2904,6 +2910,14 @@ DEF_TRAVERSE_STMT(SEHFinallyStmt, {})
DEF_TRAVERSE_STMT(SEHLeaveStmt, {})
DEF_TRAVERSE_STMT(CapturedStmt, { TRY_TO(TraverseDecl(S->getCapturedDecl())); })

DEF_TRAVERSE_STMT(SYCLKernelCallStmt, {
if (getDerived().shouldVisitImplicitCode()) {
TRY_TO(TraverseStmt(S->getOriginalStmt()));
TRY_TO(TraverseDecl(S->getOutlinedFunctionDecl()));
ShouldVisitChildren = false;
}
})

DEF_TRAVERSE_STMT(CXXOperatorCallExpr, {})
DEF_TRAVERSE_STMT(CXXRewrittenBinaryOperator, {
if (!getDerived().shouldVisitImplicitCode()) {
Expand Down
95 changes: 95 additions & 0 deletions clang/include/clang/AST/StmtSYCL.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//===- StmtSYCL.h - Classes for SYCL kernel calls ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
/// \file
/// This file defines SYCL AST classes used to represent calls to SYCL kernels.
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_AST_STMTSYCL_H
#define LLVM_CLANG_AST_STMTSYCL_H

#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Stmt.h"
#include "clang/Basic/SourceLocation.h"

namespace clang {

//===----------------------------------------------------------------------===//
// AST classes for SYCL kernel calls.
//===----------------------------------------------------------------------===//

/// SYCLKernelCallStmt represents the transformation that is applied to the body
/// of a function declared with the sycl_kernel_entry_point attribute. The body
/// of such a function specifies the statements to be executed on a SYCL device
/// to invoke a SYCL kernel with a particular set of kernel arguments. The
/// SYCLKernelCallStmt associates an original statement (the compound statement
/// that is the function body) with an OutlinedFunctionDecl that holds the
/// kernel parameters and the transformed body. During code generation, the
/// OutlinedFunctionDecl is used to emit an offload kernel entry point suitable
/// for invocation from a SYCL library implementation. If executed, the
/// SYCLKernelCallStmt behaves as a no-op; no code generation is performed for
/// it.
class SYCLKernelCallStmt : public Stmt {
friend class ASTStmtReader;
friend class ASTStmtWriter;

private:
Stmt *OriginalStmt = nullptr;
OutlinedFunctionDecl *OFDecl = nullptr;

public:
/// Construct a SYCL kernel call statement.
SYCLKernelCallStmt(Stmt *OS, OutlinedFunctionDecl *OFD)
: Stmt(SYCLKernelCallStmtClass), OriginalStmt(OS), OFDecl(OFD) {}

/// Construct an empty SYCL kernel call statement.
SYCLKernelCallStmt(EmptyShell Empty)
: Stmt(SYCLKernelCallStmtClass, Empty) {}

/// Retrieve the model statement.
Stmt *getOriginalStmt() { return OriginalStmt; }
const Stmt *getOriginalStmt() const { return OriginalStmt; }
void setOriginalStmt(Stmt *S) { OriginalStmt = S; }

/// Retrieve the outlined function declaration.
OutlinedFunctionDecl *getOutlinedFunctionDecl() { return OFDecl; }
const OutlinedFunctionDecl *getOutlinedFunctionDecl() const { return OFDecl; }

/// Set the outlined function declaration.
void setOutlinedFunctionDecl(OutlinedFunctionDecl *OFD) {
OFDecl = OFD;
}

SourceLocation getBeginLoc() const LLVM_READONLY {
return getOriginalStmt()->getBeginLoc();
}

SourceLocation getEndLoc() const LLVM_READONLY {
return getOriginalStmt()->getEndLoc();
}

SourceRange getSourceRange() const LLVM_READONLY {
return getOriginalStmt()->getSourceRange();
}

static bool classof(const Stmt *T) {
return T->getStmtClass() == SYCLKernelCallStmtClass;
}

child_range children() {
return child_range(&OriginalStmt, &OriginalStmt + 1);
}

const_child_range children() const {
return const_child_range(&OriginalStmt, &OriginalStmt + 1);
}
};

} // end namespace clang

#endif
1 change: 1 addition & 0 deletions clang/include/clang/AST/StmtVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/AST/StmtSYCL.h"
#include "clang/Basic/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/DeclNodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def Friend : DeclNode<Decl>;
def FriendTemplate : DeclNode<Decl>;
def StaticAssert : DeclNode<Decl>;
def Block : DeclNode<Decl, "blocks">, DeclContext;
def OutlinedFunction : DeclNode<Decl>, DeclContext;
def Captured : DeclNode<Decl>, DeclContext;
def Import : DeclNode<Decl>;
def OMPThreadPrivate : DeclNode<Decl>;
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/StmtNodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def SwitchCase : StmtNode<Stmt, 1>;
def CaseStmt : StmtNode<SwitchCase>;
def DefaultStmt : StmtNode<SwitchCase>;
def CapturedStmt : StmtNode<Stmt>;
def SYCLKernelCallStmt : StmtNode<Stmt>;

// Statements that might produce a value (for example, as the last non-null
// statement in a GNU statement-expression).
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/SemaSYCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class SemaSYCL : public SemaBase {
void handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL);

void CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD);
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, Stmt *Body);
};

} // namespace clang
Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/Sema/Template.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,10 @@ enum class TemplateSubstitutionKind : char {
#define EMPTY(DERIVED, BASE)
#define LIFETIMEEXTENDEDTEMPORARY(DERIVED, BASE)

// Decls which use special-case instantiation code.
// Decls which never appear inside a template.
#define OUTLINEDFUNCTION(DERIVED, BASE)

// Decls which use special-case instantiation code.
#define BLOCK(DERIVED, BASE)
#define CAPTURED(DERIVED, BASE)
#define IMPLICITPARAM(DERIVED, BASE)
Expand Down
6 changes: 6 additions & 0 deletions clang/include/clang/Serialization/ASTBitCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,9 @@ enum DeclCode {
/// A BlockDecl record.
DECL_BLOCK,

/// A OutlinedFunctionDecl record.
DECL_OUTLINEDFUNCTION,

/// A CapturedDecl record.
DECL_CAPTURED,

Expand Down Expand Up @@ -1588,6 +1591,9 @@ enum StmtCode {
/// A CapturedStmt record.
STMT_CAPTURED,

/// A SYCLKernelCallStmt record.
STMT_SYCLKERNELCALL,

/// A GCC-style AsmStmt record.
STMT_GCCASM,

Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/ASTStructuralEquivalence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/AST/StmtSYCL.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TemplateName.h"
#include "clang/AST/Type.h"
Expand Down
25 changes: 25 additions & 0 deletions clang/lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5440,6 +5440,31 @@ BlockDecl *BlockDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID) {
return new (C, ID) BlockDecl(nullptr, SourceLocation());
}


OutlinedFunctionDecl::OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams)
: Decl(OutlinedFunction, DC, SourceLocation()), DeclContext(OutlinedFunction),
NumParams(NumParams), BodyAndNothrow(nullptr, false) {}

OutlinedFunctionDecl *OutlinedFunctionDecl::Create(ASTContext &C, DeclContext *DC,
unsigned NumParams) {
return new (C, DC, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
OutlinedFunctionDecl(DC, NumParams);
}

OutlinedFunctionDecl *OutlinedFunctionDecl::CreateDeserialized(ASTContext &C,
GlobalDeclID ID,
unsigned NumParams) {
return new (C, ID, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
OutlinedFunctionDecl(nullptr, NumParams);
}

Stmt *OutlinedFunctionDecl::getBody() const { return BodyAndNothrow.getPointer(); }
void OutlinedFunctionDecl::setBody(Stmt *B) { BodyAndNothrow.setPointer(B); }

bool OutlinedFunctionDecl::isNothrow() const { return BodyAndNothrow.getInt(); }
void OutlinedFunctionDecl::setNothrow(bool Nothrow) { BodyAndNothrow.setInt(Nothrow); }


CapturedDecl::CapturedDecl(DeclContext *DC, unsigned NumParams)
: Decl(Captured, DC, SourceLocation()), DeclContext(Captured),
NumParams(NumParams), ContextParam(0), BodyAndNothrow(nullptr, false) {}
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/AST/DeclBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ unsigned Decl::getIdentifierNamespaceForKind(Kind DeclKind) {
case PragmaDetectMismatch:
case Block:
case Captured:
case OutlinedFunction:
case TranslationUnit:
case ExternCContext:
case Decomposition:
Expand Down Expand Up @@ -1237,6 +1238,8 @@ template <class T> static Decl *getNonClosureContext(T *D) {
return getNonClosureContext(BD->getParent());
if (auto *CD = dyn_cast<CapturedDecl>(D))
return getNonClosureContext(CD->getParent());
if (auto *OFD = dyn_cast<OutlinedFunctionDecl>(D))
return getNonClosureContext(OFD->getParent());
return nullptr;
}

Expand Down Expand Up @@ -1429,6 +1432,7 @@ DeclContext *DeclContext::getPrimaryContext() {
case Decl::TopLevelStmt:
case Decl::Block:
case Decl::Captured:
case Decl::OutlinedFunction:
case Decl::OMPDeclareReduction:
case Decl::OMPDeclareMapper:
case Decl::RequiresExprBody:
Expand Down
Loading

0 comments on commit d021c2b

Please sign in to comment.