Skip to content

Commit

Permalink
Make precompileForTargets work with Slang API (shader-slang#4845)
Browse files Browse the repository at this point in the history
* Make precompileForTargets work with Slang API

precompileForTargets, renamed to precompileForTarget, does not need
an EndToEndCompileRequest and some objects created from it are not
necessary either.

Take only a target enum and a diagnostic blob as input and handle
everything else internally, such as creating the TargetReq with
chosen profile.

Fixes shader-slang#4790

* Update slang-module.cpp

* Update slang-module.cpp
  • Loading branch information
cheneym2 authored Aug 15, 2024
1 parent 99673d7 commit 27b2229
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 67 deletions.
4 changes: 4 additions & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -5444,6 +5444,10 @@ namespace slang
SlangInt32 index) = 0;

virtual SLANG_NO_THROW DeclReflection* SLANG_MCALL getModuleReflection() = 0;

virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget(
SlangCompileTarget target,
ISlangBlob** outDiagnostics) = 0;
};

#define SLANG_UUID_IModule IModule::getTypeGuid()
Expand Down
11 changes: 11 additions & 0 deletions source/slang-record-replay/record/slang-module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,17 @@ namespace SlangRecord
return res;
}

SLANG_NO_THROW SlangResult ModuleRecorder::precompileForTarget(
SlangCompileTarget target,
ISlangBlob** outDiagnostics)
{
// TODO: We should record this call
// https://github.com/shader-slang/slang/issues/4853
slangRecordLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
SlangResult res = m_actualModule->precompileForTarget(target, outDiagnostics);
return res;
}

SLANG_NO_THROW slang::ISession* ModuleRecorder::getSession()
{
slangRecordLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
Expand Down
3 changes: 3 additions & 0 deletions source/slang-record-replay/record/slang-module.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ namespace SlangRecord
virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDependencyFileCount() override;
virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(
SlangInt32 index) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget(
SlangCompileTarget target,
ISlangBlob** outDiagnostics) override;

// Interfaces for `IComponentType`
virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override;
Expand Down
84 changes: 28 additions & 56 deletions source/slang/slang-compiler-tu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,60 +8,24 @@

namespace Slang
{
SLANG_NO_THROW SlangResult SLANG_MCALL Module::precompileForTargets(
DiagnosticSink* sink,
EndToEndCompileRequest* endToEndReq,
TargetRequest* targetReq)
SLANG_NO_THROW SlangResult SLANG_MCALL Module::precompileForTarget(
SlangCompileTarget target,
slang::IBlob** outDiagnostics)
{
auto module = getIRModule();
Slang::Session* session = endToEndReq->getSession();
Slang::ASTBuilder* astBuilder = session->getGlobalASTBuilder();
Slang::Linkage* builtinLinkage = session->getBuiltinLinkage();
Slang::Linkage linkage(session, astBuilder, builtinLinkage);

CapabilityName precompileRequirement = CapabilityName::Invalid;
switch (targetReq->getTarget())
if (target != SLANG_DXIL)
{
case CodeGenTarget::DXIL:
linkage.addTarget(Slang::CodeGenTarget::DXIL);
precompileRequirement = CapabilityName::dxil_lib;
break;
default:
assert(!"Unhandled target");
break;
return SLANG_FAIL;
}
SLANG_ASSERT(precompileRequirement != CapabilityName::Invalid);
CodeGenTarget targetEnum = CodeGenTarget(target);

// Ensure precompilation capability requirements are met.
auto targetCaps = targetReq->getTargetCaps();
auto precompileRequirementsCapabilitySet = CapabilitySet(precompileRequirement);
if (targetCaps.atLeastOneSetImpliedInOther(precompileRequirementsCapabilitySet) == CapabilitySet::ImpliesReturnFlags::NotImplied)
{
// If `RestrictiveCapabilityCheck` is true we will error, else we will warn.
// error ...: dxil libraries require $0, entry point compiled with $1.
// warn ...: dxil libraries require $0, entry point compiled with $1, implicitly upgrading capabilities.
maybeDiagnoseWarningOrError(
sink,
targetReq->getOptionSet(),
DiagnosticCategory::Capability,
SourceLoc(),
Diagnostics::incompatibleWithPrecompileLib,
Diagnostics::incompatibleWithPrecompileLibRestrictive,
precompileRequirementsCapabilitySet,
targetCaps);

// add precompile requirements to the cooked targetCaps
targetCaps.join(precompileRequirementsCapabilitySet);
if (targetCaps.isInvalid())
{
sink->diagnose(SourceLoc(), Diagnostics::unknownCapability, targetCaps);
return SLANG_FAIL;
}
else
{
targetReq->setTargetCaps(targetCaps);
}
}
auto module = getIRModule();
auto linkage = getLinkage();

DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet);
applySettingsToDiagnosticSink(&sink, &sink, m_optionSet);

TargetRequest* targetReq = new TargetRequest(linkage, targetEnum);

List<RefPtr<ComponentType>> allComponentTypes;
allComponentTypes.add(this); // Add Module as a component type
Expand All @@ -72,23 +36,34 @@ namespace Slang
}

auto composite = CompositeComponentType::create(
&linkage,
linkage,
allComponentTypes);

TargetProgram tp(composite, targetReq);
tp.getOrCreateLayout(sink);
tp.getOrCreateLayout(&sink);
Slang::Index const entryPointCount = m_entryPoints.getCount();
tp.getOptionSet().add(CompilerOptionName::GenerateWholeProgram, true);

switch (targetReq->getTarget())
{
case CodeGenTarget::DXIL:
tp.getOptionSet().add(CompilerOptionName::Profile, Profile::RawEnum::DX_Lib_6_6);
break;
}

CodeGenContext::EntryPointIndices entryPointIndices;

entryPointIndices.setCount(entryPointCount);
for (Index i = 0; i < entryPointCount; i++)
entryPointIndices[i] = i;
CodeGenContext::Shared sharedCodeGenContext(&tp, entryPointIndices, sink, endToEndReq);
CodeGenContext::Shared sharedCodeGenContext(&tp, entryPointIndices, &sink, nullptr);
CodeGenContext codeGenContext(&sharedCodeGenContext);

ComPtr<IArtifact> outArtifact;
SlangResult res = codeGenContext.emitTranslationUnit(outArtifact);

sink.getBlobIfNeeded(outDiagnostics);

if (res != SLANG_OK)
{
return res;
Expand All @@ -105,9 +80,6 @@ namespace Slang
case CodeGenTarget::DXIL:
builder.emitEmbeddedDXIL(blob);
break;
default:
assert(!"Unhandled target");
break;
}

return SLANG_OK;
Expand Down
7 changes: 3 additions & 4 deletions source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -1482,10 +1482,9 @@ namespace Slang
SlangInt32 index) override;

/// Precompile TU to target language
virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTargets(
DiagnosticSink* sink,
EndToEndCompileRequest* endToEndReq,
TargetRequest* targetReq);
virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget(
SlangCompileTarget target,
slang::IBlob** outDiagnostics) override;

virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE;

Expand Down
8 changes: 4 additions & 4 deletions source/slang/slang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3252,10 +3252,10 @@ SlangResult EndToEndCompileRequest::executeActionsInner()

for (auto translationUnit : frontEndReq->translationUnits)
{
translationUnit->getModule()->precompileForTargets(
getSink(),
this,
targetReq);
SlangCompileTarget target = SlangCompileTarget(targetReq->getTarget());
translationUnit->getModule()->precompileForTarget(
target,
nullptr);
}
}
}
Expand Down
36 changes: 33 additions & 3 deletions tools/gfx-unit-test/precompiled-module-2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace gfx_test
static Slang::Result precompileProgram(
gfx::IDevice* device,
ISlangMutableFileSystem* fileSys,
const char* shaderModuleName)
const char* shaderModuleName,
bool precompileToTarget)
{
Slang::ComPtr<slang::ISession> slangSession;
SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
Expand All @@ -34,6 +35,20 @@ namespace gfx_test
if (!module)
return SLANG_FAIL;

if (precompileToTarget)
{
SlangCompileTarget target;
switch (device->getDeviceInfo().deviceType)
{
case gfx::DeviceType::DirectX12:
target = SLANG_DXIL;
break;
default:
return SLANG_FAIL;
}
module->precompileForTarget(target, diagnosticsBlob.writeRef());
}

// Write loaded modules to memory file system.
for (SlangInt i = 0; i < slangSession->getLoadedModuleCount(); i++)
{
Expand All @@ -50,7 +65,7 @@ namespace gfx_test
return SLANG_OK;
}

void precompiledModule2TestImpl(IDevice* device, UnitTestContext* context)
void precompiledModule2TestImplCommon(IDevice* device, UnitTestContext* context, bool precompileToTarget)
{
Slang::ComPtr<ITransientResourceHeap> transientHeap;
ITransientResourceHeap::Desc transientHeapDesc = {};
Expand All @@ -63,7 +78,7 @@ namespace gfx_test

ComPtr<IShaderProgram> shaderProgram;
slang::ProgramLayout* slangReflection;
GFX_CHECK_CALL_ABORT(precompileProgram(device, memoryFileSystem.get(), "precompiled-module-imported"));
GFX_CHECK_CALL_ABORT(precompileProgram(device, memoryFileSystem.get(), "precompiled-module-imported", precompileToTarget));

// Next, load the precompiled slang program.
Slang::ComPtr<slang::ISession> slangSession;
Expand Down Expand Up @@ -168,11 +183,26 @@ namespace gfx_test
Slang::makeArray<float>(3.0f, 3.0f, 3.0f, 3.0f));
}

void precompiledModule2TestImpl(IDevice* device, UnitTestContext* context)
{
precompiledModule2TestImplCommon(device, context, false);
}

void precompiledTargetModule2TestImpl(IDevice* device, UnitTestContext* context)
{
precompiledModule2TestImplCommon(device, context, true);
}

SLANG_UNIT_TEST(precompiledModule2D3D12)
{
runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
}

SLANG_UNIT_TEST(precompiledTargetModule2D3D12)
{
runTestImpl(precompiledTargetModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
}

SLANG_UNIT_TEST(precompiledModule2Vulkan)
{
runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
Expand Down

0 comments on commit 27b2229

Please sign in to comment.