Skip to content

Commit

Permalink
Fix crash when using optional type in a generic. (shader-slang#4341)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Jun 12, 2024
1 parent 5da06d4 commit 3fe4a77
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 25 deletions.
63 changes: 38 additions & 25 deletions source/slang/slang-ir-lower-optional-type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ namespace Slang
InstWorkList workList;
InstHashSet workListSet;

IRGeneric* genericOptionalStructType = nullptr;
IRStructKey* valueKey = nullptr;
IRStructKey* hasValueKey = nullptr;

OptionalTypeLoweringContext(IRModule* inModule)
:module(inModule), workList(inModule), workListSet(inModule)
{}
Expand All @@ -24,8 +28,6 @@ namespace Slang
IRType* optionalType = nullptr;
IRType* valueType = nullptr;
IRType* loweredType = nullptr;
IRStructField* valueField = nullptr;
IRStructField* hasValueField = nullptr;
};
Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> mapLoweredTypeToOptionalTypeInfo;
Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> loweredOptionalTypes;
Expand All @@ -38,6 +40,34 @@ namespace Slang
return type;
}

IRInst* getOrCreateGenericOptionalStruct()
{
if (genericOptionalStructType)
return genericOptionalStructType;
IRBuilder builder(module);
builder.setInsertInto(module->getModuleInst());

valueKey = builder.createStructKey();
builder.addNameHintDecoration(valueKey, UnownedStringSlice("value"));
hasValueKey = builder.createStructKey();
builder.addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue"));

genericOptionalStructType = builder.emitGeneric();
builder.addNameHintDecoration(genericOptionalStructType, UnownedStringSlice("_slang_Optional"));

builder.setInsertInto(genericOptionalStructType);
auto block = builder.emitBlock();
auto typeParam = builder.emitParam(builder.getTypeKind());
auto structType = builder.createStructType();
builder.addNameHintDecoration(structType, UnownedStringSlice("_slang_Optional"));
builder.createStructField(structType, valueKey, (IRType*)typeParam);
builder.createStructField(structType, hasValueKey, builder.getBoolType());
builder.setInsertInto(block);
builder.emitReturn(structType);
genericOptionalStructType->setFullType(builder.getTypeKind());
return genericOptionalStructType;
}

bool typeHasNullValue(IRInst* type)
{
switch (type->getOp())
Expand Down Expand Up @@ -78,19 +108,10 @@ namespace Slang
}
else
{
auto structType = builder->createStructType();
info->loweredType = structType;
builder->addNameHintDecoration(structType, UnownedStringSlice("OptionalType"));

info->valueType = valueType;
auto valueKey = builder->createStructKey();
builder->addNameHintDecoration(valueKey, UnownedStringSlice("value"));
info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType);

auto boolType = builder->getBoolType();
auto hasValueKey = builder->createStructKey();
builder->addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue"));
info->hasValueField = builder->createStructField(structType, hasValueKey, (IRType*)boolType);
auto genericType = getOrCreateGenericOptionalStruct();
IRInst* args[] = { valueType };
auto specializedType = builder->emitSpecializeInst(builder->getTypeKind(), genericType, 1, args);
info->loweredType = (IRType*)specializedType;
}
mapLoweredTypeToOptionalTypeInfo[info->loweredType] = info;
loweredOptionalTypes[type] = info;
Expand All @@ -100,12 +121,6 @@ namespace Slang
void addToWorkList(
IRInst* inst)
{
for (auto ii = inst->getParent(); ii; ii = ii->getParent())
{
if (as<IRGeneric>(ii))
return;
}

if (workListSet.contains(inst))
return;

Expand Down Expand Up @@ -169,7 +184,7 @@ namespace Slang
result = builder->emitFieldExtract(
builder->getBoolType(),
optionalInst,
loweredOptionalTypeInfo->hasValueField->getKey());
hasValueKey);
}
else
{
Expand Down Expand Up @@ -201,11 +216,10 @@ namespace Slang
if (loweredOptionalTypeInfo->loweredType != loweredOptionalTypeInfo->valueType)
{
SLANG_ASSERT(loweredOptionalTypeInfo);
SLANG_ASSERT(loweredOptionalTypeInfo->valueField);
auto getElement = builder->emitFieldExtract(
loweredOptionalTypeInfo->valueType,
base,
loweredOptionalTypeInfo->valueField->getKey());
valueKey);
inst->replaceUsesWith(getElement);
}
else
Expand Down Expand Up @@ -257,7 +271,6 @@ namespace Slang
while (workList.getCount() != 0)
{
IRInst* inst = workList.getLast();

workList.removeLast();
workListSet.remove(inst);

Expand Down
22 changes: 22 additions & 0 deletions tests/bugs/optional-generic.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute -vk


Optional<T> genFunc<T : IArithmetic>(T v)
{
if (v is int)
return v;
return none;
}

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name buffer

RWStructuredBuffer<int> buffer;

[numthreads(1,1,1)]
void computeMain()
{
// BUF: 2
buffer[0] = genFunc(2).value;
}

42 changes: 42 additions & 0 deletions tests/bugs/optional.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute -vk

interface IFoo
{
void foo();
}

struct S : IFoo { int x; void foo(); }

struct P
{
IFoo f;
}
struct Tr
{
int test<T:IArithmetic>(T t, inout P p)
{
const IFoo hit = p.f;
let castResult = hit as S;
if (!castResult.hasValue)
return 0;
return castResult.value.x;
}
}

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name buffer

RWStructuredBuffer<int> buffer;

[numthreads(1,1,1)]
void computeMain()
{
P p;
S s;
s.x = 2;
p.f = s;
Tr tt;
// BUF: 2
buffer[0] = tt.test(0, p);
}

0 comments on commit 3fe4a77

Please sign in to comment.