Skip to content

Commit

Permalink
[SYCL] Add support for __registered_kernels__ (#16485)
Browse files Browse the repository at this point in the history
This change adds support for a new attribute
``__sycl_detail::__registered_kernels__``, which appears at translation
unit scope.  The parameter for this attribute is a list of pairs like:

```
[[__sycl_detail__::__registered_kernels__(
  {"foo", foo},
  {"(void(*)(int, int*))iota", (void(*)(int, int*))iota},
  {"kernel<float>", kernel<float>}
)]];
```

The first element in each pair is a string, and the second element is a
constant expressiton for a pointer to a SYCL free function kernel.

The change creates the kernel's wrapper function and generates
module-level metadata of the form:

```
!sycl_registered_kernels = !{!0, !1}
!0 = !{!"foo", !"mangled-name-of-wrapper-for-foo"}
!1 = !{!"kernel<float>", !"mangled-name-of-wrapper-for-kernel<float>"}
```

where the first element in the pair of strings, is the first element
of the pair in ``__registered_kernels__`` and the second element is the
mangled named of the wrapper corresponding to the free function.
  • Loading branch information
premanandrao authored Jan 22, 2025
1 parent 8af1eb3 commit 745423b
Show file tree
Hide file tree
Showing 11 changed files with 411 additions and 20 deletions.
16 changes: 16 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -2147,6 +2147,22 @@ def SYCLAddIRAnnotationsMember : InheritableAttr {
let Documentation = [SYCLAddIRAnnotationsMemberDocs];
}

def SYCLRegisteredKernels : InheritableAttr {
let Spellings = [CXX11<"__sycl_detail__", "__registered_kernels__">];
let Args = [VariadicExprArgument<"Args">];
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
let Subjects = SubjectList<[Empty], ErrorDiag, "Translation Unit Scope">;
let AdditionalMembers = SYCLAddIRAttrCommonMembers.MemberCode;
let Documentation = [SYCLAddIRAnnotationsMemberDocs];
}

def SYCLRegisteredKernelName : InheritableAttr {
let Spellings = [];
let Subjects = SubjectList<[Function]>;
let Args = [StringArgument<"RegName">];
let Documentation = [InternalOnly];
}

def C11NoReturn : InheritableAttr {
let Spellings = [CustomKeyword<"_Noreturn">];
let Subjects = SubjectList<[Function], ErrorDiag>;
Expand Down
14 changes: 14 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12546,6 +12546,20 @@ def err_sycl_special_type_num_init_method : Error<
def warn_launch_bounds_is_cuda_specific : Warning<
"%0 attribute ignored, only applicable when targeting Nvidia devices">,
InGroup<IgnoredAttributes>;
def err_registered_kernels_num_of_args : Error<
"'__registered_kernels__' attribute must have at least one argument">;
def err_registered_kernels_init_list : Error<
"argument to the '__registered_kernels__' attribute must be an "
"initializer list expression">;
def err_registered_kernels_init_list_pair_values : Error<
"each initializer list argument to the '__registered_kernels__' attribute "
"must contain a pair of values">;
def err_registered_kernels_resolve_function : Error<
"unable to resolve free function kernel '%0'">;
def err_registered_kernels_name_already_registered : Error<
"free function kernel has already been registered with '%0'; cannot register with '%1'">;
def err_not_sycl_free_function : Error<
"attempting to register a function that is not a SYCL free function as '%0'">;

def warn_cuda_maxclusterrank_sm_90 : Warning<
"'maxclusterrank' requires sm_90 or higher, CUDA arch provided: %0, ignoring "
Expand Down
14 changes: 12 additions & 2 deletions clang/include/clang/Sema/SemaSYCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,9 @@ class SemaSYCL : public SemaBase {
// We need to store the list of the sycl_kernel functions and their associated
// generated OpenCL Kernels so we can go back and re-name these after the
// fact.
llvm::SmallVector<std::pair<const FunctionDecl *, FunctionDecl *>>
SyclKernelsToOpenCLKernels;
using KernelFDPairs =
llvm::SmallVector<std::pair<const FunctionDecl *, FunctionDecl *>>;
KernelFDPairs SyclKernelsToOpenCLKernels;

// Used to suppress diagnostics during kernel construction, since these were
// already emitted earlier. Diagnosing during Kernel emissions also skips the
Expand Down Expand Up @@ -296,11 +297,15 @@ class SemaSYCL : public SemaBase {
llvm::DenseSet<QualType> Visited,
ValueDecl *DeclToCheck);

const KernelFDPairs &getKernelFDPairs() { return SyclKernelsToOpenCLKernels; }

void addSyclOpenCLKernel(const FunctionDecl *SyclKernel,
FunctionDecl *OpenCLKernel) {
SyclKernelsToOpenCLKernels.emplace_back(SyclKernel, OpenCLKernel);
}

void constructFreeFunctionKernel(FunctionDecl *FD, StringRef NameStr = "");

void addSyclDeviceDecl(Decl *d) { SyclDeviceDecls.insert(d); }
llvm::SetVector<Decl *> &syclDeviceDecls() { return SyclDeviceDecls; }

Expand Down Expand Up @@ -480,6 +485,7 @@ class SemaSYCL : public SemaBase {
void handleSYCLIntelMaxWorkGroupsPerMultiprocessor(Decl *D,
const ParsedAttr &AL);
void handleSYCLScopeAttr(Decl *D, const ParsedAttr &AL);
void handleSYCLRegisteredKernels(Decl *D, const ParsedAttr &AL);

void checkSYCLAddIRAttributesFunctionAttrConflicts(Decl *D);

Expand Down Expand Up @@ -655,6 +661,10 @@ class SemaSYCL : public SemaBase {
void addIntelReqdSubGroupSizeAttr(Decl *D, const AttributeCommonInfo &CI,
Expr *E);
void handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL);

// Used to check whether the function represented by FD is a SYCL
// free function kernel or not.
bool isFreeFunction(const FunctionDecl *FD);
};

} // namespace clang
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,12 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,

llvm::LLVMContext &Context = getLLVMContext();

if (getLangOpts().SYCLIsDevice)
if (FD->hasAttr<SYCLRegisteredKernelNameAttr>())
CGM.SYCLAddRegKernelNamePairs(
FD->getAttr<SYCLRegisteredKernelNameAttr>()->getRegName(),
FD->getNameAsString());

if (FD->hasAttr<OpenCLKernelAttr>() || FD->hasAttr<CUDAGlobalAttr>())
CGM.GenKernelArgMetadata(Fn, FD, this);

Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,19 @@ void CodeGenModule::Release() {
AspectEnumValsMD->addOperand(
getAspectEnumValueMD(Context, TheModule.getContext(), ECD));
}

if (!SYCLRegKernelNames.empty()) {
std::vector<llvm::Metadata *> Nodes;
llvm::LLVMContext &Ctx = TheModule.getContext();
for (auto MDKernelNames : SYCLRegKernelNames) {
llvm::Metadata *Vals[] = {MDKernelNames.first, MDKernelNames.second};
Nodes.push_back(llvm::MDTuple::get(Ctx, Vals));
}

llvm::NamedMDNode *SYCLRegKernelsMD =
TheModule.getOrInsertNamedMetadata("sycl_registered_kernels");
SYCLRegKernelsMD->addOperand(llvm::MDNode::get(Ctx, Nodes));
}
}

// HLSL related end of code gen work items.
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ class CodeGenModule : public CodeGenTypeCache {
/// handled differently than regular annotations so they cannot share map.
llvm::DenseMap<unsigned, llvm::Constant *> SYCLAnnotationArgs;

typedef std::pair<llvm::Metadata *, llvm::Metadata *> MetadataPair;
SmallVector<MetadataPair, 4> SYCLRegKernelNames;

llvm::StringMap<llvm::GlobalVariable *> CFConstantStringMap;

llvm::DenseMap<llvm::Constant *, llvm::GlobalVariable *> ConstantStringMap;
Expand Down Expand Up @@ -1483,6 +1486,12 @@ class CodeGenModule : public CodeGenTypeCache {
llvm::Constant *EmitSYCLAnnotationArgs(
SmallVectorImpl<std::pair<std::string, std::string>> &Pairs);

void SYCLAddRegKernelNamePairs(StringRef First, StringRef Second) {
SYCLRegKernelNames.push_back(
std::make_pair(llvm::MDString::get(getLLVMContext(), First),
llvm::MDString::get(getLLVMContext(), Second)));
}

/// Add attributes from add_ir_attributes_global_variable on TND to GV.
void AddGlobalSYCLIRAttributes(llvm::GlobalVariable *GV,
const RecordDecl *RD);
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7479,6 +7479,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_SYCLAddIRAnnotationsMember:
S.SYCL().handleSYCLAddIRAnnotationsMemberAttr(D, AL);
break;
case ParsedAttr::AT_SYCLRegisteredKernels:
S.SYCL().handleSYCLRegisteredKernels(D, AL);
break;

// Swift attributes.
case ParsedAttr::AT_SwiftAsyncName:
Expand Down
78 changes: 60 additions & 18 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1148,10 +1148,10 @@ static target getAccessTarget(QualType FieldTy,

// FIXME: Free functions must have void return type and be declared at file
// scope, outside any namespaces.
static bool isFreeFunction(SemaSYCL &SemaSYCLRef, const FunctionDecl *FD) {
bool SemaSYCL::isFreeFunction(const FunctionDecl *FD) {
for (auto *IRAttr : FD->specific_attrs<SYCLAddIRAttributesFunctionAttr>()) {
SmallVector<std::pair<std::string, std::string>, 4> NameValuePairs =
IRAttr->getAttributeNameValuePairs(SemaSYCLRef.getASTContext());
IRAttr->getAttributeNameValuePairs(getASTContext());
for (const auto &NameValuePair : NameValuePairs) {
if (NameValuePair.first == "sycl-nd-range-kernel" ||
NameValuePair.first == "sycl-single-task-kernel") {
Expand Down Expand Up @@ -5291,7 +5291,7 @@ void SemaSYCL::SetSYCLKernelNames() {
SyclKernelsToOpenCLKernels) {
std::string CalculatedName, StableName;
StringRef KernelName;
if (isFreeFunction(*this, Pair.first)) {
if (isFreeFunction(Pair.first)) {
std::tie(CalculatedName, StableName) =
constructFreeFunctionKernelName(*this, Pair.first, *MangleCtx);
KernelName = CalculatedName;
Expand Down Expand Up @@ -5414,24 +5414,66 @@ void SemaSYCL::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
}
}

void ConstructFreeFunctionKernel(SemaSYCL &SemaSYCLRef, FunctionDecl *FD) {
SyclKernelArgsSizeChecker argsSizeChecker(SemaSYCLRef, FD->getLocation(),
static void addRegisteredKernelName(SemaSYCL &S, StringRef Str,
FunctionDecl *FD, SourceLocation Loc) {
if (!Str.empty())
FD->addAttr(SYCLRegisteredKernelNameAttr::CreateImplicit(S.getASTContext(),
Str, Loc));
}

static bool checkAndAddRegisteredKernelName(SemaSYCL &S, FunctionDecl *FD,
StringRef Str) {
using KernelPair = std::pair<const FunctionDecl *, FunctionDecl *>;
for (const KernelPair &Pair : S.getKernelFDPairs()) {
if (Pair.first == FD) {
// If the current list of free function entries already contains this
// free function, apply the name Str as an attribute. But if it already
// has an attribute name, issue a diagnostic instead.
if (!Str.empty()) {
if (!Pair.second->hasAttr<SYCLRegisteredKernelNameAttr>())
addRegisteredKernelName(S, Str, Pair.second, FD->getLocation());
else
S.Diag(FD->getLocation(),
diag::err_registered_kernels_name_already_registered)
<< Pair.second->getAttr<SYCLRegisteredKernelNameAttr>()
->getRegName()
<< Str;
}
// An empty name string implies a regular free kernel construction
// call, so simply return.
return false;
}
}
return true;
}

void SemaSYCL::constructFreeFunctionKernel(FunctionDecl *FD,
StringRef NameStr) {
if (!checkAndAddRegisteredKernelName(*this, FD, NameStr))
return;

SyclKernelArgsSizeChecker argsSizeChecker(*this, FD->getLocation(),
false /*IsSIMDKernel*/);
SyclKernelDeclCreator kernel_decl(SemaSYCLRef, FD->getLocation(),
FD->isInlined(), false /*IsSIMDKernel */,
FD);
SyclKernelDeclCreator kernel_decl(*this, FD->getLocation(), FD->isInlined(),
false /*IsSIMDKernel */, FD);

FreeFunctionKernelBodyCreator kernel_body(SemaSYCLRef, kernel_decl, FD);
FreeFunctionKernelBodyCreator kernel_body(*this, kernel_decl, FD);

SyclKernelIntHeaderCreator int_header(
SemaSYCLRef, SemaSYCLRef.getSyclIntegrationHeader(), FD->getType(), FD);
SyclKernelIntHeaderCreator int_header(*this, getSyclIntegrationHeader(),
FD->getType(), FD);

SyclKernelIntFooterCreator int_footer(SemaSYCLRef,
SemaSYCLRef.getSyclIntegrationFooter());
KernelObjVisitor Visitor{SemaSYCLRef};
SyclKernelIntFooterCreator int_footer(*this, getSyclIntegrationFooter());
KernelObjVisitor Visitor{*this};

Visitor.VisitFunctionParameters(FD, argsSizeChecker, kernel_decl, kernel_body,
int_header, int_footer);

assert(getKernelFDPairs().back().first == FD &&
"OpenCL Kernel not found for free function entry");
// Register the kernel name with the OpenCL kernel generated for the
// free function.
addRegisteredKernelName(*this, NameStr, getKernelFDPairs().back().second,
FD->getLocation());
}

// Figure out the sub-group for the this function. First we check the
Expand Down Expand Up @@ -5717,7 +5759,7 @@ void SemaSYCL::MarkDevices() {
}

void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
if (isFreeFunction(*this, FD)) {
if (isFreeFunction(FD)) {
SyclKernelDecompMarker DecompMarker(*this);
SyclKernelFieldChecker FieldChecker(*this);
SyclKernelUnionChecker UnionChecker(*this);
Expand All @@ -5736,7 +5778,7 @@ void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
if (!FieldChecker.isValid() || !UnionChecker.isValid())
return;

ConstructFreeFunctionKernel(*this, FD);
constructFreeFunctionKernel(FD);
}
}

Expand Down Expand Up @@ -6621,7 +6663,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
unsigned ShimCounter = 1;
int FreeFunctionCount = 0;
for (const KernelDesc &K : KernelDescs) {
if (!isFreeFunction(S, K.SyclKernel))
if (!S.isFreeFunction(K.SyclKernel))
continue;
++FreeFunctionCount;
// Generate forward declaration for free function.
Expand Down Expand Up @@ -6739,7 +6781,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
}
ShimCounter = 1;
for (const KernelDesc &K : KernelDescs) {
if (!isFreeFunction(S, K.SyclKernel))
if (!S.isFreeFunction(K.SyclKernel))
continue;

O << "\n// Definition of kernel_id of " << K.Name << "\n";
Expand Down
65 changes: 65 additions & 0 deletions clang/lib/Sema/SemaSYCLDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3162,3 +3162,68 @@ void SemaSYCL::checkSYCLAddIRAttributesFunctionAttrConflicts(Decl *D) {
Diag(Attr->getLoc(), diag::warn_sycl_old_and_new_kernel_attributes)
<< Attr;
}

void SemaSYCL::handleSYCLRegisteredKernels(Decl *D, const ParsedAttr &A) {
// Check for SYCL device compilation context.
if (!getLangOpts().SYCLIsDevice)
return;

unsigned NumArgs = A.getNumArgs();
// When declared, we expect at least one item in the list.
if (NumArgs == 0) {
Diag(A.getLoc(), diag::err_registered_kernels_num_of_args);
return;
}

// Traverse through the items in the list.
for (unsigned I = 0; I < NumArgs; I++) {
assert(A.isArgExpr(I) && "Expected expression argument");
// Each item in the list must be an initializer list expression.
Expr *ArgExpr = A.getArgAsExpr(I);
if (!isa<InitListExpr>(ArgExpr)) {
Diag(ArgExpr->getExprLoc(), diag::err_registered_kernels_init_list);
return;
}

auto *ArgListE = cast<InitListExpr>(ArgExpr);
unsigned NumInits = ArgListE->getNumInits();
// Each init-list expression must have a pair of values.
if (NumInits != 2) {
Diag(ArgExpr->getExprLoc(),
diag::err_registered_kernels_init_list_pair_values);
return;
}

// The first value of the pair must be a string.
Expr *FirstExpr = ArgListE->getInit(0);
StringRef CurStr;
SourceLocation Loc = FirstExpr->getExprLoc();
if (!SemaRef.checkStringLiteralArgumentAttr(A, FirstExpr, CurStr, &Loc))
return;

// Resolve the FunctionDecl from the second value of the pair.
Expr *SecondE = ArgListE->getInit(1);
FunctionDecl *FD = nullptr;
if (auto *ULE = dyn_cast<UnresolvedLookupExpr>(SecondE)) {
FD = SemaRef.ResolveSingleFunctionTemplateSpecialization(ULE, true);
Loc = ULE->getExprLoc();
} else {
SecondE = SecondE->IgnoreParenCasts();
if (auto *DRE = dyn_cast<DeclRefExpr>(SecondE))
FD = dyn_cast<FunctionDecl>(DRE->getDecl());
Loc = SecondE->getExprLoc();
}
// Issue a diagnostic if we are unable to resolve the FunctionDecl.
if (!FD) {
Diag(Loc, diag::err_registered_kernels_resolve_function) << CurStr;
return;
}
// Issue a diagnostic is the FunctionDecl is not a SYCL free function.
if (!isFreeFunction(FD)) {
Diag(FD->getLocation(), diag::err_not_sycl_free_function) << CurStr;
return;
}
// Construct a free function kernel.
constructFreeFunctionKernel(FD, CurStr);
}
}
Loading

0 comments on commit 745423b

Please sign in to comment.