From 3fe4a77287345c303aeb985e24ee237f272e8eca Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 11 Jun 2024 23:58:25 -0700 Subject: [PATCH] Fix crash when using optional type in a generic. (#4341) --- source/slang/slang-ir-lower-optional-type.cpp | 63 +++++++++++-------- tests/bugs/optional-generic.slang | 22 +++++++ tests/bugs/optional.slang | 42 +++++++++++++ 3 files changed, 102 insertions(+), 25 deletions(-) create mode 100644 tests/bugs/optional-generic.slang create mode 100644 tests/bugs/optional.slang diff --git a/source/slang/slang-ir-lower-optional-type.cpp b/source/slang/slang-ir-lower-optional-type.cpp index ba2584976f..272f045450 100644 --- a/source/slang/slang-ir-lower-optional-type.cpp +++ b/source/slang/slang-ir-lower-optional-type.cpp @@ -15,6 +15,10 @@ namespace Slang InstWorkList workList; InstHashSet workListSet; + IRGeneric* genericOptionalStructType = nullptr; + IRStructKey* valueKey = nullptr; + IRStructKey* hasValueKey = nullptr; + OptionalTypeLoweringContext(IRModule* inModule) :module(inModule), workList(inModule), workListSet(inModule) {} @@ -24,8 +28,6 @@ namespace Slang IRType* optionalType = nullptr; IRType* valueType = nullptr; IRType* loweredType = nullptr; - IRStructField* valueField = nullptr; - IRStructField* hasValueField = nullptr; }; Dictionary> mapLoweredTypeToOptionalTypeInfo; Dictionary> loweredOptionalTypes; @@ -38,6 +40,34 @@ namespace Slang return type; } + IRInst* getOrCreateGenericOptionalStruct() + { + if (genericOptionalStructType) + return genericOptionalStructType; + IRBuilder builder(module); + builder.setInsertInto(module->getModuleInst()); + + valueKey = builder.createStructKey(); + builder.addNameHintDecoration(valueKey, UnownedStringSlice("value")); + hasValueKey = builder.createStructKey(); + builder.addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue")); + + genericOptionalStructType = builder.emitGeneric(); + builder.addNameHintDecoration(genericOptionalStructType, UnownedStringSlice("_slang_Optional")); + + builder.setInsertInto(genericOptionalStructType); + auto block = builder.emitBlock(); + auto typeParam = builder.emitParam(builder.getTypeKind()); + auto structType = builder.createStructType(); + builder.addNameHintDecoration(structType, UnownedStringSlice("_slang_Optional")); + builder.createStructField(structType, valueKey, (IRType*)typeParam); + builder.createStructField(structType, hasValueKey, builder.getBoolType()); + builder.setInsertInto(block); + builder.emitReturn(structType); + genericOptionalStructType->setFullType(builder.getTypeKind()); + return genericOptionalStructType; + } + bool typeHasNullValue(IRInst* type) { switch (type->getOp()) @@ -78,19 +108,10 @@ namespace Slang } else { - auto structType = builder->createStructType(); - info->loweredType = structType; - builder->addNameHintDecoration(structType, UnownedStringSlice("OptionalType")); - - info->valueType = valueType; - auto valueKey = builder->createStructKey(); - builder->addNameHintDecoration(valueKey, UnownedStringSlice("value")); - info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType); - - auto boolType = builder->getBoolType(); - auto hasValueKey = builder->createStructKey(); - builder->addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue")); - info->hasValueField = builder->createStructField(structType, hasValueKey, (IRType*)boolType); + auto genericType = getOrCreateGenericOptionalStruct(); + IRInst* args[] = { valueType }; + auto specializedType = builder->emitSpecializeInst(builder->getTypeKind(), genericType, 1, args); + info->loweredType = (IRType*)specializedType; } mapLoweredTypeToOptionalTypeInfo[info->loweredType] = info; loweredOptionalTypes[type] = info; @@ -100,12 +121,6 @@ namespace Slang void addToWorkList( IRInst* inst) { - for (auto ii = inst->getParent(); ii; ii = ii->getParent()) - { - if (as(ii)) - return; - } - if (workListSet.contains(inst)) return; @@ -169,7 +184,7 @@ namespace Slang result = builder->emitFieldExtract( builder->getBoolType(), optionalInst, - loweredOptionalTypeInfo->hasValueField->getKey()); + hasValueKey); } else { @@ -201,11 +216,10 @@ namespace Slang if (loweredOptionalTypeInfo->loweredType != loweredOptionalTypeInfo->valueType) { SLANG_ASSERT(loweredOptionalTypeInfo); - SLANG_ASSERT(loweredOptionalTypeInfo->valueField); auto getElement = builder->emitFieldExtract( loweredOptionalTypeInfo->valueType, base, - loweredOptionalTypeInfo->valueField->getKey()); + valueKey); inst->replaceUsesWith(getElement); } else @@ -257,7 +271,6 @@ namespace Slang while (workList.getCount() != 0) { IRInst* inst = workList.getLast(); - workList.removeLast(); workListSet.remove(inst); diff --git a/tests/bugs/optional-generic.slang b/tests/bugs/optional-generic.slang new file mode 100644 index 0000000000..16b466273f --- /dev/null +++ b/tests/bugs/optional-generic.slang @@ -0,0 +1,22 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute -vk + + +Optional genFunc(T v) +{ + if (v is int) + return v; + return none; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name buffer + +RWStructuredBuffer buffer; + +[numthreads(1,1,1)] +void computeMain() +{ + // BUF: 2 + buffer[0] = genFunc(2).value; +} + diff --git a/tests/bugs/optional.slang b/tests/bugs/optional.slang new file mode 100644 index 0000000000..3512ba29fd --- /dev/null +++ b/tests/bugs/optional.slang @@ -0,0 +1,42 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute -vk + +interface IFoo +{ + void foo(); +} + +struct S : IFoo { int x; void foo(); } + +struct P +{ + IFoo f; +} +struct Tr +{ + int test(T t, inout P p) + { + const IFoo hit = p.f; + let castResult = hit as S; + if (!castResult.hasValue) + return 0; + return castResult.value.x; + } +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name buffer + +RWStructuredBuffer buffer; + +[numthreads(1,1,1)] +void computeMain() +{ + P p; + S s; + s.x = 2; + p.f = s; + Tr tt; + // BUF: 2 + buffer[0] = tt.test(0, p); +} +