Skip to content

Commit

Permalink
Fix SPIRV SV_TessFactor type adaptation logic. (shader-slang#5010)
Browse files Browse the repository at this point in the history
* Fix SPIRV SV_TessFactor type adaptation logic.

* Fix compile error.
  • Loading branch information
csyonghe authored Sep 5, 2024
1 parent 879ee3d commit 33e8bfd
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
26 changes: 22 additions & 4 deletions source/slang/slang-ir-glsl-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1197,13 +1197,15 @@ ScalarizedVal createSimpleGLSLGlobalVarying(

IRType* type = inType;
IRType* peeledRequiredType = nullptr;

ShortList<IRInst*> peeledRequiredArraySizes;
bool peeledRequiredArrayLevelMatchesUserDeclaredType = false;
// A system-value semantic might end up needing to override the type
// that the user specified.
if( systemValueInfo && systemValueInfo->requiredType )
{
type = systemValueInfo->requiredType;
peeledRequiredType = type;
peeledRequiredArrayLevelMatchesUserDeclaredType = true;
// Unpeel `type` using declarators so that it matches `inType`.
for (auto dd = declarator; dd; dd = dd->next)
{
Expand All @@ -1214,8 +1216,13 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
if (auto arrayType = as<IRArrayTypeBase>(type))
{
type = arrayType->getElementType();
peeledRequiredArraySizes.add(arrayType->getElementCount());
peeledRequiredType = type;
}
else
{
peeledRequiredArrayLevelMatchesUserDeclaredType = false;
}
break;
}
}
Expand Down Expand Up @@ -1305,26 +1312,30 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
// Construct the actual type and type-layout for the global variable
//
IRTypeLayout* typeLayout = inTypeLayout;
Index requiredArraySizeIndex = peeledRequiredArraySizes.getCount() - 1;
for( auto dd = declarator; dd; dd = dd->next )
{
switch(dd->flavor)
{
case GlobalVaryingDeclarator::Flavor::array:
{
auto elementCount = peeledRequiredArrayLevelMatchesUserDeclaredType
? peeledRequiredArraySizes[requiredArraySizeIndex] : dd->elementCount;

auto arrayType = builder->getArrayType(
type,
dd->elementCount);
elementCount);
requiredArraySizeIndex--;

IRArrayTypeLayout::Builder arrayTypeLayoutBuilder(builder, typeLayout);
if( auto resInfo = inTypeLayout->findSizeAttr(kind) )
{
// TODO: it is kind of gross to be re-running some
// of the type layout logic here.

UInt elementCount = (UInt) getIntVal(dd->elementCount);
arrayTypeLayoutBuilder.addResourceUsage(
kind,
resInfo->getSize() * elementCount);
resInfo->getSize() * getIntVal(elementCount));
}
auto arrayTypeLayout = arrayTypeLayoutBuilder.build();

Expand Down Expand Up @@ -1774,6 +1785,13 @@ ScalarizedVal adaptType(
val = builder->emitSwizzle(fromVector->getElementType(), val, 1, &index);
}
}
else if (auto fromArray = as<IRArrayTypeBase>(fromType))
{
if (as<IRBasicType>(toType))
{
val = builder->emitElementExtract(fromArray->getElementType(), val, builder->getIntValue(builder->getIntType(), 0));
}
}
// TODO: actually consider what needs to go on here...
return ScalarizedVal::value(builder->emitCast(
toType,
Expand Down
52 changes: 52 additions & 0 deletions tests/bugs/gh-4456.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv -lang slang -D__spirv__ -emit-spirv-directly -profile ds_6_5 -entry main

// CHECK: OpEntryPoint

struct VSSceneIn {
float3 pos : POSITION;
};

struct PSSceneIn {
float4 pos : SV_Position;
};

struct HSPerVertexData {
PSSceneIn v;
};

struct HSPerPatchData {
float edges[3] : SV_TessFactor;
float inside : SV_InsideTessFactor;
};

RaytracingAccelerationStructure AccelerationStructure : register(t0);
RayDesc MakeRayDesc()
{
RayDesc desc;
desc.Origin = float3(0,0,0);
desc.Direction = float3(1,0,0);
desc.TMin = 0.0f;
desc.TMax = 9999.0;
return desc;
}
void doInitialize(out RayQuery<RAY_FLAG_FORCE_OPAQUE> query, RayDesc ray)
{
query.TraceRayInline(AccelerationStructure,RAY_FLAG_FORCE_NON_OPAQUE,0xFF,ray);
}

[domain("tri")] PSSceneIn main(
const float3 bary : SV_DomainLocation,
const OutputPatch<HSPerVertexData, 3> patch,
const HSPerPatchData perPatchData)
{
PSSceneIn v;
v.pos = patch[0].v.pos * bary.x + patch[1].v.pos * bary.y + patch[2].v.pos * bary.z + perPatchData.edges[1];

RayQuery<RAY_FLAG_FORCE_OPAQUE> q;
RayDesc ray = MakeRayDesc();

q.TraceRayInline(AccelerationStructure,RAY_FLAG_FORCE_OPAQUE, 0xFF, ray);
doInitialize(q, ray);

return v;
}

0 comments on commit 33e8bfd

Please sign in to comment.