Skip to content

Commit

Permalink
Add API method to specialize function reference with argument types (s…
Browse files Browse the repository at this point in the history
…hader-slang#4966)

* Add `FunctionReflection::specializeWithArgTypes()`

* Update slang.cpp

* Use a shared semantics context on linkage

Improve performance on reflection queries

* Try to fix linux/mac compile errors
  • Loading branch information
saipraveenb25 authored Sep 16, 2024
1 parent c46ca4c commit d866c0b
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 53 deletions.
6 changes: 6 additions & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -2589,6 +2589,7 @@ extern "C"
SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func);
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func);
SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic);
SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(SlangReflectionFunction* func, SlangInt argTypeCount, SlangReflectionType* const* argTypes);

// Abstract Decl Reflection

Expand Down Expand Up @@ -3587,6 +3588,11 @@ namespace slang
{
return (FunctionReflection*)spReflectionFunction_applySpecializations((SlangReflectionFunction*)this, (SlangReflectionGeneric*)generic);
}

FunctionReflection* specializeWithArgTypes(unsigned int argCount, TypeReflection* const* types)
{
return (FunctionReflection*)spReflectionFunction_specializeWithArgTypes((SlangReflectionFunction*)this, argCount, (SlangReflectionType* const*)types);
}
};

struct GenericReflection
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ namespace Slang
};

/// Shared state for a semantics-checking session.
struct SharedSemanticsContext
struct SharedSemanticsContext : public RefObject
{
Linkage* m_linkage = nullptr;

Expand Down
14 changes: 13 additions & 1 deletion source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
namespace Slang
{
struct PathInfo;
struct IncludeHandler;
struct IncludeHandler;
struct SharedSemanticsContext;

class ProgramLayout;
class PtrType;
class TargetProgram;
Expand Down Expand Up @@ -2170,6 +2172,11 @@ namespace Slang
DeclRef<Decl> declRef,
List<Expr*> argExprs,
DiagnosticSink* sink);

DeclRef<Decl> specializeWithArgTypes(
DeclRef<Decl> funcDeclRef,
List<Type*> argTypes,
DiagnosticSink* sink);

DiagnosticSink::Flags diagnosticSinkFlags = 0;

Expand All @@ -2183,6 +2190,9 @@ namespace Slang
m_retainedSession = nullptr;
}

// Get shared semantics information for reflection purposes.
SharedSemanticsContext* getSemanticsForReflection();

private:
/// The global Slang library session that this linkage is a child of
Session* m_session = nullptr;
Expand Down Expand Up @@ -2236,6 +2246,8 @@ namespace Slang

List<Type*> m_specializedTypes;

RefPtr<SharedSemanticsContext> m_semanticsForReflection;

};

/// Shared functionality between front- and back-end compile requests.
Expand Down
69 changes: 55 additions & 14 deletions source/slang/slang-reflection-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,18 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti
programLayout->getTargetReq()->getLinkage()->getSourceManager(),
Lexer::sourceLocationLexer);

auto astBuilder = program->getLinkage()->getASTBuilder();
try
{
auto result = program->findDeclFromString(name, &sink);

if (auto genericDeclRef = result.as<GenericDecl>())
{
auto innerDeclRef = substituteDeclRef(
SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
}

if (auto funcDeclRef = result.as<FunctionDeclBase>())
return convert(funcDeclRef);
}
Expand Down Expand Up @@ -924,35 +933,36 @@ SLANG_API bool spReflection_isSubType(
}
}

SlangReflectionGeneric* getInnermostGenericParent(DeclRef<Decl> declRef)
DeclRef<Decl> getInnermostGenericParent(DeclRef<Decl> declRef)
{
auto decl = declRef.getDecl();
auto astBuilder = getModule(decl)->getLinkage()->getASTBuilder();
auto parentDecl = decl;
while(parentDecl)
{
if(parentDecl->parentDecl && as<GenericDecl>(parentDecl->parentDecl))
return convertDeclToGeneric(
substituteDeclRef(
return substituteDeclRef(
SubstitutionSet(declRef),
astBuilder,
createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl))));
createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl)));
parentDecl = parentDecl->parentDecl;
}

return nullptr;
return DeclRef<Decl>();
}

SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type)
{
auto slangType = convert(type);
if (auto declRefType = as<DeclRefType>(slangType))
{
return getInnermostGenericParent(declRefType->getDeclRef());
return convertDeclToGeneric(
getInnermostGenericParent(declRefType->getDeclRef()));
}
else if (auto genericDeclRefType = as<GenericDeclRefType>(slangType))
{
return getInnermostGenericParent(genericDeclRefType->getDeclRef());
return convertDeclToGeneric(
getInnermostGenericParent(genericDeclRefType->getDeclRef()));
}

return nullptr;
Expand Down Expand Up @@ -2835,7 +2845,7 @@ SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inV
SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var)
{
auto declRef = convert(var);
return getInnermostGenericParent(declRef);
return convertDeclToGeneric(getInnermostGenericParent(declRef));
}

SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic)
Expand Down Expand Up @@ -3072,7 +3082,7 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func)
{
auto declRef = convert(func);
return getInnermostGenericParent(declRef);
return convertDeclToGeneric(getInnermostGenericParent(declRef));
}

SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic)
Expand All @@ -3088,6 +3098,36 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(Sla
return convert(substDeclRef.as<FunctionDeclBase>());
}

SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
SlangReflectionFunction* func,
SlangInt argTypeCount,
SlangReflectionType* const* argTypes)
{
auto declRef = convert(func);
if (!declRef)
return nullptr;


auto linkage = getModule(declRef.getDecl())->getLinkage();

List<Type*> argTypeList;
for (SlangInt ii = 0; ii < argTypeCount; ++ii)
{
auto argType = convert(argTypes[ii]);
argTypeList.add(argType);
}

try
{
DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as<FunctionDeclBase>());
}
catch (...)
{
return nullptr;
}
}

// Abstract decl reflection

SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl)
Expand Down Expand Up @@ -3329,11 +3369,12 @@ SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(S

auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder();

return getInnermostGenericParent(
substituteDeclRef(
SubstitutionSet(declRef),
astBuilder,
createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl))));
return convertDeclToGeneric(
getInnermostGenericParent(
substituteDeclRef(
SubstitutionSet(declRef),
astBuilder,
createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl)))));
}

SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam)
Expand Down
100 changes: 67 additions & 33 deletions source/slang/slang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "slang-type-layout.h"
#include "slang-lookup.h"

#
#include "slang-options.h"

#include "slang-repro.h"
Expand Down Expand Up @@ -1069,8 +1068,12 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka
for (const auto& nameToMod : builtinLinkage->mapNameToLoadedModules)
mapNameToLoadedModules.add(nameToMod);
}

m_semanticsForReflection = new SharedSemanticsContext(this, nullptr, nullptr);
}

SharedSemanticsContext* Linkage::getSemanticsForReflection() { return m_semanticsForReflection.get(); }

ISlangUnknown* Linkage::getInterface(const Guid& guid)
{
if(guid == ISlangUnknown::getTypeGuid() || guid == ISession::getTypeGuid())
Expand Down Expand Up @@ -1348,18 +1351,11 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType(
return asExternal(specializedType);
}


DeclRef<Decl> Linkage::specializeGeneric(
DeclRef<Decl> declRef,
List<Expr*> argExprs,
DiagnosticSink* sink)
DeclRef<GenericDecl> getGenericParentDeclRef(
ASTBuilder* astBuilder,
SemanticsVisitor* visitor,
DeclRef<Decl> declRef)
{
SLANG_AST_BUILDER_RAII(getASTBuilder());
SLANG_ASSERT(declRef);

SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink);
SemanticsVisitor visitor(&sharedSemanticsContext);

// Create substituted parent decl ref.
auto decl = declRef.getDecl();

Expand All @@ -1369,9 +1365,58 @@ DeclRef<Decl> Linkage::specializeGeneric(
}

auto genericDecl = as<GenericDecl>(decl);
auto genericDeclRef = createDefaultSubstitutionsIfNeeded(getASTBuilder(), &visitor, DeclRef(genericDecl)).as<GenericDecl>();
genericDeclRef = substituteDeclRef(SubstitutionSet(declRef), getASTBuilder(), genericDeclRef).as<GenericDecl>();
auto genericDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, DeclRef(genericDecl)).as<GenericDecl>();
return substituteDeclRef(SubstitutionSet(declRef), astBuilder, genericDeclRef).as<GenericDecl>();
}

DeclRef<Decl> Linkage::specializeWithArgTypes(
DeclRef<Decl> funcDeclRef,
List<Type*> argTypes,
DiagnosticSink* sink)
{
SemanticsVisitor visitor(getSemanticsForReflection());
visitor = visitor.withSink(sink);

ASTBuilder* astBuilder = getASTBuilder();

List<Expr*> argExprs;
for (SlangInt aa = 0; aa < argTypes.getCount(); ++aa)
{
auto argType = argTypes[aa];

// Create an 'empty' expr with the given type. Ideally, the expression itself should not matter
// only its checked type.
//
auto argExpr = astBuilder->create<VarExpr>();
argExpr->type = argType;
argExprs.add(argExpr);
}

// Construct invoke expr.
auto invokeExpr = astBuilder->create<InvokeExpr>();
auto declRefExpr = astBuilder->create<DeclRefExpr>();

declRefExpr->declRef = getGenericParentDeclRef(getASTBuilder(), &visitor, funcDeclRef);
invokeExpr->functionExpr = declRefExpr;
invokeExpr->arguments = argExprs;

auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr);
return as<DeclRefExpr>(as<InvokeExpr>(checkedInvokeExpr)->functionExpr)->declRef;
}


DeclRef<Decl> Linkage::specializeGeneric(
DeclRef<Decl> declRef,
List<Expr*> argExprs,
DiagnosticSink* sink)
{
SLANG_AST_BUILDER_RAII(getASTBuilder());
SLANG_ASSERT(declRef);

SemanticsVisitor visitor(getSemanticsForReflection());
visitor = visitor.withSink(sink);

auto genericDeclRef = getGenericParentDeclRef(getASTBuilder(), &visitor, declRef);

DeclRefExpr* declRefExpr = getASTBuilder()->create<DeclRefExpr>();
declRefExpr->declRef = genericDeclRef;
Expand Down Expand Up @@ -1561,8 +1606,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentTy

try
{
SharedSemanticsContext sharedSemanticsContext(this, nullptr, &sink);
SemanticsVisitor visitor(&sharedSemanticsContext);
SemanticsVisitor visitor(getSemanticsForReflection());
visitor = visitor.withSink(&sink);

auto witness =
visitor.isSubtype((Slang::Type*)type, (Slang::Type*)interfaceType, IsSubTypeOptions::None);
if (auto subtypeWitness = as<SubtypeWitness>(witness))
Expand Down Expand Up @@ -2318,12 +2364,8 @@ DeclRef<Decl> ComponentType::findDeclFromString(

Expr* expr = linkage->parseTermString(name, scope);

SharedSemanticsContext sharedSemanticsContext(
linkage,
nullptr,
sink);
SemanticsContext context(&sharedSemanticsContext);
context = context.allowStaticReferenceToNonStaticMember();
SemanticsContext context(linkage->getSemanticsForReflection());
context = context.allowStaticReferenceToNonStaticMember().withSink(sink);

SemanticsVisitor visitor(context);

Expand Down Expand Up @@ -2377,12 +2419,8 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType(

Expr* expr = linkage->parseTermString(name, scope);

SharedSemanticsContext sharedSemanticsContext(
linkage,
nullptr,
sink);
SemanticsContext context(&sharedSemanticsContext);
context = context.allowStaticReferenceToNonStaticMember();
SemanticsContext context(linkage->getSemanticsForReflection());
context = context.allowStaticReferenceToNonStaticMember().withSink(sink);

SemanticsVisitor visitor(context);

Expand Down Expand Up @@ -2433,11 +2471,7 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType(

bool ComponentType::isSubType(Type* subType, Type* superType)
{
SharedSemanticsContext sharedSemanticsContext(
getLinkage(),
nullptr,
nullptr);
SemanticsContext context(&sharedSemanticsContext);
SemanticsContext context(getLinkage()->getSemanticsForReflection());
SemanticsVisitor visitor(context);

return (visitor.isSubtype(subType, superType, IsSubTypeOptions::None) != nullptr);
Expand Down
Loading

0 comments on commit d866c0b

Please sign in to comment.