Skip to content

Commit

Permalink
Fix global value inlining for spirv_asm blocks. (shader-slang#4339)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Jun 11, 2024
1 parent 7e79669 commit 5da06d4
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 122 deletions.
5 changes: 4 additions & 1 deletion source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -3190,7 +3190,10 @@ struct IRSPIRVAsmInst : IRInst

IRSPIRVAsmOperand* getOpcodeOperand()
{
const auto opcodeOperand = cast<IRSPIRVAsmOperand>(getOperand(0));
auto operand = getOperand(0);
if (auto globalRef = as<IRGlobalValueRef>(operand))
operand = globalRef->getValue();
const auto opcodeOperand = cast<IRSPIRVAsmOperand>(operand);
// This must be either:
// - An enum, such as 'OpNop'
// - The __truncate pseudo-instruction
Expand Down
294 changes: 173 additions & 121 deletions source/slang/slang-ir-spirv-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1690,143 +1690,194 @@ struct SPIRVLegalizationContext : public SourceEmitterBase

}

// Opcodes that can exist in global scope, as long as the operands are.
bool isLegalGlobalInst(IRInst* inst)
struct GlobalInstInliningContext
{
switch (inst->getOp())
Dictionary<IRInst*, bool> m_mapGlobalInstToShouldInline;

// Opcodes that can exist in global scope, as long as the operands are.
bool isLegalGlobalInst(IRInst* inst)
{
case kIROp_MakeStruct:
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
case kIROp_MakeVector:
case kIROp_MakeMatrix:
case kIROp_MakeMatrixFromScalar:
case kIROp_MakeVectorFromScalar:
return true;
default:
return false;
switch (inst->getOp())
{
case kIROp_MakeStruct:
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
case kIROp_MakeVector:
case kIROp_MakeMatrix:
case kIROp_MakeMatrixFromScalar:
case kIROp_MakeVectorFromScalar:
return true;
default:
if (as<IRConstant>(inst))
return true;
if (as<IRSPIRVAsmOperand>(inst))
return true;
return false;
}
}
}

// Opcodes that can be inlined into function bodies.
bool isInlinableGlobalInst(IRInst* inst)
{
switch (inst->getOp())
// Opcodes that can be inlined into function bodies.
bool isInlinableGlobalInst(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
case kIROp_FRem:
case kIROp_IRem:
case kIROp_Lsh:
case kIROp_Rsh:
case kIROp_And:
case kIROp_Or:
case kIROp_Not:
case kIROp_Neg:
case kIROp_Div:
case kIROp_FieldExtract:
case kIROp_FieldAddress:
case kIROp_GetElement:
case kIROp_GetElementPtr:
case kIROp_GetOffsetPtr:
case kIROp_UpdateElement:
case kIROp_MakeTuple:
case kIROp_GetTupleElement:
case kIROp_MakeStruct:
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
case kIROp_MakeVector:
case kIROp_MakeMatrix:
case kIROp_MakeMatrixFromScalar:
case kIROp_MakeVectorFromScalar:
case kIROp_swizzle:
case kIROp_swizzleSet:
case kIROp_MatrixReshape:
case kIROp_MakeString:
case kIROp_MakeResultError:
case kIROp_MakeResultValue:
case kIROp_GetResultError:
case kIROp_GetResultValue:
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
case kIROp_CastIntToPtr:
case kIROp_PtrCast:
case kIROp_CastPtrToBool:
case kIROp_CastPtrToInt:
case kIROp_BitAnd:
case kIROp_BitNot:
case kIROp_BitOr:
case kIROp_BitXor:
case kIROp_BitCast:
case kIROp_IntCast:
case kIROp_FloatCast:
case kIROp_Greater:
case kIROp_Less:
case kIROp_Geq:
case kIROp_Leq:
case kIROp_Neq:
case kIROp_Eql:
case kIROp_Call:
case kIROp_SPIRVAsm:
return true;
default:
if (as<IRSPIRVAsmInst>(inst))
return true;
if (as<IRSPIRVAsmOperand>(inst))
return true;
return false;
}
}

bool shouldInlineInstImpl(IRInst* inst)
{
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
case kIROp_FRem:
case kIROp_IRem:
case kIROp_Lsh:
case kIROp_Rsh:
case kIROp_And:
case kIROp_Or:
case kIROp_Not:
case kIROp_Neg:
case kIROp_Div:
case kIROp_FieldExtract:
case kIROp_FieldAddress:
case kIROp_GetElement:
case kIROp_GetElementPtr:
case kIROp_GetOffsetPtr:
case kIROp_UpdateElement:
case kIROp_MakeTuple:
case kIROp_GetTupleElement:
case kIROp_MakeStruct:
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
case kIROp_MakeVector:
case kIROp_MakeMatrix:
case kIROp_MakeMatrixFromScalar:
case kIROp_MakeVectorFromScalar:
case kIROp_swizzle:
case kIROp_swizzleSet:
case kIROp_MatrixReshape:
case kIROp_MakeString:
case kIROp_MakeResultError:
case kIROp_MakeResultValue:
case kIROp_GetResultError:
case kIROp_GetResultValue:
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
case kIROp_CastIntToPtr:
case kIROp_PtrCast:
case kIROp_CastPtrToBool:
case kIROp_CastPtrToInt:
case kIROp_BitAnd:
case kIROp_BitNot:
case kIROp_BitOr:
case kIROp_BitXor:
case kIROp_BitCast:
case kIROp_IntCast:
case kIROp_FloatCast:
case kIROp_Greater:
case kIROp_Less:
case kIROp_Geq:
case kIROp_Leq:
case kIROp_Neq:
case kIROp_Eql:
case kIROp_Call:
case kIROp_SPIRVAsm:
if (!isInlinableGlobalInst(inst))
return false;
if (isLegalGlobalInst(inst))
{
for (UInt i = 0; i < inst->getOperandCount(); i++)
if (shouldInlineInst(inst->getOperand(i)))
return true;
return false;
}
return true;
default:
return false;
}
}

bool shouldInlineInst(IRInst* inst)
{
if (!isInlinableGlobalInst(inst))
return false;
if (isLegalGlobalInst(inst))
bool shouldInlineInst(IRInst* inst)
{
for (UInt i = 0; i < inst->getOperandCount(); i++)
if (shouldInlineInst(inst->getOperand(i)))
return true;
return false;
bool result = false;
if (m_mapGlobalInstToShouldInline.tryGetValue(inst, result))
return result;
result = shouldInlineInstImpl(inst);
m_mapGlobalInstToShouldInline[inst] = result;
return result;
}
return true;
}

/// Inline `inst` in the local function body so they can be emitted as a local inst.
///
IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* inst, IRCloneEnv& cloneEnv)
{
if (!shouldInlineInst(inst))
IRInst* inlineInst(IRBuilder& builder, IRCloneEnv& cloneEnv, IRInst* inst)
{
switch (inst->getOp())
IRInst* result;
if (cloneEnv.mapOldValToNew.tryGetValue(inst, result))
return result;

for (UInt i = 0; i < inst->getOperandCount(); i++)
{
case kIROp_Func:
case kIROp_Specialize:
case kIROp_Generic:
case kIROp_LookupWitness:
return inst;
}
if (as<IRType>(inst))
return inst;

// If we encounter a global value that shouldn't be inlined, e.g. a const literal,
// we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses
// can be pinned to the function body.
auto result = builder.emitGlobalValueRef(inst);
auto operand = inst->getOperand(i);
IRBuilder operandBuilder(builder);
setInsertBeforeOutsideASM(operandBuilder, builder.getInsertLoc().getInst());
maybeInlineGlobalValue(operandBuilder, inst, operand, cloneEnv);
}
result = cloneInstAndOperands(&cloneEnv, &builder, inst);
cloneEnv.mapOldValToNew[inst] = result;
IRBuilder subBuilder(builder);
subBuilder.setInsertInto(result);
for (auto child : inst->getDecorations())
{
cloneInst(&cloneEnv, &subBuilder, child);
}
for (auto child : inst->getChildren())
{
inlineInst(subBuilder, cloneEnv, child);
}
return result;
}

// If the global value is inlinable, we make all its operands avaialble locally, and then copy it
// to the local scope.
ShortList<IRInst*> args;
for (UInt i = 0; i < inst->getOperandCount(); i++)
/// Inline `inst` in the local function body so they can be emitted as a local inst.
///
IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* user, IRInst* inst, IRCloneEnv& cloneEnv)
{
auto operand = inst->getOperand(i);
auto inlinedOperand = maybeInlineGlobalValue(builder, operand, cloneEnv);
args.add(inlinedOperand);
if (!shouldInlineInst(inst))
{
switch (inst->getOp())
{
case kIROp_Func:
case kIROp_Specialize:
case kIROp_Generic:
case kIROp_LookupWitness:
return inst;
}
if (as<IRType>(inst))
return inst;

// If we encounter a global value that shouldn't be inlined, e.g. a const literal,
// we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses
// can be pinned to the function body.
auto result = inst;
bool shouldWrapGlobalRef = true;
if (!isLegalGlobalInst(user) && !getIROpInfo(user->getOp()).isHoistable())
shouldWrapGlobalRef = false;
else if (as<IRSPIRVAsmOperand>(user) && as<IRSPIRVAsmOperandInst>(user))
shouldWrapGlobalRef = false;
else if (as<IRSPIRVAsmInst>(user))
shouldWrapGlobalRef = false;
if (shouldWrapGlobalRef)
result = builder.emitGlobalValueRef(inst);
cloneEnv.mapOldValToNew[inst] = result;
return result;
}

// If the global value is inlinable, we make all its operands avaialble locally, and then copy it
// to the local scope.
return inlineInst(builder, cloneEnv, inst);
}
auto result = cloneInst(&cloneEnv, &builder, inst);
cloneEnv.mapOldValToNew[inst] = result;
return result;
}
};

void processBranch(IRInst* branch)
{
Expand Down Expand Up @@ -2079,7 +2130,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}

void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst)
static void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst)
{
auto parent = beforeInst->getParent();
while (parent)
Expand Down Expand Up @@ -2234,6 +2285,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Inline global values that can't represented by SPIRV constant inst
// to their use sites.
List<IRUse*> globalInstUsesToInline;
GlobalInstInliningContext globalInstInliningContext;

for (auto globalInst : m_module->getGlobalInsts())
{
Expand All @@ -2248,7 +2300,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
sortBlocksInFunc(func);
}

if (isInlinableGlobalInst(globalInst))
if (globalInstInliningContext.isInlinableGlobalInst(globalInst))
{
for (auto use = globalInst->firstUse; use; use = use->nextUse)
{
Expand All @@ -2264,7 +2316,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(user);
setInsertBeforeOutsideASM(builder, user);
IRCloneEnv cloneEnv;
auto val = maybeInlineGlobalValue(builder, use->get(), cloneEnv);
auto val = globalInstInliningContext.maybeInlineGlobalValue(builder, use->getUser(), use->get(), cloneEnv);
if (val != use->get())
builder.replaceOperand(use, val);
}
Expand Down
22 changes: 22 additions & 0 deletions tests/spirv/static-array-spv-asm.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv

// Test that we can use intrinsics in global scope constant array, which causes
// the spirv_asm to be inlined in global module scope.
// Our global value inlining pass should be able to clean up the global scope spirv_asm
// blocks and inlining them to use sites.

// CHECK: %main = OpFunction
// CHECK: OpStore

static const uint staticArr[] = {
uint((((uint)round(saturate(1) * 255) << 24) | ((uint)round(saturate(0) * 255) << 16) | ((uint)round(saturate(0) * 255) << 8) | 0xff)),
uint((((uint)round(saturate(1) * 255) << 24) | ((uint)round(saturate(0) * 255) << 16) | ((uint)round(saturate(1) * 255) << 8) | 0xff))
};

RWStructuredBuffer<int> buffer;

[numthreads(1,1,1)]
void main(int i : SV_DispatchThreadID)
{
buffer[0] = staticArr[i] + staticArr[1];
}

0 comments on commit 5da06d4

Please sign in to comment.