Skip to content

Commit

Permalink
Precompiled SPIR-V import support (shader-slang#5048)
Browse files Browse the repository at this point in the history
* Precompiled SPIR-V import support

Adds appropriate linkage and function declaration syntax
for SPIR-V functions that are declared, to be imported
from another SPIR-V module.

Unlike DXIL, stripping the Slang IR for a function down
to a declaration requires retaining a block of parameters,
as the function declaration must be emitted to SPIR-V
with the same parameters as a definition. Because that
thwarts the logic in Slang to tell the difference between
a declaration and definition, and explicit decoration is
introduced to explicitly mark functions which need to be
treated as declarations during emit phase.

Fixes shader-slang#4992

Co-authored-by: Yong He <[email protected]>
  • Loading branch information
cheneym2 and csyonghe authored Oct 29, 2024
1 parent 99c728f commit 613a29a
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 21 deletions.
107 changes: 95 additions & 12 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "slang-ir-call-graph.h"
#include "slang-ir-insts.h"
#include "slang-ir-layout.h"
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-spirv-legalize.h"
#include "slang-ir-spirv-snippet.h"
#include "slang-ir-util.h"
Expand Down Expand Up @@ -2628,14 +2629,75 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
/// Emit a declaration for the given `irFunc`
SpvInst* emitFuncDeclaration(IRFunc* irFunc)
{
if (irFunc->findDecorationImpl(kIROp_SPIRVOpDecoration))
return nullptr;
// For now we aren't handling function declarations;
// we expect to deal only with fully linked modules.
// [2.4: Logical Layout of a Module]
//
// > All function declarations("declarations" are functions without a
// body; there is no forward declaration to a function with a body).
//
auto section = getSection(SpvLogicalSectionID::FunctionDeclarations);

// > A function declaration is as follows.
// > * Function declaration, using OpFunction.
// > * Function parameter declarations, using OpFunctionParameter.
// > * Function end, using OpFunctionEnd.
//

// [3.24. Function Control]
//
// TODO: We should eventually support emitting the "function control"
// mask to include inline and other hint bits based on decorations
// set on `irFunc`.
//
SpvFunctionControlMask spvFunctionControl = SpvFunctionControlMaskNone;

// [3.32.9. Function Instructions]
//
// > OpFunction
//
// Note that the type <id> of a SPIR-V function uses the
// *result* type of the function, while the actual function
// type is given as a later operand. Slan IR instead uses
// the type of a function instruction store, you know, its *type*.
//
SpvInst* spvFunc = emitOpFunction(
section,
irFunc,
irFunc->getDataType()->getResultType(),
spvFunctionControl,
irFunc->getDataType());

// > OpFunctionParameter
//
// Though parameters always belong to blocks in Slang, there are no
// blocks in a function declaration, so we will emit the parameters
// as derived from the function's type.
//
auto funcType = irFunc->getDataType();
auto paramCount = funcType->getParamCount();
for (UInt pp = 0; pp < paramCount; ++pp)
{
auto paramType = funcType->getParamType(pp);
SpvInst* spvParam = emitOpFunctionParameter(spvFunc, nullptr, paramType);
maybeEmitPointerDecoration(spvParam, paramType, false, kIROp_Param);
}

// [3.32.9. Function Instructions]
//
// > OpFunctionEnd
//
// In the SPIR-V encoding a function is logically the parent of any
// instructions up to a matching `OpFunctionEnd`. In our intermediate
// structure we will make the `OpFunctionEnd` be the last child of
// the `OpFunction`.
//
m_sink->diagnose(irFunc, Diagnostics::internalCompilerError);
SLANG_UNEXPECTED("function declaration in SPIR-V emit");
UNREACHABLE_RETURN(nullptr);
emitOpFunctionEnd(spvFunc, nullptr);

// We will emit any decorations pertinent to the function to the
// appropriate section of the module.
//
emitDecorations(irFunc, getID(spvFunc));

return spvFunc;
}

/// Emit a SPIR-V function definition for the Slang IR function `irFunc`.
Expand Down Expand Up @@ -4358,6 +4420,21 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SpvLinkageTypeExport);
break;
}
case kIROp_DownstreamModuleImportDecoration:
{
requireSPIRVCapability(SpvCapabilityLinkage);
auto name =
decoration->getParent()->findDecoration<IRExportDecoration>()->getMangledName();
emitInst(
getSection(SpvLogicalSectionID::Annotations),
decoration,
SpvOpDecorate,
dstID,
SpvDecorationLinkageAttributes,
name,
SpvLinkageTypeImport);
break;
}
// ...
}

Expand Down Expand Up @@ -5019,9 +5096,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return nullptr;
}

void maybeEmitPointerDecoration(SpvInst* varInst, IRInst* inst)
void maybeEmitPointerDecoration(SpvInst* varInst, IRType* type, bool isVar, IROp op)
{
auto ptrType = as<IRPtrType>(unwrapArray(inst->getDataType()));
auto ptrType = as<IRPtrType>(unwrapArray(type));
if (!ptrType)
return;
if (addressSpaceToStorageClass(ptrType->getAddressSpace()) ==
Expand All @@ -5033,7 +5110,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
getSection(SpvLogicalSectionID::Annotations),
nullptr,
varInst,
(as<IRVar>(inst) ? SpvDecorationAliasedPointer : SpvDecorationAliased));
(isVar ? SpvDecorationAliasedPointer : SpvDecorationAliased));
}
else
{
Expand All @@ -5049,14 +5126,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
getSection(SpvLogicalSectionID::Annotations),
nullptr,
varInst,
(inst->getOp() == kIROp_GlobalVar || inst->getOp() == kIROp_Var ||
inst->getOp() == kIROp_DebugVar
(op == kIROp_GlobalVar || op == kIROp_Var || op == kIROp_DebugVar
? SpvDecorationAliasedPointer
: SpvDecorationAliased));
}
}
}

void maybeEmitPointerDecoration(SpvInst* varInst, IRInst* inst)
{
maybeEmitPointerDecoration(varInst, inst->getDataType(), as<IRVar>(inst), inst->getOp());
}

SpvInst* emitParam(SpvInstParent* parent, IRInst* inst)
{
auto paramSpvInst = emitOpFunctionParameter(parent, inst, inst->getFullType());
Expand Down Expand Up @@ -7534,6 +7615,8 @@ SlangResult emitSPIRVFromIR(
}
#endif

removeAvailableInDownstreamModuleDecorations(CodeGenTarget::SPIRV, irModule);

auto shouldPreserveParams = codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(
CompilerOptionName::PreserveParameters);
auto generateWholeProgram = codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(PublicDecoration, public, 0, 0)
INST(HLSLExportDecoration, hlslExport, 0, 0)
INST(DownstreamModuleExportDecoration, downstreamModuleExport, 0, 0)
INST(DownstreamModuleImportDecoration, downstreamModuleImport, 0, 0)
INST(PatchConstantFuncDecoration, patchConstantFunc, 1, 0)
INST(OutputControlPointsDecoration, outputControlPoints, 1, 0)
INST(OutputTopologyDecoration, outputTopology, 1, 0)
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ IR_SIMPLE_DECORATION(HLSLMeshPayloadDecoration)
IR_SIMPLE_DECORATION(GlobalInputDecoration)
IR_SIMPLE_DECORATION(GlobalOutputDecoration)
IR_SIMPLE_DECORATION(DownstreamModuleExportDecoration)
IR_SIMPLE_DECORATION(DownstreamModuleImportDecoration)

struct IRAvailableInDownstreamIRDecoration : IRDecoration
{
Expand Down
9 changes: 4 additions & 5 deletions source/slang/slang-ir-redundancy-removal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
void removeAvailableInDownstreamModuleDecorations(CodeGenTarget target, IRModule* module)
{
List<IRInst*> toRemove;
auto builder = IRBuilder(module);
for (auto globalInst : module->getGlobalInsts())
{
if (auto funcInst = as<IRFunc>(globalInst))
Expand All @@ -181,13 +182,11 @@ void removeAvailableInDownstreamModuleDecorations(CodeGenTarget target, IRModule
(dec->getTarget() == target))
{
// Gut the function definition, turning it into a declaration
for (auto inst : funcInst->getChildren())
for (auto block : funcInst->getBlocks())
{
if (inst->getOp() == kIROp_Block)
{
toRemove.add(inst);
}
toRemove.add(block);
}
builder.addDecoration(funcInst, kIROp_DownstreamModuleImportDecoration);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/library/export-library-generics.slang
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public int normalFuncUsesGeneric(int a)
return genericFunc(obj);
}

public int normalFunc(int a)
public int normalFunc(int a, float b)
{
return a - 2;
return a - floor(b);
}
10 changes: 10 additions & 0 deletions tests/library/module-library-pointer-param.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//TEST_IGNORE_FILE:

// module-library-pointer-param.slang

module "module-library-pointer-param";

public int ptrFunc(int* a)
{
return *a;
}
2 changes: 1 addition & 1 deletion tests/library/precompiled-dxil-generics.slang
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ struct Attributes
[shader("anyhit")]
void anyhit(inout Payload payload, Attributes attrib)
{
payload.val = normalFunc(x * y) + normalFuncUsesGeneric(y);
payload.val = normalFunc(floor(x * y), x) + normalFuncUsesGeneric(y);
}
2 changes: 1 addition & 1 deletion tests/library/precompiled-spirv-generics.slang
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ struct Attributes
[shader("anyhit")]
void anyhit(inout Payload payload, Attributes attrib)
{
payload.val = normalFunc(x * y) + normalFuncUsesGeneric(y);
payload.val = normalFunc(floor(x * y), x) + normalFuncUsesGeneric(y);
}
31 changes: 31 additions & 0 deletions tests/library/precompiled-spirv-pointer-param.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// precompiled-spirv-pointer-param.slang

// A test that uses slang-modules with embedded precompiled SPIRV and a library containing
// a function with a pointer parameter.
// The test compiles a library slang (module-library-pointer-param.slang) with -embed-downstream-ir then links the
// library to entrypoint slang (this file).
// The test passes if there is no errror thrown.
// TODO: Check if final linkage used only the precompiled spirv.

//TEST:COMPILE: tests/library/module-library-pointer-param.slang -o tests/library/module-library-pointer-param.slang-module -target spirv -embed-downstream-ir -incomplete-library
//TEST:COMPILE: tests/library/precompiled-spirv-pointer-param.slang -target spirv -stage anyhit -entry anyhit -o tests/library/linked.spirv

import "module-library-pointer-param";

struct Payload
{
int val;
}

struct Attributes
{
float2 bary;
}

[vk::push_constant] int* g_int;

[shader("anyhit")]
void anyhit(inout Payload payload, Attributes attrib)
{
payload.val = ptrFunc(g_int);
}

0 comments on commit 613a29a

Please sign in to comment.