diff --git a/include/slang.h b/include/slang.h index 05ab1a0ce8..3df31f0f0c 100644 --- a/include/slang.h +++ b/include/slang.h @@ -5444,6 +5444,10 @@ namespace slang SlangInt32 index) = 0; virtual SLANG_NO_THROW DeclReflection* SLANG_MCALL getModuleReflection() = 0; + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget( + SlangCompileTarget target, + ISlangBlob** outDiagnostics) = 0; }; #define SLANG_UUID_IModule IModule::getTypeGuid() diff --git a/source/slang-record-replay/record/slang-module.cpp b/source/slang-record-replay/record/slang-module.cpp index b59aaa18bb..42dcd9eea8 100644 --- a/source/slang-record-replay/record/slang-module.cpp +++ b/source/slang-record-replay/record/slang-module.cpp @@ -213,6 +213,17 @@ namespace SlangRecord return res; } + SLANG_NO_THROW SlangResult ModuleRecorder::precompileForTarget( + SlangCompileTarget target, + ISlangBlob** outDiagnostics) + { + // TODO: We should record this call + // https://github.com/shader-slang/slang/issues/4853 + slangRecordLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + SlangResult res = m_actualModule->precompileForTarget(target, outDiagnostics); + return res; + } + SLANG_NO_THROW slang::ISession* ModuleRecorder::getSession() { slangRecordLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); diff --git a/source/slang-record-replay/record/slang-module.h b/source/slang-record-replay/record/slang-module.h index ca0403c2d6..22339f1cdb 100644 --- a/source/slang-record-replay/record/slang-module.h +++ b/source/slang-record-replay/record/slang-module.h @@ -39,6 +39,9 @@ namespace SlangRecord virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDependencyFileCount() override; virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath( SlangInt32 index) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget( + SlangCompileTarget target, + ISlangBlob** outDiagnostics) override; // Interfaces for `IComponentType` virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override; diff --git a/source/slang/slang-compiler-tu.cpp b/source/slang/slang-compiler-tu.cpp index b007e7b71d..d4f3976a63 100644 --- a/source/slang/slang-compiler-tu.cpp +++ b/source/slang/slang-compiler-tu.cpp @@ -8,60 +8,24 @@ namespace Slang { - SLANG_NO_THROW SlangResult SLANG_MCALL Module::precompileForTargets( - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq, - TargetRequest* targetReq) + SLANG_NO_THROW SlangResult SLANG_MCALL Module::precompileForTarget( + SlangCompileTarget target, + slang::IBlob** outDiagnostics) { - auto module = getIRModule(); - Slang::Session* session = endToEndReq->getSession(); - Slang::ASTBuilder* astBuilder = session->getGlobalASTBuilder(); - Slang::Linkage* builtinLinkage = session->getBuiltinLinkage(); - Slang::Linkage linkage(session, astBuilder, builtinLinkage); - - CapabilityName precompileRequirement = CapabilityName::Invalid; - switch (targetReq->getTarget()) + if (target != SLANG_DXIL) { - case CodeGenTarget::DXIL: - linkage.addTarget(Slang::CodeGenTarget::DXIL); - precompileRequirement = CapabilityName::dxil_lib; - break; - default: - assert(!"Unhandled target"); - break; + return SLANG_FAIL; } - SLANG_ASSERT(precompileRequirement != CapabilityName::Invalid); + CodeGenTarget targetEnum = CodeGenTarget(target); - // Ensure precompilation capability requirements are met. - auto targetCaps = targetReq->getTargetCaps(); - auto precompileRequirementsCapabilitySet = CapabilitySet(precompileRequirement); - if (targetCaps.atLeastOneSetImpliedInOther(precompileRequirementsCapabilitySet) == CapabilitySet::ImpliesReturnFlags::NotImplied) - { - // If `RestrictiveCapabilityCheck` is true we will error, else we will warn. - // error ...: dxil libraries require $0, entry point compiled with $1. - // warn ...: dxil libraries require $0, entry point compiled with $1, implicitly upgrading capabilities. - maybeDiagnoseWarningOrError( - sink, - targetReq->getOptionSet(), - DiagnosticCategory::Capability, - SourceLoc(), - Diagnostics::incompatibleWithPrecompileLib, - Diagnostics::incompatibleWithPrecompileLibRestrictive, - precompileRequirementsCapabilitySet, - targetCaps); - - // add precompile requirements to the cooked targetCaps - targetCaps.join(precompileRequirementsCapabilitySet); - if (targetCaps.isInvalid()) - { - sink->diagnose(SourceLoc(), Diagnostics::unknownCapability, targetCaps); - return SLANG_FAIL; - } - else - { - targetReq->setTargetCaps(targetCaps); - } - } + auto module = getIRModule(); + auto linkage = getLinkage(); + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + TargetRequest* targetReq = new TargetRequest(linkage, targetEnum); List> allComponentTypes; allComponentTypes.add(this); // Add Module as a component type @@ -72,23 +36,34 @@ namespace Slang } auto composite = CompositeComponentType::create( - &linkage, + linkage, allComponentTypes); TargetProgram tp(composite, targetReq); - tp.getOrCreateLayout(sink); + tp.getOrCreateLayout(&sink); Slang::Index const entryPointCount = m_entryPoints.getCount(); + tp.getOptionSet().add(CompilerOptionName::GenerateWholeProgram, true); + + switch (targetReq->getTarget()) + { + case CodeGenTarget::DXIL: + tp.getOptionSet().add(CompilerOptionName::Profile, Profile::RawEnum::DX_Lib_6_6); + break; + } CodeGenContext::EntryPointIndices entryPointIndices; entryPointIndices.setCount(entryPointCount); for (Index i = 0; i < entryPointCount; i++) entryPointIndices[i] = i; - CodeGenContext::Shared sharedCodeGenContext(&tp, entryPointIndices, sink, endToEndReq); + CodeGenContext::Shared sharedCodeGenContext(&tp, entryPointIndices, &sink, nullptr); CodeGenContext codeGenContext(&sharedCodeGenContext); ComPtr outArtifact; SlangResult res = codeGenContext.emitTranslationUnit(outArtifact); + + sink.getBlobIfNeeded(outDiagnostics); + if (res != SLANG_OK) { return res; @@ -105,9 +80,6 @@ namespace Slang case CodeGenTarget::DXIL: builder.emitEmbeddedDXIL(blob); break; - default: - assert(!"Unhandled target"); - break; } return SLANG_OK; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 1180d4be22..ac85f175e1 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1482,10 +1482,9 @@ namespace Slang SlangInt32 index) override; /// Precompile TU to target language - virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTargets( - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq, - TargetRequest* targetReq); + virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget( + SlangCompileTarget target, + slang::IBlob** outDiagnostics) override; virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 81171fdcc6..f08a6fc441 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -3252,10 +3252,10 @@ SlangResult EndToEndCompileRequest::executeActionsInner() for (auto translationUnit : frontEndReq->translationUnits) { - translationUnit->getModule()->precompileForTargets( - getSink(), - this, - targetReq); + SlangCompileTarget target = SlangCompileTarget(targetReq->getTarget()); + translationUnit->getModule()->precompileForTarget( + target, + nullptr); } } } diff --git a/tools/gfx-unit-test/precompiled-module-2.cpp b/tools/gfx-unit-test/precompiled-module-2.cpp index 3da77e05c5..93b9d1b897 100644 --- a/tools/gfx-unit-test/precompiled-module-2.cpp +++ b/tools/gfx-unit-test/precompiled-module-2.cpp @@ -17,7 +17,8 @@ namespace gfx_test static Slang::Result precompileProgram( gfx::IDevice* device, ISlangMutableFileSystem* fileSys, - const char* shaderModuleName) + const char* shaderModuleName, + bool precompileToTarget) { Slang::ComPtr slangSession; SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); @@ -34,6 +35,20 @@ namespace gfx_test if (!module) return SLANG_FAIL; + if (precompileToTarget) + { + SlangCompileTarget target; + switch (device->getDeviceInfo().deviceType) + { + case gfx::DeviceType::DirectX12: + target = SLANG_DXIL; + break; + default: + return SLANG_FAIL; + } + module->precompileForTarget(target, diagnosticsBlob.writeRef()); + } + // Write loaded modules to memory file system. for (SlangInt i = 0; i < slangSession->getLoadedModuleCount(); i++) { @@ -50,7 +65,7 @@ namespace gfx_test return SLANG_OK; } - void precompiledModule2TestImpl(IDevice* device, UnitTestContext* context) + void precompiledModule2TestImplCommon(IDevice* device, UnitTestContext* context, bool precompileToTarget) { Slang::ComPtr transientHeap; ITransientResourceHeap::Desc transientHeapDesc = {}; @@ -63,7 +78,7 @@ namespace gfx_test ComPtr shaderProgram; slang::ProgramLayout* slangReflection; - GFX_CHECK_CALL_ABORT(precompileProgram(device, memoryFileSystem.get(), "precompiled-module-imported")); + GFX_CHECK_CALL_ABORT(precompileProgram(device, memoryFileSystem.get(), "precompiled-module-imported", precompileToTarget)); // Next, load the precompiled slang program. Slang::ComPtr slangSession; @@ -168,11 +183,26 @@ namespace gfx_test Slang::makeArray(3.0f, 3.0f, 3.0f, 3.0f)); } + void precompiledModule2TestImpl(IDevice* device, UnitTestContext* context) + { + precompiledModule2TestImplCommon(device, context, false); + } + + void precompiledTargetModule2TestImpl(IDevice* device, UnitTestContext* context) + { + precompiledModule2TestImplCommon(device, context, true); + } + SLANG_UNIT_TEST(precompiledModule2D3D12) { runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); } + SLANG_UNIT_TEST(precompiledTargetModule2D3D12) + { + runTestImpl(precompiledTargetModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + } + SLANG_UNIT_TEST(precompiledModule2Vulkan) { runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);