diff --git a/include/slang.h b/include/slang.h index 777cd406b3..632bcbcbc5 100644 --- a/include/slang.h +++ b/include/slang.h @@ -2589,6 +2589,7 @@ extern "C" SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func); SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func); SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic); + SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(SlangReflectionFunction* func, SlangInt argTypeCount, SlangReflectionType* const* argTypes); // Abstract Decl Reflection @@ -3587,6 +3588,11 @@ namespace slang { return (FunctionReflection*)spReflectionFunction_applySpecializations((SlangReflectionFunction*)this, (SlangReflectionGeneric*)generic); } + + FunctionReflection* specializeWithArgTypes(unsigned int argCount, TypeReflection* const* types) + { + return (FunctionReflection*)spReflectionFunction_specializeWithArgTypes((SlangReflectionFunction*)this, argCount, (SlangReflectionType* const*)types); + } }; struct GenericReflection diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index dc4568f8a2..ad3539a217 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -598,7 +598,7 @@ namespace Slang }; /// Shared state for a semantics-checking session. - struct SharedSemanticsContext + struct SharedSemanticsContext : public RefObject { Linkage* m_linkage = nullptr; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 62e4c5f4a8..0c788ae182 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -40,7 +40,9 @@ namespace Slang { struct PathInfo; - struct IncludeHandler; + struct IncludeHandler; + struct SharedSemanticsContext; + class ProgramLayout; class PtrType; class TargetProgram; @@ -2170,6 +2172,11 @@ namespace Slang DeclRef declRef, List argExprs, DiagnosticSink* sink); + + DeclRef specializeWithArgTypes( + DeclRef funcDeclRef, + List argTypes, + DiagnosticSink* sink); DiagnosticSink::Flags diagnosticSinkFlags = 0; @@ -2183,6 +2190,9 @@ namespace Slang m_retainedSession = nullptr; } + // Get shared semantics information for reflection purposes. + SharedSemanticsContext* getSemanticsForReflection(); + private: /// The global Slang library session that this linkage is a child of Session* m_session = nullptr; @@ -2236,6 +2246,8 @@ namespace Slang List m_specializedTypes; + RefPtr m_semanticsForReflection; + }; /// Shared functionality between front- and back-end compile requests. diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index efa9a20a9e..38129babf5 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -797,9 +797,18 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti programLayout->getTargetReq()->getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); + auto astBuilder = program->getLinkage()->getASTBuilder(); try { auto result = program->findDeclFromString(name, &sink); + + if (auto genericDeclRef = result.as()) + { + auto innerDeclRef = substituteDeclRef( + SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); + result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef); + } + if (auto funcDeclRef = result.as()) return convert(funcDeclRef); } @@ -924,7 +933,7 @@ SLANG_API bool spReflection_isSubType( } } -SlangReflectionGeneric* getInnermostGenericParent(DeclRef declRef) +DeclRef getInnermostGenericParent(DeclRef declRef) { auto decl = declRef.getDecl(); auto astBuilder = getModule(decl)->getLinkage()->getASTBuilder(); @@ -932,15 +941,14 @@ SlangReflectionGeneric* getInnermostGenericParent(DeclRef declRef) while(parentDecl) { if(parentDecl->parentDecl && as(parentDecl->parentDecl)) - return convertDeclToGeneric( - substituteDeclRef( + return substituteDeclRef( SubstitutionSet(declRef), astBuilder, - createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl)))); + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl))); parentDecl = parentDecl->parentDecl; } - return nullptr; + return DeclRef(); } SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type) @@ -948,11 +956,13 @@ SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangRefl auto slangType = convert(type); if (auto declRefType = as(slangType)) { - return getInnermostGenericParent(declRefType->getDeclRef()); + return convertDeclToGeneric( + getInnermostGenericParent(declRefType->getDeclRef())); } else if (auto genericDeclRefType = as(slangType)) { - return getInnermostGenericParent(genericDeclRefType->getDeclRef()); + return convertDeclToGeneric( + getInnermostGenericParent(genericDeclRefType->getDeclRef())); } return nullptr; @@ -2835,7 +2845,7 @@ SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inV SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var) { auto declRef = convert(var); - return getInnermostGenericParent(declRef); + return convertDeclToGeneric(getInnermostGenericParent(declRef)); } SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic) @@ -3072,7 +3082,7 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func) { auto declRef = convert(func); - return getInnermostGenericParent(declRef); + return convertDeclToGeneric(getInnermostGenericParent(declRef)); } SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic) @@ -3088,6 +3098,36 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(Sla return convert(substDeclRef.as()); } +SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( + SlangReflectionFunction* func, + SlangInt argTypeCount, + SlangReflectionType* const* argTypes) +{ + auto declRef = convert(func); + if (!declRef) + return nullptr; + + + auto linkage = getModule(declRef.getDecl())->getLinkage(); + + List argTypeList; + for (SlangInt ii = 0; ii < argTypeCount; ++ii) + { + auto argType = convert(argTypes[ii]); + argTypeList.add(argType); + } + + try + { + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as()); + } + catch (...) + { + return nullptr; + } +} + // Abstract decl reflection SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl) @@ -3329,11 +3369,12 @@ SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(S auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder(); - return getInnermostGenericParent( - substituteDeclRef( - SubstitutionSet(declRef), - astBuilder, - createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl)))); + return convertDeclToGeneric( + getInnermostGenericParent( + substituteDeclRef( + SubstitutionSet(declRef), + astBuilder, + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl))))); } SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam) diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index c78348a869..6c152cdddc 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -28,7 +28,6 @@ #include "slang-type-layout.h" #include "slang-lookup.h" -# #include "slang-options.h" #include "slang-repro.h" @@ -1069,8 +1068,12 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka for (const auto& nameToMod : builtinLinkage->mapNameToLoadedModules) mapNameToLoadedModules.add(nameToMod); } + + m_semanticsForReflection = new SharedSemanticsContext(this, nullptr, nullptr); } +SharedSemanticsContext* Linkage::getSemanticsForReflection() { return m_semanticsForReflection.get(); } + ISlangUnknown* Linkage::getInterface(const Guid& guid) { if(guid == ISlangUnknown::getTypeGuid() || guid == ISession::getTypeGuid()) @@ -1348,18 +1351,11 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( return asExternal(specializedType); } - -DeclRef Linkage::specializeGeneric( - DeclRef declRef, - List argExprs, - DiagnosticSink* sink) +DeclRef getGenericParentDeclRef( + ASTBuilder* astBuilder, + SemanticsVisitor* visitor, + DeclRef declRef) { - SLANG_AST_BUILDER_RAII(getASTBuilder()); - SLANG_ASSERT(declRef); - - SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink); - SemanticsVisitor visitor(&sharedSemanticsContext); - // Create substituted parent decl ref. auto decl = declRef.getDecl(); @@ -1369,9 +1365,58 @@ DeclRef Linkage::specializeGeneric( } auto genericDecl = as(decl); - auto genericDeclRef = createDefaultSubstitutionsIfNeeded(getASTBuilder(), &visitor, DeclRef(genericDecl)).as(); - genericDeclRef = substituteDeclRef(SubstitutionSet(declRef), getASTBuilder(), genericDeclRef).as(); + auto genericDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, DeclRef(genericDecl)).as(); + return substituteDeclRef(SubstitutionSet(declRef), astBuilder, genericDeclRef).as(); +} + +DeclRef Linkage::specializeWithArgTypes( + DeclRef funcDeclRef, + List argTypes, + DiagnosticSink* sink) +{ + SemanticsVisitor visitor(getSemanticsForReflection()); + visitor = visitor.withSink(sink); + + ASTBuilder* astBuilder = getASTBuilder(); + List argExprs; + for (SlangInt aa = 0; aa < argTypes.getCount(); ++aa) + { + auto argType = argTypes[aa]; + + // Create an 'empty' expr with the given type. Ideally, the expression itself should not matter + // only its checked type. + // + auto argExpr = astBuilder->create(); + argExpr->type = argType; + argExprs.add(argExpr); + } + + // Construct invoke expr. + auto invokeExpr = astBuilder->create(); + auto declRefExpr = astBuilder->create(); + + declRefExpr->declRef = getGenericParentDeclRef(getASTBuilder(), &visitor, funcDeclRef); + invokeExpr->functionExpr = declRefExpr; + invokeExpr->arguments = argExprs; + + auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr); + return as(as(checkedInvokeExpr)->functionExpr)->declRef; +} + + +DeclRef Linkage::specializeGeneric( + DeclRef declRef, + List argExprs, + DiagnosticSink* sink) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + SLANG_ASSERT(declRef); + + SemanticsVisitor visitor(getSemanticsForReflection()); + visitor = visitor.withSink(sink); + + auto genericDeclRef = getGenericParentDeclRef(getASTBuilder(), &visitor, declRef); DeclRefExpr* declRefExpr = getASTBuilder()->create(); declRefExpr->declRef = genericDeclRef; @@ -1561,8 +1606,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentTy try { - SharedSemanticsContext sharedSemanticsContext(this, nullptr, &sink); - SemanticsVisitor visitor(&sharedSemanticsContext); + SemanticsVisitor visitor(getSemanticsForReflection()); + visitor = visitor.withSink(&sink); + auto witness = visitor.isSubtype((Slang::Type*)type, (Slang::Type*)interfaceType, IsSubTypeOptions::None); if (auto subtypeWitness = as(witness)) @@ -2318,12 +2364,8 @@ DeclRef ComponentType::findDeclFromString( Expr* expr = linkage->parseTermString(name, scope); - SharedSemanticsContext sharedSemanticsContext( - linkage, - nullptr, - sink); - SemanticsContext context(&sharedSemanticsContext); - context = context.allowStaticReferenceToNonStaticMember(); + SemanticsContext context(linkage->getSemanticsForReflection()); + context = context.allowStaticReferenceToNonStaticMember().withSink(sink); SemanticsVisitor visitor(context); @@ -2377,12 +2419,8 @@ DeclRef ComponentType::findDeclFromStringInType( Expr* expr = linkage->parseTermString(name, scope); - SharedSemanticsContext sharedSemanticsContext( - linkage, - nullptr, - sink); - SemanticsContext context(&sharedSemanticsContext); - context = context.allowStaticReferenceToNonStaticMember(); + SemanticsContext context(linkage->getSemanticsForReflection()); + context = context.allowStaticReferenceToNonStaticMember().withSink(sink); SemanticsVisitor visitor(context); @@ -2433,11 +2471,7 @@ DeclRef ComponentType::findDeclFromStringInType( bool ComponentType::isSubType(Type* subType, Type* superType) { - SharedSemanticsContext sharedSemanticsContext( - getLinkage(), - nullptr, - nullptr); - SemanticsContext context(&sharedSemanticsContext); + SemanticsContext context(getLinkage()->getSemanticsForReflection()); SemanticsVisitor visitor(context); return (visitor.isSubtype(subType, superType, IsSubTypeOptions::None) != nullptr); diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp index d98ea0423f..89579c5850 100644 --- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -64,12 +64,16 @@ SLANG_UNIT_TEST(declTreeReflection) struct MyGenericType { T z; + + __init(T _z) { z = _z; } + T g() { return z; } U h(U x, out T y) { y = z; return x; } T j(T x, out int o) { o = N; return x; } - } + U q(U x, T y) { return x; } + } namespace MyNamespace { @@ -79,6 +83,8 @@ SLANG_UNIT_TEST(declTreeReflection) } } + T foo(T t, U u) { return t; } + )"; auto moduleName = "moduleG" + String(Process::getId()); @@ -110,7 +116,7 @@ SLANG_UNIT_TEST(declTreeReflection) auto moduleDeclReflection = module->getModuleReflection(); SLANG_CHECK(moduleDeclReflection != nullptr); SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module); - SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 8); + SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 9); // First declaration should be a struct with 1 variable auto firstDecl = moduleDeclReflection->getChild(0); @@ -379,6 +385,59 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(compositeProgram->getLayout()->isSubType(uintType, diffType) == false); } + // Check specializeWithArgTypes() + { + auto unspecializedFoo = compositeProgram->getLayout()->findFunctionByName("foo"); + SLANG_CHECK(unspecializedFoo != nullptr); + + auto floatType = compositeProgram->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + auto uintType = compositeProgram->getLayout()->findTypeByName("uint"); + SLANG_CHECK(uintType != nullptr); + + List argTypes; + argTypes.add(floatType); + argTypes.add(uintType); + + slang::FunctionReflection* specializedFoo = unspecializedFoo->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer()); + SLANG_CHECK(specializedFoo != nullptr); + + SLANG_CHECK(getTypeFullName(specializedFoo->getReturnType()) == "float"); + SLANG_CHECK(specializedFoo->getParameterCount() == 2); + + SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(0)->getName()) == "t"); + SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(0)->getType()) == "float"); + + SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(1)->getName()) == "u"); + SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(1)->getType()) == "uint"); + } + + // Check specializeArgTypes on member method looked up through a specialized type + { + auto specializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType"); + SLANG_CHECK(specializedType != nullptr); + + auto unspecializedMethod = compositeProgram->getLayout()->findFunctionByNameInType(specializedType, "h"); + SLANG_CHECK(unspecializedMethod != nullptr); + + // Specialize the method with float + auto floatType = compositeProgram->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + + auto halfType = compositeProgram->getLayout()->findTypeByName("half"); + SLANG_CHECK(halfType != nullptr); + + List argTypes; + argTypes.add(floatType); + argTypes.add(halfType); + + auto specializedMethodWithFloat = unspecializedMethod->specializeWithArgTypes( + argTypes.getCount(), + argTypes.getBuffer()); + SLANG_CHECK(specializedMethodWithFloat != nullptr); + SLANG_CHECK(getTypeFullName(specializedMethodWithFloat->getReturnType()) == "float"); + } + // Check iterators { unsigned int count = 0; @@ -386,7 +445,7 @@ SLANG_UNIT_TEST(declTreeReflection) { count++; } - SLANG_CHECK(count == 8); + SLANG_CHECK(count == 9); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind()) @@ -407,7 +466,7 @@ SLANG_UNIT_TEST(declTreeReflection) { count++; } - SLANG_CHECK(count == 1); + SLANG_CHECK(count == 2); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind())