From 33e8bfd43f66613f6f834fb0e1816ef43071f2e4 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 5 Sep 2024 11:53:56 -0700 Subject: [PATCH] Fix SPIRV SV_TessFactor type adaptation logic. (#5010) * Fix SPIRV SV_TessFactor type adaptation logic. * Fix compile error. --- source/slang/slang-ir-glsl-legalize.cpp | 26 +++++++++++-- tests/bugs/gh-4456.slang | 52 +++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 tests/bugs/gh-4456.slang diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index dabf0294f3..d8c1aa91ef 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -1197,13 +1197,15 @@ ScalarizedVal createSimpleGLSLGlobalVarying( IRType* type = inType; IRType* peeledRequiredType = nullptr; - + ShortList peeledRequiredArraySizes; + bool peeledRequiredArrayLevelMatchesUserDeclaredType = false; // A system-value semantic might end up needing to override the type // that the user specified. if( systemValueInfo && systemValueInfo->requiredType ) { type = systemValueInfo->requiredType; peeledRequiredType = type; + peeledRequiredArrayLevelMatchesUserDeclaredType = true; // Unpeel `type` using declarators so that it matches `inType`. for (auto dd = declarator; dd; dd = dd->next) { @@ -1214,8 +1216,13 @@ ScalarizedVal createSimpleGLSLGlobalVarying( if (auto arrayType = as(type)) { type = arrayType->getElementType(); + peeledRequiredArraySizes.add(arrayType->getElementCount()); peeledRequiredType = type; } + else + { + peeledRequiredArrayLevelMatchesUserDeclaredType = false; + } break; } } @@ -1305,15 +1312,20 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // Construct the actual type and type-layout for the global variable // IRTypeLayout* typeLayout = inTypeLayout; + Index requiredArraySizeIndex = peeledRequiredArraySizes.getCount() - 1; for( auto dd = declarator; dd; dd = dd->next ) { switch(dd->flavor) { case GlobalVaryingDeclarator::Flavor::array: { + auto elementCount = peeledRequiredArrayLevelMatchesUserDeclaredType + ? peeledRequiredArraySizes[requiredArraySizeIndex] : dd->elementCount; + auto arrayType = builder->getArrayType( type, - dd->elementCount); + elementCount); + requiredArraySizeIndex--; IRArrayTypeLayout::Builder arrayTypeLayoutBuilder(builder, typeLayout); if( auto resInfo = inTypeLayout->findSizeAttr(kind) ) @@ -1321,10 +1333,9 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // TODO: it is kind of gross to be re-running some // of the type layout logic here. - UInt elementCount = (UInt) getIntVal(dd->elementCount); arrayTypeLayoutBuilder.addResourceUsage( kind, - resInfo->getSize() * elementCount); + resInfo->getSize() * getIntVal(elementCount)); } auto arrayTypeLayout = arrayTypeLayoutBuilder.build(); @@ -1774,6 +1785,13 @@ ScalarizedVal adaptType( val = builder->emitSwizzle(fromVector->getElementType(), val, 1, &index); } } + else if (auto fromArray = as(fromType)) + { + if (as(toType)) + { + val = builder->emitElementExtract(fromArray->getElementType(), val, builder->getIntValue(builder->getIntType(), 0)); + } + } // TODO: actually consider what needs to go on here... return ScalarizedVal::value(builder->emitCast( toType, diff --git a/tests/bugs/gh-4456.slang b/tests/bugs/gh-4456.slang new file mode 100644 index 0000000000..05b346e77c --- /dev/null +++ b/tests/bugs/gh-4456.slang @@ -0,0 +1,52 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -lang slang -D__spirv__ -emit-spirv-directly -profile ds_6_5 -entry main + +// CHECK: OpEntryPoint + +struct VSSceneIn { + float3 pos : POSITION; +}; + +struct PSSceneIn { + float4 pos : SV_Position; +}; + +struct HSPerVertexData { + PSSceneIn v; +}; + +struct HSPerPatchData { + float edges[3] : SV_TessFactor; + float inside : SV_InsideTessFactor; +}; + +RaytracingAccelerationStructure AccelerationStructure : register(t0); +RayDesc MakeRayDesc() +{ + RayDesc desc; + desc.Origin = float3(0,0,0); + desc.Direction = float3(1,0,0); + desc.TMin = 0.0f; + desc.TMax = 9999.0; + return desc; +} +void doInitialize(out RayQuery query, RayDesc ray) +{ + query.TraceRayInline(AccelerationStructure,RAY_FLAG_FORCE_NON_OPAQUE,0xFF,ray); +} + +[domain("tri")] PSSceneIn main( + const float3 bary : SV_DomainLocation, + const OutputPatch patch, + const HSPerPatchData perPatchData) +{ + PSSceneIn v; + v.pos = patch[0].v.pos * bary.x + patch[1].v.pos * bary.y + patch[2].v.pos * bary.z + perPatchData.edges[1]; + + RayQuery q; + RayDesc ray = MakeRayDesc(); + + q.TraceRayInline(AccelerationStructure,RAY_FLAG_FORCE_OPAQUE, 0xFF, ray); + doInitialize(q, ray); + + return v; +}