From 7c2ff54758d26b73074fd14143ecd843ba685e0d Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 4 Nov 2024 17:37:50 -0800 Subject: [PATCH] Various WGSL fixes. (#5490) * [WGSL] make sure switch has a default label. * Various WGSL fixes. * Update rhi submodule commit * format code * Remove unnecessary DISABLE_TEST directive on not applicable test. * Matrix comp mul + `select`. * Legalize binary ops for wgsl. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- external/slang-rhi | 2 +- source/slang/slang-emit-wgsl.cpp | 126 +++++++++++++++++- source/slang/slang-emit-wgsl.h | 2 + source/slang/slang-ir-insts.h | 1 + source/slang/slang-ir-wgsl-legalize.cpp | 90 ++++++++++++- source/slang/slang-ir.cpp | 11 ++ tests/autodiff-dstdlib/dstdlib-abs.slang | 2 +- tests/autodiff/matrix-arithmetic-fwd.slang | 2 +- .../reverse-loop-checkpoint-test.slang | 1 + tests/bugs/nested-switch.slang | 2 +- .../compute-sampler-feedback.slang | 1 - tests/ir/string-literal-hash.slang | 2 +- .../constants/constexpr-loop.slang | 2 +- tests/library/linked.spirv | Bin 816 -> 0 bytes 14 files changed, 231 insertions(+), 13 deletions(-) delete mode 100644 tests/library/linked.spirv diff --git a/external/slang-rhi b/external/slang-rhi index 93c2ba8f68..10ab9c69fb 160000 --- a/external/slang-rhi +++ b/external/slang-rhi @@ -1 +1 @@ -Subproject commit 93c2ba8f68edee6732372ce4505bfc2a8640a1ba +Subproject commit 10ab9c69fb0f1e3f476c7fd66ca7f3bedffebe55 diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 4aca03a61b..d8ec017763 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -497,6 +497,34 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout) } } +static bool isStaticConst(IRInst* inst) +{ + if (inst->getParent()->getOp() == kIROp_Module) + { + return true; + } + switch (inst->getOp()) + { + case kIROp_MakeVector: + case kIROp_swizzle: + case kIROp_swizzleSet: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_CastFloatToInt: + case kIROp_CastIntToFloat: + case kIROp_BitCast: + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (!isStaticConst(inst->getOperand(i))) + return false; + } + return true; + } + } + return false; +} + void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl) { switch (varDecl->getOp()) @@ -505,14 +533,10 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl) case kIROp_GlobalVar: case kIROp_Var: m_writer->emit("var"); break; default: - if (as(varDecl->getParent())) - { + if (isStaticConst(varDecl)) m_writer->emit("const"); - } else - { m_writer->emit("var"); - } break; } @@ -977,6 +1001,33 @@ void WGSLSourceEmitter::emitCallArg(IRInst* inst) } } +bool WGSLSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) +{ + bool result = CLikeSourceEmitter::shouldFoldInstIntoUseSites(inst); + if (result) + { + // If inst is a matrix, and is used in a component-wise multiply, + // we need to not fold it. + if (as(inst->getDataType())) + { + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (user->getOp() == kIROp_Mul) + { + if (as(user->getOperand(0)->getDataType()) && + as(user->getOperand(1)->getDataType())) + { + return false; + } + } + } + } + } + return result; +} + + bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { EmitOpInfo outerPrec = inOuterPrec; @@ -1126,6 +1177,71 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu return true; } break; + + case kIROp_GetStringHash: + { + auto getStringHashInst = as(inst); + auto stringLit = getStringHashInst->getStringLit(); + + if (stringLit) + { + auto slice = stringLit->getStringSlice(); + emitType(inst->getDataType()); + m_writer->emit("("); + m_writer->emit((int)getStableHashCode32(slice.begin(), slice.getLength()).hash); + m_writer->emit(")"); + } + else + { + // Couldn't handle + diagnoseUnhandledInst(inst); + } + return true; + } + + case kIROp_Mul: + { + if (!as(inst->getOperand(0)->getDataType()) || + !as(inst->getOperand(1)->getDataType())) + { + return false; + } + // Mul(m1, m2) should be translated to component-wise multiplication in WGSL. + auto matrixType = as(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + emitType(inst->getDataType()); + m_writer->emit("("); + for (IRIntegerValue i = 0; i < rowCount; i++) + { + if (i != 0) + { + m_writer->emit(", "); + } + emitOperand(inst->getOperand(0), getInfo(EmitOp::Postfix)); + m_writer->emit("["); + m_writer->emit(i); + m_writer->emit("] * "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::Postfix)); + m_writer->emit("["); + m_writer->emit(i); + m_writer->emit("]"); + } + m_writer->emit(")"); + + return true; + } + + case kIROp_Select: + { + m_writer->emit("select("); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } } return false; diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h index 70df65933c..1a8ec2fd5a 100644 --- a/source/slang/slang-emit-wgsl.h +++ b/source/slang/slang-emit-wgsl.h @@ -50,6 +50,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter void emit(const AddressSpace addressSpace); + virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE; + private: // Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns void emitMatrixType( diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index c44211c1c0..9a081f9de5 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4021,6 +4021,7 @@ struct IRBuilder IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair); IRInst* emitMakeVector(IRType* type, UInt argCount, IRInst* const* args); IRInst* emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue); + IRInst* emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue); IRInst* emitMakeVector(IRType* type, List const& args) { diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index c97a8a89f6..96eb13be4c 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -51,6 +51,8 @@ struct LegalizeWGSLEntryPointContext String* optionalSemanticIndex, IRInst* parentVar); void legalizeCall(IRCall* call); + void legalizeSwitch(IRSwitch* switchInst); + void legalizeBinaryOp(IRInst* inst); void processInst(IRInst* inst); }; @@ -349,11 +351,97 @@ void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call) } } +void LegalizeWGSLEntryPointContext::legalizeSwitch(IRSwitch* switchInst) +{ + // WGSL Requires all switch statements to contain a default case. + // If the switch statement does not contain a default case, we will add one. + if (switchInst->getDefaultLabel() != switchInst->getBreakLabel()) + return; + IRBuilder builder(switchInst); + auto defaultBlock = builder.createBlock(); + builder.setInsertInto(defaultBlock); + builder.emitBranch(switchInst->getBreakLabel()); + defaultBlock->insertBefore(switchInst->getBreakLabel()); + List cases; + for (UInt i = 0; i < switchInst->getCaseCount(); i++) + { + cases.add(switchInst->getCaseValue(i)); + cases.add(switchInst->getCaseLabel(i)); + } + builder.setInsertBefore(switchInst); + auto newSwitch = builder.emitSwitch( + switchInst->getCondition(), + switchInst->getBreakLabel(), + defaultBlock, + (UInt)cases.getCount(), + cases.getBuffer()); + switchInst->transferDecorationsTo(newSwitch); + switchInst->removeAndDeallocate(); +} + +void LegalizeWGSLEntryPointContext::legalizeBinaryOp(IRInst* inst) +{ + auto isVectorOrMatrix = [](IRType* type) + { + switch (type->getOp()) + { + case kIROp_VectorType: + case kIROp_MatrixType: return true; + default: return false; + } + }; + if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) && + as(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newRhs = builder.emitMakeCompositeFromScalar( + inst->getOperand(0)->getDataType(), + inst->getOperand(1)); + builder.replaceOperand(inst->getOperands() + 1, newRhs); + } + else if ( + as(inst->getOperand(0)->getDataType()) && + isVectorOrMatrix(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newLhs = builder.emitMakeCompositeFromScalar( + inst->getOperand(1)->getDataType(), + inst->getOperand(0)); + builder.replaceOperand(inst->getOperands(), newLhs); + } +} + void LegalizeWGSLEntryPointContext::processInst(IRInst* inst) { switch (inst->getOp()) { - case kIROp_Call: legalizeCall(static_cast(inst)); break; + case kIROp_Call: legalizeCall(static_cast(inst)); break; + case kIROp_Switch: legalizeSwitch(as(inst)); break; + + // For all binary operators, make sure both side of the operator have the same type + // (vector-ness and matrix-ness). + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: legalizeBinaryOp(inst); break; + default: for (auto child : inst->getModifiableChildren()) processInst(child); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 49273163e8..3bd31d6e96 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4162,6 +4162,17 @@ IRInst* IRBuilder::emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue) return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &scalarValue); } +IRInst* IRBuilder::emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue) +{ + switch (type->getOp()) + { + case kIROp_VectorType: return emitMakeVectorFromScalar(type, scalarValue); + case kIROp_MatrixType: return emitMakeMatrixFromScalar(type, scalarValue); + case kIROp_ArrayType: return emitMakeArrayFromElement(type, scalarValue); + default: SLANG_UNEXPECTED("unhandled composite type"); UNREACHABLE_RETURN(nullptr); + } +} + IRInst* IRBuilder::emitMatrixReshape(IRType* type, IRInst* inst) { return emitIntrinsicInst(type, kIROp_MatrixReshape, 1, &inst); diff --git a/tests/autodiff-dstdlib/dstdlib-abs.slang b/tests/autodiff-dstdlib/dstdlib-abs.slang index c0878bfb46..d11f06b31a 100644 --- a/tests/autodiff-dstdlib/dstdlib-abs.slang +++ b/tests/autodiff-dstdlib/dstdlib-abs.slang @@ -1,6 +1,6 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu +//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang b/tests/autodiff/matrix-arithmetic-fwd.slang index 0dd1936af7..0c2db76e92 100644 --- a/tests/autodiff/matrix-arithmetic-fwd.slang +++ b/tests/autodiff/matrix-arithmetic-fwd.slang @@ -1,6 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; diff --git a/tests/autodiff/reverse-loop-checkpoint-test.slang b/tests/autodiff/reverse-loop-checkpoint-test.slang index 19316a786d..8191608fd4 100644 --- a/tests/autodiff/reverse-loop-checkpoint-test.slang +++ b/tests/autodiff/reverse-loop-checkpoint-test.slang @@ -1,6 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj -output-using-type //TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none //DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates diff --git a/tests/bugs/nested-switch.slang b/tests/bugs/nested-switch.slang index 485a83e1fe..90abe70d5f 100644 --- a/tests/bugs/nested-switch.slang +++ b/tests/bugs/nested-switch.slang @@ -3,7 +3,7 @@ //TEST(compute):COMPARE_COMPUTE: -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj //TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu +//TEST(compute):COMPARE_COMPUTE:-wgpu int test(int t, int r) { diff --git a/tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang b/tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang index a7fc8731ce..77e7c20509 100644 --- a/tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang +++ b/tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang @@ -1,5 +1,4 @@ //TEST:COMPILE: -entry computeMain -stage compute -target callable tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu // Not available on non PS shader // dx.op.writeSamplerFeedback WriteSamplerFeedback diff --git a/tests/ir/string-literal-hash.slang b/tests/ir/string-literal-hash.slang index 678a8d9c7d..2d61a84c19 100644 --- a/tests/ir/string-literal-hash.slang +++ b/tests/ir/string-literal-hash.slang @@ -1,6 +1,6 @@ //TEST(compute):COMPARE_COMPUTE: -shaderobj //TEST(compute):COMPARE_COMPUTE: -vk -shaderobj -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu +//TEST(compute):COMPARE_COMPUTE:-wgpu // Note: disabled on CPU target until we can fill // in a more correct/complete `String` and `getStringHash` diff --git a/tests/language-feature/constants/constexpr-loop.slang b/tests/language-feature/constants/constexpr-loop.slang index 81b0a5c17a..7af9c60b24 100644 --- a/tests/language-feature/constants/constexpr-loop.slang +++ b/tests/language-feature/constants/constexpr-loop.slang @@ -1,6 +1,6 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu +//TEST(compute):COMPARE_COMPUTE_EX: -wgpu -compute -output-using-type //TEST_INPUT: set g_texture = Texture2D(size=8, content = one) //TEST_INPUT: set g_sampler = Sampler diff --git a/tests/library/linked.spirv b/tests/library/linked.spirv deleted file mode 100644 index 7ea385e714d710c0345885b3be9e5fb3d5531b6b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 816 zcmX|9O-sW-5S=z@tM#MS+IlOky$D_uMHEpe6;eVvFlX%5mMM?$P6_HDTt1d-(tDC(5yDDWFg?G`_@HR?S9%ZSzkCO*k z7f