From 27a038399efc85a2bf3ccf536b510b90176eb545 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 22 Feb 2024 11:14:54 +0100 Subject: [PATCH 01/47] Rename VCLLVM to Pallas, Update to LLVM 17, Implement Stack Allocation --- .clang-format | 236 +++++++++ .github/workflows/release.yml | 6 +- .github/workflows/scalatest.yml | 22 +- build.sc | 223 ++++---- src/col/vct/col/ast/Node.scala | 206 ++++++-- .../expr/context/AmbiguousResultImpl.scala | 4 +- src/col/vct/col/ast/expr/op/BinExprImpl.scala | 13 + .../family/coercion/CoerceLLVMArrayImpl.scala | 8 + .../coercion/CoerceLLVMIntIntImpl.scala | 9 + .../coercion/CoerceLLVMPointerImpl.scala | 9 + .../ast/family/coercion/CoercionImpl.scala | 3 + .../col/ast/lang/llvm/LLVMAllocAImpl.scala | 9 + ...LLVMAmbiguousFunctionInvocationImpl.scala} | 18 +- .../ast/lang/llvm/LLVMArrayValueImpl.scala | 11 + .../vct/col/ast/lang/llvm/LLVMExprImpl.scala | 7 + .../lang/llvm/LLVMFunctionContractImpl.scala | 9 + .../llvm/LLVMFunctionDefinitionImpl.scala | 16 + ...scala => LLVMFunctionInvocationImpl.scala} | 8 +- .../llvm/LLVMFunctionPointerValueImpl.scala | 12 + .../lang/llvm/LLVMGetElementPointerImpl.scala | 11 + .../llvm/LLVMGlobalSpecificationImpl.scala | 12 + .../lang/llvm/LLVMGlobalVariableImpl.scala | 10 + .../ast/lang/llvm/LLVMIntegerValueImpl.scala | 11 + .../vct/col/ast/lang/llvm/LLVMLoadImpl.scala | 9 + ...lvmLocalImpl.scala => LLVMLocalImpl.scala} | 8 +- .../ast/lang/llvm/LLVMLoopContractImpl.scala | 9 + .../vct/col/ast/lang/llvm/LLVMLoopImpl.scala | 9 + .../ast/lang/llvm/LLVMLoopInvariantImpl.scala | 9 + .../ast/lang/llvm/LLVMMemoryAcquireImpl.scala | 10 + .../llvm/LLVMMemoryAcquireReleaseImpl.scala | 10 + .../lang/llvm/LLVMMemoryMonotonicImpl.scala | 10 + .../lang/llvm/LLVMMemoryNotAtomicImpl.scala | 10 + .../lang/llvm/LLVMMemoryOrderingImpl.scala | 10 + .../ast/lang/llvm/LLVMMemoryReleaseImpl.scala | 10 + ...LLVMMemorySequentiallyConsistentImpl.scala | 11 + .../lang/llvm/LLVMMemoryUnorderedImpl.scala | 10 + .../ast/lang/llvm/LLVMPointerValueImpl.scala | 23 + .../ast/lang/llvm/LLVMRawArrayValueImpl.scala | 11 + .../lang/llvm/LLVMRawVectorValueImpl.scala | 11 + .../ast/lang/llvm/LLVMSignExtendImpl.scala | 11 + ...nImpl.scala => LLVMSpecFunctionImpl.scala} | 10 +- .../col/ast/lang/llvm/LLVMStatementImpl.scala | 7 + .../vct/col/ast/lang/llvm/LLVMStoreImpl.scala | 8 + .../ast/lang/llvm/LLVMStructValueImpl.scala | 11 + .../col/ast/lang/llvm/LLVMTArrayImpl.scala | 10 + .../col/ast/lang/llvm/LLVMTFunctionImpl.scala | 10 + .../vct/col/ast/lang/llvm/LLVMTIntImpl.scala | 10 + .../col/ast/lang/llvm/LLVMTMetadataImpl.scala | 9 + .../col/ast/lang/llvm/LLVMTPointerImpl.scala | 9 + .../col/ast/lang/llvm/LLVMTStructImpl.scala | 10 + .../col/ast/lang/llvm/LLVMTVectorImpl.scala | 10 + .../col/ast/lang/llvm/LLVMTruncateImpl.scala | 11 + .../ast/lang/llvm/LLVMVectorValueImpl.scala | 11 + .../ast/lang/llvm/LLVMZeroExtendImpl.scala | 11 + .../llvm/LLVMZeroedAggregateValueImpl.scala | 12 + .../vct/col/ast/lang/llvm/LlvmExprImpl.scala | 7 - .../lang/llvm/LlvmFunctionContractImpl.scala | 9 - .../llvm/LlvmFunctionDefinitionImpl.scala | 16 - .../col/ast/lang/llvm/LlvmGlobalImpl.scala | 12 - .../ast/lang/llvm/LlvmLoopContractImpl.scala | 9 - .../vct/col/ast/lang/llvm/LlvmLoopImpl.scala | 9 - .../ast/lang/llvm/LlvmLoopInvariantImpl.scala | 9 - .../statement/exceptional/ReturnImpl.scala | 2 +- src/col/vct/col/resolve/Resolve.scala | 78 ++- .../ctx/ReferenceResolutionContext.scala | 1 + src/col/vct/col/resolve/ctx/Referrable.scala | 40 +- src/col/vct/col/resolve/lang/LLVM.scala | 8 +- .../vct/col/serialize/SerializeOrigin.scala | 2 +- .../vct/col/typerules/CoercingRewriter.scala | 70 ++- src/col/vct/col/typerules/CoercionUtils.scala | 82 +++ src/col/vct/col/typerules/Types.scala | 3 + src/col/vct/col/util/AstBuildHelpers.scala | 8 +- src/llvm/include/Origin/ContextDeriver.h | 48 +- src/llvm/include/Origin/OriginProvider.h | 62 ++- .../include/Origin/PreferredNameDeriver.h | 25 +- .../include/Origin/ShortPositionDeriver.h | 32 +- .../Passes/Function/FunctionBodyTransformer.h | 314 ++++++------ .../Function/FunctionContractDeclarer.h | 108 ++-- .../Passes/Function/FunctionDeclarer.h | 173 ++++--- .../include/Passes/Function/PureAssigner.h | 35 +- .../Passes/Module/GlobalVariableDeclarer.h | 17 + .../Passes/Module/ModuleSpecCollector.h | 32 +- .../include/Passes/Module/ProtobufPrinter.h | 16 + .../include/Passes/Module/RootContainer.h | 28 ++ src/llvm/include/Transform/BlockTransform.h | 94 ++-- .../Transform/Instruction/BinaryOpTransform.h | 18 +- .../Transform/Instruction/CastOpTransform.h | 26 +- .../Instruction/FuncletPadOpTransform.h | 18 +- .../Transform/Instruction/MemoryOpTransform.h | 33 +- .../Transform/Instruction/OtherOpTransform.h | 65 +-- .../Transform/Instruction/TermOpTransform.h | 34 +- .../Transform/Instruction/UnaryOpTransform.h | 17 +- src/llvm/include/Transform/Transform.h | 143 +++--- src/llvm/include/Util/Constants.h | 18 +- src/llvm/include/Util/Exceptions.h | 85 +++- src/llvm/lib/Origin/ContextDeriver.cpp | 81 +++ src/llvm/lib/Origin/OriginProvider.cpp | 343 +++++++++++++ src/llvm/lib/Origin/PreferredNameDeriver.cpp | 102 ++++ src/llvm/lib/Origin/ShortPositionDeriver.cpp | 49 ++ .../Function/FunctionBodyTransformer.cpp | 185 +++++++ .../Function/FunctionContractDeclarer.cpp | 87 ++++ .../lib/Passes/Function/FunctionDeclarer.cpp | 137 +++++ src/llvm/lib/Passes/Function/PureAssigner.cpp | 59 +++ .../Passes/Module/GlobalVariableDeclarer.cpp | 49 ++ .../lib/Passes/Module/ModuleSpecCollector.cpp | 46 ++ .../lib/Passes/Module/ProtobufPrinter.cpp | 30 ++ src/llvm/lib/Passes/Module/RootContainer.cpp | 26 + src/llvm/lib/Plugin.cpp | 69 +++ src/llvm/lib/Transform/BlockTransform.cpp | 77 +++ .../Instruction/BinaryOpTransform.cpp | 82 +++ .../Transform/Instruction/CastOpTransform.cpp | 87 ++++ .../Instruction/FuncletPadOpTransform.cpp | 13 + .../Instruction/MemoryOpTransform.cpp | 149 ++++++ .../Instruction/OtherOpTransform.cpp | 190 +++++++ .../Transform/Instruction/TermOpTransform.cpp | 141 ++++++ .../Instruction/UnaryOpTransform.cpp | 10 + src/llvm/lib/Transform/Transform.cpp | 306 +++++++++++ src/llvm/lib/Util/Exceptions.cpp | 109 ++++ src/llvm/lib/origin/ContextDeriver.cpp | 74 --- src/llvm/lib/origin/OriginProvider.cpp | 216 -------- src/llvm/lib/origin/PreferredNameDeriver.cpp | 69 --- src/llvm/lib/origin/ShortPositionDeriver.cpp | 43 -- .../Function/FunctionBodyTransformer.cpp | 149 ------ .../Function/FunctionContractDeclarer.cpp | 90 ---- .../lib/passes/Function/FunctionDeclarer.cpp | 120 ----- src/llvm/lib/passes/Function/PureAssigner.cpp | 50 -- .../lib/passes/Module/ModuleSpecCollector.cpp | 40 -- src/llvm/lib/transform/BlockTransform.cpp | 59 --- .../Instruction/BinaryOpTransform.cpp | 46 -- .../transform/Instruction/CastOpTransform.cpp | 14 - .../Instruction/FuncletPadOpTransform.cpp | 15 - .../Instruction/MemoryOpTransform.cpp | 15 - .../Instruction/OtherOpTransform.cpp | 151 ------ .../transform/Instruction/TermOpTransform.cpp | 120 ----- .../Instruction/UnaryOpTransform.cpp | 12 - src/llvm/lib/transform/Transform.cpp | 101 ---- src/llvm/lib/util/Exceptions.cpp | 52 -- src/llvm/tools/vcllvm/VCLLVM.cpp | 145 ------ src/main/vct/main/stages/Parsing.scala | 2 +- src/main/vct/main/stages/Resolution.scala | 46 +- src/main/vct/main/stages/Transformation.scala | 2 + src/main/vct/options/Options.scala | 6 + src/main/vct/resources/Resources.scala | 2 +- src/parsers/antlr4/LangPVLLexer.g4 | 1 + src/parsers/antlr4/LangPVLParser.g4 | 2 + .../vct/parsers/parser/ColLLVMParser.scala | 41 +- .../parsers/transform/LLVMContractToCol.scala | 6 +- .../vct/parsers/transform/PVLToCol.scala | 6 +- src/rewrite/vct/rewrite/ClassToRef.scala | 4 +- .../rewrite/DesugarPermissionOperators.scala | 2 +- .../vct/rewrite/EncodeCurrentThread.scala | 2 +- .../ResolveExpressionSideEffects.scala | 8 +- src/rewrite/vct/rewrite/TrivialAddrOf.scala | 9 +- .../vct/rewrite/VariableToPointer.scala | 213 ++++++++ .../vct/rewrite/lang/LangLLVMToCol.scala | 475 ++++++++++++++++-- .../vct/rewrite/lang/LangSpecificToCol.scala | 40 +- 156 files changed, 5244 insertions(+), 2640 deletions(-) create mode 100644 .clang-format create mode 100644 src/col/vct/col/ast/family/coercion/CoerceLLVMArrayImpl.scala create mode 100644 src/col/vct/col/ast/family/coercion/CoerceLLVMIntIntImpl.scala create mode 100644 src/col/vct/col/ast/family/coercion/CoerceLLVMPointerImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala rename src/col/vct/col/ast/lang/llvm/{LlvmAmbiguousFunctionInvocationImpl.scala => LLVMAmbiguousFunctionInvocationImpl.scala} (51%) create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMArrayValueImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMExprImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMFunctionContractImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMFunctionDefinitionImpl.scala rename src/col/vct/col/ast/lang/llvm/{LlvmFunctionInvocationImpl.scala => LLVMFunctionInvocationImpl.scala} (64%) create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMFunctionPointerValueImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMGetElementPointerImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMGlobalSpecificationImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMGlobalVariableImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMIntegerValueImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala rename src/col/vct/col/ast/lang/llvm/{LlvmLocalImpl.scala => LLVMLocalImpl.scala} (55%) create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMLoopContractImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMLoopImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMLoopInvariantImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireReleaseImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMMemoryMonotonicImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMMemoryNotAtomicImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMMemoryOrderingImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMMemoryReleaseImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMMemorySequentiallyConsistentImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMMemoryUnorderedImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMPointerValueImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMRawArrayValueImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMRawVectorValueImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMSignExtendImpl.scala rename src/col/vct/col/ast/lang/llvm/{LlvmSpecFunctionImpl.scala => LLVMSpecFunctionImpl.scala} (86%) create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMStatementImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMStoreImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMStructValueImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMTArrayImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMTFunctionImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMTIntImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMTMetadataImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMTPointerImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMTStructImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMTVectorImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMTruncateImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMVectorValueImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMZeroExtendImpl.scala create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMZeroedAggregateValueImpl.scala delete mode 100644 src/col/vct/col/ast/lang/llvm/LlvmExprImpl.scala delete mode 100644 src/col/vct/col/ast/lang/llvm/LlvmFunctionContractImpl.scala delete mode 100644 src/col/vct/col/ast/lang/llvm/LlvmFunctionDefinitionImpl.scala delete mode 100644 src/col/vct/col/ast/lang/llvm/LlvmGlobalImpl.scala delete mode 100644 src/col/vct/col/ast/lang/llvm/LlvmLoopContractImpl.scala delete mode 100644 src/col/vct/col/ast/lang/llvm/LlvmLoopImpl.scala delete mode 100644 src/col/vct/col/ast/lang/llvm/LlvmLoopInvariantImpl.scala create mode 100644 src/llvm/include/Passes/Module/GlobalVariableDeclarer.h create mode 100644 src/llvm/include/Passes/Module/ProtobufPrinter.h create mode 100644 src/llvm/include/Passes/Module/RootContainer.h create mode 100644 src/llvm/lib/Origin/ContextDeriver.cpp create mode 100644 src/llvm/lib/Origin/OriginProvider.cpp create mode 100644 src/llvm/lib/Origin/PreferredNameDeriver.cpp create mode 100644 src/llvm/lib/Origin/ShortPositionDeriver.cpp create mode 100644 src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp create mode 100644 src/llvm/lib/Passes/Function/FunctionContractDeclarer.cpp create mode 100644 src/llvm/lib/Passes/Function/FunctionDeclarer.cpp create mode 100644 src/llvm/lib/Passes/Function/PureAssigner.cpp create mode 100644 src/llvm/lib/Passes/Module/GlobalVariableDeclarer.cpp create mode 100644 src/llvm/lib/Passes/Module/ModuleSpecCollector.cpp create mode 100644 src/llvm/lib/Passes/Module/ProtobufPrinter.cpp create mode 100644 src/llvm/lib/Passes/Module/RootContainer.cpp create mode 100644 src/llvm/lib/Plugin.cpp create mode 100644 src/llvm/lib/Transform/BlockTransform.cpp create mode 100644 src/llvm/lib/Transform/Instruction/BinaryOpTransform.cpp create mode 100644 src/llvm/lib/Transform/Instruction/CastOpTransform.cpp create mode 100644 src/llvm/lib/Transform/Instruction/FuncletPadOpTransform.cpp create mode 100644 src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp create mode 100644 src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp create mode 100644 src/llvm/lib/Transform/Instruction/TermOpTransform.cpp create mode 100644 src/llvm/lib/Transform/Instruction/UnaryOpTransform.cpp create mode 100644 src/llvm/lib/Transform/Transform.cpp create mode 100644 src/llvm/lib/Util/Exceptions.cpp delete mode 100644 src/llvm/lib/origin/ContextDeriver.cpp delete mode 100644 src/llvm/lib/origin/OriginProvider.cpp delete mode 100644 src/llvm/lib/origin/PreferredNameDeriver.cpp delete mode 100644 src/llvm/lib/origin/ShortPositionDeriver.cpp delete mode 100644 src/llvm/lib/passes/Function/FunctionBodyTransformer.cpp delete mode 100644 src/llvm/lib/passes/Function/FunctionContractDeclarer.cpp delete mode 100644 src/llvm/lib/passes/Function/FunctionDeclarer.cpp delete mode 100644 src/llvm/lib/passes/Function/PureAssigner.cpp delete mode 100644 src/llvm/lib/passes/Module/ModuleSpecCollector.cpp delete mode 100644 src/llvm/lib/transform/BlockTransform.cpp delete mode 100644 src/llvm/lib/transform/Instruction/BinaryOpTransform.cpp delete mode 100644 src/llvm/lib/transform/Instruction/CastOpTransform.cpp delete mode 100644 src/llvm/lib/transform/Instruction/FuncletPadOpTransform.cpp delete mode 100644 src/llvm/lib/transform/Instruction/MemoryOpTransform.cpp delete mode 100644 src/llvm/lib/transform/Instruction/OtherOpTransform.cpp delete mode 100644 src/llvm/lib/transform/Instruction/TermOpTransform.cpp delete mode 100644 src/llvm/lib/transform/Instruction/UnaryOpTransform.cpp delete mode 100644 src/llvm/lib/transform/Transform.cpp delete mode 100644 src/llvm/lib/util/Exceptions.cpp delete mode 100644 src/llvm/tools/vcllvm/VCLLVM.cpp create mode 100644 src/rewrite/vct/rewrite/VariableToPointer.scala diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000..15f70246e5 --- /dev/null +++ b/.clang-format @@ -0,0 +1,236 @@ +--- +Language: Cpp +# BasedOnStyle: LLVM +AccessModifierOffset: -2 +AlignAfterOpenBracket: Align +AlignArrayOfStructures: None +AlignConsecutiveAssignments: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: true +AlignConsecutiveBitFields: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: false +AlignConsecutiveDeclarations: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: false +AlignConsecutiveMacros: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + PadOperators: false +AlignConsecutiveShortCaseStatements: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCaseColons: false +AlignEscapedNewlines: Right +AlignOperands: Align +AlignTrailingComments: + Kind: Always + OverEmptyLines: 0 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortEnumsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: All +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: MultiLine +AttributeMacros: + - __capability +BinPackArguments: true +BinPackParameters: true +BitFieldColonSpacing: Both +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: Never + AfterEnum: false + AfterExternBlock: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakAfterAttributes: Never +BreakAfterJavaFieldAnnotations: false +BreakArrays: true +BreakBeforeBinaryOperators: None +BreakBeforeConceptDeclarations: Always +BreakBeforeBraces: Attach +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeColon +BreakInheritanceList: BeforeColon +BreakStringLiterals: true +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +EmptyLineAfterAccessModifier: Never +EmptyLineBeforeAccessModifier: LogicalBlock +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IfMacros: + - KJ_IF_MAYBE +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + SortPriority: 0 + CaseSensitive: false + - Regex: '^(<|"(gtest|gmock|isl|json)/)' + Priority: 3 + SortPriority: 0 + CaseSensitive: false + - Regex: '.*' + Priority: 1 + SortPriority: 0 + CaseSensitive: false +IncludeIsMainRegex: '(Test)?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentCaseBlocks: false +IndentCaseLabels: false +IndentExternBlock: AfterExternBlock +IndentGotoLabels: true +IndentPPDirectives: None +IndentRequiresClause: true +IndentWidth: 4 +IndentWrappedFunctionNames: false +InsertBraces: false +InsertNewlineAtEOF: false +InsertTrailingCommas: None +IntegerLiteralSeparator: + Binary: 0 + BinaryMinDigits: 0 + Decimal: 0 + DecimalMinDigits: 0 + Hex: 0 + HexMinDigits: 0 +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +KeepEmptyLinesAtEOF: false +LambdaBodyIndentation: Signature +LineEnding: DeriveLF +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCBreakBeforeNestedBlockParam: true +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PackConstructorInitializers: BinPack +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakOpenParenthesis: 0 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyIndentedWhitespace: 0 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Right +PPIndentWidth: -1 +QualifierAlignment: Leave +ReferenceAlignment: Pointer +ReflowComments: true +RemoveBracesLLVM: false +RemoveParentheses: Leave +RemoveSemicolon: false +RequiresClausePosition: OwnLine +RequiresExpressionIndentation: OuterScope +SeparateDefinitionBlocks: Leave +ShortNamespaceLines: 1 +SortIncludes: CaseSensitive +SortJavaStaticImport: Before +SortUsingDeclarations: LexicographicNumeric +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceAroundPointerQualifiers: Default +SpaceBeforeAssignmentOperators: true +SpaceBeforeCaseColon: false +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeJsonColon: false +SpaceBeforeParens: ControlStatements +SpaceBeforeParensOptions: + AfterControlStatements: true + AfterForeachMacros: true + AfterFunctionDefinitionName: false + AfterFunctionDeclarationName: false + AfterIfMacros: true + AfterOverloadedOperator: false + AfterRequiresInClause: false + AfterRequiresInExpression: false + BeforeNonEmptyParentheses: false +SpaceBeforeRangeBasedForLoopColon: true +SpaceBeforeSquareBrackets: false +SpaceInEmptyBlock: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: Never +SpacesInContainerLiterals: true +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParens: Never +SpacesInParensOptions: + InCStyleCasts: false + InConditionalStatements: false + InEmptyParentheses: false + Other: false +SpacesInSquareBrackets: false +Standard: Latest +StatementAttributeLikeMacros: + - Q_EMIT +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 8 +UseTab: Never +VerilogBreakBetweenInstancePorts: true +WhitespaceSensitiveMacros: + - BOOST_PP_STRINGIZE + - CF_SWIFT_NAME + - NS_SWIFT_NAME + - PP_STRINGIZE + - STRINGIZE +... + diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8ae38ee0e7..35d4481090 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: jobs: Release: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - name: Checkout VerCors uses: actions/checkout@v2 @@ -22,8 +22,8 @@ jobs: uses: actions/setup-java@v1 with: java-version: 17 - - name: Enable VCLLVM compilation - run: touch .include-vcllvm + - name: Enable Pallas compilation + run: touch .include-pallas - name: Build Release run: ./mill -j 0 vercors.main.release - name: Set Properties diff --git a/.github/workflows/scalatest.yml b/.github/workflows/scalatest.yml index 826213d9d2..632e6fe244 100644 --- a/.github/workflows/scalatest.yml +++ b/.github/workflows/scalatest.yml @@ -16,7 +16,7 @@ concurrency: jobs: Compile: if: (github.event_name != 'pull_request' || github.event.pull_request.head.repo.fork) - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - name: Checkout VerCors uses: actions/checkout@v2 @@ -32,20 +32,20 @@ jobs: restore-keys: | vercors-ci-ubuntu-${{ hashFiles('build.sc') }} vercors-ci-ubuntu - - name: Enable VCLLVM compilation - run: touch .include-vcllvm + - name: Enable Pallas compilation + run: touch .include-pallas - name: Compile - run: ./mill -j 0 vercors.allTests.assembly + vercors.vcllvm.compile + run: ./mill -j 0 vercors.allTests.assembly + vercors.pallas.compile - name: Upload VerCors uses: actions/upload-artifact@v3 with: name: allTests path: out/vercors/allTests/assembly.dest/out.jar - - name: Upload VCLLVM + - name: Upload Pallas uses: actions/upload-artifact@v3 with: - name: vcllvm - path: out/vercors/vcllvm/compile.dest/vcllvm + name: pallas + path: out/vercors/pallas/compile.dest/pallas - name: Delete Uncached Files run: | find out -type f -name upstreamAssembly.json -print -exec rm -rf {} + @@ -130,7 +130,7 @@ jobs: matrix: batch: ["-n MATRIX[0]", "-n MATRIX[1]", "-n MATRIX[2]", "-n MATRIX[3]", "-n MATRIX[4]", "-n MATRIX[5]", "-n MATRIX[6]", "-n MATRIX[7]", "-l MATRIX"] - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - name: Checkout VerCors uses: actions/checkout@v2 @@ -143,13 +143,11 @@ jobs: with: name: allTests path: '.' - - name: Download VCLLVM + - name: Download Pallas uses: actions/download-artifact@v3 with: - name: vcllvm + name: pallas path: '.' - - name: Make VCLLVM executable - run: chmod +x vcllvm - name: ls run: ls -lasFhR - name: Run scalatest diff --git a/build.sc b/build.sc index 23a2c215fa..4304306658 100644 --- a/build.sc +++ b/build.sc @@ -4,6 +4,7 @@ import $ivy.`com.lihaoyi::mill-contrib-buildinfo:` import util._ import os._ import mill.{util => _, _} +import mill.api.Result import scalalib.{JavaModule => _, ScalaModule => _, _} import contrib.buildinfo.BuildInfo import me.pieterbos.mill.cpp.options.implicits._ @@ -16,6 +17,32 @@ import vct.col.ast.structure.{AllFamilies, FamilyDefinition, Name, NodeDefinitio import scala.util.control.NonFatal +trait CppSharedModule extends CppModule { + def executableOptions: T[CppExecutableOptions] = T { + CppExecutableOptions( + transitiveDynamicObjects().map(_.path), + transitiveSystemLibraryDeps(), + Seq("-shared", "-fPIC"), + Nil, + ) + } + + def compile: T[PathRef] = T { +// def temp = //transitiveStaticObjects()).map(_.path)) + print("Definitely a new file") + print(compileOnly() ++ T.traverse(moduleDeps){ + case it: CppModule => it.compileOnly + case it: LinkableModule => it.staticObjects + case _ => T.task { Result.Success(Seq.empty) } + }().flatten) + PathRef(toolchain.linkExecutable((compileOnly() ++ T.traverse(moduleDeps){ + case it: CppModule => it.compileOnly + case it: LinkableModule => it.staticObjects + case _ => T.task { Result.Success(Seq.empty) } + }().flatten).map(_.path), T.dest, name(), executableOptions())) + } +} + object external extends Module { object z3 extends Module { def url = T { "https://www.sosy-lab.org/ivy/org.sosy_lab/javasmt-solver-z3/com.microsoft.z3-4.8.7.jar" } @@ -406,24 +433,25 @@ object vercors extends Module { ivy"org.apache.logging.log4j:log4j-to-slf4j:2.23.1", ) override def moduleDeps = Seq(hre, col, serialize) + override def unmanagedClasspath = super.unmanagedClasspath() ++ Agg(PathRef(pallas.compile().path / os.up)) - val includeVcllvmCross = interp.watchValue { - if(os.exists(settings.root / ".include-vcllvm")) { - Seq("vcllvm") + val includePallasCross = interp.watchValue { + if(os.exists(settings.root / ".include-pallas")) { + Seq("pallas") } else { Seq.empty[String] } } - - object vcllvmDep extends Cross[VcllvmDep](includeVcllvmCross) - trait VcllvmDep extends Cross.Module[String] { + + object pallasDep extends Cross[PallasDep](includePallasCross) + trait PallasDep extends Cross.Module[String] { def path = T { - vcllvm.compile().path / os.up + pallas.compile().path / os.up } } override def bareResourcePaths = T { - T.traverse(includeVcllvmCross.map(vcllvmDep(_)))(_.path)() + T.traverse(includePallasCross.map(pallasDep(_)))(_.path)() } trait GenModule extends Module { @@ -636,149 +664,81 @@ object vercors extends Module { } } - object vcllvm extends CppExecutableModule { - outer => - def root: T[os.Path] = T { - settings.src / "llvm" - } + object pallas extends CppSharedModule { outer => + def root: T[os.Path] = T { settings.src / "llvm" } object llvm extends LinkableModule { def moduleDeps = Nil - - def systemLibraryDeps = T { - Seq("LLVM-15") - } - - def staticObjects = T { - Seq.empty[PathRef] - } - - def dynamicObjects = T { - Seq.empty[PathRef] - } - + def systemLibraryDeps = T { Seq("LLVM-17") } + def staticObjects = T { Seq.empty[PathRef] } + def dynamicObjects = T { Seq.empty[PathRef] } def exportIncludePaths = T.sources( - os.Path("/usr/include/llvm-15"), - os.Path("/usr/include/llvm-c-15"), - os.Path("/usr/local/opt/llvm/include") + os.Path("/usr/include/llvm-17"), + os.Path("/usr/include/llvm-c-17"), ) } - object json extends LinkableModule { - def moduleDeps = Nil - - def systemLibraryDeps = T { - Seq.empty[String] - } - - def staticObjects = T { - Seq.empty[PathRef] - } - - def dynamicObjects = T { - Seq.empty[PathRef] - } - - def exportIncludePaths = T { - os.write(T.dest / "json.tar.xz", requests.get.stream("https://github.com/nlohmann/json/releases/download/v3.11.2/json.tar.xz")) - os.proc("tar", "-xf", T.dest / "json.tar.xz").call(cwd = T.dest) - Seq(PathRef(T.dest / "json" / "include")) - } - } - object origin extends CppModule { - def moduleDeps = Seq(llvm, json, proto, protobuf.libprotobuf) - - def sources = T.sources(vcllvm.root() / "lib" / "origin") - - def includePaths = T.sources(vcllvm.root() / "include") - - override def unixToolchain = GccCompatible("g++", "ar") - + override def moduleDeps = Seq(llvm, proto, proto.protobuf.libprotobuf) + override def sources = T.sources(pallas.root() / "lib" / "Origin") + override def includePaths = T.sources(pallas.root() / "include") + override def compileOptions: T[Seq[String]] = Seq("-fPIC") } - object passes extends CppModule { - def moduleDeps = Seq(llvm, proto, protobuf.libprotobuf) - - def sources = T.sources(vcllvm.root() / "lib" / "passes") - - def includePaths = T.sources(vcllvm.root() / "include") - - override def unixToolchain = GccCompatible("g++", "ar") - + override def moduleDeps = Seq(llvm, proto, util, origin, transform, proto.protobuf.libprotobuf) + override def sources = T.sources(pallas.root() / "lib" / "Passes") + override def includePaths = T.sources(pallas.root() / "include") + override def compileOptions: T[Seq[String]] = Seq("-fPIC") } - object transform extends CppModule { - def moduleDeps = Seq(llvm, proto, protobuf.libprotobuf) - - def sources = T.sources(vcllvm.root() / "lib" / "transform") - - def includePaths = T.sources(vcllvm.root() / "include") - - override def unixToolchain = GccCompatible("g++", "ar") - + override def moduleDeps = Seq(llvm, proto, util, origin, proto.protobuf.libprotobuf) + override def sources = T.sources(pallas.root() / "lib" / "Transform") + override def includePaths = T.sources(pallas.root() / "include") + override def compileOptions: T[Seq[String]] = Seq("-fPIC") } - object util extends CppModule { - def moduleDeps = Seq(llvm, proto, protobuf.libprotobuf) - - def sources = T.sources(vcllvm.root() / "lib" / "util") - - def includePaths = T.sources(vcllvm.root() / "include") - - override def unixToolchain = GccCompatible("g++", "ar") - + override def moduleDeps = Seq(llvm, proto, origin, proto.protobuf.libprotobuf) + override def sources = T.sources(pallas.root() / "lib" / "Util") + override def includePaths = T.sources(pallas.root() / "include") + override def compileOptions: T[Seq[String]] = Seq("-fPIC") + } + object plugin extends CppModule { + override def moduleDeps = Seq(llvm, proto, passes, transform, proto.protobuf.libprotobuf) + override def sources = T.sources(pallas.root() / "lib" / "Plugin.cpp") + override def includePaths = T.sources(pallas.root() / "include") + override def compileOptions: T[Seq[String]] = Seq("-fPIC") } - override def unixToolchain = GccCompatible("g++", "ar") - - def moduleDeps = Seq(origin, passes, transform, util, llvm, proto, protobuf.libprotobuf) - - def sources = T.sources(vcllvm.root() / "tools" / "vcllvm") - - def includePaths = T.sources(vcllvm.root() / "include") - - object protobuf extends CMakeModule { - object protobufGit extends GitModule { - override def url: T[String] = "https://github.com/protocolbuffers/protobuf" - - override def commitish: T[String] = "v25.2" - - override def fetchSubmodulesRecursively = true - } - - override def root = T.source(protobufGit.repo()) - - override def jobs = T { - 2 - } + object proto extends CppModule { + object protobuf extends CMakeModule { + object protobufGit extends GitModule { + override def url: T[String] = "https://github.com/protocolbuffers/protobuf" + override def commitish: T[String] = "v25.2" + override def fetchSubmodulesRecursively = true + } + override def root = T.source(protobufGit.repo()) + override def jobs = T { 2 } - override def cMakeBuild: T[PathRef] = T { - os.proc("cmake", "-B", T.dest, "-Dprotobuf_BUILD_TESTS=OFF", "-DABSL_PROPAGATE_CXX_STD=ON", "-S", root().path).call(cwd = T.dest) - os.proc("make", "-j", jobs(), "all").call(cwd = T.dest) - PathRef(T.dest) - } + override def cMakeBuild: T[PathRef] = T { + os.proc("cmake", "-B", T.dest, "-D", "protobuf_BUILD_TESTS=OFF", "-D", "ABSL_PROPAGATE_CXX_STD=ON", "-D", "CMAKE_POSITION_INDEPENDENT_CODE=ON", "-D", "CMAKE_CXX_FLAGS=-fPIC", "-D", "CMAKE_C_FLAGS=-fPIC", "-S", root().path).call(cwd = T.dest) + os.proc("make", "-j", jobs(), "all").call(cwd = T.dest) + PathRef(T.dest) + } - object libprotobuf extends CMakeLibrary { - def target = T { - "libprotobuf" + object libprotobuf extends CMakeLibrary { + def target = T { "libprotobuf" } } - } - object protoc extends CMakeExecutable { - def target = T { - "protoc" + object protoc extends CMakeExecutable { + def target = T { "protoc" } } } - } - object proto extends CppModule { def protoPath = T.sources( vercors.col.helpers.megacol().path / os.up / os.up / os.up / os.up, settings.src / "serialize", serialize.scalaPBUnpackProto().path ) - def generate = T { os.proc(protobuf.protoc.executable().path, protoPath().map(p => "-I=" + p.path.toString), @@ -789,16 +749,10 @@ object vercors extends Module { ).call() T.dest } - override def moduleDeps = Seq(protobuf.libprotobuf) - - override def sources = T { - Seq(PathRef(generate())) - } - - override def includePaths = T { - Seq(PathRef(generate())) - } + override def sources = T { Seq(PathRef(generate())) } + override def includePaths = T { Seq(PathRef(generate())) } + override def compileOptions: T[Seq[String]] = T { Seq("-fPIC") } def precompileHeaders: T[PathRef] = T { def isHiddenFile(path: os.Path): Boolean = path.last.startsWith(".") @@ -837,9 +791,10 @@ object vercors extends Module { override def exportIncludePaths: T[Seq[PathRef]] = T { Seq(precompileHeaders()) } - - override def unixToolchain = GccCompatible("g++", "ar") } + + override def moduleDeps = Seq(origin, passes, transform, util, llvm, plugin, proto, proto.protobuf.libprotobuf) + override def compileOptions: T[Seq[String]] = T { Seq("-fPIC") } } object allTests extends ScalaModule with ReleaseModule { diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 838740745f..12009f3f91 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -1128,6 +1128,15 @@ final case class CoerceRatZFrac[G]()(implicit val o: Origin) final case class CoerceZFracFrac[G]()(implicit val o: Origin) extends Coercion[G] with CoerceZFracFracImpl[G] +final case class CoerceLLVMIntInt[G]()(implicit val o: Origin) + extends Coercion[G] with CoerceLLVMIntIntImpl[G] +final case class CoerceLLVMPointer[G](from: Option[Type[G]], to: Type[G])( + implicit val o: Origin +) extends Coercion[G] with CoerceLLVMPointerImpl[G] +final case class CoerceLLVMArray[G](source: Type[G], target: Type[G])( + implicit val o: Origin +) extends Coercion[G] with CoerceLLVMArrayImpl[G] + @family sealed trait Expr[G] extends NodeFamily[G] with ExprImpl[G] @@ -3422,28 +3431,37 @@ final case class BipTransitionSynchronization[G]( extends GlobalDeclaration[G] with BipTransitionSynchronizationImpl[G] @family -final class LlvmFunctionContract[G]( +final class LLVMFunctionContract[G]( + val name: String, val value: String, val variableRefs: Seq[(String, Ref[G, Variable[G]])], - val invokableRefs: Seq[(String, Ref[G, LlvmCallable[G]])], + val invokableRefs: Seq[(String, Ref[G, LLVMCallable[G]])], )(val blame: Blame[NontrivialUnsatisfiable])(implicit val o: Origin) - extends NodeFamily[G] with LlvmFunctionContractImpl[G] { + extends NodeFamily[G] with LLVMFunctionContractImpl[G] { var data: Option[ApplicableContract[G]] = None } -sealed trait LlvmCallable[G] extends GlobalDeclaration[G] + +final case class LLVMGlobalVariable[G]( + variableType: Type[G], + value: Option[Expr[G]], + constant: Boolean, +)(implicit val o: Origin) + extends GlobalDeclaration[G] with LLVMGlobalVariableImpl[G] + +sealed trait LLVMCallable[G] extends GlobalDeclaration[G] @scopes[LabelDecl] -final class LlvmFunctionDefinition[G]( +final class LLVMFunctionDefinition[G]( val returnType: Type[G], val args: Seq[Variable[G]], - val functionBody: Statement[G], - val contract: LlvmFunctionContract[G], + val functionBody: Option[Statement[G]], + val contract: LLVMFunctionContract[G], val pure: Boolean = false, )(val blame: Blame[CallableFailure])(implicit val o: Origin) - extends LlvmCallable[G] + extends LLVMCallable[G] with Applicable[G] - with LlvmFunctionDefinitionImpl[G] + with LLVMFunctionDefinitionImpl[G] @scopes[LabelDecl] -final class LlvmSpecFunction[G]( +final class LLVMSpecFunction[G]( val name: String, val returnType: Type[G], val args: Seq[Variable[G]], @@ -3453,51 +3471,179 @@ final class LlvmSpecFunction[G]( val inline: Boolean = false, val threadLocal: Boolean = false, )(val blame: Blame[ContractedFailure])(implicit val o: Origin) - extends LlvmCallable[G] + extends LLVMCallable[G] with AbstractFunction[G] - with LlvmSpecFunctionImpl[G] -final case class LlvmFunctionInvocation[G]( - ref: Ref[G, LlvmFunctionDefinition[G]], + with LLVMSpecFunctionImpl[G] +final case class LLVMFunctionInvocation[G]( + ref: Ref[G, LLVMFunctionDefinition[G]], args: Seq[Expr[G]], givenMap: Seq[(Ref[G, Variable[G]], Expr[G])], yields: Seq[(Expr[G], Ref[G, Variable[G]])], )(val blame: Blame[InvocationFailure])(implicit val o: Origin) - extends Apply[G] with LlvmFunctionInvocationImpl[G] -final case class LlvmLoop[G]( + extends Apply[G] with LLVMFunctionInvocationImpl[G] + +final case class LLVMLoop[G]( cond: Expr[G], - contract: LlvmLoopContract[G], + contract: LLVMLoopContract[G], body: Statement[G], )(implicit val o: Origin) - extends CompositeStatement[G] with LlvmLoopImpl[G] + extends CompositeStatement[G] with LLVMLoopImpl[G] + @family -sealed trait LlvmLoopContract[G] - extends NodeFamily[G] with LlvmLoopContractImpl[G] -final case class LlvmLoopInvariant[G]( +sealed trait LLVMLoopContract[G] + extends NodeFamily[G] with LLVMLoopContractImpl[G] + +final case class LLVMLoopInvariant[G]( value: String, references: Seq[(String, Ref[G, Declaration[G]])], )(val blame: Blame[LoopInvariantFailure])(implicit val o: Origin) - extends LlvmLoopContract[G] with LlvmLoopInvariantImpl[G] -sealed trait LlvmExpr[G] extends Expr[G] with LlvmExprImpl[G] -final case class LlvmLocal[G](name: String)( + extends LLVMLoopContract[G] with LLVMLoopInvariantImpl[G] + +sealed trait LLVMStatement[G] extends Statement[G] with LLVMStatementImpl[G] + +sealed trait LLVMExpr[G] extends Expr[G] with LLVMExprImpl[G] + +final case class LLVMLocal[G](name: String)( val blame: Blame[DerefInsufficientPermission] )(implicit val o: Origin) - extends LlvmExpr[G] with LlvmLocalImpl[G] { + extends LLVMExpr[G] with LLVMLocalImpl[G] { var ref: Option[Ref[G, Variable[G]]] = None } -final case class LlvmAmbiguousFunctionInvocation[G]( +final case class LLVMAmbiguousFunctionInvocation[G]( name: String, args: Seq[Expr[G]], givenMap: Seq[(Ref[G, Variable[G]], Expr[G])], yields: Seq[(Expr[G], Ref[G, Variable[G]])], )(val blame: Blame[InvocationFailure])(implicit val o: Origin) - extends LlvmExpr[G] with LlvmAmbiguousFunctionInvocationImpl[G] { - var ref: Option[Ref[G, LlvmCallable[G]]] = None + extends LLVMExpr[G] with LLVMAmbiguousFunctionInvocationImpl[G] { + var ref: Option[Ref[G, LLVMCallable[G]]] = None } -final class LlvmGlobal[G](val value: String)(implicit val o: Origin) - extends GlobalDeclaration[G] with LlvmGlobalImpl[G] { +final case class LLVMAllocA[G](allocationType: Type[G], numElements: Expr[G])( + implicit val o: Origin +) extends LLVMExpr[G] with LLVMAllocAImpl[G] + +final case class LLVMLoad[G]( + loadType: Type[G], + pointer: Expr[G], + ordering: LLVMMemoryOrdering[G], +)(implicit val o: Origin) + extends LLVMExpr[G] with LLVMLoadImpl[G] + +final case class LLVMStore[G]( + value: Expr[G], + pointer: Expr[G], + ordering: LLVMMemoryOrdering[G], +)(implicit val o: Origin) + extends LLVMStatement[G] with LLVMStoreImpl[G] + +final case class LLVMGetElementPointer[G]( + structureType: Type[G], + resultType: Type[G], + pointer: Expr[G], + indices: Seq[Expr[G]], +)(implicit val o: Origin) + extends LLVMExpr[G] with LLVMGetElementPointerImpl[G] + +final case class LLVMSignExtend[G]( + inputType: Type[G], + outputType: Type[G], + value: Expr[G], +)(implicit val o: Origin) + extends LLVMExpr[G] with LLVMSignExtendImpl[G] + +final case class LLVMZeroExtend[G]( + inputType: Type[G], + outputType: Type[G], + value: Expr[G], +)(implicit val o: Origin) + extends LLVMExpr[G] with LLVMZeroExtendImpl[G] + +final case class LLVMTruncate[G]( + inputType: Type[G], + outputType: Type[G], + value: Expr[G], +)(implicit val o: Origin) + extends LLVMExpr[G] with LLVMTruncateImpl[G] + +final class LLVMGlobalSpecification[G](val value: String)( + implicit val o: Origin +) extends GlobalDeclaration[G] with LLVMGlobalSpecificationImpl[G] { var data: Option[Seq[GlobalDeclaration[G]]] = None } + +@family +sealed trait LLVMMemoryOrdering[G] + extends NodeFamily[G] with LLVMMemoryOrderingImpl[G] + +final case class LLVMMemoryNotAtomic[G]()(implicit val o: Origin) + extends LLVMMemoryOrdering[G] with LLVMMemoryNotAtomicImpl[G] +final case class LLVMMemoryUnordered[G]()(implicit val o: Origin) + extends LLVMMemoryOrdering[G] with LLVMMemoryUnorderedImpl[G] +final case class LLVMMemoryMonotonic[G]()(implicit val o: Origin) + extends LLVMMemoryOrdering[G] with LLVMMemoryMonotonicImpl[G] +final case class LLVMMemoryAcquire[G]()(implicit val o: Origin) + extends LLVMMemoryOrdering[G] with LLVMMemoryAcquireImpl[G] +final case class LLVMMemoryRelease[G]()(implicit val o: Origin) + extends LLVMMemoryOrdering[G] with LLVMMemoryReleaseImpl[G] +final case class LLVMMemoryAcquireRelease[G]()(implicit val o: Origin) + extends LLVMMemoryOrdering[G] with LLVMMemoryAcquireReleaseImpl[G] +final case class LLVMMemorySequentiallyConsistent[G]()(implicit val o: Origin) + extends LLVMMemoryOrdering[G] with LLVMMemorySequentiallyConsistentImpl[G] + +final case class LLVMIntegerValue[G](value: BigInt, integerType: Type[G])( + implicit val o: Origin +) extends ConstantInt[G] with LLVMExpr[G] with LLVMIntegerValueImpl[G] +final case class LLVMPointerValue[G](value: Ref[G, Declaration[G]])( + implicit val o: Origin +) extends Constant[G] with LLVMExpr[G] with LLVMPointerValueImpl[G] +// TODO: The LLVMFunctionPointerValue references a GlobalDeclaration instead of an LLVMFunctionDefinition because there is no other COL node we can use as a function pointer literal +final case class LLVMFunctionPointerValue[G]( + value: Ref[G, GlobalDeclaration[G]] +)(implicit val o: Origin) + extends Constant[G] with LLVMExpr[G] with LLVMFunctionPointerValueImpl[G] +final case class LLVMStructValue[G](value: Seq[Expr[G]], structType: Type[G])( + implicit val o: Origin +) extends Constant[G] with LLVMExpr[G] with LLVMStructValueImpl[G] +final case class LLVMArrayValue[G](value: Seq[Expr[G]], arrayType: Type[G])( + implicit val o: Origin +) extends Constant[G] with LLVMExpr[G] with LLVMArrayValueImpl[G] +final case class LLVMRawArrayValue[G](value: String, arrayType: Type[G])( + implicit val o: Origin +) extends Constant[G] with LLVMExpr[G] with LLVMRawArrayValueImpl[G] +final case class LLVMVectorValue[G](value: Seq[Expr[G]], vectorType: Type[G])( + implicit val o: Origin +) extends Constant[G] with LLVMExpr[G] with LLVMVectorValueImpl[G] +final case class LLVMRawVectorValue[G](value: String, vectorType: Type[G])( + implicit val o: Origin +) extends Constant[G] with LLVMExpr[G] with LLVMRawVectorValueImpl[G] +final case class LLVMZeroedAggregateValue[G](aggregateType: Type[G])( + implicit val o: Origin +) extends Constant[G] with LLVMExpr[G] with LLVMZeroedAggregateValueImpl[G] + +final case class LLVMTInt[G](bitWidth: Int)( + implicit val o: Origin = DiagnosticOrigin +) extends IntType[G] with LLVMTIntImpl[G] +final case class LLVMTFunction[G]()(implicit val o: Origin = DiagnosticOrigin) + extends Type[G] with LLVMTFunctionImpl[G] +final case class LLVMTPointer[G](innerType: Option[Type[G]])( + implicit val o: Origin = DiagnosticOrigin +) extends Type[G] with LLVMTPointerImpl[G] +final case class LLVMTMetadata[G]()(implicit val o: Origin = DiagnosticOrigin) + extends Type[G] with LLVMTMetadataImpl[G] +final case class LLVMTStruct[G]( + name: Option[String], + packed: Boolean, + elements: Seq[Type[G]], +)(implicit val o: Origin = DiagnosticOrigin) + extends Type[G] with LLVMTStructImpl[G] +final case class LLVMTArray[G](numElements: Long, elementType: Type[G])( + implicit val o: Origin = DiagnosticOrigin +) extends Type[G] with LLVMTArrayImpl[G] +final case class LLVMTVector[G](numElements: Long, elementType: Type[G])( + implicit val o: Origin = DiagnosticOrigin +) extends Type[G] with LLVMTVectorImpl[G] + sealed trait PVLType[G] extends Type[G] with PVLTypeImpl[G] final case class PVLNamedType[G](name: String, typeArgs: Seq[Type[G]])( implicit val o: Origin = DiagnosticOrigin diff --git a/src/col/vct/col/ast/expr/context/AmbiguousResultImpl.scala b/src/col/vct/col/ast/expr/context/AmbiguousResultImpl.scala index eefa45fd0b..ec83c193b7 100644 --- a/src/col/vct/col/ast/expr/context/AmbiguousResultImpl.scala +++ b/src/col/vct/col/ast/expr/context/AmbiguousResultImpl.scala @@ -37,8 +37,8 @@ trait AmbiguousResultImpl[G] case RefProcedure(decl) => decl.returnType case RefJavaMethod(decl) => decl.returnType case RefJavaAnnotationMethod(decl) => decl.returnType - case RefLlvmFunctionDefinition(decl) => decl.returnType - case RefLlvmSpecFunction(decl) => decl.returnType + case RefLLVMFunctionDefinition(decl) => decl.returnType + case RefLLVMSpecFunction(decl) => decl.returnType case RefInstanceFunction(decl) => decl.returnType case RefInstanceMethod(decl) => decl.returnType case RefInstanceOperatorMethod(decl) => decl.returnType diff --git a/src/col/vct/col/ast/expr/op/BinExprImpl.scala b/src/col/vct/col/ast/expr/op/BinExprImpl.scala index 7bbf539079..326e807069 100644 --- a/src/col/vct/col/ast/expr/op/BinExprImpl.scala +++ b/src/col/vct/col/ast/expr/op/BinExprImpl.scala @@ -7,6 +7,7 @@ import vct.col.ast.{ IntType, TBool, TCInt, + LLVMTInt, TInt, TProcess, TRational, @@ -35,6 +36,12 @@ object BinOperatorTypes { CoercionUtils.getCoercion(lt, TBool()).isDefined && CoercionUtils.getCoercion(rt, TBool()).isDefined + def isLLVMIntOp[G](lt: Type[G], rt: Type[G]): Boolean = + (lt, rt) match { + case (LLVMTInt(_), LLVMTInt(_)) => true + case _ => false + } + def isStringOp[G](lt: Type[G], rt: Type[G]): Boolean = CoercionUtils.getCoercion(lt, TString()).isDefined @@ -78,6 +85,8 @@ object BinOperatorTypes { def getIntType[G](lt: Type[G], rt: Type[G]): IntType[G] = if (isCIntOp(lt, rt)) TCInt() + else if (isLLVMIntOp(lt, rt)) + Types.leastCommonSuperType(lt, rt).asInstanceOf[LLVMTInt[G]] else TInt() @@ -91,6 +100,8 @@ object BinOperatorTypes { def getNumericType[G](lt: Type[G], rt: Type[G], o: Origin): Type[G] = { if (isCIntOp(lt, rt)) TCInt[G]() + else if (isLLVMIntOp(lt, rt)) + Types.leastCommonSuperType(lt, rt).asInstanceOf[LLVMTInt[G]] else if (isIntOp(lt, rt)) TInt[G]() else @@ -113,6 +124,8 @@ trait BinExprImpl[G] { def isRationalOp: Boolean = BinOperatorTypes.isRationalOp(left.t, right.t) + def isLLVMIntOp: Boolean = BinOperatorTypes.isLLVMIntOp(left.t, right.t) + def isBoolOp: Boolean = BinOperatorTypes.isBoolOp(left.t, right.t) def isStringOp: Boolean = BinOperatorTypes.isStringOp(left.t, right.t) diff --git a/src/col/vct/col/ast/family/coercion/CoerceLLVMArrayImpl.scala b/src/col/vct/col/ast/family/coercion/CoerceLLVMArrayImpl.scala new file mode 100644 index 0000000000..a77db8453b --- /dev/null +++ b/src/col/vct/col/ast/family/coercion/CoerceLLVMArrayImpl.scala @@ -0,0 +1,8 @@ +package vct.col.ast.family.coercion + +import vct.col.ast.ops.CoerceLLVMArrayOps +import vct.col.ast.CoerceLLVMArray + +trait CoerceLLVMArrayImpl[G] extends CoerceLLVMArrayOps[G] { + this: CoerceLLVMArray[G] => +} diff --git a/src/col/vct/col/ast/family/coercion/CoerceLLVMIntIntImpl.scala b/src/col/vct/col/ast/family/coercion/CoerceLLVMIntIntImpl.scala new file mode 100644 index 0000000000..d0d67d7cd2 --- /dev/null +++ b/src/col/vct/col/ast/family/coercion/CoerceLLVMIntIntImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.family.coercion + +import vct.col.ast.{CoerceLLVMIntInt, TInt} +import vct.col.ast.ops.CoerceLLVMIntIntOps + +trait CoerceLLVMIntIntImpl[G] extends CoerceLLVMIntIntOps[G] { + this: CoerceLLVMIntInt[G] => + override def target: TInt[G] = TInt() +} diff --git a/src/col/vct/col/ast/family/coercion/CoerceLLVMPointerImpl.scala b/src/col/vct/col/ast/family/coercion/CoerceLLVMPointerImpl.scala new file mode 100644 index 0000000000..e686d1c593 --- /dev/null +++ b/src/col/vct/col/ast/family/coercion/CoerceLLVMPointerImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.family.coercion + +import vct.col.ast.{CoerceLLVMPointer, TPointer} +import vct.col.ast.ops.CoerceLLVMPointerOps + +trait CoerceLLVMPointerImpl[G] extends CoerceLLVMPointerOps[G] { + this: CoerceLLVMPointer[G] => + override def target: TPointer[G] = TPointer(to) +} diff --git a/src/col/vct/col/ast/family/coercion/CoercionImpl.scala b/src/col/vct/col/ast/family/coercion/CoercionImpl.scala index 738aae3fdc..998074fa51 100644 --- a/src/col/vct/col/ast/family/coercion/CoercionImpl.scala +++ b/src/col/vct/col/ast/family/coercion/CoercionImpl.scala @@ -87,5 +87,8 @@ trait CoercionImpl[G] extends CoercionFamilyOps[G] { case CoerceCFloatFloat(_, _) => true case CoerceDecreasePrecision(_, _) => false case CoerceCFloatCInt(_) => false + + case CoerceLLVMIntInt() => true + case CoerceLLVMPointer(_, _) => true } } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala new file mode 100644 index 0000000000..b14e42e820 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.ops.LLVMAllocAOps +import vct.col.ast.{LLVMAllocA, Type, LLVMTPointer} + +trait LLVMAllocAImpl[G] extends LLVMAllocAOps[G] { + this: LLVMAllocA[G] => + override val t: Type[G] = LLVMTPointer(Some(this.allocationType)) +} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmAmbiguousFunctionInvocationImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMAmbiguousFunctionInvocationImpl.scala similarity index 51% rename from src/col/vct/col/ast/lang/llvm/LlvmAmbiguousFunctionInvocationImpl.scala rename to src/col/vct/col/ast/lang/llvm/LLVMAmbiguousFunctionInvocationImpl.scala index cb6b12eba3..b0217cf271 100644 --- a/src/col/vct/col/ast/lang/llvm/LlvmAmbiguousFunctionInvocationImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMAmbiguousFunctionInvocationImpl.scala @@ -1,21 +1,21 @@ package vct.col.ast.lang.llvm import vct.col.ast.{ - LlvmAmbiguousFunctionInvocation, - LlvmFunctionDefinition, - LlvmSpecFunction, + LLVMAmbiguousFunctionInvocation, + LLVMFunctionDefinition, + LLVMSpecFunction, Type, } import vct.col.print.{Ctx, Doc, DocUtil, Group, Precedence, Text} -import vct.col.ast.ops.LlvmAmbiguousFunctionInvocationOps +import vct.col.ast.ops.LLVMAmbiguousFunctionInvocationOps -trait LlvmAmbiguousFunctionInvocationImpl[G] - extends LlvmAmbiguousFunctionInvocationOps[G] { - this: LlvmAmbiguousFunctionInvocation[G] => +trait LLVMAmbiguousFunctionInvocationImpl[G] + extends LLVMAmbiguousFunctionInvocationOps[G] { + this: LLVMAmbiguousFunctionInvocation[G] => override lazy val t: Type[G] = ref.get.decl match { - case func: LlvmFunctionDefinition[G] => func.returnType - case func: LlvmSpecFunction[G] => func.returnType + case func: LLVMFunctionDefinition[G] => func.returnType + case func: LLVMSpecFunction[G] => func.returnType } override def precedence: Int = Precedence.POSTFIX diff --git a/src/col/vct/col/ast/lang/llvm/LLVMArrayValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMArrayValueImpl.scala new file mode 100644 index 0000000000..d7be0dbc2e --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMArrayValueImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{Type, LLVMArrayValue} +import vct.col.ast.ops.LLVMArrayValueOps +import vct.col.print._ + +trait LLVMArrayValueImpl[G] extends LLVMArrayValueOps[G] { + this: LLVMArrayValue[G] => + override def t: Type[G] = arrayType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMExprImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMExprImpl.scala new file mode 100644 index 0000000000..da49321904 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMExprImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMExpr +trait LLVMExprImpl[G] { + this: LLVMExpr[G] => + +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMFunctionContractImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMFunctionContractImpl.scala new file mode 100644 index 0000000000..481e5689c7 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMFunctionContractImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMFunctionContract +import vct.col.ast.ops.{LLVMFunctionContractOps, LLVMFunctionContractFamilyOps} + +trait LLVMFunctionContractImpl[G] + extends LLVMFunctionContractOps[G] with LLVMFunctionContractFamilyOps[G] { + this: LLVMFunctionContract[G] => +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMFunctionDefinitionImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMFunctionDefinitionImpl.scala new file mode 100644 index 0000000000..326a86a142 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMFunctionDefinitionImpl.scala @@ -0,0 +1,16 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.declaration.category.ApplicableImpl +import vct.col.ast.{Declaration, LLVMFunctionDefinition, Statement} +import vct.col.ast.util.Declarator +import vct.col.ast.ops.LLVMFunctionDefinitionOps + +trait LLVMFunctionDefinitionImpl[G] + extends Declarator[G] + with ApplicableImpl[G] + with LLVMFunctionDefinitionOps[G] { + this: LLVMFunctionDefinition[G] => + override def declarations: Seq[Declaration[G]] = args + + override def body: Option[Statement[G]] = functionBody +} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmFunctionInvocationImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMFunctionInvocationImpl.scala similarity index 64% rename from src/col/vct/col/ast/lang/llvm/LlvmFunctionInvocationImpl.scala rename to src/col/vct/col/ast/lang/llvm/LLVMFunctionInvocationImpl.scala index 40f41dd1b4..e92e32fb16 100644 --- a/src/col/vct/col/ast/lang/llvm/LlvmFunctionInvocationImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMFunctionInvocationImpl.scala @@ -1,11 +1,11 @@ package vct.col.ast.lang.llvm -import vct.col.ast.LlvmFunctionInvocation +import vct.col.ast.LLVMFunctionInvocation import vct.col.print.{Ctx, Doc, DocUtil, Empty, Group, Precedence, Text} -import vct.col.ast.ops.LlvmFunctionInvocationOps +import vct.col.ast.ops.LLVMFunctionInvocationOps -trait LlvmFunctionInvocationImpl[G] extends LlvmFunctionInvocationOps[G] { - this: LlvmFunctionInvocation[G] => +trait LLVMFunctionInvocationImpl[G] extends LLVMFunctionInvocationOps[G] { + this: LLVMFunctionInvocation[G] => override def precedence: Int = Precedence.POSTFIX override def layout(implicit ctx: Ctx): Doc = diff --git a/src/col/vct/col/ast/lang/llvm/LLVMFunctionPointerValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMFunctionPointerValueImpl.scala new file mode 100644 index 0000000000..4513befdb2 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMFunctionPointerValueImpl.scala @@ -0,0 +1,12 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{Type, LLVMTPointer, LLVMFunctionPointerValue} +import vct.col.ast.ops.LLVMFunctionPointerValueOps +import vct.col.print._ + +trait LLVMFunctionPointerValueImpl[G] extends LLVMFunctionPointerValueOps[G] { + this: LLVMFunctionPointerValue[G] => + // TODO: Do we want a separate type for function pointers? For now we don't support function pointers anyway + override def t: Type[G] = LLVMTPointer(None) + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMGetElementPointerImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMGetElementPointerImpl.scala new file mode 100644 index 0000000000..795dc90cc6 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMGetElementPointerImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{LLVMGetElementPointer, LLVMTPointer, Type} +import vct.col.ast.ops.LLVMGetElementPointerOps +import vct.col.print._ + +trait LLVMGetElementPointerImpl[G] extends LLVMGetElementPointerOps[G] { + this: LLVMGetElementPointer[G] => + override def t: Type[G] = LLVMTPointer(Some(resultType)) + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMGlobalSpecificationImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMGlobalSpecificationImpl.scala new file mode 100644 index 0000000000..6a6d6274e8 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMGlobalSpecificationImpl.scala @@ -0,0 +1,12 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMGlobalSpecification +import vct.col.print.{Ctx, Doc, Text} +import vct.col.ast.ops.LLVMGlobalSpecificationOps + +trait LLVMGlobalSpecificationImpl[G] extends LLVMGlobalSpecificationOps[G] { + this: LLVMGlobalSpecification[G] => + + override def layout(implicit ctx: Ctx): Doc = Text(value) + +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMGlobalVariableImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMGlobalVariableImpl.scala new file mode 100644 index 0000000000..80a7d93b43 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMGlobalVariableImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMGlobalVariable +import vct.col.ast.ops.LLVMGlobalVariableOps +import vct.col.print._ + +trait LLVMGlobalVariableImpl[G] extends LLVMGlobalVariableOps[G] { + this: LLVMGlobalVariable[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMIntegerValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMIntegerValueImpl.scala new file mode 100644 index 0000000000..6edc11d6b0 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMIntegerValueImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{Type, LLVMIntegerValue} +import vct.col.ast.ops.LLVMIntegerValueOps +import vct.col.print._ + +trait LLVMIntegerValueImpl[G] extends LLVMIntegerValueOps[G] { + this: LLVMIntegerValue[G] => + override def t: Type[G] = integerType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala new file mode 100644 index 0000000000..af05a52038 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{LLVMLoad, Type} +import vct.col.ast.ops.LLVMLoadOps + +trait LLVMLoadImpl[G] extends LLVMLoadOps[G] { + this: LLVMLoad[G] => + override val t: Type[G] = this.loadType +} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmLocalImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLocalImpl.scala similarity index 55% rename from src/col/vct/col/ast/lang/llvm/LlvmLocalImpl.scala rename to src/col/vct/col/ast/lang/llvm/LLVMLocalImpl.scala index 0baa756520..0a02d992fb 100644 --- a/src/col/vct/col/ast/lang/llvm/LlvmLocalImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMLocalImpl.scala @@ -1,11 +1,11 @@ package vct.col.ast.lang.llvm -import vct.col.ast.{LlvmLocal, Type} +import vct.col.ast.{LLVMLocal, Type} import vct.col.print.{Ctx, Doc, Text} -import vct.col.ast.ops.LlvmLocalOps +import vct.col.ast.ops.LLVMLocalOps -trait LlvmLocalImpl[G] extends LlvmLocalOps[G] { - this: LlvmLocal[G] => +trait LLVMLocalImpl[G] extends LLVMLocalOps[G] { + this: LLVMLocal[G] => override lazy val t: Type[G] = ref.get.decl.t override def layout(implicit ctx: Ctx): Doc = Text(name) diff --git a/src/col/vct/col/ast/lang/llvm/LLVMLoopContractImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLoopContractImpl.scala new file mode 100644 index 0000000000..17bcc24632 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMLoopContractImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMLoopContract +import vct.col.ast.ops.LLVMLoopContractFamilyOps + +trait LLVMLoopContractImpl[G] extends LLVMLoopContractFamilyOps[G] { + this: LLVMLoopContract[G] => + +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMLoopImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLoopImpl.scala new file mode 100644 index 0000000000..49c130ba3a --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMLoopImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMLoop +import vct.col.ast.ops.LLVMLoopOps + +trait LLVMLoopImpl[G] extends LLVMLoopOps[G] { + this: LLVMLoop[G] => + +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMLoopInvariantImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLoopInvariantImpl.scala new file mode 100644 index 0000000000..ee4eeb2c2e --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMLoopInvariantImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMLoopInvariant +import vct.col.ast.ops.LLVMLoopInvariantOps + +trait LLVMLoopInvariantImpl[G] extends LLVMLoopInvariantOps[G] { + this: LLVMLoopInvariant[G] => + +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireImpl.scala new file mode 100644 index 0000000000..4aaafdb0b5 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMMemoryAcquire +import vct.col.ast.ops.LLVMMemoryAcquireOps +import vct.col.print._ + +trait LLVMMemoryAcquireImpl[G] extends LLVMMemoryAcquireOps[G] { + this: LLVMMemoryAcquire[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireReleaseImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireReleaseImpl.scala new file mode 100644 index 0000000000..87fd27b4ea --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireReleaseImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMMemoryAcquireRelease +import vct.col.ast.ops.LLVMMemoryAcquireReleaseOps +import vct.col.print._ + +trait LLVMMemoryAcquireReleaseImpl[G] extends LLVMMemoryAcquireReleaseOps[G] { + this: LLVMMemoryAcquireRelease[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryMonotonicImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryMonotonicImpl.scala new file mode 100644 index 0000000000..b3d700bc42 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryMonotonicImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMMemoryMonotonic +import vct.col.ast.ops.LLVMMemoryMonotonicOps +import vct.col.print._ + +trait LLVMMemoryMonotonicImpl[G] extends LLVMMemoryMonotonicOps[G] { + this: LLVMMemoryMonotonic[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryNotAtomicImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryNotAtomicImpl.scala new file mode 100644 index 0000000000..796ceb4c7e --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryNotAtomicImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMMemoryNotAtomic +import vct.col.ast.ops.LLVMMemoryNotAtomicOps +import vct.col.print._ + +trait LLVMMemoryNotAtomicImpl[G] extends LLVMMemoryNotAtomicOps[G] { + this: LLVMMemoryNotAtomic[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryOrderingImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryOrderingImpl.scala new file mode 100644 index 0000000000..b4ece7c46b --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryOrderingImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMMemoryOrdering +import vct.col.ast.ops.LLVMMemoryOrderingFamilyOps +import vct.col.print._ + +trait LLVMMemoryOrderingImpl[G] extends LLVMMemoryOrderingFamilyOps[G] { + this: LLVMMemoryOrdering[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryReleaseImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryReleaseImpl.scala new file mode 100644 index 0000000000..e856bc3b35 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryReleaseImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMMemoryRelease +import vct.col.ast.ops.LLVMMemoryReleaseOps +import vct.col.print._ + +trait LLVMMemoryReleaseImpl[G] extends LLVMMemoryReleaseOps[G] { + this: LLVMMemoryRelease[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemorySequentiallyConsistentImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemorySequentiallyConsistentImpl.scala new file mode 100644 index 0000000000..9576bcb1e6 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemorySequentiallyConsistentImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMMemorySequentiallyConsistent +import vct.col.ast.ops.LLVMMemorySequentiallyConsistentOps +import vct.col.print._ + +trait LLVMMemorySequentiallyConsistentImpl[G] + extends LLVMMemorySequentiallyConsistentOps[G] { + this: LLVMMemorySequentiallyConsistent[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryUnorderedImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryUnorderedImpl.scala new file mode 100644 index 0000000000..7f929bf5fb --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryUnorderedImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMMemoryUnordered +import vct.col.ast.ops.LLVMMemoryUnorderedOps +import vct.col.print._ + +trait LLVMMemoryUnorderedImpl[G] extends LLVMMemoryUnorderedOps[G] { + this: LLVMMemoryUnordered[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMPointerValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMPointerValueImpl.scala new file mode 100644 index 0000000000..890bab6ca6 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMPointerValueImpl.scala @@ -0,0 +1,23 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{ + LLVMGlobalVariable, + LLVMPointerValue, + LLVMTPointer, + Type, + HeapVariable, +} +import vct.col.ast.ops.LLVMPointerValueOps +import vct.col.print._ + +trait LLVMPointerValueImpl[G] extends LLVMPointerValueOps[G] { + this: LLVMPointerValue[G] => + override lazy val t: Type[G] = { + value.decl match { + case LLVMGlobalVariable(variableType, _, _) => + LLVMTPointer(Some(variableType)) + case v: HeapVariable[G] => LLVMTPointer(Some(v.t)) + } + } + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMRawArrayValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMRawArrayValueImpl.scala new file mode 100644 index 0000000000..e923da0f96 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMRawArrayValueImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{Type, LLVMRawArrayValue} +import vct.col.ast.ops.LLVMRawArrayValueOps +import vct.col.print._ + +trait LLVMRawArrayValueImpl[G] extends LLVMRawArrayValueOps[G] { + this: LLVMRawArrayValue[G] => + override def t: Type[G] = arrayType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMRawVectorValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMRawVectorValueImpl.scala new file mode 100644 index 0000000000..738f4add09 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMRawVectorValueImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{Type, LLVMRawVectorValue} +import vct.col.ast.ops.LLVMRawVectorValueOps +import vct.col.print._ + +trait LLVMRawVectorValueImpl[G] extends LLVMRawVectorValueOps[G] { + this: LLVMRawVectorValue[G] => + override def t: Type[G] = vectorType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMSignExtendImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMSignExtendImpl.scala new file mode 100644 index 0000000000..5fb7c7b1c3 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMSignExtendImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{LLVMSignExtend, Type} +import vct.col.ast.ops.LLVMSignExtendOps +import vct.col.print._ + +trait LLVMSignExtendImpl[G] extends LLVMSignExtendOps[G] { + this: LLVMSignExtend[G] => + override def t: Type[G] = outputType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmSpecFunctionImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMSpecFunctionImpl.scala similarity index 86% rename from src/col/vct/col/ast/lang/llvm/LlvmSpecFunctionImpl.scala rename to src/col/vct/col/ast/lang/llvm/LLVMSpecFunctionImpl.scala index fe9430a026..a6b204bd7e 100644 --- a/src/col/vct/col/ast/lang/llvm/LlvmSpecFunctionImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMSpecFunctionImpl.scala @@ -1,18 +1,18 @@ package vct.col.ast.lang.llvm -import vct.col.ast.LlvmSpecFunction +import vct.col.ast.LLVMSpecFunction import vct.col.ast.declaration.category.AbstractFunctionImpl import vct.col.ast.declaration.global.GlobalDeclarationImpl import vct.col.print.{Ctx, Doc, Empty, Group, Show, Text} import scala.collection.immutable.ListMap -import vct.col.ast.ops.LlvmSpecFunctionOps +import vct.col.ast.ops.LLVMSpecFunctionOps -trait LlvmSpecFunctionImpl[G] +trait LLVMSpecFunctionImpl[G] extends GlobalDeclarationImpl[G] with AbstractFunctionImpl[G] - with LlvmSpecFunctionOps[G] { - this: LlvmSpecFunction[G] => + with LLVMSpecFunctionOps[G] { + this: LLVMSpecFunction[G] => def layoutModifiers(implicit ctx: Ctx): Seq[Doc] = ListMap(inline -> "inline", threadLocal -> "thread_local").filter(_._1) diff --git a/src/col/vct/col/ast/lang/llvm/LLVMStatementImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMStatementImpl.scala new file mode 100644 index 0000000000..3e8a6f6a5c --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMStatementImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMStatement +trait LLVMStatementImpl[G] { + this: LLVMStatement[G] => + +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMStoreImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMStoreImpl.scala new file mode 100644 index 0000000000..9b3b82337d --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMStoreImpl.scala @@ -0,0 +1,8 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMStore +import vct.col.ast.ops.LLVMStoreOps + +trait LLVMStoreImpl[G] extends LLVMStoreOps[G] { + this: LLVMStore[G] => +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMStructValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMStructValueImpl.scala new file mode 100644 index 0000000000..a37d6ab04e --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMStructValueImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{LLVMStructValue, Type} +import vct.col.ast.ops.LLVMStructValueOps +import vct.col.print._ + +trait LLVMStructValueImpl[G] extends LLVMStructValueOps[G] { + this: LLVMStructValue[G] => + override def t: Type[G] = structType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTArrayImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTArrayImpl.scala new file mode 100644 index 0000000000..fbceb6ba2b --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMTArrayImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMTArray +import vct.col.ast.ops.LLVMTArrayOps +import vct.col.print._ + +trait LLVMTArrayImpl[G] extends LLVMTArrayOps[G] { + this: LLVMTArray[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTFunctionImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTFunctionImpl.scala new file mode 100644 index 0000000000..d96fa61d59 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMTFunctionImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMTFunction +import vct.col.ast.ops.LLVMTFunctionOps +import vct.col.print._ + +trait LLVMTFunctionImpl[G] extends LLVMTFunctionOps[G] { + this: LLVMTFunction[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTIntImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTIntImpl.scala new file mode 100644 index 0000000000..697c024544 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMTIntImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMTInt +import vct.col.ast.ops.LLVMTIntOps +import vct.col.print._ + +trait LLVMTIntImpl[G] extends LLVMTIntOps[G] { + this: LLVMTInt[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTMetadataImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTMetadataImpl.scala new file mode 100644 index 0000000000..bbf522bc87 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMTMetadataImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMTMetadata +import vct.col.ast.ops.LLVMTMetadataOps + +trait LLVMTMetadataImpl[G] extends LLVMTMetadataOps[G] { + this: LLVMTMetadata[G] => + +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTPointerImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTPointerImpl.scala new file mode 100644 index 0000000000..770fa9295c --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMTPointerImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMTPointer +import vct.col.ast.ops.LLVMTPointerOps + +trait LLVMTPointerImpl[G] extends LLVMTPointerOps[G] { + this: LLVMTPointer[G] => + +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTStructImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTStructImpl.scala new file mode 100644 index 0000000000..cb839186b9 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMTStructImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMTStruct +import vct.col.ast.ops.LLVMTStructOps +import vct.col.print._ + +trait LLVMTStructImpl[G] extends LLVMTStructOps[G] { + this: LLVMTStruct[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTVectorImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTVectorImpl.scala new file mode 100644 index 0000000000..6bc89d8b85 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMTVectorImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMTVector +import vct.col.ast.ops.LLVMTVectorOps +import vct.col.print._ + +trait LLVMTVectorImpl[G] extends LLVMTVectorOps[G] { + this: LLVMTVector[G] => + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTruncateImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTruncateImpl.scala new file mode 100644 index 0000000000..25344cf63a --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMTruncateImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{LLVMTruncate, Type} +import vct.col.ast.ops.LLVMTruncateOps +import vct.col.print._ + +trait LLVMTruncateImpl[G] extends LLVMTruncateOps[G] { + this: LLVMTruncate[G] => + override def t: Type[G] = outputType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMVectorValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMVectorValueImpl.scala new file mode 100644 index 0000000000..3661272542 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMVectorValueImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{Type, LLVMVectorValue} +import vct.col.print._ +import vct.col.ast.ops.LLVMVectorValueOps + +trait LLVMVectorValueImpl[G] extends LLVMVectorValueOps[G] { + this: LLVMVectorValue[G] => + override def t: Type[G] = vectorType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMZeroExtendImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMZeroExtendImpl.scala new file mode 100644 index 0000000000..eca0bfddc0 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMZeroExtendImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{LLVMZeroExtend, Type} +import vct.col.ast.ops.LLVMZeroExtendOps +import vct.col.print._ + +trait LLVMZeroExtendImpl[G] extends LLVMZeroExtendOps[G] { + this: LLVMZeroExtend[G] => + override def t: Type[G] = outputType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMZeroedAggregateValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMZeroedAggregateValueImpl.scala new file mode 100644 index 0000000000..e8217e0c40 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMZeroedAggregateValueImpl.scala @@ -0,0 +1,12 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.{Type, LLVMZeroedAggregateValue} +import vct.col.ast.ops.LLVMZeroedAggregateValueOps +import vct.col.print._ + +trait LLVMZeroedAggregateValueImpl[G] extends LLVMZeroedAggregateValueOps[G] { + this: LLVMZeroedAggregateValue[G] => + override def value: Unit = () + override def t: Type[G] = aggregateType + // override def layout(implicit ctx: Ctx): Doc = ??? +} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmExprImpl.scala b/src/col/vct/col/ast/lang/llvm/LlvmExprImpl.scala deleted file mode 100644 index 6b4c7524e8..0000000000 --- a/src/col/vct/col/ast/lang/llvm/LlvmExprImpl.scala +++ /dev/null @@ -1,7 +0,0 @@ -package vct.col.ast.lang.llvm - -import vct.col.ast.LlvmExpr -trait LlvmExprImpl[G] { - this: LlvmExpr[G] => - -} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmFunctionContractImpl.scala b/src/col/vct/col/ast/lang/llvm/LlvmFunctionContractImpl.scala deleted file mode 100644 index 48c8828248..0000000000 --- a/src/col/vct/col/ast/lang/llvm/LlvmFunctionContractImpl.scala +++ /dev/null @@ -1,9 +0,0 @@ -package vct.col.ast.lang.llvm - -import vct.col.ast.LlvmFunctionContract -import vct.col.ast.ops.{LlvmFunctionContractOps, LlvmFunctionContractFamilyOps} - -trait LlvmFunctionContractImpl[G] - extends LlvmFunctionContractOps[G] with LlvmFunctionContractFamilyOps[G] { - this: LlvmFunctionContract[G] => -} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmFunctionDefinitionImpl.scala b/src/col/vct/col/ast/lang/llvm/LlvmFunctionDefinitionImpl.scala deleted file mode 100644 index 89b836798c..0000000000 --- a/src/col/vct/col/ast/lang/llvm/LlvmFunctionDefinitionImpl.scala +++ /dev/null @@ -1,16 +0,0 @@ -package vct.col.ast.lang.llvm - -import vct.col.ast.declaration.category.ApplicableImpl -import vct.col.ast.{Declaration, LlvmFunctionDefinition, Statement} -import vct.col.ast.util.Declarator -import vct.col.ast.ops.LlvmFunctionDefinitionOps - -trait LlvmFunctionDefinitionImpl[G] - extends Declarator[G] - with ApplicableImpl[G] - with LlvmFunctionDefinitionOps[G] { - this: LlvmFunctionDefinition[G] => - override def declarations: Seq[Declaration[G]] = args - - override def body: Option[Statement[G]] = Some(functionBody) -} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmGlobalImpl.scala b/src/col/vct/col/ast/lang/llvm/LlvmGlobalImpl.scala deleted file mode 100644 index 3a92e81e9e..0000000000 --- a/src/col/vct/col/ast/lang/llvm/LlvmGlobalImpl.scala +++ /dev/null @@ -1,12 +0,0 @@ -package vct.col.ast.lang.llvm - -import vct.col.ast.LlvmGlobal -import vct.col.print.{Ctx, Doc, Text} -import vct.col.ast.ops.LlvmGlobalOps - -trait LlvmGlobalImpl[G] extends LlvmGlobalOps[G] { - this: LlvmGlobal[G] => - - override def layout(implicit ctx: Ctx): Doc = Text(value) - -} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmLoopContractImpl.scala b/src/col/vct/col/ast/lang/llvm/LlvmLoopContractImpl.scala deleted file mode 100644 index 95f2c492c5..0000000000 --- a/src/col/vct/col/ast/lang/llvm/LlvmLoopContractImpl.scala +++ /dev/null @@ -1,9 +0,0 @@ -package vct.col.ast.lang.llvm - -import vct.col.ast.LlvmLoopContract -import vct.col.ast.ops.LlvmLoopContractFamilyOps - -trait LlvmLoopContractImpl[G] extends LlvmLoopContractFamilyOps[G] { - this: LlvmLoopContract[G] => - -} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmLoopImpl.scala b/src/col/vct/col/ast/lang/llvm/LlvmLoopImpl.scala deleted file mode 100644 index 29b2aca394..0000000000 --- a/src/col/vct/col/ast/lang/llvm/LlvmLoopImpl.scala +++ /dev/null @@ -1,9 +0,0 @@ -package vct.col.ast.lang.llvm - -import vct.col.ast.LlvmLoop -import vct.col.ast.ops.LlvmLoopOps - -trait LlvmLoopImpl[G] extends LlvmLoopOps[G] { - this: LlvmLoop[G] => - -} diff --git a/src/col/vct/col/ast/lang/llvm/LlvmLoopInvariantImpl.scala b/src/col/vct/col/ast/lang/llvm/LlvmLoopInvariantImpl.scala deleted file mode 100644 index 30f26a31b1..0000000000 --- a/src/col/vct/col/ast/lang/llvm/LlvmLoopInvariantImpl.scala +++ /dev/null @@ -1,9 +0,0 @@ -package vct.col.ast.lang.llvm - -import vct.col.ast.LlvmLoopInvariant -import vct.col.ast.ops.LlvmLoopInvariantOps - -trait LlvmLoopInvariantImpl[G] extends LlvmLoopInvariantOps[G] { - this: LlvmLoopInvariant[G] => - -} diff --git a/src/col/vct/col/ast/statement/exceptional/ReturnImpl.scala b/src/col/vct/col/ast/statement/exceptional/ReturnImpl.scala index c0a39fd96e..f260c33f03 100644 --- a/src/col/vct/col/ast/statement/exceptional/ReturnImpl.scala +++ b/src/col/vct/col/ast/statement/exceptional/ReturnImpl.scala @@ -13,7 +13,7 @@ trait ReturnImpl[G] extends ExceptionalStatementImpl[G] with ReturnOps[G] { _: InstanceOperatorMethod[G] => () case _: JavaMethod[G] | _: CFunctionDefinition[G] | - _: CPPFunctionDefinition[G] | _: LlvmFunctionDefinition[G] => + _: CPPFunctionDefinition[G] | _: LLVMFunctionDefinition[G] => () case _: BipTransition[G] | _: BipGuard[G] | _: BipOutgoingData[G] => () } diff --git a/src/col/vct/col/resolve/Resolve.scala b/src/col/vct/col/resolve/Resolve.scala index 1bd8310820..4c3f8ec18f 100644 --- a/src/col/vct/col/resolve/Resolve.scala +++ b/src/col/vct/col/resolve/Resolve.scala @@ -29,6 +29,8 @@ import vct.col.resolve.lang.JavaAnnotationData.{ BipTransition, } import vct.col.rewrite.InitialGeneration +import vct.col.util.AstBuildHelpers.{ExprBuildHelpers, VarBuildHelpers} +import vct.col.util.Substitute import vct.result.VerificationError.{Unreachable, UserError} import scala.collection.immutable.{AbstractSeq, LinearSeq} @@ -42,11 +44,14 @@ case object Resolve { trait SpecContractParser { def parse[G]( - input: LlvmFunctionContract[G], + input: LLVMFunctionContract[G], o: Origin, ): ApplicableContract[G] - def parse[G](input: LlvmGlobal[G], o: Origin): Seq[GlobalDeclaration[G]] + def parse[G]( + input: LLVMGlobalSpecification[G], + o: Origin, + ): Seq[GlobalDeclaration[G]] } def extractLiteral(e: Expr[_]): Option[String] = @@ -66,12 +71,14 @@ case object Resolve { case class MalformedBipAnnotation(n: Node[_], err: String) extends UserError { override def code: String = "badBipAnnotation" + override def text: String = n.o.messageInContext(s"Malformed JavaBIP annotation: $err") } case class UnexpectedComplicatedExpression(e: Expr[_]) extends UserError { override def code: String = "unexpectedComplicatedExpression" + override def text: String = e.o.messageInContext( "This expression must either be a string literal or trivially resolve to one" @@ -95,8 +102,10 @@ case object Resolve { case object ResolveTypes { sealed trait JavaClassPathEntry + case object JavaClassPathEntry { case object SourcePackageRoot extends JavaClassPathEntry + case class Path(root: java.nio.file.Path) extends JavaClassPathEntry } @@ -290,8 +299,9 @@ case object ResolveReferences extends LazyLogging { program: Program[G], jp: SpecExprParser, lsp: SpecContractParser, + importedDeclarations: Seq[GlobalDeclaration[G]], ): Seq[CheckError] = { - resolve(program, ReferenceResolutionContext[G](jp, lsp)) + resolve(program, ReferenceResolutionContext[G](jp, lsp, importedDeclarations)) } def resolve[G]( @@ -338,7 +348,7 @@ case object ResolveReferences extends LazyLogging { Right(ctx.copy(inGpuKernel = true)) case p: Program[G] => p.declarations.foreach { - case glob: LlvmGlobal[G] => + case glob: LLVMGlobalSpecification[G] => val decls = ctx.llvmSpecParser.parse(glob, glob.o) glob.data = Some(decls) case _ => @@ -539,10 +549,10 @@ case object ResolveReferences extends LazyLogging { CPP.paramsFromDeclarator(func.declarator) ++ scanLabels(func.body) ++ func.contract.givenArgs ++ func.contract.yieldsArgs ) - case func: LlvmFunctionDefinition[G] => - ctx.copy(currentResult = Some(RefLlvmFunctionDefinition(func))) - case func: LlvmSpecFunction[G] => - ctx.copy(currentResult = Some(RefLlvmSpecFunction(func))) + case func: LLVMFunctionDefinition[G] => + ctx.copy(currentResult = Some(RefLLVMFunctionDefinition(func))) + case func: LLVMSpecFunction[G] => + ctx.copy(currentResult = Some(RefLLVMSpecFunction(func))) .declare(func.args) case par: ParStatement[G] => ctx.declare(scanBlocks(par.impl).map(_.decl)) case Scope(locals, body) => @@ -1099,20 +1109,55 @@ case object ResolveReferences extends LazyLogging { ) => portName.data = Some((cls, getLit(name))) - case contract: LlvmFunctionContract[G] => + case contract: LLVMFunctionContract[G] => + implicit val o: Origin = contract.o + val llvmFunction = + ctx.currentResult.get.asInstanceOf[RefLLVMFunctionDefinition[G]].decl val applicableContract = ctx.llvmSpecParser.parse(contract, contract.o) - contract.data = Some(applicableContract) - resolve(applicableContract, ctx) - case local: LlvmLocal[G] => + val importedDecl = ctx.importedDeclarations.find { + case procedure: Procedure[G] => + contract.name == procedure.o.get[SourceName].name + } + if (importedDecl.isDefined) { + val importedProcedure = importedDecl.get.asInstanceOf[Procedure[G]] + val importedContract = importedProcedure.contract + val substitute = Substitute[G]( + ((Result[G](importedProcedure.ref) -> AmbiguousResult[G]()) +: + importedProcedure.args.zipWithIndex.map { case (l, idx) => + Local[G](l.ref) -> Local[G](llvmFunction.args(idx).ref) + }).toMap + ) + val substitutedContract = substitute.dispatch(importedContract) + contract.data = Some( + ApplicableContract[G]( + SplitAccountedPredicate( + applicableContract.requires, + substitutedContract.requires, + ), + SplitAccountedPredicate( + applicableContract.ensures, + substitutedContract.ensures, + ), + applicableContract.contextEverywhere &* + substitutedContract.contextEverywhere, + applicableContract.signals ++ substitutedContract.signals, + applicableContract.givenArgs ++ substitutedContract.givenArgs, + applicableContract.yieldsArgs ++ substitutedContract.yieldsArgs, + applicableContract.decreases.orElse(substitutedContract.decreases), + )(contract.blame) + ) + } else { contract.data = Some(applicableContract) } + resolve(contract.data.get, ctx) + case local: LLVMLocal[G] => local.ref = ctx.currentResult.get match { - case RefLlvmFunctionDefinition(decl) => + case RefLLVMFunctionDefinition(decl) => decl.contract.variableRefs .find(ref => ref._1 == local.name) match { case Some(ref) => Some(ref._2) case None => throw NoSuchNameError("local", local.name, local) } - case RefLlvmSpecFunction(_) => + case RefLLVMSpecFunction(_) => Some( Spec.findLocal(local.name, ctx) .getOrElse(throw NoSuchNameError("local", local.name, local)) @@ -1120,13 +1165,14 @@ case object ResolveReferences extends LazyLogging { ) case _ => None } - case inv: LlvmAmbiguousFunctionInvocation[G] => + case inv: LLVMAmbiguousFunctionInvocation[G] => inv.ref = LLVM.findCallable(inv.name, ctx) match { case Some(callable) => Some(callable.ref) case None => throw NoSuchNameError("function", inv.name, inv) } - case glob: LlvmGlobal[G] => glob.data.get.foreach(resolve(_, ctx)) + case glob: LLVMGlobalSpecification[G] => + glob.data.get.foreach(resolve(_, ctx)) case comm: PVLCommunicate[G] => /* Endpoint contexts for communicate are resolved early, because otherwise \sender, \receiver, \msg cannot be typed. */ diff --git a/src/col/vct/col/resolve/ctx/ReferenceResolutionContext.scala b/src/col/vct/col/resolve/ctx/ReferenceResolutionContext.scala index 8a7e0c2a3e..db2e9ca2f9 100644 --- a/src/col/vct/col/resolve/ctx/ReferenceResolutionContext.scala +++ b/src/col/vct/col/resolve/ctx/ReferenceResolutionContext.scala @@ -12,6 +12,7 @@ import scala.collection.mutable case class ReferenceResolutionContext[G]( javaParser: SpecExprParser, llvmSpecParser: SpecContractParser, + importedDeclarations: Seq[GlobalDeclaration[G]], stack: Seq[Seq[Referrable[G]]] = Nil, topLevelJavaDeref: Option[JavaDeref[G]] = None, externallyLoadedElements: mutable.ArrayBuffer[GlobalDeclaration[G]] = diff --git a/src/col/vct/col/resolve/ctx/Referrable.scala b/src/col/vct/col/resolve/ctx/Referrable.scala index b146df4d06..e11ee1b017 100644 --- a/src/col/vct/col/resolve/ctx/Referrable.scala +++ b/src/col/vct/col/resolve/ctx/Referrable.scala @@ -92,9 +92,6 @@ sealed trait Referrable[G] { case RefProverType(decl) => Referrable.originName(decl) case RefProverFunction(decl) => Referrable.originName(decl) case RefJavaBipGuard(decl) => Referrable.originName(decl) - case RefLlvmFunctionDefinition(decl) => Referrable.originName(decl) - case RefLlvmGlobal(decl, i) => Referrable.originName(decl.data.get(i)) - case RefLlvmSpecFunction(decl) => Referrable.originName(decl) case RefBipComponent(decl) => Referrable.originName(decl) case RefBipGlue(decl) => "" case RefBipGuard(decl) => Referrable.originName(decl) @@ -110,7 +107,10 @@ sealed trait Referrable[G] { case RefPVLEndpoint(decl) => decl.name case RefPVLChoreography(decl) => decl.name case RefPVLChorRun(_) => "" - + case RefLLVMFunctionDefinition(decl) => Referrable.originName(decl) + case RefLLVMGlobalSpecification(decl, i) => Referrable.originName(decl) + case RefLLVMGlobalVariable(decl) => Referrable.originName(decl) + case RefLLVMSpecFunction(decl) => Referrable.originName(decl) case RefJavaBipGlueContainer() => "" case PVLBuiltinInstanceMethod(_) => "" case BuiltinField(_) => "" @@ -201,13 +201,15 @@ case object Referrable { case decl: PVLConstructor[G] => RefPVLConstructor(decl) case decl: Choreography[G] => RefChoreography(decl) case decl: Endpoint[G] => RefEndpoint(decl) - case decl: LlvmFunctionDefinition[G] => RefLlvmFunctionDefinition(decl) - case decl: LlvmGlobal[G] => + case decl: LLVMFunctionDefinition[G] => RefLLVMFunctionDefinition(decl) + case decl: LLVMGlobalSpecification[G] => decl.data match { - case Some(data) => return data.indices.map(RefLlvmGlobal(decl, _)) - case None => RefLlvmGlobal(decl, -1) + case Some(data) => + return data.indices.map(RefLLVMGlobalSpecification(decl, _)) + case None => RefLLVMGlobalSpecification(decl, -1) } - case decl: LlvmSpecFunction[G] => RefLlvmSpecFunction(decl) + case decl: LLVMSpecFunction[G] => RefLLVMSpecFunction(decl) + case decl: LLVMGlobalVariable[G] => RefLLVMGlobalVariable(decl) case decl: ProverType[G] => RefProverType(decl) case decl: ProverFunction[G] => RefProverFunction(decl) case decl: JavaBipGlueContainer[G] => RefJavaBipGlueContainer() @@ -282,7 +284,7 @@ sealed trait JavaInvocationTarget[G] extends Referrable[G] sealed trait CInvocationTarget[G] extends Referrable[G] sealed trait CPPInvocationTarget[G] extends Referrable[G] sealed trait PVLInvocationTarget[G] extends Referrable[G] -sealed trait LlvmInvocationTarget[G] extends Referrable[G] +sealed trait LLVMInvocationTarget[G] extends Referrable[G] sealed trait SpecInvocationTarget[G] extends JavaInvocationTarget[G] with CNameTarget[G] @@ -292,7 +294,7 @@ sealed trait SpecInvocationTarget[G] with CPPDerefTarget[G] with CPPInvocationTarget[G] with PVLInvocationTarget[G] - with LlvmInvocationTarget[G] + with LLVMInvocationTarget[G] sealed trait ThisTarget[G] extends Referrable[G] @@ -454,9 +456,14 @@ case class RefJavaBipGuard[G](decl: JavaMethod[G]) extends Referrable[G] with JavaNameTarget[G] case class RefJavaBipGlueContainer[G]() extends Referrable[G] // Bip glue jobs are not actually referrable -case class RefLlvmFunctionDefinition[G](decl: LlvmFunctionDefinition[G]) - extends Referrable[G] with LlvmInvocationTarget[G] with ResultTarget[G] -case class RefLlvmGlobal[G](decl: LlvmGlobal[G], idx: Int) extends Referrable[G] +case class RefLLVMFunctionDefinition[G](decl: LLVMFunctionDefinition[G]) + extends Referrable[G] with LLVMInvocationTarget[G] with ResultTarget[G] +case class RefLLVMGlobalSpecification[G]( + decl: LLVMGlobalSpecification[G], + idx: Int, +) extends Referrable[G] +case class RefLLVMGlobalVariable[G](decl: LLVMGlobalVariable[G]) + extends Referrable[G] case class RefBipComponent[G](decl: BipComponent[G]) extends Referrable[G] case class RefBipGlue[G](decl: BipGlue[G]) extends Referrable[G] case class RefBipGuard[G](decl: BipGuard[G]) extends Referrable[G] @@ -478,9 +485,8 @@ case class RefPVLEndpoint[G](decl: PVLEndpoint[G]) case class RefPVLChoreography[G](decl: PVLChoreography[G]) extends Referrable[G] with ThisTarget[G] case class RefPVLChorRun[G](decl: PVLChorRun[G]) extends Referrable[G] - -case class RefLlvmSpecFunction[G](decl: LlvmSpecFunction[G]) - extends Referrable[G] with LlvmInvocationTarget[G] with ResultTarget[G] +case class RefLLVMSpecFunction[G](decl: LLVMSpecFunction[G]) + extends Referrable[G] with LLVMInvocationTarget[G] with ResultTarget[G] case class RefChoreography[G](decl: Choreography[G]) extends Referrable[G] with ThisTarget[G] case class RefEndpoint[G](decl: Endpoint[G]) extends Referrable[G] diff --git a/src/col/vct/col/resolve/lang/LLVM.scala b/src/col/vct/col/resolve/lang/LLVM.scala index e010a9aaec..f81a8bbfbf 100644 --- a/src/col/vct/col/resolve/lang/LLVM.scala +++ b/src/col/vct/col/resolve/lang/LLVM.scala @@ -10,12 +10,12 @@ object LLVM { def findCallable[G]( name: String, ctx: ReferenceResolutionContext[G], - ): Option[LlvmCallable[G]] = { + ): Option[LLVMCallable[G]] = { // look in context val callable = ctx.stack.flatten.map { - case RefLlvmGlobal(decl, i) => + case RefLLVMGlobalSpecification(decl, i) => decl.data.get(i) match { - case f: LlvmSpecFunction[G] if f.name == name => Some(f) + case f: LLVMSpecFunction[G] if f.name == name => Some(f) case _ => None } case _ => None @@ -25,7 +25,7 @@ object LLVM { case Some(callable) => Some(callable) case None => ctx.currentResult.get match { - case RefLlvmFunctionDefinition(decl) => + case RefLLVMFunctionDefinition(decl) => decl.contract.invokableRefs.find(ref => ref._1 == name) match { case Some(ref) => Some(ref._2.decl) case None => None diff --git a/src/col/vct/col/serialize/SerializeOrigin.scala b/src/col/vct/col/serialize/SerializeOrigin.scala index 86e479728c..bcee9d6735 100644 --- a/src/col/vct/col/serialize/SerializeOrigin.scala +++ b/src/col/vct/col/serialize/SerializeOrigin.scala @@ -54,7 +54,7 @@ case class DeserializedContext( (context, tail) override protected def inlineContextHere( tail: Origin, - compress: Boolean, + compress: Boolean = true, ): (String, Origin) = (inlineContext, tail) override protected def shortPositionHere(tail: Origin): (String, Origin) = (shortPosition, tail) diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index baf88034cf..b7d62063bf 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -301,6 +301,10 @@ abstract class CoercingRewriter[Pre <: Generation]() case CoerceCIntCFloat(_) => e case CoerceCIntInt() => e case CoerceCFloatFloat(_, _) => e + + case CoerceLLVMIntInt() => e + case CoerceLLVMPointer(_, _) => e + case CoerceLLVMArray(_, _) => e } } @@ -349,8 +353,9 @@ abstract class CoercingRewriter[Pre <: Generation]() case node: BipGlueAccepts[Pre] => node case node: BipGlueDataWire[Pre] => node case node: BipTransitionSignature[Pre] => node - case node: LlvmFunctionContract[Pre] => node - case node: LlvmLoopContract[Pre] => node + case node: LLVMFunctionContract[Pre] => node + case node: LLVMLoopContract[Pre] => node + case node: LLVMMemoryOrdering[Pre] => node case node: ProverLanguage[Pre] => node case node: SmtlibFunctionSymbol[Pre] => node case node: ChorRun[Pre] => node @@ -543,6 +548,15 @@ abstract class CoercingRewriter[Pre <: Generation]() (ApplyCoercion(e, coercion)(coercionOrigin(e)), t) case None => throw IncoercibleText(e, s"pointer") } + def llvmPointer( + e: Expr[Pre], + innerType: Type[Pre], + ): (Expr[Pre], TPointer[Pre]) = + CoercionUtils.getAnyLLVMPointerCoercion(e.t, innerType) match { + case Some((coercion, t)) => + (ApplyCoercion(e, coercion)(coercionOrigin(e)), t) + case None => throw IncoercibleText(e, s"llvm pointer of $innerType") + } def matrix(e: Expr[Pre]): (Expr[Pre], TMatrix[Pre]) = CoercionUtils.getAnyMatrixCoercion(e.t) match { case Some((coercion, t)) => @@ -1648,15 +1662,15 @@ abstract class CoercingRewriter[Pre <: Generation]() coerceYields(yields, inv), )(inv.blame) ) - case inv @ LlvmFunctionInvocation(ref, args, givenMap, yields) => - LlvmFunctionInvocation(ref, args, givenMap, yields)(inv.blame) - case inv @ LlvmAmbiguousFunctionInvocation( + case inv @ LLVMFunctionInvocation(ref, args, givenMap, yields) => + LLVMFunctionInvocation(ref, args, givenMap, yields)(inv.blame) + case inv @ LLVMAmbiguousFunctionInvocation( name, args, givenMap, yields, ) => - LlvmAmbiguousFunctionInvocation(name, args, givenMap, yields)(inv.blame) + LLVMAmbiguousFunctionInvocation(name, args, givenMap, yields)(inv.blame) case ProcessApply(process, args) => ProcessApply(process, coerceArgs(args, process.decl)) case ProcessChoice(left, right) => @@ -2127,13 +2141,35 @@ abstract class CoercingRewriter[Pre <: Generation]() Z3TransitiveClosure(ref, coerceArgs(args, ref.ref.decl)) case localIncoming: BipLocalIncomingData[Pre] => localIncoming case glue: JavaBipGlue[Pre] => glue - case LlvmLocal(name) => e case PVLSender() => e case PVLReceiver() => e case PVLMessage() => e case Sender(_) => e case Receiver(_) => e case Message(_) => e + case LLVMLocal(name) => e + case LLVMAllocA(allocationType, numElements) => e + case LLVMLoad(loadType, p, ordering) => + LLVMLoad(loadType, llvmPointer(p, loadType)._1, ordering) + case LLVMGetElementPointer(structureType, resultType, pointer, indices) => + LLVMGetElementPointer( + structureType, + resultType, + llvmPointer(pointer, structureType)._1, + indices, + ) + case LLVMSignExtend(inputType, outputType, value) => e + case LLVMZeroExtend(inputType, outputType, value) => e + case LLVMTruncate(inputType, outputType, value) => e + case LLVMIntegerValue(value, integerType) => e + case LLVMPointerValue(value) => e + case LLVMFunctionPointerValue(value) => e + case LLVMStructValue(value, structType) => e + case LLVMArrayValue(value, arrayType) => e + case LLVMRawArrayValue(value, arrayType) => e + case LLVMVectorValue(value, vectorType) => e + case LLVMRawVectorValue(value, vectorType) => e + case LLVMZeroedAggregateValue(aggregateType) => e } } @@ -2240,8 +2276,10 @@ abstract class CoercingRewriter[Pre <: Generation]() case l @ Lock(obj) => Lock(cls(obj))(l.blame) case Loop(init, cond, update, contract, body) => Loop(init, bool(cond), update, contract, body) - case LlvmLoop(cond, contract, body) => - LlvmLoop(bool(cond), contract, body) + case LLVMLoop(cond, contract, body) => + LLVMLoop(bool(cond), contract, body) + case LLVMStore(value, p, ordering) => + LLVMStore(value, llvmPointer(p, value.t)._1, ordering) case ModelDo(model, perm, after, action, impl) => ModelDo(model, rat(perm), after, action, impl) case n @ Notify(obj) => Notify(cls(obj))(n.blame) @@ -2507,11 +2545,11 @@ abstract class CoercingRewriter[Pre <: Generation]() case glue: BipGlue[Pre] => glue case synchronization: BipPortSynchronization[Pre] => synchronization case synchronization: BipTransitionSynchronization[Pre] => synchronization - case definition: LlvmFunctionDefinition[Pre] => definition + case definition: LLVMFunctionDefinition[Pre] => definition case typ: ProverType[Pre] => typ case func: ProverFunction[Pre] => func - case function: LlvmSpecFunction[Pre] => - new LlvmSpecFunction[Pre]( + case function: LLVMSpecFunction[Pre] => + new LLVMSpecFunction[Pre]( function.name, function.returnType, function.args, @@ -2521,7 +2559,8 @@ abstract class CoercingRewriter[Pre <: Generation]() function.inline, function.threadLocal, )(function.blame) - case glob: LlvmGlobal[Pre] => glob + case glob: LLVMGlobalSpecification[Pre] => glob + case glob: LLVMGlobalVariable[Pre] => glob case endpoint: PVLEndpoint[Pre] => endpoint case seqProg: PVLChoreography[Pre] => seqProg case seqRun: PVLChorRun[Pre] => seqRun @@ -2922,8 +2961,9 @@ abstract class CoercingRewriter[Pre <: Generation]() def coerce(node: JavaBipGlueElement[Pre]): JavaBipGlueElement[Pre] = node def coerce(node: JavaBipGlueName[Pre]): JavaBipGlueName[Pre] = node - def coerce(node: LlvmFunctionContract[Pre]): LlvmFunctionContract[Pre] = node - def coerce(node: LlvmLoopContract[Pre]): LlvmLoopContract[Pre] = node + def coerce(node: LLVMFunctionContract[Pre]): LLVMFunctionContract[Pre] = node + def coerce(node: LLVMLoopContract[Pre]): LLVMLoopContract[Pre] = node + def coerce(node: LLVMMemoryOrdering[Pre]): LLVMMemoryOrdering[Pre] = node def coerce(node: ProverLanguage[Pre]): ProverLanguage[Pre] = node def coerce(node: SmtlibFunctionSymbol[Pre]): SmtlibFunctionSymbol[Pre] = node diff --git a/src/col/vct/col/typerules/CoercionUtils.scala b/src/col/vct/col/typerules/CoercionUtils.scala index 47295af0af..dcffe8c925 100644 --- a/src/col/vct/col/typerules/CoercionUtils.scala +++ b/src/col/vct/col/typerules/CoercionUtils.scala @@ -6,6 +6,8 @@ import vct.col.origin.{DiagnosticOrigin, Origin} import vct.col.resolve.lang.{C, CPP} import vct.col.resolve.lang.CPP.getBaseTypeFromSpecs +import scala.annotation.tailrec + case object CoercionUtils { private implicit val o: Origin = DiagnosticOrigin @@ -123,6 +125,8 @@ case object CoercionUtils { case (TNull(), TAnyClass()) => CoerceNullAnyClass() case (TNull(), TPointer(target)) => CoerceNullPointer(target) case (TNull(), CTPointer(target)) => CoerceNullPointer(target) + case (TNull(), LLVMTPointer(Some(target))) => CoerceNullPointer(target) + case (TNull(), LLVMTPointer(None)) => CoerceNullPointer(TAny()) case (TNull(), TEnum(target)) => CoerceNullEnum(target) case (CTArray(_, innerType), TArray(element)) if element == innerType => @@ -193,6 +197,7 @@ case object CoercionUtils { CoerceCFloatFloat(coercedCFloat, target), )) case (TCInt(), TInt()) => CoerceCIntInt() + case (LLVMTInt(_), TInt()) => CoerceLLVMIntInt() case (TBoundedInt(gte, lt), TFraction()) if gte >= 1 && lt <= 2 => CoerceBoundIntFrac() @@ -300,6 +305,26 @@ case object CoercionUtils { case None => return None } + // TODO: Back and forth should not be needed... + case (LLVMTPointer(Some(_)), LLVMTPointer(None)) => + CoerceIdentity(LLVMTPointer(None)) + case (LLVMTPointer(None), LLVMTPointer(Some(innerType))) => + CoerceIdentity(LLVMTPointer(Some(innerType))) + case (TPointer(_), LLVMTPointer(None)) => + CoerceIdentity(LLVMTPointer(None)) + case (LLVMTPointer(None), TPointer(innerType)) => + CoerceLLVMPointer(None, innerType) + case ( + LLVMTPointer(Some(LLVMTArray(numElements, elementType))), + TPointer(innerType), + ) if numElements > 0 => + getAnyCoercion(elementType, innerType).getOrElse(return None) + case (LLVMTPointer(Some(leftInner)), TPointer(rightInner)) => + getAnyCoercion(leftInner, rightInner).getOrElse(return None) + + case (TPointer(TAny()), TPointer(any)) => CoerceIdentity(TPointer(any)) + case (TPointer(any), TPointer(TAny())) => CoerceIdentity(TPointer(any)) + // Something with TVar? // Unsafe coercions @@ -436,12 +461,54 @@ case object CoercionUtils { case t: CPPPrimitiveType[G] => chainCPPCoercion(t, getAnyPointerCoercion) case t: CPPTArray[G] => Some((CoerceCPPArrayPointer(t.innerType), TPointer(t.innerType))) + case LLVMTPointer(None) => + Some((CoerceIdentity(source), TPointer[G](TAnyValue()))) + case LLVMTPointer(Some(innerType)) => + Some((CoerceIdentity(source), TPointer(innerType))) + case LLVMTArray(numElements, innerType) if numElements > 0 => + Some((CoerceIdentity(source), TPointer(innerType))) case _: TNull[G] => val t = TPointer[G](TAnyValue()) Some((CoerceNullPointer(t), t)) case _ => None } + @tailrec + def firstElementIsType[G](aggregate: Type[G], innerType: Type[G]): Boolean = + aggregate match { + case aggregate if getAnyCoercion(aggregate, innerType).isDefined => true + case clazz: TClass[G] => + firstElementIsType( + clazz.cls.decl.declarations.head.asInstanceOf[InstanceField[G]].t, + innerType, + ) + case TArray(element) => firstElementIsType(element, innerType) + case LLVMTStruct(_, _, elements) => + firstElementIsType(elements.head, innerType) + case LLVMTArray(numElements, elementType) => + numElements > 0 && firstElementIsType(elementType, innerType) + case LLVMTVector(_, _) => false // TODO: Should this be possible? + case _ => false + } + + def getAnyLLVMPointerCoercion[G]( + source: Type[G], + innerType: Type[G], + ): Option[(Coercion[G], TPointer[G])] = + source match { + case LLVMTPointer(None) => + Some((CoerceLLVMPointer(None, innerType), TPointer[G](innerType))) + case LLVMTPointer(Some(t)) if firstElementIsType(t, innerType) => + Some(CoerceLLVMPointer(Some(t), innerType), TPointer[G](innerType)) + case TPointer(TAny()) => + Some((CoerceLLVMPointer(None, innerType), TPointer[G](innerType))) + case TPointer(t) if firstElementIsType(t, innerType) => + Some(CoerceLLVMPointer(Some(t), innerType), TPointer[G](innerType)) + case _: TNull[G] => + Some((CoerceLLVMPointer(None, innerType), TPointer[G](innerType))) + case _ => None + } + def getAnyCArrayCoercion[G]( source: Type[G] ): Option[(Coercion[G], CTArray[G])] = @@ -479,6 +546,21 @@ case object CoercionUtils { .asInstanceOf[TArray[G]], )) case t: TArray[G] => Some((CoerceIdentity(source), t)) + case t: LLVMTArray[G] => { + val t2 = TArray[G](t.elementType) + Some(CoerceLLVMArray(t, t2), t2) + } + case LLVMTPointer(None) => + Some(CoerceIdentity(source), TArray[G](TAnyValue())) + case LLVMTPointer(Some(t)) => + getAnyArrayCoercion(t) match { + case Some(inner) => + Some( + CoercionSequence(Seq(inner._1, CoerceIdentity(source))), + inner._2, + ) + case None => None + } case _: TNull[G] => val t = TArray[G](TAnyValue()) Some((CoerceNullArray(t), t)) diff --git a/src/col/vct/col/typerules/Types.scala b/src/col/vct/col/typerules/Types.scala index c06bdb2e45..118485dbe5 100644 --- a/src/col/vct/col/typerules/Types.scala +++ b/src/col/vct/col/typerules/Types.scala @@ -91,6 +91,9 @@ object Types { case (TBoundedInt(leftGte, leftLt), TBoundedInt(rightGte, rightLt)) => TBoundedInt(leftGte.min(rightGte), leftLt.max(rightLt)) + case (LLVMTInt(leftWidth), LLVMTInt(rightWidth)) => + LLVMTInt(leftWidth.max(rightWidth)) + // Unrelated types below rational are simply a rational case (left, right) if TRational().superTypeOf(left) && TRational().superTypeOf(right) => diff --git a/src/col/vct/col/util/AstBuildHelpers.scala b/src/col/vct/col/util/AstBuildHelpers.scala index 0618c8caaa..925cbd876a 100644 --- a/src/col/vct/col/util/AstBuildHelpers.scala +++ b/src/col/vct/col/util/AstBuildHelpers.scala @@ -126,7 +126,7 @@ object AstBuildHelpers { case function: ADTFunction[Pre] => function.rewrite(args = args) case process: ModelProcess[Pre] => process.rewrite(args = args) case action: ModelAction[Pre] => action.rewrite(args = args) - case llvm: LlvmFunctionDefinition[Pre] => llvm.rewrite(args = args) + case llvm: LLVMFunctionDefinition[Pre] => llvm.rewrite(args = args) case prover: ProverFunction[Pre] => prover.rewrite(args = args) } } @@ -185,7 +185,7 @@ object AstBuildHelpers { inline = Some(inline), contract = contract, ) - case function: LlvmSpecFunction[Pre] => + case function: LLVMSpecFunction[Pre] => function.rewrite( args = args, returnType = returnType, @@ -319,7 +319,7 @@ object AstBuildHelpers { threadLocal = Some(threadLocal), blame = blame, ) - case function: LlvmSpecFunction[Pre] => + case function: LLVMSpecFunction[Pre] => function.rewrite( returnType = returnType, args = args, @@ -366,7 +366,7 @@ object AstBuildHelpers { apply match { case inv: ADTFunctionInvocation[Pre] => inv.rewrite(args = args) case inv: ProverFunctionInvocation[Pre] => inv.rewrite(args = args) - case inv: LlvmFunctionInvocation[Pre] => inv.rewrite(args = args) + case inv: LLVMFunctionInvocation[Pre] => inv.rewrite(args = args) case apply: ApplyAnyPredicate[Pre] => new ApplyAnyPredicateBuildHelpers(apply).rewrite(args = args) case inv: Invocation[Pre] => diff --git a/src/llvm/include/Origin/ContextDeriver.h b/src/llvm/include/Origin/ContextDeriver.h index 7f1cf1decd..2f26945c1d 100644 --- a/src/llvm/include/Origin/ContextDeriver.h +++ b/src/llvm/include/Origin/ContextDeriver.h @@ -1,35 +1,41 @@ -#ifndef VCLLVM_CONTEXTDERIVER_H -#define VCLLVM_CONTEXTDERIVER_H +#ifndef PALLAS_CONTEXTDERIVER_H +#define PALLAS_CONTEXTDERIVER_H #include /** - * Generators for VerCors origin objects context fields for various LLVM Value types. + * Generators for VerCors origin objects context fields for various LLVM Value + * types. * - * For more info on VerCors origins see: https://github.com/utwente-fmt/vercors/discussions/884 + * For more info on VerCors origins see: + * https://github.com/utwente-fmt/vercors/discussions/884 */ -namespace llvm2Col { - // module derivers - std::string deriveModuleContext(llvm::Module &llvmModule); +namespace llvm2col { +// module derivers +std::string deriveModuleContext(llvm::Module &llvmModule); - // function derivers - std::string deriveFunctionContext(llvm::Function &llvmFunction); +// function derivers +std::string deriveFunctionContext(llvm::Function &llvmFunction); - // block derivers - std::string deriveLabelContext(llvm::BasicBlock &llvmBlock); +// block derivers +std::string deriveLabelContext(llvm::BasicBlock &llvmBlock); - std::string deriveBlockContext(llvm::BasicBlock &llvmBlock); +std::string deriveBlockContext(llvm::BasicBlock &llvmBlock); - // instruction derivers - std::string deriveSurroundingInstructionContext(llvm::Instruction &llvmInstruction); +// instruction derivers +std::string +deriveSurroundingInstructionContext(llvm::Instruction &llvmInstruction); - std::string deriveInstructionContext(llvm::Instruction &llvmInstruction); +std::string deriveInstructionContext(llvm::Instruction &llvmInstruction); - std::string deriveInstructionLhs(llvm::Instruction &llvmInstruction); +std::string +deriveGlobalVariableContext(llvm::GlobalVariable &llvmGlobalVariable); - std::string deriveInstructionRhs(llvm::Instruction &llvmInstruction); +std::string deriveInstructionLhs(llvm::Instruction &llvmInstruction); - // operand derivers - std::string deriveOperandContext(llvm::Value &llvmOperand); -} -#endif //VCLLVM_CONTEXTDERIVER_H +std::string deriveInstructionRhs(llvm::Instruction &llvmInstruction); + +// operand derivers +std::string deriveOperandContext(llvm::Value &llvmOperand); +} // namespace llvm2col +#endif // PALLAS_CONTEXTDERIVER_H diff --git a/src/llvm/include/Origin/OriginProvider.h b/src/llvm/include/Origin/OriginProvider.h index 2dd6067535..e11c079662 100644 --- a/src/llvm/include/Origin/OriginProvider.h +++ b/src/llvm/include/Origin/OriginProvider.h @@ -1,46 +1,64 @@ -#ifndef VCLLVM_ORIGINPROVIDER_H -#define VCLLVM_ORIGINPROVIDER_H +#ifndef PALLAS_ORIGINPROVIDER_H +#define PALLAS_ORIGINPROVIDER_H -#include -#include #include "vct/col/ast/Origin.pb.h" +#include +#include /** * Generators for VerCors origin objects for various LLVM Value types. * - * For more info on VerCors origins see: https://github.com/utwente-fmt/vercors/discussions/884 + * For more info on VerCors origins see: + * https://github.com/utwente-fmt/vercors/discussions/884 */ -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; + +col::Origin *generateProgramOrigin(llvm::Module &llvmModule); + +col::Origin *generateFuncDefOrigin(llvm::Function &llvmFunction); + +col::Origin *generateFunctionContractOrigin(llvm::Function &llvmFunction, + const std::string &contract); - col::Origin *generateProgramOrigin(llvm::Module &llvmModule); +col::Origin *generateGlobalValOrigin(llvm::Module &llvmModule, + const std::string &globVal); - col::Origin *generateFuncDefOrigin(llvm::Function &llvmFunction); +col::Origin *generateArgumentOrigin(llvm::Argument &llvmArgument); - col::Origin *generateFunctionContractOrigin(llvm::Function &llvmFunction, const std::string& contract); +col::Origin *generateBlockOrigin(llvm::BasicBlock &llvmBlock); - col::Origin *generateGlobalValOrigin(llvm::Module &llvmModule, const std::string &globVal); +col::Origin *generateLabelOrigin(llvm::BasicBlock &llvmBlock); - col::Origin *generateArgumentOrigin(llvm::Argument &llvmArgument); +col::Origin *generateSingleStatementOrigin(llvm::Instruction &llvmInstruction); - col::Origin *generateBlockOrigin(llvm::BasicBlock &llvmBlock); +col::Origin *generateAssignTargetOrigin(llvm::Instruction &llvmInstruction); - col::Origin *generateLabelOrigin(llvm::BasicBlock &llvmBlock); +col::Origin *generateBinExprOrigin(llvm::Instruction &llvmInstruction); - col::Origin *generateSingleStatementOrigin(llvm::Instruction &llvmInstruction); +col::Origin *generateFunctionCallOrigin(llvm::CallInst &callInstruction); - col::Origin *generateAssignTargetOrigin(llvm::Instruction &llvmInstruction); +col::Origin *generateOperandOrigin(llvm::Instruction &llvmInstruction, + llvm::Value &llvmOperand); - col::Origin *generateBinExprOrigin(llvm::Instruction &llvmInstruction); +col::Origin * +generateGlobalVariableOrigin(llvm::Module &llvmModule, + llvm::GlobalVariable &llvmGlobalVariable); - col::Origin *generateFunctionCallOrigin(llvm::CallInst &callInstruction); +col::Origin *generateGlobalVariableInitializerOrigin( + llvm::Module &llvmModule, llvm::GlobalVariable &llvmGlobalVariable, + llvm::Value &llvmInitializer); - col::Origin *generateOperandOrigin(llvm::Instruction &llvmInstruction, llvm::Value &llvmOperand); +col::Origin *generateVoidOperandOrigin(llvm::Instruction &llvmInstruction); - col::Origin *generateTypeOrigin(llvm::Type &llvmType); +col::Origin *generateTypeOrigin(llvm::Type &llvmType); -} -#endif //VCLLVM_ORIGINPROVIDER_H +col::Origin *generateMemoryOrderingOrigin(llvm::AtomicOrdering &llvmOrdering); +std::string extractShortPosition(const col::Origin &origin); +col::Origin *deepenOperandOrigin(const col::Origin &origin, + llvm::Value &llvmOperand); +} // namespace llvm2col +#endif // PALLAS_ORIGINPROVIDER_H diff --git a/src/llvm/include/Origin/PreferredNameDeriver.h b/src/llvm/include/Origin/PreferredNameDeriver.h index 3b6c7f938d..3de8fd7fcb 100644 --- a/src/llvm/include/Origin/PreferredNameDeriver.h +++ b/src/llvm/include/Origin/PreferredNameDeriver.h @@ -1,18 +1,23 @@ -#ifndef VCLLVM_PREFERREDNAMEDERIVER_H -#define VCLLVM_PREFERREDNAMEDERIVER_H +#ifndef PALLAS_PREFERREDNAMEDERIVER_H +#define PALLAS_PREFERREDNAMEDERIVER_H #include +#include /** - * Generators for VerCors origin objects preferredName fields for various LLVM Value types. + * Generators for VerCors origin objects preferredName fields for various LLVM + * Value types. * - * For more info on VerCors origins see: https://github.com/utwente-fmt/vercors/discussions/884 + * For more info on VerCors origins see: + * https://github.com/utwente-fmt/vercors/discussions/884 */ -namespace llvm2Col { - std::string deriveOperandPreferredName(llvm::Value &llvmOperand); +namespace llvm2col { +std::string deriveOperandPreferredName(llvm::Value &llvmOperand); - std::string deriveTypePreferredName(llvm::Type &llvmType); +std::string deriveTypePreferredName(llvm::Type &llvmType); - std::string deriveArgumentPreferredName(llvm::Argument &llvmArgument); +std::string +deriveMemoryOrderingPreferredName(llvm::AtomicOrdering &llvmOrdering); -} -#endif //VCLLVM_PREFERREDNAMEDERIVER_H +std::string deriveArgumentPreferredName(llvm::Argument &llvmArgument); +} // namespace llvm2col +#endif // PALLAS_PREFERREDNAMEDERIVER_H diff --git a/src/llvm/include/Origin/ShortPositionDeriver.h b/src/llvm/include/Origin/ShortPositionDeriver.h index 4ef8fe7c1f..ee1efaf8b4 100644 --- a/src/llvm/include/Origin/ShortPositionDeriver.h +++ b/src/llvm/include/Origin/ShortPositionDeriver.h @@ -1,25 +1,29 @@ -#ifndef VCLLVM_SHORTPOSITIONDERIVER_H -#define VCLLVM_SHORTPOSITIONDERIVER_H +#ifndef PALLAS_SHORTPOSITIONDERIVER_H +#define PALLAS_SHORTPOSITIONDERIVER_H #include /** - * Generators for VerCors origin objects shortPosition fields for various LLVM Value types. + * Generators for VerCors origin objects shortPosition fields for various LLVM + * Value types. * - * It generates a path from the highest level abstraction to the lowest in order of Module -> Function -> Block -> Instruction. + * It generates a path from the highest level abstraction to the lowest in order + * of Module -> Function -> Block -> Instruction. * - * Each abstraction level calls its parent generator to generate its path (e.g. deriveBlockShortPosition calls - * deriveFunctionShortPosition and deriveInstructionShortPosition calls deriveBlockShortPosition) + * Each abstraction level calls its parent generator to generate its path (e.g. + * deriveBlockShortPosition calls deriveFunctionShortPosition and + * deriveInstructionShortPosition calls deriveBlockShortPosition) * - * For more info on VerCors origins see: https://github.com/utwente-fmt/vercors/discussions/884 + * For more info on VerCors origins see: + * https://github.com/utwente-fmt/vercors/discussions/884 */ -namespace llvm2Col { - std::string deriveModuleShortPosition(llvm::Module &llvmModule); +namespace llvm2col { +std::string deriveModuleShortPosition(llvm::Module &llvmModule); - std::string deriveFunctionShortPosition(llvm::Function &llvmFunction); +std::string deriveFunctionShortPosition(llvm::Function &llvmFunction); - std::string deriveBlockShortPosition(llvm::BasicBlock &llvmBlock); +std::string deriveBlockShortPosition(llvm::BasicBlock &llvmBlock); - std::string deriveInstructionShortPosition(llvm::Instruction &llvmInstruction); -} -#endif //VCLLVM_SHORTPOSITIONDERIVER_H +std::string deriveInstructionShortPosition(llvm::Instruction &llvmInstruction); +} // namespace llvm2col +#endif // PALLAS_SHORTPOSITIONDERIVER_H diff --git a/src/llvm/include/Passes/Function/FunctionBodyTransformer.h b/src/llvm/include/Passes/Function/FunctionBodyTransformer.h index 8c6b9309a1..3cc962b394 100644 --- a/src/llvm/include/Passes/Function/FunctionBodyTransformer.h +++ b/src/llvm/include/Passes/Function/FunctionBodyTransformer.h @@ -1,162 +1,176 @@ -#ifndef VCLLVM_FUNCTIONBODYTRANSFORMER_H -#define VCLLVM_FUNCTIONBODYTRANSFORMER_H +#ifndef PALLAS_FUNCTIONBODYTRANSFORMER_H +#define PALLAS_FUNCTIONBODYTRANSFORMER_H #include "vct/col/ast/col.pb.h" #include #include "FunctionDeclarer.h" +#include "vct/col/ast/col.pb.h" +/** + * The FunctionBodyTransformer that transforms LLVM blocks and instructions into + * suitable VerCors COL abstractions. + */ +namespace pallas { +using namespace llvm; +namespace col = vct::col::ast; + +struct LabeledColBlock { + col::Label &label; + col::Block █ +}; + /** - * The FunctionBodyTransformer that transforms LLVM blocks and instructions into suitable VerCors COL abstractions. + * The FunctionCursor is a stateful utility class to transform a LLVM function + * body to a COL function body. */ -namespace vcllvm { - using namespace llvm; - namespace col = vct::col::ast; +class FunctionCursor { + friend class FunctionBodyTransformerPass; + + private: + col::Scope &functionScope; + + col::Block &functionBody; + + llvm::Function &llvmFunction; + + /// Gives access to all other analysis passes ran by pallas as well as + /// existing LLVM analysis passes (i.e. loop analysis). + llvm::FunctionAnalysisManager &FAM; + + /// Most LLVM instructions are transformed to a COL assignment to a COL + /// variable. The resulting end product is a 1-to-1 mapping from and LLVM + /// Value to a COL variable. The generic LLVM Value was chosen to also + /// include function arguments in the lut. + std::unordered_map variableMap; + + /// All LLVM blocks mapped 1-to-1 to a COL block. This mapping is not direct + /// in the sense that it uses the intermediate LabeledColBlock struct which + /// contains both the COL label and COL block associated to the LLVM block + std::unordered_map + llvmBlock2LabeledColBlock; + + /// set of all COL blocks that have been completed. Completed meaning all + /// instructions of the corresponding LLVM block have been transformed. This + /// excludes possible future phi node back transformations. + std::set completedColBlocks; + + /// Almost always when adding a variable to the variableMap, some extra + /// processing is required which is why this method is private as to not + /// accidentally use it outside the functionCursor + void addVariableMapEntry(llvm::Value &llvmValue, col::Variable &colVar); + + public: + explicit FunctionCursor(col::Scope &functionScope, col::Block &functionBody, + llvm::Function &llvmFunction, + llvm::FunctionAnalysisManager &FAM); + + const col::Scope &getFunctionScope(); + + /** + * declares variable in the function scope + * @param llvmInstruction + * @return the created variable declaration + */ + col::Variable &declareVariable(Instruction &llvmInstruction, + Type *llvmPointerType = nullptr); + + /** + * Functionality is twofold: + *
    + *
  1. Creates a variable declaration in the function scope (declare + * variable)
  2. Creates an assignment in the provided colBlock
  3. + *
+ * @param llvmInstruction + * @param colBlock + * @return The created col assignment + */ + col::Assign & + createAssignmentAndDeclaration(Instruction &llvmInstruction, + col::Block &colBlock, + Type *llvmPointerType = nullptr); - struct LabeledColBlock { - col::Label &label; - col::Block █ - }; + /** + * Creates an assignment in the provided colBlock referencing the provided + * variable declaration + * + * @param llvmInstruction + * @param colBlock + * @param varDecl + * @return the created col assignment + */ + col::Assign &createAssignment(Instruction &llvmInstruction, + col::Block &colBlock, col::Variable &varDecl); + + col::Variable &getVariableMapEntry(llvm::Value &llvmValue, bool inPhiNode); + + /** + * In many cases during transformation, it is not possible to derive whether + * a COL block has yet been mapped and initialised. This is why we have a + * get or set method which does the following"
  • If a mapping between + * the given LLVM block and a COL block already exists, return the COL + * block
  • Else, initalise a new COL block in the buffer, add it to + * the llvmBlock2LabeledColBlock lut and return the newly created COL + * block
  • + *
+ * + * @param llvmBlock + * @return A LabeledColBlock struct to which this llvmBlock is mapped to. + */ + LabeledColBlock & + getOrSetLLVMBlock2LabeledColBlockEntry(BasicBlock &llvmBlock); + + llvm::FunctionAnalysisManager &getFunctionAnalysisManager(); + + /** + * Indicates whether a LLVM block has been visited (i.e. whether a mapping + * exists to a COL block). Note that does not mean that it has been fully + * transformed. For that see the isComplete + * + * @param llvmBlock + * @return + */ + bool isVisited(llvm::BasicBlock &llvmBlock); + + /** + * Mark COL Block as complete by adding it to the completedColBlocks set. + * @param llvmBlock + */ + void complete(col::Block &colBlock); + + /** + * Indicates whether an llvmBlock has been fully transformed (excluding + * possible phi node back transformations). Any completed block is also + * visited. + * @return true if block is in the completedColBlocks set, false otherwise. + */ + bool isComplete(col::Block &colBlock); + + LoopInfo &getLoopInfo(); + + LoopInfo &getLoopInfo(llvm::Function &otherLLVMFunction); + + /** + * Retrieve the FunctionDeclarerPass analysis result from the function this + * FunctionCursor is associated with by querying the + * FunctionAnalysisManager. + * @return + */ + FDResult &getFDResult(); /** - * The FunctionCursor is a stateful utility class to transform a LLVM function body to a COL function body. + * Retrieve the FunctionDeclarerPass analysis result from a function in the + * current program by querying the FunctionAnalysisManager. + * @param otherLLVMFunction + * @return */ - class FunctionCursor { - friend class FunctionBodyTransformerPass; - - private: - col::Scope &functionScope; - - col::Block &functionBody; - - llvm::Function &llvmFunction; - - /// Gives access to all other analysis passes ran by vcllvm as well as existing LLVM analysis passes (i.e. loop - /// analysis). - llvm::FunctionAnalysisManager &FAM; - - /// Most LLVM instructions are transformed to a COL assignment to a COL variable. The resulting end product is - /// a 1-to-1 mapping from and LLVM Value to a COL variable. The generic LLVM Value was chosen to also include - /// function arguments in the lut. - std::unordered_map variableMap; - - /// All LLVM blocks mapped 1-to-1 to a COL block. This mapping is not direct in the sense that it uses the - /// intermediate LabeledColBlock struct which contains both the COL label and COL block associated to the LLVM - /// block - std::unordered_map llvmBlock2LabeledColBlock; - - /// set of all COL blocks that have been completed. Completed meaning all instructions of the corresponding LLVM - /// block have been transformed. This excludes possible future phi node back transformations. - std::set completedColBlocks; - - /// Almost always when adding a variable to the variableMap, some extra processing is required which is why this - /// method is private as to not accidentally use it outside the functionCursor - void addVariableMapEntry(llvm::Value &llvmValue, col::Variable &colVar); - - public: - explicit FunctionCursor(col::Scope &functionScope, - col::Block &functionBody, - llvm::Function &llvmFunction, - llvm::FunctionAnalysisManager &FAM); - - const col::Scope &getFunctionScope(); - - /** - * declares variable in the function scope - * @param llvmInstruction - * @return the created variable declaration - */ - col::Variable &declareVariable(Instruction &llvmInstruction); - - /** - * Functionality is twofold: - *
    - *
  1. Creates a variable declaration in the function scope (declare variable)
  2. - *
  3. Creates an assignment in the provided colBlock
  4. - *
- * @param llvmInstruction - * @param colBlock - * @return The created col assignment - */ - col::Assign &createAssignmentAndDeclaration(Instruction &llvmInstruction, col::Block &colBlock); - - /** - * Creates an assignment in the provided colBlock referencing the provided variable declaration - * - * @param llvmInstruction - * @param colBlock - * @param varDecl - * @return the created col assignment - */ - col::Assign &createAssignment(Instruction &llvmInstruction, col::Block &colBlock, col::Variable &varDecl); - - col::Variable &getVariableMapEntry(llvm::Value &llvmValue); - - /** - * In many cases during transformation, it is not possible to derive whether a COL block has yet been mapped and - * initialised. This is why we have a get or set method which does the following" - *
    - *
  • If a mapping between the given LLVM block and a COL block already exists, return the COL block
  • - *
  • Else, initalise a new COL block in the buffer, add it to the llvmBlock2LabeledColBlock lut and return - * the newly created COL block
  • - *
- * - * @param llvmBlock - * @return A LabeledColBlock struct to which this llvmBlock is mapped to. - */ - LabeledColBlock &getOrSetLlvmBlock2LabeledColBlockEntry(BasicBlock &llvmBlock); - - /** - * Indicates whether a LLVM block has been visited (i.e. whether a mapping exists to a COL block). - * Note that does not mean that it has been fully transformed. For that see the isComplete - * - * @param llvmBlock - * @return - */ - bool isVisited(llvm::BasicBlock &llvmBlock); - - /** - * Mark COL Block as complete by adding it to the completedColBlocks set. - * @param llvmBlock - */ - void complete(col::Block &colBlock); - - /** - * Indicates whether an llvmBlock has been fully transformed (excluding possible phi node back transformations). - * Any completed block is also visited. - * @return true if block is in the completedColBlocks set, false otherwise. - */ - bool isComplete(col::Block &colBlock); - - LoopInfo &getLoopInfo(); - - LoopInfo &getLoopInfo(llvm::Function &otherLlvmFunction); - - /** - * Retrieve the FunctionDeclarerPass analysis result from the function this FunctionCursor is associated with by - * querying the FunctionAnalysisManager. - * @return - */ - FDResult &getFDResult(); - - /** - * Retrieve the FunctionDeclarerPass analysis result from a function in the current program by querying - * the FunctionAnalysisManager. - * @param otherLlvmFunction - * @return - */ - FDResult &getFDResult(llvm::Function &otherLlvmFunction); - - }; - - class FunctionBodyTransformerPass : public PassInfoMixin { - private: - std::shared_ptr pProgram; - - public: - explicit FunctionBodyTransformerPass(std::shared_ptr pProgram); - - PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); - }; -} -#endif //VCLLVM_FUNCTIONBODYTRANSFORMER_H + FDResult &getFDResult(llvm::Function &otherLLVMFunction); +}; + +class FunctionBodyTransformerPass + : public PassInfoMixin { + public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); +}; +} // namespace pallas +#endif // PALLAS_FUNCTIONBODYTRANSFORMER_H diff --git a/src/llvm/include/Passes/Function/FunctionContractDeclarer.h b/src/llvm/include/Passes/Function/FunctionContractDeclarer.h index e6ac3d28a7..07ef66ff2e 100644 --- a/src/llvm/include/Passes/Function/FunctionContractDeclarer.h +++ b/src/llvm/include/Passes/Function/FunctionContractDeclarer.h @@ -1,68 +1,70 @@ -#ifndef VCLLVM_FUNCTIONCONTRACTDECLARER_H -#define VCLLVM_FUNCTIONCONTRACTDECLARER_H +#ifndef PALLAS_FUNCTIONCONTRACTDECLARER_H +#define PALLAS_FUNCTIONCONTRACTDECLARER_H #include "vct/col/ast/col.pb.h" #include /** - * Pass that adds an LLVMFunctionContract to its corresponding LLVMFunctionDefinition in the presence - * of a contract metadata node. The resulting FDCResult class can be used by a FunctionAnalysisManager to access the - * created contract and add named references to the contract (e.g. map functions arguments string representations to COL - * variables representing these same arguments). + * Pass that adds an LlvmfunctionContract to its corresponding + * LlvmfunctionDefinition in the presence of a contract metadata node. The + * resulting FDCResult class can be used by a FunctionAnalysisManager to access + * the created contract and add named references to the contract (e.g. map + * functions arguments string representations to COL variables representing + * these same arguments). * - * The pass is twofold: it has an analysis pass (FunctionContractDeclarer) that merely creates objects in the buffer and - * adds them to the associated result object. This way, the result object of this pass can be queried by other passes in - * order to retrieve the relevant COL nodes associated to this LLVM function. + * The pass is twofold: it has an analysis pass (FunctionContractDeclarer) that + * merely creates objects in the buffer and adds them to the associated result + * object. This way, the result object of this pass can be queried by other + * passes in order to retrieve the relevant COL nodes associated to this LLVM + * function. * - * The second pass is a regular function pass (FunctionContractDeclarerPass) that finishes the transformation started by - * the FunctionContractDeclarer analysis pass. + * The second pass is a regular function pass (FunctionContractDeclarerPass) + * that finishes the transformation started by the FunctionContractDeclarer + * analysis pass. */ -namespace vcllvm { - using namespace llvm; - namespace col = vct::col::ast; +namespace pallas { +using namespace llvm; +namespace col = vct::col::ast; - class FDCResult { - private: - col::LlvmFunctionContract &associatedColFuncContract; - public: - explicit FDCResult(col::LlvmFunctionContract &colFuncContract); +class FDCResult { + private: + col::LlvmFunctionContract &associatedColFuncContract; - col::LlvmFunctionContract &getAssociatedColFuncContract(); - }; + public: + explicit FDCResult(col::LlvmFunctionContract &colFuncContract); - class FunctionContractDeclarer : public AnalysisInfoMixin { - friend AnalysisInfoMixin; - static AnalysisKey Key; - private: - std::shared_ptr pProgram; - public: - using Result = FDCResult; + col::LlvmFunctionContract &getAssociatedColFuncContract(); +}; - explicit FunctionContractDeclarer(std::shared_ptr pProgram); +class FunctionContractDeclarer + : public AnalysisInfoMixin { + friend AnalysisInfoMixin; + static AnalysisKey Key; - /** - * Merely creates a COL LlvmFunctionDefinition object in the buffer and sets it in a FDCResult object. - * @param F - * @param FAM - * @return - */ - Result run(Function &F, FunctionAnalysisManager &FAM); - }; + public: + using Result = FDCResult; - class FunctionContractDeclarerPass : public AnalysisInfoMixin { - private: - std::shared_ptr pProgram; - public: - explicit FunctionContractDeclarerPass(std::shared_ptr pProgram); + /** + * Merely creates a COL LlvmfunctionDefinition object in the buffer and sets + * it in a FDCResult object. + * @param F + * @param FAM + * @return + */ + Result run(Function &F, FunctionAnalysisManager &FAM); +}; - /** - * Retrieves the LlvmFunctionDefinition object in the buffer from the FDCResult object and sets the origin and - * string value of the contract. - * @param F - * @param FAM - * @return - */ - PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); - }; -} -#endif //VCLLVM_FUNCTIONCONTRACTDECLARER_H +class FunctionContractDeclarerPass + : public AnalysisInfoMixin { + public: + /** + * Retrieves the LlvmfunctionDefinition object in the buffer from the + * FDCResult object and sets the origin and string value of the contract. + * @param F + * @param FAM + * @return + */ + PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); +}; +} // namespace pallas +#endif // PALLAS_FUNCTIONCONTRACTDECLARER_H diff --git a/src/llvm/include/Passes/Function/FunctionDeclarer.h b/src/llvm/include/Passes/Function/FunctionDeclarer.h index eeb5fc4666..002c5d64d6 100644 --- a/src/llvm/include/Passes/Function/FunctionDeclarer.h +++ b/src/llvm/include/Passes/Function/FunctionDeclarer.h @@ -1,96 +1,95 @@ -#ifndef VCLLVM_FUNCTIONDECLARER_H -#define VCLLVM_FUNCTIONDECLARER_H +#ifndef PALLAS_FUNCTIONDECLARER_H +#define PALLAS_FUNCTIONDECLARER_H #include "vct/col/ast/col.pb.h" #include /** - * Pass that creates a signature for a LLVMFunctionDefinition in COL and exposes an FDResult object that - * binds the the LLVM IR Function to a LLVMFunctionDefinition COL object. The actual function implementation is + * Pass that creates a signature for a LlvmfunctionDefinition in COL and exposes + * an FDResult object that binds the the LLVM IR Function to a + * LlvmfunctionDefinition COL object. The actual function implementation is * transformed by the FunctionBodyTransformer pass. * - * The pass is twofold: it has an analysis pass (FunctionDeclarer) that merely creates objects in the buffer and adds - * them to the associated result object. This way, the result object of this pass can be queried by other passes in order - * to retrieve the relevant COL nodes associated to this LLVM function. + * The pass is twofold: it has an analysis pass (FunctionDeclarer) that merely + * creates objects in the buffer and adds them to the associated result object. + * This way, the result object of this pass can be queried by other passes in + * order to retrieve the relevant COL nodes associated to this LLVM function. * - * The second pass is a regular function pass (FunctionDeclarerPass) that finishes the transformation started by the - * FunctionDeclarer analysis pass. + * The second pass is a regular function pass (FunctionDeclarerPass) that + * finishes the transformation started by the FunctionDeclarer analysis pass. */ -namespace vcllvm { - using namespace llvm; - namespace col = vct::col::ast; - - /// wrapper struct for a COL scope and block. Intended use is the block to be declared in the scope. - struct ColScopedFuncBody { - col::Scope *scope; - col::Block *block; - }; - - class FDResult { - friend class FunctionDeclarer; - - private: - col::LlvmFunctionDefinition &associatedColFuncDef; - ColScopedFuncBody associatedScopedColFuncBody; - int64_t functionId; - /// contains the 1-to-1 mapping from LLVM function arguments to COL variables that are used as function - /// arguments. - std::unordered_map funcArgMap; - - void addFuncArgMapEntry(llvm::Argument &llvmArg, col::Variable &colArg); - - public: - explicit FDResult(col::LlvmFunctionDefinition &colFuncDef, - ColScopedFuncBody associatedScopedColFuncBody, - int64_t functionId); - - col::LlvmFunctionDefinition &getAssociatedColFuncDef(); - - ColScopedFuncBody getAssociatedScopedColFuncBody(); - - col::Variable &getFuncArgMapEntry(llvm::Argument &arg); - - int64_t &getFunctionId(); - }; - - class FunctionDeclarer : public AnalysisInfoMixin { - friend AnalysisInfoMixin; - static AnalysisKey Key; - private: - std::shared_ptr pProgram; - public: - using Result = FDResult; - - explicit FunctionDeclarer(std::shared_ptr pProgram); - - /** - * Creates a COL LlvmFunctionDefinition in the buffer, including a function scope and body and their origins. - * It maps the corresponding LLVM Function to the created COL LlvmFunctionDefinition. - * - * Additionally, it creates the function arguments (COL variables) in the buffer and maps the corresponding - * LLVM arguments to the created COL arguments. - * - * @param F - * @param FAM - * @return - */ - Result run(Function &F, FunctionAnalysisManager &FAM); - - }; - - class FunctionDeclarerPass : public AnalysisInfoMixin { - private: - std::shared_ptr pProgram; - public: - explicit FunctionDeclarerPass(std::shared_ptr pProgram); - /** - * Completes the function definition transformation by adding a return type to the COL LLVMFunctionDefinition - * - * @param F - * @param FAM - * @return - */ - PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); - }; -} -#endif //VCLLVM_FUNCTIONDECLARER_H \ No newline at end of file +namespace pallas { +using namespace llvm; +namespace col = vct::col::ast; + +/// wrapper struct for a COL scope and block. Intended use is the block to be +/// declared in the scope. +struct ColScopedFuncBody { + col::Scope *scope; + col::Block *block; +}; + +class FDResult { + friend class FunctionDeclarer; + + private: + col::LlvmFunctionDefinition &associatedColFuncDef; + ColScopedFuncBody associatedScopedColFuncBody; + int64_t functionId; + /// contains the 1-to-1 mapping from LLVM function arguments to COL + /// variables that are used as function arguments. + std::unordered_map funcArgMap; + + void addFuncArgMapEntry(llvm::Argument &llvmArg, col::Variable &colArg); + + public: + explicit FDResult(col::LlvmFunctionDefinition &colFuncDef, + ColScopedFuncBody associatedScopedColFuncBody, + int64_t functionId); + + col::LlvmFunctionDefinition &getAssociatedColFuncDef(); + + ColScopedFuncBody getAssociatedScopedColFuncBody(); + + col::Variable &getFuncArgMapEntry(llvm::Argument &arg); + + int64_t &getFunctionId(); +}; + +class FunctionDeclarer : public AnalysisInfoMixin { + friend AnalysisInfoMixin; + static AnalysisKey Key; + + public: + using Result = FDResult; + + /** + * Creates a COL LlvmfunctionDefinition in the buffer, including a function + * scope and body and their origins. It maps the corresponding LLVM Function + * to the created COL LlvmfunctionDefinition. + * + * Additionally, it creates the function arguments (COL variables) in the + * buffer and maps the corresponding LLVM arguments to the created COL + * arguments. + * + * @param F + * @param FAM + * @return + */ + Result run(Function &F, FunctionAnalysisManager &FAM); +}; + +class FunctionDeclarerPass : public AnalysisInfoMixin { + public: + /** + * Completes the function definition transformation by adding a return type + * to the COL LlvmfunctionDefinition + * + * @param F + * @param FAM + * @return + */ + PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); +}; +} // namespace pallas +#endif // PALLAS_FUNCTIONDECLARER_H diff --git a/src/llvm/include/Passes/Function/PureAssigner.h b/src/llvm/include/Passes/Function/PureAssigner.h index 6cd8797b72..a41e01bbe1 100644 --- a/src/llvm/include/Passes/Function/PureAssigner.h +++ b/src/llvm/include/Passes/Function/PureAssigner.h @@ -1,28 +1,19 @@ -#ifndef VCLLVM_PUREASSIGNER_H -#define VCLLVM_PUREASSIGNER_H +#ifndef PALLAS_PUREASSIGNER_H +#define PALLAS_PUREASSIGNER_H #include "vct/col/ast/col.pb.h" #include /** - * The PureAssignerPass checks if a LLVM function is pure (i.e. whether the !VC.pure metadata node is set) + * The PureAssignerPass checks if a LLVM function is pure (i.e. whether the + * !VC.pure metadata node is set) */ -namespace vcllvm { - using namespace llvm; - namespace col = vct::col::ast; +namespace pallas { +using namespace llvm; +namespace col = vct::col::ast; - class PureAssignerPass : public PassInfoMixin { - private: - std::shared_ptr pProgram; - public: - explicit PureAssignerPass(std::shared_ptr pProgram); - - PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); - }; - /** - * Helper function to generate errors generated by this Pass - * @param F - * @param explanation - */ - void reportError(Function &F, const std::string &explanation); -} -#endif //VCLLVM_PUREASSIGNER_H +class PureAssignerPass : public PassInfoMixin { + public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); +}; +} // namespace pallas +#endif // PALLAS_PUREASSIGNER_H diff --git a/src/llvm/include/Passes/Module/GlobalVariableDeclarer.h b/src/llvm/include/Passes/Module/GlobalVariableDeclarer.h new file mode 100644 index 0000000000..5ba3e66cb2 --- /dev/null +++ b/src/llvm/include/Passes/Module/GlobalVariableDeclarer.h @@ -0,0 +1,17 @@ +#ifndef PALLAS_GLOBALVARIABLEDECLARER_H +#define PALLAS_GLOBALVARIABLEDECLARER_H + +#include "vct/col/ast/col.pb.h" +#include + +namespace pallas { +using namespace llvm; +namespace col = vct::col::ast; + +class GlobalVariableDeclarerPass + : public AnalysisInfoMixin { + public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); +}; +} // namespace pallas +#endif // PALLAS_GLOBALVARIABLEDECLARER_H diff --git a/src/llvm/include/Passes/Module/ModuleSpecCollector.h b/src/llvm/include/Passes/Module/ModuleSpecCollector.h index c95bce722d..5131b6b37d 100644 --- a/src/llvm/include/Passes/Module/ModuleSpecCollector.h +++ b/src/llvm/include/Passes/Module/ModuleSpecCollector.h @@ -1,23 +1,21 @@ -#ifndef VCLLVM_MODULESPECCOLLECTOR_H -#define VCLLVM_MODULESPECCOLLECTOR_H +#ifndef PALLAS_MODULESPECCOLLECTOR_H +#define PALLAS_MODULESPECCOLLECTOR_H #include "vct/col/ast/col.pb.h" #include /** - * Pass that adds global specifications (i.e. not related to a loop or function) to the AST as unparsed strings. It's - * VerCors job to parse the string into any global declaration as if it were in a spec comment. + * Pass that adds global specifications (i.e. not related to a loop or function) + * to the AST as unparsed strings. It's VerCors job to parse the string into any + * global declaration as if it were in a spec comment. */ -namespace vcllvm { - using namespace llvm; - namespace col = vct::col::ast; +namespace pallas { +using namespace llvm; +namespace col = vct::col::ast; - class ModuleSpecCollectorPass : public AnalysisInfoMixin { - private: - std::shared_ptr pProgram; - public: - explicit ModuleSpecCollectorPass(std::shared_ptr pProgram); - - PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); - }; -} -#endif //VCLLVM_MODULESPECCOLLECTOR_H +class ModuleSpecCollectorPass + : public AnalysisInfoMixin { + public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); +}; +} // namespace pallas +#endif // PALLAS_MODULESPECCOLLECTOR_H diff --git a/src/llvm/include/Passes/Module/ProtobufPrinter.h b/src/llvm/include/Passes/Module/ProtobufPrinter.h new file mode 100644 index 0000000000..538b6f74f0 --- /dev/null +++ b/src/llvm/include/Passes/Module/ProtobufPrinter.h @@ -0,0 +1,16 @@ +#ifndef PALLAS_PROTOBUFPRINTER_H +#define PALLAS_PROTOBUFPRINTER_H + +#include "vct/col/ast/col.pb.h" +#include + +namespace pallas { +using namespace llvm; +namespace col = vct::col::ast; + +class ProtobufPrinter : public AnalysisInfoMixin { + public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); +}; +} // namespace pallas +#endif // PALLAS_PROTOBUFPRINTER_H diff --git a/src/llvm/include/Passes/Module/RootContainer.h b/src/llvm/include/Passes/Module/RootContainer.h new file mode 100644 index 0000000000..52806fac84 --- /dev/null +++ b/src/llvm/include/Passes/Module/RootContainer.h @@ -0,0 +1,28 @@ +#ifndef PALLAS_ROOTCONTAINER_H +#define PALLAS_ROOTCONTAINER_H + +#include "vct/col/ast/col.pb.h" +#include + +namespace pallas { +using namespace llvm; +namespace col = vct::col::ast; + +class ProgramWrapper { + public: + std::shared_ptr program; + bool invalidate(Module &M, const PreservedAnalyses &PA, + ModuleAnalysisManager::Invalidator &); +}; + +class RootContainer : public AnalysisInfoMixin { + friend AnalysisInfoMixin; + static AnalysisKey Key; + + public: + using Result = ProgramWrapper; + + Result run(Module &M, ModuleAnalysisManager &MAM); +}; +} // namespace pallas +#endif // PALLAS_ROOTCONTAINER_H diff --git a/src/llvm/include/Transform/BlockTransform.h b/src/llvm/include/Transform/BlockTransform.h index e250d96b77..d0215de03a 100644 --- a/src/llvm/include/Transform/BlockTransform.h +++ b/src/llvm/include/Transform/BlockTransform.h @@ -1,52 +1,54 @@ -#ifndef VCLLVM_BLOCKTRANSFORM_H -#define VCLLVM_BLOCKTRANSFORM_H +#ifndef PALLAS_BLOCKTRANSFORM_H +#define PALLAS_BLOCKTRANSFORM_H #include "Passes/Function/FunctionBodyTransformer.h" -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; - /** - * Entry point for each block transformation. It performs the following steps: - *
    - *
  1. Create or fetch the corresponding labeled col block from the function cursor
  2. - *
  3. Check if all predecessor blocks have been visited yet, otherwise, return
  4. - *
  5. If block turns out to be a loop header, hand over control to the transformLoop function. - * Else, transform instructions of the block
  6. - *
- * - * Note: The transformTermOp function will take care of subsequent blocks recursively - * @param functionCursor - * @param llvmBlock - */ - void transformLlvmBlock(llvm::BasicBlock &llvmBlock, vcllvm::FunctionCursor &functionCursor); +/** + * Entry point for each block transformation. It performs the following steps: + *
    + *
  1. Create or fetch the corresponding labeled col block from the function + * cursor
  2. Check if all predecessor blocks have been visited yet, + * otherwise, return
  3. If block turns out to be a loop header, hand over + * control to the transformLoop function. Else, transform + * instructions of the block
  4. + *
+ * + * Note: The transformTermOp function will take care of subsequent + * blocks recursively + * @param functionCursor + * @param llvmBlock + */ +void transformLLVMBlock(llvm::BasicBlock &llvmBlock, + pallas::FunctionCursor &functionCursor); - /** - * Unimplemented - * @param llvmBlock - * @param functionCursor - */ - void transformLoop(llvm::BasicBlock &llvmBlock, vcllvm::FunctionCursor &functionCursor); +/** + * Unimplemented + * @param llvmBlock + * @param functionCursor + */ +void transformLoop(llvm::BasicBlock &llvmBlock, + pallas::FunctionCursor &functionCursor); - /** - * Instructions are split up in their separate LLVM categories and transformed by their respective transformer. - * These are: - *
    - *
  1. Binary operators
  2. - *
  3. Casting operators
  4. - *
  5. Funclet pad operators
  6. - *
  7. Memory operators
  8. - *
  9. Terminal operators
  10. - *
  11. Unary operators
  12. - *
- * @param funcCursor - * @param llvmInstruction - * @param colBodyBlock - */ - void transformInstruction(vcllvm::FunctionCursor &funcCursor, - llvm::Instruction &llvmInstruction, - col::Block &colBodyBlock); - - void reportUnsupportedOperatorError(const std::string &source, llvm::Instruction &llvmInstruction); -} -#endif //VCLLVM_BLOCKTRANSFORM_H +/** + * Instructions are split up in their separate LLVM categories and transformed + * by their respective transformer. These are:
  1. Binary operators
  2. + *
  3. Casting operators
  4. + *
  5. Funclet pad operators
  6. + *
  7. Memory operators
  8. + *
  9. Terminal operators
  10. + *
  11. Unary operators
  12. + *
+ * @param funcCursor + * @param llvmInstruction + * @param colBodyBlock + */ +void transformInstruction(pallas::FunctionCursor &funcCursor, + llvm::Instruction &llvmInstruction, + col::Block &colBodyBlock); +void reportUnsupportedOperatorError(const std::string &source, + llvm::Instruction &llvmInstruction); +} // namespace llvm2col +#endif // PALLAS_BLOCKTRANSFORM_H diff --git a/src/llvm/include/Transform/Instruction/BinaryOpTransform.h b/src/llvm/include/Transform/Instruction/BinaryOpTransform.h index 59f2989ed2..8e965b2bfa 100644 --- a/src/llvm/include/Transform/Instruction/BinaryOpTransform.h +++ b/src/llvm/include/Transform/Instruction/BinaryOpTransform.h @@ -1,15 +1,13 @@ -#ifndef VCLLVM_BINARYOPTRANSFORM_H -#define VCLLVM_BINARYOPTRANSFORM_H +#ifndef PALLAS_BINARYOPTRANSFORM_H +#define PALLAS_BINARYOPTRANSFORM_H #include "Passes/Function/FunctionBodyTransformer.h" -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; +void transformBinaryOp(llvm::Instruction &llvmInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); - void transformBinaryOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); - -} -#endif //VCLLVM_BINARYOPTRANSFORM_H +} // namespace llvm2col +#endif // PALLAS_BINARYOPTRANSFORM_H diff --git a/src/llvm/include/Transform/Instruction/CastOpTransform.h b/src/llvm/include/Transform/Instruction/CastOpTransform.h index c42dbdeac8..b69d436e43 100644 --- a/src/llvm/include/Transform/Instruction/CastOpTransform.h +++ b/src/llvm/include/Transform/Instruction/CastOpTransform.h @@ -1,12 +1,20 @@ -#ifndef VCLLVM_CASTOPTRANSFORM_H -#define VCLLVM_CASTOPTRANSFORM_H +#ifndef PALLAS_CASTOPTRANSFORM_H +#define PALLAS_CASTOPTRANSFORM_H #include "Passes/Function/FunctionBodyTransformer.h" -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; - void convertCastOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); -} -#endif //VCLLVM_CASTOPTRANSFORM_H +void transformCastOp(llvm::Instruction &llvmInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); + +void transformSExt(llvm::SExtInst &sextInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); + +void transformZExt(llvm::ZExtInst &sextInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); + +void transformTrunc(llvm::TruncInst &truncInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); +} // namespace llvm2col +#endif // PALLAS_CASTOPTRANSFORM_H diff --git a/src/llvm/include/Transform/Instruction/FuncletPadOpTransform.h b/src/llvm/include/Transform/Instruction/FuncletPadOpTransform.h index c6bc303977..ae810415b1 100644 --- a/src/llvm/include/Transform/Instruction/FuncletPadOpTransform.h +++ b/src/llvm/include/Transform/Instruction/FuncletPadOpTransform.h @@ -1,12 +1,12 @@ -#ifndef VCLLVM_FUNCLETPADOPTRANSFORM_H -#define VCLLVM_FUNCLETPADOPTRANSFORM_H +#ifndef PALLAS_FUNCLETPADOPTRANSFORM_H +#define PALLAS_FUNCLETPADOPTRANSFORM_H #include "Passes/Function/FunctionBodyTransformer.h" -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; - void transformFuncletPadOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); -} -#endif //VCLLVM_FUNCLETPADOPTRANSFORM_H +void transformFuncletPadOp(llvm::Instruction &llvmInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor); +} // namespace llvm2col +#endif // PALLAS_FUNCLETPADOPTRANSFORM_H diff --git a/src/llvm/include/Transform/Instruction/MemoryOpTransform.h b/src/llvm/include/Transform/Instruction/MemoryOpTransform.h index eb8d274db7..71707639f0 100644 --- a/src/llvm/include/Transform/Instruction/MemoryOpTransform.h +++ b/src/llvm/include/Transform/Instruction/MemoryOpTransform.h @@ -1,12 +1,27 @@ -#ifndef VCLLVM_MEMORYOPTRANSFORM_H -#define VCLLVM_MEMORYOPTRANSFORM_H +#ifndef PALLAS_MEMORYOPTRANSFORM_H +#define PALLAS_MEMORYOPTRANSFORM_H #include "Passes/Function/FunctionBodyTransformer.h" -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; - void transformMemoryOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); -} -#endif //VCLLVM_MEMORYOPTRANSFORM_H +void transformMemoryOp(llvm::Instruction &llvmInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); + +void transformAllocA(llvm::AllocaInst &allocAInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); + +void transformAtomicOrdering(llvm::AtomicOrdering ordering, + col::LlvmMemoryOrdering *colOrdering); + +void transformLoad(llvm::LoadInst &loadInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); + +void transformStore(llvm::StoreInst &storeInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); + +void transformGetElementPtr(llvm::GetElementPtrInst &gepInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor); +} // namespace llvm2col +#endif // PALLAS_MEMORYOPTRANSFORM_H diff --git a/src/llvm/include/Transform/Instruction/OtherOpTransform.h b/src/llvm/include/Transform/Instruction/OtherOpTransform.h index 0ab0519ce0..cd86314fab 100644 --- a/src/llvm/include/Transform/Instruction/OtherOpTransform.h +++ b/src/llvm/include/Transform/Instruction/OtherOpTransform.h @@ -1,38 +1,39 @@ -#ifndef VCLLVM_OTHEROPTRANSFORM_H -#define VCLLVM_OTHEROPTRANSFORM_H +#ifndef PALLAS_OTHEROPTRANSFORM_H +#define PALLAS_OTHEROPTRANSFORM_H #include "Passes/Function/FunctionBodyTransformer.h" -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; - void transformOtherOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); - /** - * Phi nodes get transformed retroactively by creating a variable declaration and retroactively assign the variable - * in each originating COL block of each phi pair. - * @param phiInstruction - * @param funcCursor - */ - void transformPhi(llvm::PHINode &phiInstruction, vcllvm::FunctionCursor &funcCursor); +void transformOtherOp(llvm::Instruction &llvmInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); +/** + * Phi nodes get transformed retroactively by creating a variable declaration + * and retroactively assign the variable in each originating COL block of each + * phi pair. + * @param phiInstruction + * @param funcCursor + */ +void transformPhi(llvm::PHINode &phiInstruction, + pallas::FunctionCursor &funcCursor); - void transformICmp(llvm::ICmpInst &icmpInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); - /** - * Transforms the common part of all compare instructions (the argument pair). Currently only used by transformIcmp - * but could also be used in the future by for example an FCMP transformation. - * @param cmpInstruction - * @param colCompareExpr - * @param funcCursor - */ - void transformCmpExpr(llvm::CmpInst &cmpInstruction, - auto &colCompareExpr, - vcllvm::FunctionCursor &funcCursor); +void transformICmp(llvm::ICmpInst &icmpInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); +/** + * Transforms the common part of all compare instructions (the argument pair). + * Currently only used by transformIcmp but could also be used in the future by + * for example an FCMP transformation. + * @param cmpInstruction + * @param colCompareExpr + * @param funcCursor + */ +void transformCmpExpr(llvm::CmpInst &cmpInstruction, auto &colCompareExpr, + pallas::FunctionCursor &funcCursor); - void transformCallExpr(llvm::CallInst &callInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); -} +void transformCallExpr(llvm::CallInst &callInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); -#endif //VCLLVM_OTHEROPTRANSFORM_H +bool checkCallSupport(llvm::CallInst &callInstruction); +} // namespace llvm2col + +#endif // PALLAS_OTHEROPTRANSFORM_H diff --git a/src/llvm/include/Transform/Instruction/TermOpTransform.h b/src/llvm/include/Transform/Instruction/TermOpTransform.h index b9fec00303..8f3eb03d2a 100644 --- a/src/llvm/include/Transform/Instruction/TermOpTransform.h +++ b/src/llvm/include/Transform/Instruction/TermOpTransform.h @@ -1,25 +1,23 @@ -#ifndef VCLLVM_TERMOPTRANSFORM_H -#define VCLLVM_TERMOPTRANSFORM_H +#ifndef PALLAS_TERMOPTRANSFORM_H +#define PALLAS_TERMOPTRANSFORM_H #include "Passes/Function/FunctionBodyTransformer.h" -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; - void transformTermOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); +void transformTermOp(llvm::Instruction &llvmInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); - void transformRet(llvm::ReturnInst &llvmRetInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); +void transformRet(llvm::ReturnInst &llvmRetInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); - void transformConditionalBranch(llvm::BranchInst &llvmBrInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); +void transformConditionalBranch(llvm::BranchInst &llvmBrInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor); - void transformUnConditionalBranch(llvm::BranchInst &llvmBrInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); -} -#endif //VCLLVM_TERMOPTRANSFORM_H +void transformUnConditionalBranch(llvm::BranchInst &llvmBrInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor); +} // namespace llvm2col +#endif // PALLAS_TERMOPTRANSFORM_H diff --git a/src/llvm/include/Transform/Instruction/UnaryOpTransform.h b/src/llvm/include/Transform/Instruction/UnaryOpTransform.h index 03dac632fe..2995eb075f 100644 --- a/src/llvm/include/Transform/Instruction/UnaryOpTransform.h +++ b/src/llvm/include/Transform/Instruction/UnaryOpTransform.h @@ -1,13 +1,12 @@ -#ifndef VCLLVM_UNARYOPTRANSFORM_H -#define VCLLVM_UNARYOPTRANSFORM_H +#ifndef PALLAS_UNARYOPTRANSFORM_H +#define PALLAS_UNARYOPTRANSFORM_H #include "Passes/Function/FunctionBodyTransformer.h" -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; - void transformUnaryOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor); -} -#endif //VCLLVM_UNARYOPTRANSFORM_H \ No newline at end of file +void transformUnaryOp(llvm::Instruction &llvmInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor); +} // namespace llvm2col +#endif // PALLAS_UNARYOPTRANSFORM_H diff --git a/src/llvm/include/Transform/Transform.h b/src/llvm/include/Transform/Transform.h index 52addb856a..4506aa4a97 100644 --- a/src/llvm/include/Transform/Transform.h +++ b/src/llvm/include/Transform/Transform.h @@ -1,82 +1,87 @@ -#ifndef VCLLVM_TRANSFORM_H -#define VCLLVM_TRANSFORM_H +#ifndef PALLAS_TRANSFORM_H +#define PALLAS_TRANSFORM_H -#include "Passes/Function/FunctionBodyTransformer.h" #include "Origin/OriginProvider.h" +#include "Passes/Function/FunctionBodyTransformer.h" /** * General helper functions for transformations */ -namespace llvm2Col { - namespace col = vct::col::ast; +namespace llvm2col { +namespace col = vct::col::ast; - // type transformers - void transformAndSetType(llvm::Type &llvmType, col::Type &colType); +// type transformers +void transformAndSetPointerType(llvm::Type &llvmType, col::Type &colType); - /** - * ATTEMPTS to convert any integer constant to a BigInt representation. - * @param apInt - * @param colIntegerValue - */ - void transformAndSetIntegerValue(llvm::APInt &apInt, col::IntegerValue &colIntegerValue); +void transformAndSetType(llvm::Type &llvmType, col::Type &colType); - /** - * Transforms and set LLVM expression in the buffer which in practice are either constants (e.g. 0, 0.1, false etc..) - * or variables (i.e. LLVM Values) (e.g. %3, %variable) - * @param functionCursor - * @param llvmInstruction - * @param llvmOperand - * @param colExpr - */ - void transformAndSetExpr(vcllvm::FunctionCursor &functionCursor, llvm::Instruction &llvmInstruction, - llvm::Value &llvmOperand, col::Expr &colExpr); - /** - * Used by TransformAndSetExpr - * @param llvmInstruction - * @param llvmConstant - * @param colExpr - */ - void transformAndSetConstExpr(llvm::Instruction &llvmInstruction, llvm::Constant &llvmConstant, col::Expr &colExpr); +/** + * ATTEMPTS to convert any integer constant to a BigInt representation. + * @param apInt + * @param colIntegerValue + */ +void transformAndSetBigInt(llvm::APInt &apInt, col::BigInt &bigInt); + +/** + * Transforms and set LLVM expression in the buffer which in practice are either + * constants (e.g. 0, 0.1, false etc..) or variables (i.e. LLVM Values) (e.g. + * %3, %variable) + * @param functionCursor + * @param llvmInstruction + * @param llvmOperand + * @param colExpr + */ +void transformAndSetExpr(pallas::FunctionCursor &functionCursor, + llvm::Instruction &llvmInstruction, + llvm::Value &llvmOperand, col::Expr &colExpr); +/** + * Used by TransformAndSetExpr + * @param llvmInstruction + * @param llvmConstant + * @param colExpr + */ +void transformAndSetConstExpr(llvm::FunctionAnalysisManager &FAM, + col::Origin *origin, llvm::Constant &llvmConstant, + col::Expr &colExpr); - /** - * Used by TransformAndSetExpr - * @param functionCursor - * @param llvmInstruction - * @param llvmOperand - * @param colExpr - */ - void transformAndSetVarExpr(vcllvm::FunctionCursor &functionCursor, llvm::Instruction &llvmInstruction, - llvm::Value &llvmOperand, col::Expr &colExpr); - template - void transformBinExpr(llvm::Instruction &llvmInstruction, - ColBinExpr &colBinExpr, - vcllvm::FunctionCursor &funcCursor) { - // set origin of entire expression - colBinExpr.set_allocated_origin(generateBinExprOrigin(llvmInstruction)); - // transform left operand - col::Expr *lExpr = colBinExpr.mutable_left(); - llvm2Col::transformAndSetExpr( - funcCursor, llvmInstruction, *llvmInstruction.getOperand(0), - *lExpr); - // transform right operand - col::Expr *rExpr = colBinExpr.mutable_right(); - llvm2Col::transformAndSetExpr( - funcCursor, llvmInstruction, *llvmInstruction.getOperand(1), - *rExpr); - } +/** + * Used by TransformAndSetExpr + * @param functionCursor + * @param llvmInstruction + * @param llvmOperand + * @param colExpr + */ +void transformAndSetVarExpr(pallas::FunctionCursor &functionCursor, + col::Origin *origin, bool inPhiNode, + llvm::Value &llvmOperand, col::Expr &colExpr); +template +void transformBinExpr(llvm::Instruction &llvmInstruction, + ColBinExpr &colBinExpr, + pallas::FunctionCursor &funcCursor) { + // set origin of entire expression + colBinExpr.set_allocated_origin(generateBinExprOrigin(llvmInstruction)); + // transform left operand + col::Expr *lExpr = colBinExpr.mutable_left(); + llvm2col::transformAndSetExpr(funcCursor, llvmInstruction, + *llvmInstruction.getOperand(0), *lExpr); + // transform right operand + col::Expr *rExpr = colBinExpr.mutable_right(); + llvm2col::transformAndSetExpr(funcCursor, llvmInstruction, + *llvmInstruction.getOperand(1), *rExpr); +} - template - int64_t setColNodeId(IDNode &idNode) { - auto id = reinterpret_cast(idNode); - idNode->set_id(id); - return id; - } - /** - * Returns a string representation of any LLVM value as it would be displayed in human readable LLVM IR - * @param llvmValue - * @return - */ - std::string getValueName(llvm::Value &llvmValue); +template int64_t setColNodeId(IDNode &idNode) { + auto id = reinterpret_cast(idNode); + idNode->set_id(id); + return id; } -#endif //VCLLVM_TRANSFORM_H +/** + * Returns a string representation of any LLVM value as it would be displayed in + * human readable LLVM IR + * @param llvmValue + * @return + */ +std::string getValueName(llvm::Value &llvmValue); +} // namespace llvm2col +#endif // PALLAS_TRANSFORM_H diff --git a/src/llvm/include/Util/Constants.h b/src/llvm/include/Util/Constants.h index 09ef619a5b..2556aa5b53 100644 --- a/src/llvm/include/Util/Constants.h +++ b/src/llvm/include/Util/Constants.h @@ -1,15 +1,15 @@ -#ifndef VCLLVM_CONSTANTS_H -#define VCLLVM_CONSTANTS_H +#ifndef PALLAS_CONSTANTS_H +#define PALLAS_CONSTANTS_H #include /** * Useful string constants to use for searching out metadata nodes */ -namespace vcllvm::constants { - const std::string VC_PREFIX = "VC."; +namespace pallas::constants { +const std::string VC_PREFIX = "VC."; - const std::string METADATA_PURE_KEYWORD = VC_PREFIX + "pure"; - const std::string METADATA_CONTRACT_KEYWORD = VC_PREFIX + "contract"; - const std::string METADATA_GLOBAL_KEYWORD = VC_PREFIX + "global"; -} +const std::string METADATA_PURE_KEYWORD = VC_PREFIX + "pure"; +const std::string METADATA_CONTRACT_KEYWORD = VC_PREFIX + "contract"; +const std::string METADATA_GLOBAL_KEYWORD = VC_PREFIX + "global"; +} // namespace pallas::constants -#endif //VCLLVM_CONSTANTS_H +#endif // PALLAS_CONSTANTS_H diff --git a/src/llvm/include/Util/Exceptions.h b/src/llvm/include/Util/Exceptions.h index ecfcf14bb2..ec6df5217e 100644 --- a/src/llvm/include/Util/Exceptions.h +++ b/src/llvm/include/Util/Exceptions.h @@ -1,36 +1,75 @@ -#ifndef VCLLVM_EXCEPTIONS_H -#define VCLLVM_EXCEPTIONS_H +#ifndef PALLAS_EXCEPTIONS_H +#define PALLAS_EXCEPTIONS_H #include +#include /** - * Error handler for VCLLVM. Contains exception types and a static ErrorReporter class to which errors can be added from - * anywhere in the program. Before attempting to serialize the buffer, VCLLVM will check for errors. If there are errors, - * VCLLVM will present them and aboard the program. + * Error handler for pallas. Contains exception types and a static ErrorReporter + * class to which errors can be added from anywhere in the program. Before + * attempting to serialize the buffer, pallas will check for errors. If there + * are errors, pallas will present them and aboard the program. */ -namespace vcllvm { - struct UnsupportedTypeException : public std::exception { - [[nodiscard]] const char * what() const noexcept override; - }; +namespace pallas { +struct UnsupportedTypeException : public std::exception { + UnsupportedTypeException(const llvm::Type &); - class ErrorReporter { - private: - static u_int32_t errorCount; + [[nodiscard]] const char *what() const noexcept; - static void addError(const std::string &source, const std::string &message, const std::string &origin); + private: + std::string str; +}; - public: - static void addError(const std::string &source, const std::string &message, llvm::Module &llvmModule); +class ErrorReporter { + private: + static u_int32_t errorCount; + static u_int32_t warningCount; - static void addError(const std::string &source, const std::string &message, llvm::Function &llvmFunction); + public: + static void addError(const std::string &source, const std::string &message); - static void addError(const std::string &source, const std::string &message, llvm::BasicBlock &llvmBlock); + static void addError(const std::string &source, const std::string &message, + const std::string &origin); - static void addError(const std::string &source, const std::string &message, llvm::Instruction &llvmInstruction); + static void addError(const std::string &source, const std::string &message, + llvm::Module &llvmModule); - static bool hasErrors(); + static void addError(const std::string &source, const std::string &message, + llvm::Function &llvmFunction); - static u_int32_t getErrorCount(); - }; -} -#endif //VCLLVM_EXCEPTIONS_H + static void addError(const std::string &source, const std::string &message, + llvm::BasicBlock &llvmBlock); + + static void addError(const std::string &source, const std::string &message, + llvm::Instruction &llvmInstruction); + + static void addWarning(const std::string &source, + const std::string &message); + + static void addWarning(const std::string &source, + const std::string &message, + const std::string &origin); + + static void addWarning(const std::string &source, + const std::string &message, + llvm::Module &llvmModule); + + static void addWarning(const std::string &source, + const std::string &message, + llvm::Function &llvmFunction); + + static void addWarning(const std::string &source, + const std::string &message, + llvm::BasicBlock &llvmBlock); + + static void addWarning(const std::string &source, + const std::string &message, + llvm::Instruction &llvmInstruction); + + static bool hasErrors(); + + static u_int32_t getErrorCount(); + static u_int32_t getWarningCount(); +}; +} // namespace pallas +#endif // PALLAS_EXCEPTIONS_H diff --git a/src/llvm/lib/Origin/ContextDeriver.cpp b/src/llvm/lib/Origin/ContextDeriver.cpp new file mode 100644 index 0000000000..b838e0d4c9 --- /dev/null +++ b/src/llvm/lib/Origin/ContextDeriver.cpp @@ -0,0 +1,81 @@ +#include "Origin/ContextDeriver.h" + +#include +#include + +// module derivers +std::string llvm2col::deriveModuleContext(llvm::Module &llvmModule) { + std::string context; + llvm::raw_string_ostream(context) << llvmModule; + return context; +} + +// function derivers +std::string llvm2col::deriveFunctionContext(llvm::Function &llvmFunction) { + std::string context; + llvm::raw_string_ostream(context) << llvmFunction; + return context; +} + +// block derivers +std::string llvm2col::deriveLabelContext(llvm::BasicBlock &llvmBlock) { + if (llvmBlock.isEntryBlock()) { + return ""; + } + std::string fullContext; + llvm::raw_string_ostream(fullContext) << llvmBlock; + return fullContext.substr(0, fullContext.find(':') + 1); +} + +std::string llvm2col::deriveBlockContext(llvm::BasicBlock &llvmBlock) { + std::string context; + llvm::raw_string_ostream(context) << llvmBlock; + return context; +} + +// instruction derivers +std::string llvm2col::deriveSurroundingInstructionContext( + llvm::Instruction &llvmInstruction) { + std::string context; + if (llvmInstruction.getPrevNode() != nullptr) { + llvm::raw_string_ostream(context) + << *llvmInstruction.getPrevNode() << '\n'; + } + llvm::raw_string_ostream(context) << llvmInstruction; + if (llvmInstruction.getNextNode() != nullptr) { + llvm::raw_string_ostream(context) << '\n' + << *llvmInstruction.getNextNode(); + } + return context; +} + +std::string +llvm2col::deriveInstructionContext(llvm::Instruction &llvmInstruction) { + std::string context; + llvm::raw_string_ostream(context) << llvmInstruction; + return context; +} + +std::string llvm2col::deriveGlobalVariableContext( + llvm::GlobalVariable &llvmGlobalVariable) { + std::string context; + llvm::raw_string_ostream(context) << llvmGlobalVariable; + return context; +} + +std::string llvm2col::deriveInstructionLhs(llvm::Instruction &llvmInstruction) { + std::string fullContext = deriveInstructionContext(llvmInstruction); + return fullContext.substr(0, fullContext.find('=')); +} + +std::string llvm2col::deriveInstructionRhs(llvm::Instruction &llvmInstruction) { + std::string fullContext = deriveInstructionContext(llvmInstruction); + return fullContext.substr(fullContext.find('=') + 1); +} + +std::string llvm2col::deriveOperandContext(llvm::Value &llvmOperand) { + std::string context; + llvm::raw_string_ostream contextStream = llvm::raw_string_ostream(context); + llvmOperand.printAsOperand(contextStream, false); + return context; +} diff --git a/src/llvm/lib/Origin/OriginProvider.cpp b/src/llvm/lib/Origin/OriginProvider.cpp new file mode 100644 index 0000000000..21bea6dbfb --- /dev/null +++ b/src/llvm/lib/Origin/OriginProvider.cpp @@ -0,0 +1,343 @@ +#include "Origin/OriginProvider.h" + +#include + +#include "Origin/ContextDeriver.h" +#include "Origin/PreferredNameDeriver.h" +#include "Origin/ShortPositionDeriver.h" + +namespace col = vct::col::ast; + +col::Origin *llvm2col::generateProgramOrigin(llvm::Module &llvmModule) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name("program:" + llvmModule.getName().str()); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveModuleContext(llvmModule)); + context->set_inline_context(deriveModuleContext(llvmModule)); + context->set_short_position(deriveModuleShortPosition(llvmModule)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin *llvm2col::generateFuncDefOrigin(llvm::Function &llvmFunction) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name(llvmFunction.getName().str()); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveFunctionContext(llvmFunction)); + context->set_inline_context(deriveFunctionContext(llvmFunction)); + context->set_short_position(deriveFunctionShortPosition(llvmFunction)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin * +llvm2col::generateFunctionContractOrigin(llvm::Function &llvmFunction, + const std::string &contract) { + col::Origin *origin = new col::Origin(); + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(contract); + context->set_inline_context(contract); + context->set_short_position(deriveFunctionShortPosition(llvmFunction)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin *llvm2col::generateGlobalValOrigin(llvm::Module &llvmModule, + const std::string &globVal) { + col::Origin *origin = new col::Origin(); + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(globVal); + context->set_inline_context(globVal); + context->set_short_position(deriveModuleShortPosition(llvmModule)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin *llvm2col::generateArgumentOrigin(llvm::Argument &llvmArgument) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name( + deriveArgumentPreferredName(llvmArgument)); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveFunctionContext(*llvmArgument.getParent())); + context->set_inline_context( + deriveFunctionContext(*llvmArgument.getParent())); + context->set_short_position( + deriveFunctionShortPosition(*llvmArgument.getParent())); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin *llvm2col::generateBlockOrigin(llvm::BasicBlock &llvmBlock) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name("block"); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveBlockContext(llvmBlock)); + context->set_inline_context(deriveBlockContext(llvmBlock)); + context->set_short_position(deriveBlockShortPosition(llvmBlock)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin *llvm2col::generateLabelOrigin(llvm::BasicBlock &llvmBlock) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name("label"); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveLabelContext(llvmBlock)); + context->set_inline_context(deriveLabelContext(llvmBlock)); + context->set_short_position(deriveBlockShortPosition(llvmBlock)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin * +llvm2col::generateSingleStatementOrigin(llvm::Instruction &llvmInstruction) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name( + deriveOperandPreferredName(llvmInstruction)); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveSurroundingInstructionContext(llvmInstruction)); + context->set_inline_context(deriveInstructionContext(llvmInstruction)); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin * +llvm2col::generateAssignTargetOrigin(llvm::Instruction &llvmInstruction) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name("var"); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveInstructionContext(llvmInstruction)); + context->set_inline_context(deriveInstructionLhs(llvmInstruction)); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin * +llvm2col::generateBinExprOrigin(llvm::Instruction &llvmInstruction) { + col::Origin *origin = new col::Origin(); + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveSurroundingInstructionContext(llvmInstruction)); + context->set_inline_context(deriveInstructionContext(llvmInstruction)); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin * +llvm2col::generateFunctionCallOrigin(llvm::CallInst &callInstruction) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name( + callInstruction.getCalledFunction()->getName().str()); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveSurroundingInstructionContext(callInstruction)); + context->set_inline_context(deriveInstructionRhs(callInstruction)); + context->set_short_position( + deriveInstructionShortPosition(callInstruction)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin *llvm2col::generateOperandOrigin(llvm::Instruction &llvmInstruction, + llvm::Value &llvmOperand) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name(deriveOperandPreferredName(llvmOperand)); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveInstructionContext(llvmInstruction)); + context->set_inline_context(deriveOperandContext(llvmOperand)); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin *llvm2col::generateGlobalVariableOrigin( + llvm::Module &llvmModule, llvm::GlobalVariable &llvmGlobalVariable) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name(llvmGlobalVariable.getName().str()); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveGlobalVariableContext(llvmGlobalVariable)); + context->set_inline_context("unknown"); + context->set_short_position(deriveModuleShortPosition(llvmModule)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin *llvm2col::generateGlobalVariableInitializerOrigin( + llvm::Module &llvmModule, llvm::GlobalVariable &llvmGlobalVariable, + llvm::Value &llvmInitializer) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name( + deriveOperandPreferredName(llvmInitializer)); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveGlobalVariableContext(llvmGlobalVariable)); + context->set_inline_context(deriveOperandContext(llvmInitializer)); + context->set_short_position(deriveModuleShortPosition(llvmModule)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin * +llvm2col::generateVoidOperandOrigin(llvm::Instruction &llvmInstruction) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name("void"); + preferredNameContent->set_allocated_preferred_name(preferredName); + + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveInstructionContext(llvmInstruction)); + context->set_inline_context("void"); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + + return origin; +} + +col::Origin *llvm2col::generateTypeOrigin(llvm::Type &llvmType) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name(deriveTypePreferredName(llvmType)); + preferredNameContent->set_allocated_preferred_name(preferredName); + + return origin; +} + +col::Origin * +llvm2col::generateMemoryOrderingOrigin(llvm::AtomicOrdering &llvmOrdering) { + col::Origin *origin = new col::Origin(); + col::OriginContent *preferredNameContent = origin->add_content(); + col::PreferredName *preferredName = new col::PreferredName(); + preferredName->add_preferred_name( + deriveMemoryOrderingPreferredName(llvmOrdering)); + preferredNameContent->set_allocated_preferred_name(preferredName); + + return origin; +} + +std::string llvm2col::extractShortPosition(const col::Origin &origin) { + for (const col::OriginContent &content : origin.content()) { + if (content.has_context()) { + return content.context().short_position(); + } + } + return "unknown"; +} + +col::Origin *llvm2col::deepenOperandOrigin(const col::Origin &origin, + llvm::Value &llvmOperand) { + col::Origin *newOrigin = new col::Origin(origin); + + bool foundName = false; + bool foundContext = false; + for (col::OriginContent &content : *newOrigin->mutable_content()) { + if (content.has_preferred_name()) { + col::PreferredName *preferredName = + content.mutable_preferred_name(); + preferredName->clear_preferred_name(); + preferredName->add_preferred_name( + deriveOperandPreferredName(llvmOperand)); + foundName = true; + } else if (content.has_context()) { + content.mutable_context()->set_inline_context( + deriveOperandContext(llvmOperand)); + foundContext = true; + } + } + + if (!foundName) { + col::PreferredName *preferredName = + newOrigin->add_content()->mutable_preferred_name(); + preferredName->clear_preferred_name(); + preferredName->add_preferred_name( + deriveOperandPreferredName(llvmOperand)); + } + + if (!foundContext) { + col::Context *context = newOrigin->add_content()->mutable_context(); + context->set_context("unknown"); + context->set_inline_context(deriveOperandContext(llvmOperand)); + context->set_short_position("unknown"); + } + + return newOrigin; +} diff --git a/src/llvm/lib/Origin/PreferredNameDeriver.cpp b/src/llvm/lib/Origin/PreferredNameDeriver.cpp new file mode 100644 index 0000000000..bd0301f8f7 --- /dev/null +++ b/src/llvm/lib/Origin/PreferredNameDeriver.cpp @@ -0,0 +1,102 @@ +#include "Origin/PreferredNameDeriver.h" + +#include +#include +#include + +std::string llvm2col::deriveOperandPreferredName(llvm::Value &llvmOperand) { + if (!llvmOperand.getName().empty()) + return std::string(llvmOperand.getName()); + std::string preferredName; + llvm::raw_string_ostream preferredNameStream = + llvm::raw_string_ostream(preferredName); + preferredNameStream << (llvm::isa(llvmOperand) ? "const_" + + : "var_"); + + llvmOperand.printAsOperand(preferredNameStream, false); + return preferredName; +} + +std::string llvm2col::deriveTypePreferredName(llvm::Type &llvmType) { + std::string prefix = "t_"; + switch (llvmType.getTypeID()) { + case llvm::Type::HalfTyID: + return prefix + "half"; + case llvm::Type::BFloatTyID: + return prefix + "bfloat"; + case llvm::Type::FloatTyID: + return prefix + "float"; + case llvm::Type::DoubleTyID: + return prefix + "double"; + case llvm::Type::X86_FP80TyID: + return prefix + "x86fp80"; + case llvm::Type::FP128TyID: + return prefix + "fp128"; + case llvm::Type::PPC_FP128TyID: + return prefix + "ppcfp128"; + case llvm::Type::VoidTyID: + return prefix + "void"; + case llvm::Type::LabelTyID: + return prefix + "label"; + case llvm::Type::MetadataTyID: + return prefix + "metadata"; + case llvm::Type::X86_MMXTyID: + return prefix + "x86mmx"; + case llvm::Type::X86_AMXTyID: + return prefix + "x86amx"; + case llvm::Type::TokenTyID: + return prefix + "token"; + case llvm::Type::IntegerTyID: + return prefix + + (llvmType.getIntegerBitWidth() == 1 ? "boolean" : "integer"); + case llvm::Type::FunctionTyID: + return prefix + "function"; + case llvm::Type::PointerTyID: + return prefix + "ptr"; + case llvm::Type::StructTyID: + return prefix + "struct"; + case llvm::Type::ArrayTyID: + return prefix + "array"; + case llvm::Type::FixedVectorTyID: + return prefix + "fixedvector"; + case llvm::Type::ScalableVectorTyID: + return prefix + "scalevector"; + case llvm::Type::TypedPointerTyID: + return prefix + "typedptr"; + case llvm::Type::TargetExtTyID: + return prefix + "targetext"; + } + return "UNKNOWN"; +} + +std::string llvm2col::deriveMemoryOrderingPreferredName( + llvm::AtomicOrdering &llvmOrdering) { + switch (llvmOrdering) { + case llvm::AtomicOrdering::NotAtomic: + return "NotAtomic"; + case llvm::AtomicOrdering::Unordered: + return "Unordered"; + case llvm::AtomicOrdering::Monotonic: + return "Monotonic"; + case llvm::AtomicOrdering::Acquire: + return "Acquire"; + case llvm::AtomicOrdering::Release: + return "Release"; + case llvm::AtomicOrdering::AcquireRelease: + return "AcquireRelease"; + case llvm::AtomicOrdering::SequentiallyConsistent: + return "SequentiallyConsistent"; + } + return "UNKNOWN"; +} + +std::string +llvm2col::deriveArgumentPreferredName(llvm::Argument &llvmArgument) { + std::string preferredName; + llvm::raw_string_ostream preferredNameStream = + llvm::raw_string_ostream(preferredName); + preferredNameStream << "arg_"; + llvmArgument.printAsOperand(preferredNameStream, false); + return preferredName; +} diff --git a/src/llvm/lib/Origin/ShortPositionDeriver.cpp b/src/llvm/lib/Origin/ShortPositionDeriver.cpp new file mode 100644 index 0000000000..ff5faf775e --- /dev/null +++ b/src/llvm/lib/Origin/ShortPositionDeriver.cpp @@ -0,0 +1,49 @@ +#include "Origin/ShortPositionDeriver.h" +#include "Origin/ContextDeriver.h" + +const std::string POSITION_POINTER = "\n\t -> "; + +std::string llvm2col::deriveModuleShortPosition(llvm::Module &llvmModule) { + return "file " + llvmModule.getSourceFileName(); +} + +std::string +llvm2col::deriveFunctionShortPosition(llvm::Function &llvmFunction) { + std::string functionPosition = + deriveModuleShortPosition(*llvmFunction.getParent()); + llvm::raw_string_ostream functionPosStream = + llvm::raw_string_ostream(functionPosition); + functionPosStream << POSITION_POINTER << "function "; + llvmFunction.printAsOperand(functionPosStream, false); + return functionPosition; +} + +std::string llvm2col::deriveBlockShortPosition(llvm::BasicBlock &llvmBlock) { + std::string blockPosition = + deriveFunctionShortPosition(*llvmBlock.getParent()); + llvm::raw_string_ostream blockPosStream = + llvm::raw_string_ostream(blockPosition); + blockPosStream << POSITION_POINTER << "block "; + llvmBlock.printAsOperand(blockPosStream, false); + blockPosStream << (llvmBlock.isEntryBlock() ? " (entryblock)" : ""); + return blockPosition; +} + +std::string +llvm2col::deriveInstructionShortPosition(llvm::Instruction &llvmInstruction) { + std::string instructionPosition = + deriveBlockShortPosition(*llvmInstruction.getParent()); + llvm::raw_string_ostream instructionPosStream = + llvm::raw_string_ostream(instructionPosition); + int pos = 0; + llvm::BasicBlock *bb = llvmInstruction.getParent(); + for (auto &I : *bb) { + pos++; + if (&I == &llvmInstruction) { + break; + } + } + instructionPosStream << POSITION_POINTER << "instruction #" << pos << " (" + << deriveInstructionContext(llvmInstruction) << ')'; + return instructionPosition; +} diff --git a/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp b/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp new file mode 100644 index 0000000000..215d9ae74e --- /dev/null +++ b/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp @@ -0,0 +1,185 @@ +#include "Passes/Function/FunctionBodyTransformer.h" + +#include "Origin/OriginProvider.h" +#include "Passes/Function/FunctionContractDeclarer.h" +#include "Passes/Function/FunctionDeclarer.h" +#include "Transform/BlockTransform.h" +#include "Transform/Transform.h" +#include "Util/Exceptions.h" +#include + +namespace pallas { +const std::string SOURCE_LOC = "Passes::Function::FunctionBodyTransformer"; + +FunctionCursor::FunctionCursor(col::Scope &functionScope, + col::Block &functionBody, + llvm::Function &llvmFunction, + llvm::FunctionAnalysisManager &FAM) + : functionScope(functionScope), functionBody(functionBody), + llvmFunction(llvmFunction), FAM(FAM) {} + +const col::Scope &FunctionCursor::getFunctionScope() { return functionScope; } + +void FunctionCursor::addVariableMapEntry(Value &llvmValue, + col::Variable &colVar) { + variableMap.insert({&llvmValue, &colVar}); + // add reference to reference lut of function contract + col::Tuple2_String_Ref_VctColAstVariable *ref = + FAM.getResult(llvmFunction) + .getAssociatedColFuncContract() + .add_variable_refs(); + ref->set_v1(llvm2col::getValueName(llvmValue)); + ref->mutable_v2()->set_id(colVar.id()); +} + +col::Variable &FunctionCursor::getVariableMapEntry(Value &llvmValue, + bool inPhiNode) { + if (auto variablePair = variableMap.find(&llvmValue); + variablePair != variableMap.end()) { + return *variablePair->second; + } else { + if (!inPhiNode) { + std::string str; + llvm::raw_string_ostream output(str); + output << "Use of undeclared variable: '" << llvmValue << "'"; + ErrorReporter::addError(SOURCE_LOC, str); + } + + col::Variable *colVar = new col::Variable(); + addVariableMapEntry(llvmValue, *colVar); + return *colVar; + } +} + +bool FunctionCursor::isVisited(BasicBlock &llvmBlock) { + return llvmBlock2LabeledColBlock.contains(&llvmBlock); +} + +void FunctionCursor::complete(col::Block &colBlock) { + completedColBlocks.insert(&colBlock); +} +bool FunctionCursor::isComplete(col::Block &colBlock) { + return completedColBlocks.contains(&colBlock); +} + +LabeledColBlock & +FunctionCursor::getOrSetLLVMBlock2LabeledColBlockEntry(BasicBlock &llvmBlock) { + if (!llvmBlock2LabeledColBlock.contains(&llvmBlock)) { + // create label in buffer + col::Label *label = functionBody.add_statements()->mutable_label(); + // set label origin + label->set_allocated_origin(llvm2col::generateLabelOrigin(llvmBlock)); + // create label declaration in buffer + col::LabelDecl *labelDecl = label->mutable_decl(); + // set label decl origin + labelDecl->set_allocated_origin( + llvm2col::generateLabelOrigin(llvmBlock)); + // set label decl id + llvm2col::setColNodeId(labelDecl); + // create block inside label statement + col::Block *block = label->mutable_stat()->mutable_block(); + // set block origin + block->set_allocated_origin(llvm2col::generateBlockOrigin(llvmBlock)); + // add labeled block to the block2block lut + LabeledColBlock labeledColBlock = {*label, *block}; + llvmBlock2LabeledColBlock.insert({&llvmBlock, labeledColBlock}); + } + return llvmBlock2LabeledColBlock.at(&llvmBlock); +} + +LoopInfo &FunctionCursor::getLoopInfo() { + return FAM.getResult(llvmFunction); +} + +LoopInfo &FunctionCursor::getLoopInfo(Function &otherLLVMFunction) { + return FAM.getResult(otherLLVMFunction); +} + +FDResult &FunctionCursor::getFDResult() { + return FAM.getResult(llvmFunction); +} + +FDResult &FunctionCursor::getFDResult(Function &otherLLVMFunction) { + return FAM.getResult(otherLLVMFunction); +} + +col::Variable &FunctionCursor::declareVariable(Instruction &llvmInstruction, + Type *llvmPointerType) { + // create declaration in buffer + col::Variable *varDecl = functionScope.add_locals(); + // set type of declaration + try { + if (llvmPointerType == nullptr) { + llvm2col::transformAndSetType(*llvmInstruction.getType(), + *varDecl->mutable_t()); + } else { + llvm2col::transformAndSetPointerType(*llvmPointerType, + *varDecl->mutable_t()); + } + } catch (pallas::UnsupportedTypeException &e) { + std::stringstream errorStream; + errorStream << e.what() << " in variable declaration."; + ErrorReporter::addError(SOURCE_LOC, errorStream.str(), llvmInstruction); + } + // set id + llvm2col::setColNodeId(varDecl); + // set origin + varDecl->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(llvmInstruction)); + // add to the variable lut + this->addVariableMapEntry(llvmInstruction, *varDecl); + return *varDecl; +} + +col::Assign &FunctionCursor::createAssignmentAndDeclaration( + Instruction &llvmInstruction, col::Block &colBlock, Type *llvmPointerType) { + col::Variable &varDecl = declareVariable(llvmInstruction, llvmPointerType); + return createAssignment(llvmInstruction, colBlock, varDecl); +} + +col::Assign &FunctionCursor::createAssignment(Instruction &llvmInstruction, + col::Block &colBlock, + col::Variable &varDecl) { + col::Assign *assignment = colBlock.add_statements()->mutable_assign(); + assignment->set_allocated_blame(new col::Blame()); + assignment->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(llvmInstruction)); + // create local target in buffer and set origin + col::Local *colLocal = assignment->mutable_target()->mutable_local(); + colLocal->set_allocated_origin( + llvm2col::generateAssignTargetOrigin(llvmInstruction)); + // set target to refer to var decl + colLocal->mutable_ref()->set_id(varDecl.id()); + if (isComplete(colBlock)) { + // if the colBlock is completed, the assignment will be inserted after + // the goto/branch statement this can occur due to e.g. phi nodes back + // tracking assignments in their origin blocks. therefore we need to + // swap the last two elements of the block (i.e. the goto statement and + // the newest assignment) + int lastIndex = colBlock.statements_size() - 1; + colBlock.mutable_statements()->SwapElements(lastIndex, lastIndex - 1); + } + return *assignment; +} + +llvm::FunctionAnalysisManager &FunctionCursor::getFunctionAnalysisManager() { + return FAM; +} + +PreservedAnalyses +FunctionBodyTransformerPass::run(Function &F, FunctionAnalysisManager &FAM) { + ColScopedFuncBody scopedFuncBody = + FAM.getResult(F).getAssociatedScopedColFuncBody(); + FunctionCursor funcCursor = + FunctionCursor(*scopedFuncBody.scope, *scopedFuncBody.block, F, FAM); + // add function arguments to the variableMap + for (auto &A : F.args()) { + funcCursor.addVariableMapEntry( + A, FAM.getResult(F).getFuncArgMapEntry(A)); + } + // start recursive block code gen with basic block + llvm::BasicBlock &entryBlock = F.getEntryBlock(); + llvm2col::transformLLVMBlock(entryBlock, funcCursor); + return PreservedAnalyses::all(); +} +} // namespace pallas diff --git a/src/llvm/lib/Passes/Function/FunctionContractDeclarer.cpp b/src/llvm/lib/Passes/Function/FunctionContractDeclarer.cpp new file mode 100644 index 0000000000..f97ca35cfe --- /dev/null +++ b/src/llvm/lib/Passes/Function/FunctionContractDeclarer.cpp @@ -0,0 +1,87 @@ +#include "Passes/Function/FunctionContractDeclarer.h" + +#include "Origin/OriginProvider.h" +#include "Passes/Function/FunctionDeclarer.h" +#include "Util/Constants.h" +#include "Util/Exceptions.h" + +namespace pallas { +const std::string SOURCE_LOC = "Passes::Function::FunctionContractDeclarer"; + +using namespace llvm; + +/* + * Function Contract Declarer Result + */ + +FDCResult::FDCResult(vct::col::ast::LlvmFunctionContract &colFuncContract) + : associatedColFuncContract(colFuncContract) {} + +col::LlvmFunctionContract &FDCResult::getAssociatedColFuncContract() { + return associatedColFuncContract; +} + +/* + * Function Contract Declarer (Analysis) + */ + +AnalysisKey FunctionContractDeclarer::Key; + +FunctionContractDeclarer::Result +FunctionContractDeclarer::run(Function &F, FunctionAnalysisManager &FAM) { + // fetch relevant function from the Function Declarer + FDResult fdResult = FAM.getResult(F); + col::LlvmFunctionDefinition &colFunction = + fdResult.getAssociatedColFuncDef(); + // set a contract in the buffer as well as make and return a result object + return FDCResult(*colFunction.mutable_contract()); +} + +/* + * Function Contract Declarer Pass + */ +PreservedAnalyses +FunctionContractDeclarerPass::run(Function &F, FunctionAnalysisManager &FAM) { + // get col contract + FDCResult result = FAM.getResult(F); + col::LlvmFunctionContract &colContract = + result.getAssociatedColFuncContract(); + colContract.set_allocated_blame(new col::Blame()); + colContract.set_name(F.getName()); + // check if contract keyword is present + if (!F.hasMetadata(pallas::constants::METADATA_CONTRACT_KEYWORD)) { + // set contract to a tautology + colContract.set_value("requires true;"); + colContract.set_allocated_origin(new col::Origin()); + return PreservedAnalyses::all(); + } + // concatenate all contract lines with new lines + MDNode *contractMDNode = + F.getMetadata(pallas::constants::METADATA_CONTRACT_KEYWORD); + std::stringstream contractStream; + for (u_int32_t i = 0; i < contractMDNode->getNumOperands(); i++) { + auto contractLine = dyn_cast(contractMDNode->getOperand(i)); + if (contractLine == nullptr) { + std::stringstream errorStream; + errorStream << "Unable to cast contract metadata node #" << i + 1 + << "to string type"; + pallas::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F); + break; + } + contractStream << contractLine->getString().str() << '\n'; + } + colContract.set_value(contractStream.str()); + colContract.set_allocated_origin( + llvm2col::generateFunctionContractOrigin(F, contractStream.str())); + // add all callable functions to the contracts invokables + for (auto &moduleF : F.getParent()->functions()) { + std::string fName = '@' + moduleF.getName().str(); + int64_t fId = FAM.getResult(moduleF).getFunctionId(); + col::Tuple2_String_Ref_VctColAstLlvmCallable *invokeRef = + colContract.add_invokable_refs(); + invokeRef->set_v1(fName); + invokeRef->mutable_v2()->set_id(fId); + } + return PreservedAnalyses::all(); +} +} // namespace pallas diff --git a/src/llvm/lib/Passes/Function/FunctionDeclarer.cpp b/src/llvm/lib/Passes/Function/FunctionDeclarer.cpp new file mode 100644 index 0000000000..ce692455a2 --- /dev/null +++ b/src/llvm/lib/Passes/Function/FunctionDeclarer.cpp @@ -0,0 +1,137 @@ +#include "Passes/Function/FunctionDeclarer.h" + +#include "Origin/OriginProvider.h" +#include "Passes/Module/RootContainer.h" +#include "Transform/Transform.h" +#include "Util/Exceptions.h" + +namespace pallas { +const std::string SOURCE_LOC = "Passes::Function::FunctionDeclarer"; +using namespace llvm; + +/** + * Checks function definition for unsupported features that might change + * semantics and adds warning if this is the case. + * @param llvmFunction: the function to be checked + */ +void checkFunctionSupport(llvm::Function &llvmFunction) { + // TODO add syntax support checks that change the semantics of the program + // to function definitions + // TODO see: https://releases.llvm.org/15.0.0/docs/LangRef.html#functions +} + +/* + * Function Declarer Result + */ + +FDResult::FDResult(col::LlvmFunctionDefinition &colFuncDef, + ColScopedFuncBody associatedScopedColFuncBody, + int64_t functionId) + : associatedColFuncDef(colFuncDef), + associatedScopedColFuncBody(associatedScopedColFuncBody), + functionId(functionId) {} + +col::LlvmFunctionDefinition &FDResult::getAssociatedColFuncDef() { + return associatedColFuncDef; +} + +ColScopedFuncBody FDResult::getAssociatedScopedColFuncBody() { + return associatedScopedColFuncBody; +} + +void FDResult::addFuncArgMapEntry(Argument &llvmArg, col::Variable &colArg) { + funcArgMap.insert({&llvmArg, &colArg}); +} + +col::Variable &FDResult::getFuncArgMapEntry(Argument &arg) { + return *funcArgMap.at(&arg); +} + +int64_t &FDResult::getFunctionId() { return functionId; } + +/* + * Function Declarer (Analysis) + */ +AnalysisKey FunctionDeclarer::Key; + +FDResult FunctionDeclarer::run(Function &F, FunctionAnalysisManager &FAM) { + auto MAM = FAM.getResult(F); + auto pProgram = MAM.getCachedResult(*F.getParent())->program; + checkFunctionSupport(F); + // create llvmFuncDef declaration in buffer + col::GlobalDeclaration *llvmFuncDefDecl = pProgram->add_declarations(); + // generate id + col::LlvmFunctionDefinition *llvmFuncDef = + llvmFuncDefDecl->mutable_llvm_function_definition(); + int64_t functionId = llvm2col::setColNodeId(llvmFuncDef); + // add body block + scope + origin + llvmFuncDef->set_allocated_blame(new col::Blame()); + // set origin + llvmFuncDef->set_allocated_origin(llvm2col::generateFuncDefOrigin(F)); + ColScopedFuncBody funcScopedBody{}; + if (!F.isDeclaration()) { + funcScopedBody.scope = + llvmFuncDef->mutable_function_body()->mutable_scope(); + funcScopedBody.scope->set_allocated_origin( + llvm2col::generateFuncDefOrigin(F)); + funcScopedBody.block = + funcScopedBody.scope->mutable_body()->mutable_block(); + funcScopedBody.block->set_allocated_origin( + llvm2col::generateFuncDefOrigin(F)); + } + FDResult result = FDResult(*llvmFuncDef, funcScopedBody, functionId); + // set args (if present) + for (llvm::Argument &llvmArg : F.args()) { + // set in buffer + col::Variable *colArg = llvmFuncDef->add_args(); + // set origin + colArg->set_allocated_origin(llvm2col::generateArgumentOrigin(llvmArg)); + llvm2col::setColNodeId(colArg); + try { + llvm2col::transformAndSetType(*llvmArg.getType(), + *colArg->mutable_t()); + } catch (pallas::UnsupportedTypeException &e) { + std::stringstream errorStream; + errorStream << e.what() << " in argument #" << llvmArg.getArgNo(); + pallas::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F); + } + // add args mapping to result + result.addFuncArgMapEntry(llvmArg, *colArg); + } + llvmFuncDef->set_allocated_blame(new col::Blame()); + // complete the function declaration in proto buffer + // set return type in protobuf of function + try { + llvm2col::transformAndSetType(*F.getReturnType(), + *llvmFuncDef->mutable_return_type()); + } catch (pallas::UnsupportedTypeException &e) { + std::stringstream errorStream; + errorStream << e.what() << " in return signature"; + pallas::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F); + } + + if (F.isDeclaration()) { + // Defined outside of this module so we don't know if it's pure or what + // its contract is + col::LlvmFunctionContract *colContract = + llvmFuncDef->mutable_contract(); + colContract->set_allocated_blame(new col::Blame()); + colContract->set_value("requires true;"); + colContract->set_name(F.getName()); + colContract->set_allocated_origin(new col::Origin()); + + llvmFuncDef->set_pure(false); + } + return result; +} + +/* + * Function Declarer Pass + */ +PreservedAnalyses FunctionDeclarerPass::run(Function &F, + FunctionAnalysisManager &FAM) { + FDResult result = FAM.getResult(F); + // Just makes sure we analyse every function + return PreservedAnalyses::all(); +} +} // namespace pallas diff --git a/src/llvm/lib/Passes/Function/PureAssigner.cpp b/src/llvm/lib/Passes/Function/PureAssigner.cpp new file mode 100644 index 0000000000..46b0415de7 --- /dev/null +++ b/src/llvm/lib/Passes/Function/PureAssigner.cpp @@ -0,0 +1,59 @@ +#include "Passes/Function/PureAssigner.h" + +#include "Passes/Function/FunctionDeclarer.h" +#include "Util/Constants.h" +#include "Util/Exceptions.h" + +namespace pallas { +const std::string SOURCE_LOC = "Passes::Function::PureAssigner"; + +using namespace llvm; + +/** + * Helper function to generate errors generated by this Pass + * @param F + * @param explanation + */ +static void reportError(Function &F, const std::string &explanation) { + std::stringstream errorStream; + errorStream << "Malformed Metadata node of type \"" + << pallas::constants::METADATA_PURE_KEYWORD + << "\":" << explanation; + pallas::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F); +} + +PreservedAnalyses PureAssignerPass::run(Function &F, + FunctionAnalysisManager &FAM) { + std::ostringstream errorStream; + FDResult result = FAM.getResult(F); + col::LlvmFunctionDefinition &colFunction = result.getAssociatedColFuncDef(); + // check if pure keyword is present, else assume unpure function + if (!F.hasMetadata(pallas::constants::METADATA_PURE_KEYWORD)) { + colFunction.set_pure(false); + return PreservedAnalyses::all(); + } + // check if the 'pure' metadata has only 1 operand, else exit with error + MDNode *pureMDNode = + F.getMetadata(pallas::constants::METADATA_PURE_KEYWORD); + if (pureMDNode->getNumOperands() != 1) { + errorStream << "Expected 1 argument but got " + << pureMDNode->getNumOperands(); + reportError(F, errorStream.str()); + return PreservedAnalyses::all(); + } + // check if the only operand is of type 'i1', else exit with error + auto *pureMDValue = cast(pureMDNode->getOperand(0)); + if (!pureMDValue->getType()->isIntegerTy(1)) { + errorStream << "MD node type must be of type \"i1\""; + reportError(F, errorStream.str()); + return PreservedAnalyses::all(); + } + // attempt down cast to ConstantInt (which shouldn't fail given previous + // checks) + bool purity = + cast(pureMDValue)->getValue()->isOneValue(); + colFunction.set_pure(purity); + return PreservedAnalyses::all(); +} + +} // namespace pallas diff --git a/src/llvm/lib/Passes/Module/GlobalVariableDeclarer.cpp b/src/llvm/lib/Passes/Module/GlobalVariableDeclarer.cpp new file mode 100644 index 0000000000..e79f351909 --- /dev/null +++ b/src/llvm/lib/Passes/Module/GlobalVariableDeclarer.cpp @@ -0,0 +1,49 @@ +#include "Passes/Module/GlobalVariableDeclarer.h" +#include "Passes/Module/RootContainer.h" +#include "Transform/Transform.h" + +namespace pallas { +const std::string SOURCE_LOC = "Passes::Module::GlobalVariableDeclarer"; + +using namespace llvm; + +PreservedAnalyses GlobalVariableDeclarerPass::run(Module &M, + ModuleAnalysisManager &MAM) { + auto pProgram = MAM.getResult(M).program; + + for (auto &global : M.globals()) { + col::GlobalDeclaration *globDecl = pProgram->add_declarations(); + col::LlvmGlobalVariable *colGlobal = + globDecl->mutable_llvm_global_variable(); + + llvm2col::transformAndSetType(*global.getType(), + *colGlobal->mutable_variable_type()); + if (global.hasInitializer()) { + llvm2col::transformAndSetConstExpr( + MAM.getResult(M) + .getManager(), + llvm2col::generateGlobalVariableInitializerOrigin( + M, global, *global.getInitializer()), + *global.getInitializer(), *colGlobal->mutable_value()); + + llvm2col::transformAndSetType(*global.getInitializer()->getType(), + *colGlobal->mutable_variable_type()); + } else { + // We don't know more about the type because we don't have an + // initializer + // TODO: This breaks the assumption that the type of the global + // declaration type is the inner type of the pointer. We should + // instead set the type to be TAny maybe? + llvm2col::transformAndSetType(*global.getType(), + *colGlobal->mutable_variable_type()); + } + colGlobal->set_constant(global.isConstant()); + colGlobal->set_allocated_origin( + llvm2col::generateGlobalVariableOrigin(M, global)); + colGlobal->set_id(reinterpret_cast(&global)); + } + + return PreservedAnalyses::all(); +} + +} // namespace pallas diff --git a/src/llvm/lib/Passes/Module/ModuleSpecCollector.cpp b/src/llvm/lib/Passes/Module/ModuleSpecCollector.cpp new file mode 100644 index 0000000000..c5e1ced8b5 --- /dev/null +++ b/src/llvm/lib/Passes/Module/ModuleSpecCollector.cpp @@ -0,0 +1,46 @@ +#include "Passes/Module/ModuleSpecCollector.h" +#include "Origin/OriginProvider.h" +#include "Passes/Module/RootContainer.h" +#include "Transform/Transform.h" +#include "Util/Constants.h" +#include "Util/Exceptions.h" + +namespace pallas { +const std::string SOURCE_LOC = "Passes::Module::ModuleSpecCollector"; + +using namespace llvm; + +PreservedAnalyses ModuleSpecCollectorPass::run(Module &M, + ModuleAnalysisManager &MAM) { + auto pProgram = MAM.getResult(M).program; + NamedMDNode *globalMDNode = + M.getNamedMetadata(pallas::constants::METADATA_GLOBAL_KEYWORD); + if (globalMDNode == nullptr) { + return PreservedAnalyses::all(); + } + for (u_int32_t i = 0; i < globalMDNode->getNumOperands(); i++) { + for (u_int32_t j = 0; j < globalMDNode->getOperand(i)->getNumOperands(); + j++) { + auto globVal = + dyn_cast(globalMDNode->getOperand(i)->getOperand(j)); + if (globVal == nullptr) { + std::stringstream errorStream; + errorStream << "Unable to cast global metadata node #" << i + 1 + << "to string type"; + pallas::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), + M); + break; + } + col::GlobalDeclaration *globDecl = pProgram->add_declarations(); + col::LlvmGlobalSpecification *colGlobal = + globDecl->mutable_llvm_global_specification(); + llvm2col::setColNodeId(colGlobal); + colGlobal->set_value(globVal->getString().str()); + colGlobal->set_allocated_origin(llvm2col::generateGlobalValOrigin( + M, globVal->getString().str())); + } + } + return PreservedAnalyses::all(); +} + +} // namespace pallas diff --git a/src/llvm/lib/Passes/Module/ProtobufPrinter.cpp b/src/llvm/lib/Passes/Module/ProtobufPrinter.cpp new file mode 100644 index 0000000000..c2f083a3ad --- /dev/null +++ b/src/llvm/lib/Passes/Module/ProtobufPrinter.cpp @@ -0,0 +1,30 @@ +#include "Passes/Module/ProtobufPrinter.h" +#include "Passes/Module/RootContainer.h" +#include "Util/Exceptions.h" + +namespace pallas { +const std::string SOURCE_LOC = "Passes::Module::ProtobufPrinter"; + +using namespace llvm; + +PreservedAnalyses ProtobufPrinter::run(Module &M, ModuleAnalysisManager &MAM) { + if (ErrorReporter::hasErrors()) { + llvm::errs() << "[ERROR] [pallas] Conversion failed with " + << ErrorReporter::getWarningCount() << " warnings and " + << ErrorReporter::getErrorCount() << " errors\n"; + } else { + llvm::errs() << "[INFO] [pallas] Conversion succeeded with " + << ErrorReporter::getWarningCount() << "warnings\n"; + } + auto pProgram = MAM.getResult(M).program; + if (pProgram->IsInitialized()) { + std::cout << pProgram->SerializeAsString(); + } else { + llvm::errs() << "[ERROR] [pallas] Internal error, invalid protobuf " + "construction\n"; + pProgram->CheckInitialized(); + } + return PreservedAnalyses::all(); +} + +} // namespace pallas diff --git a/src/llvm/lib/Passes/Module/RootContainer.cpp b/src/llvm/lib/Passes/Module/RootContainer.cpp new file mode 100644 index 0000000000..774fe9b669 --- /dev/null +++ b/src/llvm/lib/Passes/Module/RootContainer.cpp @@ -0,0 +1,26 @@ +#include "Passes/Module/RootContainer.h" + +#include "Origin/OriginProvider.h" +#include "Transform/Transform.h" +#include "Util/Exceptions.h" + +namespace pallas { +const std::string SOURCE_LOC = "Passes::Module::RootContainer"; +using namespace llvm; + +bool ProgramWrapper::invalidate(Module &M, const PreservedAnalyses &PA, + ModuleAnalysisManager::Invalidator &) { + return !PA.getChecker().preservedWhenStateless(); +} + +AnalysisKey RootContainer::Key; + +ProgramWrapper RootContainer::run(Module &M, ModuleAnalysisManager &MAM) { + auto pProgram = std::make_shared(); + // set program origin + pProgram->set_allocated_origin(llvm2col::generateProgramOrigin(M)); + pProgram->set_allocated_blame(new col::Blame()); + + return ProgramWrapper{pProgram}; +} +} // namespace pallas diff --git a/src/llvm/lib/Plugin.cpp b/src/llvm/lib/Plugin.cpp new file mode 100644 index 0000000000..2495dd6b54 --- /dev/null +++ b/src/llvm/lib/Plugin.cpp @@ -0,0 +1,69 @@ +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" + +#include "Passes/Function/FunctionBodyTransformer.h" +#include "Passes/Function/FunctionContractDeclarer.h" +#include "Passes/Function/FunctionDeclarer.h" +#include "Passes/Function/PureAssigner.h" +#include "Passes/Module/GlobalVariableDeclarer.h" +#include "Passes/Module/ModuleSpecCollector.h" +#include "Passes/Module/ProtobufPrinter.h" +#include "Passes/Module/RootContainer.h" + +using namespace llvm; + +llvm::PassPluginLibraryInfo getPallasPluginInfo() { + return {LLVM_PLUGIN_API_VERSION, "Pallas", LLVM_VERSION_STRING, + [](PassBuilder &PB) { + PB.registerAnalysisRegistrationCallback( + [](llvm::ModuleAnalysisManager &MAM) { + MAM.registerPass( + [&] { return pallas::RootContainer(); }); + }); + PB.registerAnalysisRegistrationCallback( + [](llvm::FunctionAnalysisManager &FAM) { + FAM.registerPass( + [&] { return pallas::FunctionDeclarer(); }); + FAM.registerPass( + [&] { return pallas::FunctionContractDeclarer(); }); + }); + PB.registerPipelineParsingCallback( + [](StringRef Name, llvm::ModulePassManager &MPM, + ArrayRef) { + if (Name == "pallas-collect-module-spec") { + MPM.addPass(pallas::ModuleSpecCollectorPass()); + return true; + } else if (Name == "pallas-declare-variables") { + MPM.addPass(pallas::GlobalVariableDeclarerPass()); + return true; + } else if (Name == "pallas-print-protobuf") { + MPM.addPass(pallas::ProtobufPrinter()); + return true; + } + return false; + }); + PB.registerPipelineParsingCallback( + [](StringRef Name, llvm::FunctionPassManager &FPM, + ArrayRef) { + if (Name == "pallas-declare-function") { + FPM.addPass(pallas::FunctionDeclarerPass()); + return true; + } else if (Name == "pallas-assign-pure") { + FPM.addPass(pallas::PureAssignerPass()); + return true; + } else if (Name == "pallas-declare-function-contract") { + FPM.addPass(pallas::FunctionContractDeclarerPass()); + return true; + } else if (Name == "pallas-transform-function-body") { + FPM.addPass(pallas::FunctionBodyTransformerPass()); + return true; + } + return false; + }); + }}; +} + +extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo +llvmGetPassPluginInfo() { + return getPallasPluginInfo(); +} diff --git a/src/llvm/lib/Transform/BlockTransform.cpp b/src/llvm/lib/Transform/BlockTransform.cpp new file mode 100644 index 0000000000..c32a1c05c7 --- /dev/null +++ b/src/llvm/lib/Transform/BlockTransform.cpp @@ -0,0 +1,77 @@ +#include "Transform/BlockTransform.h" + +#include "Transform/Instruction/BinaryOpTransform.h" +#include "Transform/Instruction/CastOpTransform.h" +#include "Transform/Instruction/FuncletPadOpTransform.h" +#include "Transform/Instruction/MemoryOpTransform.h" +#include "Transform/Instruction/OtherOpTransform.h" +#include "Transform/Instruction/TermOpTransform.h" +#include "Transform/Instruction/UnaryOpTransform.h" +#include "Util/Exceptions.h" + +const std::string SOURCE_LOC = "Transform::BlockTransform"; + +void llvm2col::transformLLVMBlock(llvm::BasicBlock &llvmBlock, + pallas::FunctionCursor &functionCursor) { + if (functionCursor.isVisited(llvmBlock)) + return; + col::Block &colBlock = + functionCursor.getOrSetLLVMBlock2LabeledColBlockEntry(llvmBlock).block; + /* for (auto *B : llvm::predecessors(&llvmBlock)) { */ + /* if (!functionCursor.isVisited(*B)) */ + /* return; */ + /* } */ + /* if (functionCursor.getLoopInfo().isLoopHeader(&llvmBlock)) { */ + /* transformLoop(llvmBlock, functionCursor); */ + /* return; */ + /* } */ + for (auto &I : llvmBlock) { + transformInstruction(functionCursor, I, colBlock); + } + functionCursor.complete(colBlock); +} + +void llvm2col::transformInstruction(pallas::FunctionCursor &funcCursor, + llvm::Instruction &llvmInstruction, + col::Block &colBodyBlock) { + u_int32_t opCode = llvmInstruction.getOpcode(); + if (llvm::Instruction::TermOpsBegin <= opCode && + opCode < llvm::Instruction::TermOpsEnd) { + llvm2col::transformTermOp(llvmInstruction, colBodyBlock, funcCursor); + } else if (llvm::Instruction::BinaryOpsBegin <= opCode && + opCode < llvm::Instruction::BinaryOpsEnd) { + llvm2col::transformBinaryOp(llvmInstruction, colBodyBlock, funcCursor); + } else if (llvm::Instruction::UnaryOpsBegin <= opCode && + opCode < llvm::Instruction::UnaryOpsEnd) { + llvm2col::transformUnaryOp(llvmInstruction, colBodyBlock, funcCursor); + } else if (llvm::Instruction::MemoryOpsBegin <= opCode && + opCode < llvm::Instruction::MemoryOpsEnd) { + llvm2col::transformMemoryOp(llvmInstruction, colBodyBlock, funcCursor); + } else if (llvm::Instruction::CastOpsBegin <= opCode && + opCode < llvm::Instruction::CastOpsEnd) { + llvm2col::transformCastOp(llvmInstruction, colBodyBlock, funcCursor); + } else if (llvm::Instruction::FuncletPadOpsBegin <= opCode && + opCode < llvm::Instruction::FuncletPadOpsEnd) { + llvm2col::transformFuncletPadOp(llvmInstruction, colBodyBlock, + funcCursor); + } else if (llvm::Instruction::OtherOpsBegin <= opCode && + opCode < llvm::Instruction::OtherOpsEnd) { + llvm2col::transformOtherOp(llvmInstruction, colBodyBlock, funcCursor); + } else { + reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); + } +} + +void llvm2col::transformLoop(llvm::BasicBlock &llvmBlock, + pallas::FunctionCursor &functionCursor) { + pallas::ErrorReporter::addError(SOURCE_LOC, "Unsupported loop detected", + llvmBlock); +} + +void llvm2col::reportUnsupportedOperatorError( + const std::string &source, llvm::Instruction &llvmInstruction) { + std::stringstream errorStream; + errorStream << "Unsupported operator \"" << llvmInstruction.getOpcodeName() + << '"'; + pallas::ErrorReporter::addError(source, errorStream.str(), llvmInstruction); +} diff --git a/src/llvm/lib/Transform/Instruction/BinaryOpTransform.cpp b/src/llvm/lib/Transform/Instruction/BinaryOpTransform.cpp new file mode 100644 index 0000000000..de070a21e2 --- /dev/null +++ b/src/llvm/lib/Transform/Instruction/BinaryOpTransform.cpp @@ -0,0 +1,82 @@ +#include "Transform/Instruction/BinaryOpTransform.h" + +#include "Origin/OriginProvider.h" +#include "Transform/BlockTransform.h" +#include "Transform/Transform.h" +#include "Util/Exceptions.h" + +const std::string SOURCE_LOC = "Transform::Instruction::BinaryOp"; + +void llvm2col::transformBinaryOp(llvm::Instruction &llvmInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + col::Assign &assignment = + funcCursor.createAssignmentAndDeclaration(llvmInstruction, colBlock); + switch (llvm::Instruction::BinaryOps(llvmInstruction.getOpcode())) { + case llvm::Instruction::Add: { + col::Plus &expr = *assignment.mutable_value()->mutable_plus(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + case llvm::Instruction::Sub: { + col::Minus &expr = *assignment.mutable_value()->mutable_minus(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + case llvm::Instruction::Mul: { + col::Mult &expr = *assignment.mutable_value()->mutable_mult(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + case llvm::Instruction::SDiv: + case llvm::Instruction::UDiv: { + // XXX: There is an assumption here that signed and unsigned division + // are equal + if (llvmInstruction.isExact()) { + // XXX (Alexander): I'm not sure why we wouldn't support exact + // division because it seems to me that it is simply a promise used + // by optimisations that the right operand divides the left exactly + /* pallas::ErrorReporter::addError( */ + /* SOURCE_LOC, "Exact division not supported", llvmInstruction); + */ + } + col::FloorDiv &expr = *assignment.mutable_value()->mutable_floor_div(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + // TODO: All of these are currently bitwise operators, verify that works + // correctly when operating on booleans in VerCors + case llvm::Instruction::And: { + col::BitAnd &expr = *assignment.mutable_value()->mutable_bit_and(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + case llvm::Instruction::Or: { + col::BitOr &expr = *assignment.mutable_value()->mutable_bit_or(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + case llvm::Instruction::Xor: { + col::BitXor &expr = *assignment.mutable_value()->mutable_bit_xor(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + case llvm::Instruction::Shl: { + col::BitShl &expr = *assignment.mutable_value()->mutable_bit_shl(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + case llvm::Instruction::LShr: { + col::BitUShr &expr = *assignment.mutable_value()->mutable_bit_u_shr(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + case llvm::Instruction::AShr: { + col::BitShr &expr = *assignment.mutable_value()->mutable_bit_shr(); + transformBinExpr(llvmInstruction, expr, funcCursor); + break; + } + default: + reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); + } +} diff --git a/src/llvm/lib/Transform/Instruction/CastOpTransform.cpp b/src/llvm/lib/Transform/Instruction/CastOpTransform.cpp new file mode 100644 index 0000000000..3e4a987231 --- /dev/null +++ b/src/llvm/lib/Transform/Instruction/CastOpTransform.cpp @@ -0,0 +1,87 @@ +#include "Transform/Instruction/CastOpTransform.h" + +#include "Transform/BlockTransform.h" +#include "Transform/Transform.h" +#include "Util/Exceptions.h" + +const std::string SOURCE_LOC = "Transform::Instruction::CastOp"; +void llvm2col::transformCastOp(llvm::Instruction &llvmInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + switch (llvm::Instruction::CastOps(llvmInstruction.getOpcode())) { + case llvm::Instruction::SExt: + transformSExt(llvm::cast(llvmInstruction), colBlock, + funcCursor); + break; + case llvm::Instruction::ZExt: + transformZExt(llvm::cast(llvmInstruction), colBlock, + funcCursor); + break; + case llvm::Instruction::Trunc: + transformTrunc(llvm::cast(llvmInstruction), colBlock, + funcCursor); + break; + default: + reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); + } +} + +void llvm2col::transformSExt(llvm::SExtInst &sextInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + col::Assign &assignment = + funcCursor.createAssignmentAndDeclaration(sextInstruction, colBlock); + col::Expr *sextExpr = assignment.mutable_value(); + col::LlvmSignExtend *sext = sextExpr->mutable_llvm_sign_extend(); + sext->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(sextInstruction)); + llvm2col::transformAndSetType(*sextInstruction.getSrcTy(), + *sext->mutable_input_type()); + llvm2col::transformAndSetType(*sextInstruction.getDestTy(), + *sext->mutable_output_type()); + // TODO: Surely there must be a better way to access this operand than + // getOperand(0) + llvm2col::transformAndSetExpr(funcCursor, sextInstruction, + *sextInstruction.getOperand(0), + *sext->mutable_value()); +} + +void llvm2col::transformZExt(llvm::ZExtInst &zextInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + col::Assign &assignment = + funcCursor.createAssignmentAndDeclaration(zextInstruction, colBlock); + col::Expr *zextExpr = assignment.mutable_value(); + col::LlvmSignExtend *zext = zextExpr->mutable_llvm_sign_extend(); + zext->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(zextInstruction)); + llvm2col::transformAndSetType(*zextInstruction.getSrcTy(), + *zext->mutable_input_type()); + llvm2col::transformAndSetType(*zextInstruction.getDestTy(), + *zext->mutable_output_type()); + // TODO: Surely there must be a better way to access this operand than + // getOperand(0) + llvm2col::transformAndSetExpr(funcCursor, zextInstruction, + *zextInstruction.getOperand(0), + *zext->mutable_value()); +} + +void llvm2col::transformTrunc(llvm::TruncInst &truncInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + col::Assign &assignment = + funcCursor.createAssignmentAndDeclaration(truncInstruction, colBlock); + col::Expr *truncExpr = assignment.mutable_value(); + col::LlvmSignExtend *trunc = truncExpr->mutable_llvm_sign_extend(); + trunc->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(truncInstruction)); + llvm2col::transformAndSetType(*truncInstruction.getSrcTy(), + *trunc->mutable_input_type()); + llvm2col::transformAndSetType(*truncInstruction.getDestTy(), + *trunc->mutable_output_type()); + // TODO: Surely there must be a better way to access this operand than + // getOperand(0) + llvm2col::transformAndSetExpr(funcCursor, truncInstruction, + *truncInstruction.getOperand(0), + *trunc->mutable_value()); +} diff --git a/src/llvm/lib/Transform/Instruction/FuncletPadOpTransform.cpp b/src/llvm/lib/Transform/Instruction/FuncletPadOpTransform.cpp new file mode 100644 index 0000000000..104b71d3cf --- /dev/null +++ b/src/llvm/lib/Transform/Instruction/FuncletPadOpTransform.cpp @@ -0,0 +1,13 @@ +#include "Transform/Instruction/FuncletPadOpTransform.h" + +#include "Transform/BlockTransform.h" +#include "Util/Exceptions.h" + +const std::string SOURCE_LOC = "Transform::Instruction::FuncletPadOp"; + +void llvm2col::transformFuncletPadOp(llvm::Instruction &llvmInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + // TODO stub + reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); +} diff --git a/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp b/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp new file mode 100644 index 0000000000..ac8e3511ee --- /dev/null +++ b/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp @@ -0,0 +1,149 @@ +#include "Transform/Instruction/MemoryOpTransform.h" + +#include "Origin/OriginProvider.h" +#include "Transform/BlockTransform.h" +#include "Transform/Transform.h" +#include "Util/Exceptions.h" + +const std::string SOURCE_LOC = "Transform::Instruction::MemoryOp"; + +void llvm2col::transformMemoryOp(llvm::Instruction &llvmInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + switch (llvm::Instruction::MemoryOps(llvmInstruction.getOpcode())) { + case llvm::Instruction::Alloca: + transformAllocA(llvm::cast(llvmInstruction), colBlock, + funcCursor); + break; + case llvm::Instruction::Load: + transformLoad(llvm::cast(llvmInstruction), colBlock, + funcCursor); + break; + case llvm::Instruction::Store: + transformStore(llvm::cast(llvmInstruction), colBlock, + funcCursor); + break; + case llvm::Instruction::GetElementPtr: + transformGetElementPtr( + llvm::cast(llvmInstruction), colBlock, + funcCursor); + break; + default: + reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); + } +} + +void llvm2col::transformAllocA(llvm::AllocaInst &allocAInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + col::Assign &assignment = funcCursor.createAssignmentAndDeclaration( + allocAInstruction, colBlock, + /* pointer type*/ allocAInstruction.getAllocatedType()); + col::Expr *allocAExpr = assignment.mutable_value(); + col::LlvmAllocA *allocA = allocAExpr->mutable_llvm_alloc_a(); + allocA->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(allocAInstruction)); + llvm2col::transformAndSetType(*allocAInstruction.getAllocatedType(), + *allocA->mutable_allocation_type()); + llvm2col::transformAndSetExpr(funcCursor, allocAInstruction, + *allocAInstruction.getArraySize(), + *allocA->mutable_num_elements()); +} + +void llvm2col::transformAtomicOrdering(llvm::AtomicOrdering ordering, + col::LlvmMemoryOrdering *colOrdering) { + switch (ordering) { + case llvm::AtomicOrdering::NotAtomic: + colOrdering->mutable_llvm_memory_not_atomic()->set_allocated_origin( + llvm2col::generateMemoryOrderingOrigin(ordering)); + break; + case llvm::AtomicOrdering::Unordered: + colOrdering->mutable_llvm_memory_unordered()->set_allocated_origin( + llvm2col::generateMemoryOrderingOrigin(ordering)); + break; + case llvm::AtomicOrdering::Monotonic: + colOrdering->mutable_llvm_memory_monotonic()->set_allocated_origin( + llvm2col::generateMemoryOrderingOrigin(ordering)); + break; + case llvm::AtomicOrdering::Acquire: + colOrdering->mutable_llvm_memory_acquire()->set_allocated_origin( + llvm2col::generateMemoryOrderingOrigin(ordering)); + break; + case llvm::AtomicOrdering::Release: + colOrdering->mutable_llvm_memory_release()->set_allocated_origin( + llvm2col::generateMemoryOrderingOrigin(ordering)); + break; + case llvm::AtomicOrdering::AcquireRelease: + colOrdering->mutable_llvm_memory_acquire_release() + ->set_allocated_origin( + llvm2col::generateMemoryOrderingOrigin(ordering)); + break; + case llvm::AtomicOrdering::SequentiallyConsistent: + colOrdering->mutable_llvm_memory_sequentially_consistent() + ->set_allocated_origin( + llvm2col::generateMemoryOrderingOrigin(ordering)); + break; + } +} + +void llvm2col::transformLoad(llvm::LoadInst &loadInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + // We are not storing isVolatile and getAlign + col::Assign &assignment = + funcCursor.createAssignmentAndDeclaration(loadInstruction, colBlock); + col::Expr *loadExpr = assignment.mutable_value(); + col::LlvmLoad *load = loadExpr->mutable_llvm_load(); + load->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(loadInstruction)); + llvm::errs() << "Working on " << loadInstruction << " has type " + << *loadInstruction.getType() << "\n"; + llvm2col::transformAndSetType(*loadInstruction.getType(), + *load->mutable_load_type()); + llvm2col::transformAndSetExpr(funcCursor, loadInstruction, + *loadInstruction.getPointerOperand(), + *load->mutable_pointer()); + llvm2col::transformAtomicOrdering(loadInstruction.getOrdering(), + load->mutable_ordering()); +} + +void llvm2col::transformStore(llvm::StoreInst &storeInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + // We are not storing isVolatile and getAlign + col::LlvmStore *store = colBlock.add_statements()->mutable_llvm_store(); + store->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(storeInstruction)); + llvm2col::transformAndSetExpr(funcCursor, storeInstruction, + *storeInstruction.getValueOperand(), + *store->mutable_value()); + llvm2col::transformAndSetExpr(funcCursor, storeInstruction, + *storeInstruction.getPointerOperand(), + *store->mutable_pointer()); + llvm2col::transformAtomicOrdering(storeInstruction.getOrdering(), + store->mutable_ordering()); +} + +void llvm2col::transformGetElementPtr(llvm::GetElementPtrInst &gepInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + + col::Assign &assignment = funcCursor.createAssignmentAndDeclaration( + gepInstruction, colBlock, gepInstruction.getResultElementType()); + col::Expr *gepExpr = assignment.mutable_value(); + col::LlvmGetElementPointer *gep = + gepExpr->mutable_llvm_get_element_pointer(); + gep->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(gepInstruction)); + llvm2col::transformAndSetType(*gepInstruction.getSourceElementType(), + *gep->mutable_structure_type()); + llvm2col::transformAndSetType(*gepInstruction.getResultElementType(), + *gep->mutable_result_type()); + llvm2col::transformAndSetExpr(funcCursor, gepInstruction, + *gepInstruction.getPointerOperand(), + *gep->mutable_pointer()); + for (auto &index : gepInstruction.indices()) { + llvm2col::transformAndSetExpr(funcCursor, gepInstruction, *index.get(), + *gep->add_indices()); + } +} diff --git a/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp b/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp new file mode 100644 index 0000000000..81453faa62 --- /dev/null +++ b/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp @@ -0,0 +1,190 @@ +#include "Transform/Instruction/OtherOpTransform.h" +#include + +#include "Transform/BlockTransform.h" +#include "Transform/Transform.h" +#include "Util/Exceptions.h" + +const std::string SOURCE_LOC = "Transform::Instruction::OtherOp"; + +void llvm2col::transformOtherOp(llvm::Instruction &llvmInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + switch (llvm::Instruction::OtherOps(llvmInstruction.getOpcode())) { + case llvm::Instruction::PHI: + transformPhi(llvm::cast(llvmInstruction), funcCursor); + break; + case llvm::Instruction::ICmp: + transformICmp(llvm::cast(llvmInstruction), colBlock, + funcCursor); + break; + case llvm::Instruction::Call: + transformCallExpr(llvm::cast(llvmInstruction), colBlock, + funcCursor); + break; + default: + reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); + } +} + +void llvm2col::transformPhi(llvm::PHINode &phiInstruction, + pallas::FunctionCursor &funcCursor) { + col::Variable &varDecl = funcCursor.declareVariable(phiInstruction); + for (auto &B : phiInstruction.blocks()) { + // add assignment of the variable to target block + col::Block &targetBlock = + funcCursor.getOrSetLLVMBlock2LabeledColBlockEntry(*B).block; + col::Assign &assignment = + funcCursor.createAssignment(phiInstruction, targetBlock, varDecl); + // assign correct value by looking at the value-block pair of phi + // instruction. + col::Expr *value = assignment.mutable_value(); + llvm2col::transformAndSetExpr( + funcCursor, phiInstruction, + *phiInstruction.getIncomingValueForBlock(B), *value); + } +} + +void llvm2col::transformICmp(llvm::ICmpInst &icmpInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + // we only support integer comparison + if (not icmpInstruction.getOperand(0)->getType()->isIntegerTy()) { + pallas::ErrorReporter::addError(SOURCE_LOC, "Unsupported compare type", + icmpInstruction); + return; + } + col::Assign &assignment = + funcCursor.createAssignmentAndDeclaration(icmpInstruction, colBlock); + switch (llvm::ICmpInst::Predicate(icmpInstruction.getPredicate())) { + case llvm::CmpInst::ICMP_EQ: { + col::Eq &eq = *assignment.mutable_value()->mutable_eq(); + transformCmpExpr(icmpInstruction, eq, funcCursor); + break; + } + case llvm::CmpInst::ICMP_NE: { + col::Neq &neq = *assignment.mutable_value()->mutable_neq(); + transformCmpExpr(icmpInstruction, neq, funcCursor); + break; + } + case llvm::CmpInst::ICMP_SGT: + case llvm::CmpInst::ICMP_UGT: { + col::Greater > = *assignment.mutable_value()->mutable_greater(); + transformCmpExpr(icmpInstruction, gt, funcCursor); + break; + } + case llvm::CmpInst::ICMP_SGE: + case llvm::CmpInst::ICMP_UGE: { + col::GreaterEq &geq = *assignment.mutable_value()->mutable_greater_eq(); + transformCmpExpr(icmpInstruction, geq, funcCursor); + break; + } + case llvm::CmpInst::ICMP_SLT: + case llvm::CmpInst::ICMP_ULT: { + col::Less < = *assignment.mutable_value()->mutable_less(); + transformCmpExpr(icmpInstruction, lt, funcCursor); + break; + } + case llvm::CmpInst::ICMP_SLE: + case llvm::CmpInst::ICMP_ULE: { + col::LessEq &leq = *assignment.mutable_value()->mutable_less_eq(); + transformCmpExpr(icmpInstruction, leq, funcCursor); + break; + } + default: + pallas::ErrorReporter::addError(SOURCE_LOC, "Unknown ICMP predicate", + icmpInstruction); + } +} + +void llvm2col::transformCmpExpr(llvm::CmpInst &cmpInstruction, + auto &colCompareExpr, + pallas::FunctionCursor &funcCursor) { + transformBinExpr(cmpInstruction, colCompareExpr, funcCursor); +} + +bool llvm2col::checkCallSupport(llvm::CallInst &callInstruction) { + if (callInstruction.isIndirectCall()) { + pallas::ErrorReporter::addError( + SOURCE_LOC, "Indirect calls are not supported", callInstruction); + return false; + } + // tail recursion + if (callInstruction.isMustTailCall()) { + pallas::ErrorReporter::addError(SOURCE_LOC, + "Tail call optimization not supported", + callInstruction); + return false; + } + // fast math + if (callInstruction.getFastMathFlags().any()) { + pallas::ErrorReporter::addError(SOURCE_LOC, "Fast math not supported", + callInstruction); + return false; + } + // return attributes + for (auto &A : callInstruction.getAttributes().getRetAttrs()) { + // TODO: Deal with these most of them do not affect the semantics we + // care about so we could ignore them + std::stringstream errorStream; + errorStream << "Return attribute \"" << A.getAsString() + << "\" not supported"; + pallas::ErrorReporter::addWarning(SOURCE_LOC, errorStream.str(), + callInstruction); + return true; + } + // address space is platform dependent (unlikely to change semantics) + // function attributes are just extra compiler information (no semanatic + // changes) + + // operand bundles + if (callInstruction.hasOperandBundles()) { + pallas::ErrorReporter::addError( + SOURCE_LOC, "Operand bundles not supported", callInstruction); + return false; + } + + return true; +} + +void llvm2col::transformCallExpr(llvm::CallInst &callInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + if (!checkCallSupport(callInstruction) || + callInstruction.getCalledFunction() == nullptr) + return; + + if (callInstruction.getCalledFunction()->isIntrinsic()) { + // TODO: Deal with intrinsic functions + return; + } + // allocate expression to host the function call in advance + col::Expr *functionCallExpr; + // if void function add an eval expression + if (callInstruction.getType()->isVoidTy()) { + col::Eval *eval = colBlock.add_statements()->mutable_eval(); + eval->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(callInstruction)); + functionCallExpr = eval->mutable_expr(); + } else { // else create an assignment + col::Assign &assignment = funcCursor.createAssignmentAndDeclaration( + callInstruction, colBlock); + functionCallExpr = assignment.mutable_value(); + } + // create actual invocation + col::LlvmFunctionInvocation *invocation = + functionCallExpr->mutable_llvm_function_invocation(); + invocation->set_allocated_blame(new col::Blame()); + // set origin + invocation->set_allocated_origin( + llvm2col::generateFunctionCallOrigin(callInstruction)); + // set function reference + invocation->mutable_ref()->set_id( + funcCursor.getFDResult(*callInstruction.getCalledFunction()) + .getFunctionId()); + // process function arguments + for (auto &A : callInstruction.args()) { + llvm2col::transformAndSetExpr(funcCursor, callInstruction, *A, + *invocation->add_args()); + } +} diff --git a/src/llvm/lib/Transform/Instruction/TermOpTransform.cpp b/src/llvm/lib/Transform/Instruction/TermOpTransform.cpp new file mode 100644 index 0000000000..66ec7d8565 --- /dev/null +++ b/src/llvm/lib/Transform/Instruction/TermOpTransform.cpp @@ -0,0 +1,141 @@ +#include "Transform/Instruction/TermOpTransform.h" + +#include "Origin/OriginProvider.h" +#include "Transform/BlockTransform.h" +#include "Transform/Transform.h" +#include "Util/Exceptions.h" + +const std::string SOURCE_LOC = "Transform::Instruction::TermOp"; + +void llvm2col::transformTermOp(llvm::Instruction &llvmInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + switch (llvm::Instruction::TermOps(llvmInstruction.getOpcode())) { + case llvm::Instruction::Ret: + transformRet(cast(llvmInstruction), colBlock, + funcCursor); + break; + case llvm::Instruction::Br: { + auto &llvmBranchInst = cast(llvmInstruction); + llvmBranchInst.isConditional() + ? transformConditionalBranch(llvmBranchInst, colBlock, funcCursor) + : transformUnConditionalBranch(llvmBranchInst, colBlock, + funcCursor); + break; + } + default: + reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); + break; + } +} + +void llvm2col::transformRet(llvm::ReturnInst &llvmRetInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + col::Return *returnStatement = colBlock.add_statements()->mutable_return_(); + returnStatement->set_allocated_origin( + generateSingleStatementOrigin(llvmRetInstruction)); + + col::Expr *returnExpr = returnStatement->mutable_result(); + if (llvmRetInstruction.getReturnValue() == nullptr) { + returnExpr->mutable_void_()->set_allocated_origin( + generateVoidOperandOrigin(llvmRetInstruction)); + } else { + llvm2col::transformAndSetExpr(funcCursor, llvmRetInstruction, + *llvmRetInstruction.getReturnValue(), + *returnExpr); + } +} + +void llvm2col::transformConditionalBranch(llvm::BranchInst &llvmBrInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + col::Branch *colBranch = colBlock.add_statements()->mutable_branch(); + colBranch->set_allocated_origin( + generateSingleStatementOrigin(llvmBrInstruction)); + // pre-declare completion because the final branch statement is already + // present + funcCursor.complete(colBlock); + // true branch + col::Tuple2_VctColAstExpr_VctColAstStatement *colTrueBranch = + colBranch->add_branches(); + // set conditional + transformAndSetExpr(funcCursor, llvmBrInstruction, + *llvmBrInstruction.getCondition(), + *colTrueBranch->mutable_v1()); + // get or pre-generate target labeled block + /* + * I hear you think, why query the 2nd operand? wouldn't that be the false + * branch i.e the else branch? While any logical implementation of getting + * operands would give the operands in order, the branch instruction is no + * ordinary instruction. For you see to get the branch argument we use the + * 0th index (so far so good), for the true evaluation of the branch + * instruction we use the 2nd index (uhhh okay, we might be skipping an + * index?) and the false evaluation of the branch instruction we use the 1st + * index (WHAT!?!?) + * + * Visualized: + * br i1 %var, label %yay, label %nay + * 0 2 1 + * + * Just smile and wave, don't question LLVM. + */ + auto *llvmTrueBlock = + cast(llvmBrInstruction.getOperand(2)); + // transform llvm true block + transformLLVMBlock(*llvmTrueBlock, funcCursor); + pallas::LabeledColBlock labeledTrueColBlock = + funcCursor.getOrSetLLVMBlock2LabeledColBlockEntry(*llvmTrueBlock); + // goto statement to true block + col::Goto *trueGoto = colTrueBranch->mutable_v2()->mutable_goto_(); + trueGoto->mutable_lbl()->set_id(labeledTrueColBlock.label.decl().id()); + // set origin for goto to true block + trueGoto->set_allocated_origin( + generateSingleStatementOrigin(llvmBrInstruction)); + + // false branch + col::Tuple2_VctColAstExpr_VctColAstStatement *colFalseBranch = + colBranch->add_branches(); + // set conditional (which is a true constant as else == else if(true))) + col::BooleanValue *elseCondition = + colFalseBranch->mutable_v1()->mutable_boolean_value(); + elseCondition->set_value(true); + // set origin of else condition + elseCondition->set_allocated_origin(generateOperandOrigin( + llvmBrInstruction, *llvmBrInstruction.getCondition())); + // get llvm block targeted by the llvm branch + auto *llvmFalseBlock = + cast(llvmBrInstruction.getOperand(1)); + // transform llvm falseBlock + transformLLVMBlock(*llvmFalseBlock, funcCursor); + // get or pre-generate target labeled block + pallas::LabeledColBlock labeledFalseColBlock = + funcCursor.getOrSetLLVMBlock2LabeledColBlockEntry(*llvmFalseBlock); + // goto statement to false block + col::Goto *falseGoto = colFalseBranch->mutable_v2()->mutable_goto_(); + falseGoto->mutable_lbl()->set_id(labeledFalseColBlock.label.decl().id()); + // set origin for goto to false block + falseGoto->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(llvmBrInstruction)); +} + +void llvm2col::transformUnConditionalBranch( + llvm::BranchInst &llvmBrInstruction, col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + // get llvm target block + auto *llvmTargetBlock = + cast(llvmBrInstruction.getOperand(0)); + // transform llvm targetBlock + transformLLVMBlock(*llvmTargetBlock, funcCursor); + // get or pre generate target labeled block + pallas::LabeledColBlock labeledColBlock = + funcCursor.getOrSetLLVMBlock2LabeledColBlockEntry(*llvmTargetBlock); + // create goto to target labeled block + col::Goto *colGoto = colBlock.add_statements()->mutable_goto_(); + colGoto->mutable_lbl()->set_id(labeledColBlock.label.decl().id()); + // set origin of goto statement + colGoto->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(llvmBrInstruction)); + // pre-declare completion because the final goto is already present + funcCursor.complete(colBlock); +} diff --git a/src/llvm/lib/Transform/Instruction/UnaryOpTransform.cpp b/src/llvm/lib/Transform/Instruction/UnaryOpTransform.cpp new file mode 100644 index 0000000000..a3e25de8a7 --- /dev/null +++ b/src/llvm/lib/Transform/Instruction/UnaryOpTransform.cpp @@ -0,0 +1,10 @@ +#include "Transform/Instruction/UnaryOpTransform.h" +#include "Transform/BlockTransform.h" + +const std::string SOURCE_LOC = "Transform::Instruction::UnaryOp"; +void llvm2col::transformUnaryOp(llvm::Instruction &llvmInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + // TODO stub + reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); +} diff --git a/src/llvm/lib/Transform/Transform.cpp b/src/llvm/lib/Transform/Transform.cpp new file mode 100644 index 0000000000..c908cc9b91 --- /dev/null +++ b/src/llvm/lib/Transform/Transform.cpp @@ -0,0 +1,306 @@ +#include "Transform/Transform.h" + +#include +#include + +#include "Origin/OriginProvider.h" +#include "Passes/Function/FunctionBodyTransformer.h" +#include "Util/Exceptions.h" + +/** + * Utility function that converts LLVM types to col types + * @param type + */ +const std::string SOURCE_LOC = "Transform::Transform"; + +namespace col = vct::col::ast; + +void llvm2col::transformAndSetPointerType(llvm::Type &llvmType, + col::Type &colType) { + col::LlvmtPointer *pointerType = colType.mutable_llvmt_pointer(); + pointerType->set_allocated_origin(generateTypeOrigin(llvmType)); + llvm2col::transformAndSetType(llvmType, *pointerType->mutable_inner_type()); +} + +void llvm2col::transformAndSetType(llvm::Type &llvmType, col::Type &colType) { + switch (llvmType.getTypeID()) { + case llvm::Type::IntegerTyID: + if (llvmType.getIntegerBitWidth() == 1) { + colType.mutable_t_bool()->set_allocated_origin( + generateTypeOrigin(llvmType)); + } else { + col::LlvmtInt *colInt = colType.mutable_llvmt_int(); + colInt->set_bit_width(llvmType.getIntegerBitWidth()); + colInt->set_allocated_origin(generateTypeOrigin(llvmType)); + } + break; + case llvm::Type::VoidTyID: + colType.mutable_t_void()->set_allocated_origin( + generateTypeOrigin(llvmType)); + break; + case llvm::Type::PointerTyID: + colType.mutable_llvmt_pointer()->set_allocated_origin( + generateTypeOrigin(llvmType)); + break; + case llvm::Type::MetadataTyID: + colType.mutable_llvmt_metadata()->set_allocated_origin( + generateTypeOrigin(llvmType)); + break; + case llvm::Type::StructTyID: { + llvm::StructType &structType = llvm::cast(llvmType); + col::LlvmtStruct *colStruct = colType.mutable_llvmt_struct(); + colStruct->set_allocated_origin(generateTypeOrigin(llvmType)); + if (!structType.isLiteral()) { + // TODO: Instead of storing the name do we want keep only a single + // instance of the col::LLVMTStruct per non-literal struct type? + // XXX: This name can be the empty string for unnamed types, and it + // won't be set for literal types + colStruct->set_name(structType.getName().str()); + } + colStruct->set_packed(structType.isPacked()); + for (llvm::Type *element : structType.elements()) { + llvm2col::transformAndSetType(*element, *colStruct->add_elements()); + } + break; + } + case llvm::Type::ArrayTyID: { + llvm::ArrayType &arrayType = llvm::cast(llvmType); + col::LlvmtArray *colArray = colType.mutable_llvmt_array(); + colArray->set_allocated_origin(generateTypeOrigin(llvmType)); + llvm2col::transformAndSetType(*arrayType.getElementType(), + *colArray->mutable_element_type()); + colArray->set_num_elements(arrayType.getNumElements()); + break; + } + case llvm::Type::FixedVectorTyID: + case llvm::Type::ScalableVectorTyID: { + llvm::VectorType &vectorType = llvm::cast(llvmType); + col::LlvmtVector *colVector = colType.mutable_llvmt_vector(); + colVector->set_allocated_origin(generateTypeOrigin(llvmType)); + llvm2col::transformAndSetType(*vectorType.getElementType(), + *colVector->mutable_element_type()); + colVector->set_num_elements( + vectorType.getElementCount().getKnownMinValue()); + break; + } + + default: + throw pallas::UnsupportedTypeException(llvmType); + } +} + +void llvm2col::transformAndSetExpr(pallas::FunctionCursor &functionCursor, + llvm::Instruction &llvmInstruction, + llvm::Value &llvmOperand, + col::Expr &colExpr) { + col::Origin *origin = generateOperandOrigin(llvmInstruction, llvmOperand); + if (llvm::isa(llvmOperand)) { + transformAndSetConstExpr( + functionCursor.getFunctionAnalysisManager(), origin, + llvm::cast(llvmOperand), colExpr); + } else { + transformAndSetVarExpr(functionCursor, origin, + llvmInstruction.getOpcode() == + llvm::Instruction::PHI, + llvmOperand, colExpr); + } +} + +void llvm2col::transformAndSetVarExpr(pallas::FunctionCursor &functionCursor, + col::Origin *origin, bool inPhiNode, + llvm::Value &llvmOperand, + col::Expr &colExpr) { + col::Variable colVar = + functionCursor.getVariableMapEntry(llvmOperand, inPhiNode); + col::Local *colLocal = colExpr.mutable_local(); + colLocal->set_allocated_origin(origin); + colLocal->mutable_ref()->set_id(colVar.id()); +} + +void llvm2col::transformAndSetConstExpr(llvm::FunctionAnalysisManager &FAM, + col::Origin *origin, + llvm::Constant &llvmConstant, + col::Expr &colExpr) { + if (llvm::isa(llvmConstant)) { + col::LlvmZeroedAggregateValue *colZero = + colExpr.mutable_llvm_zeroed_aggregate_value(); + + colZero->set_allocated_origin(origin); + llvm2col::transformAndSetType(*llvmConstant.getType(), + *colZero->mutable_aggregate_type()); + return; + } + llvm::Type *constType = llvmConstant.getType(); + switch (llvmConstant.getType()->getTypeID()) { + case llvm::Type::IntegerTyID: + if (constType->getIntegerBitWidth() == 1) { + col::BooleanValue *boolValue = colExpr.mutable_boolean_value(); + boolValue->set_allocated_origin(origin); + boolValue->set_value(llvmConstant.isOneValue()); + } else { + col::LlvmIntegerValue *integerValue = + colExpr.mutable_llvm_integer_value(); + integerValue->set_allocated_origin(origin); + llvm::APInt apInt = llvmConstant.getUniqueInteger(); + transformAndSetBigInt(apInt, *integerValue->mutable_value()); + col::LlvmtInt *colInt = + integerValue->mutable_integer_type()->mutable_llvmt_int(); + colInt->set_bit_width(constType->getIntegerBitWidth()); + colInt->set_allocated_origin(generateTypeOrigin(*constType)); + } + break; + case llvm::Type::PointerTyID: { + // Can't be a function since we caught that in transformAndSetExpr + llvm::Value *stripped = llvmConstant.stripPointerCastsAndAliases(); + if (llvm::isa(stripped)) { + col::LlvmFunctionPointerValue *funcPointer = + colExpr.mutable_llvm_function_pointer_value(); + funcPointer->set_allocated_origin(origin); + funcPointer->mutable_value()->set_id( + FAM.getResult( + llvm::cast(*stripped)) + .getAssociatedColFuncDef() + .id()); + } else if (llvm::isa(stripped)) { + // XXX: To avoid having a map of GlobalVariables to their COL nodes + // we break with the convention and use the memory location of the + // LLVM value instead of the memory location of the COL node as the + // id + auto id = reinterpret_cast(stripped); + col::LlvmPointerValue *pointer = + colExpr.mutable_llvm_pointer_value(); + pointer->set_allocated_origin(origin); + pointer->mutable_value()->set_id(id); + } else if (llvm::isa(stripped)) { + col::Null *pointer = colExpr.mutable_null(); + pointer->set_allocated_origin(origin); + } else { + std::string errCtx; + llvm::raw_string_ostream(errCtx) << llvmConstant; + std::stringstream errorStream; + errorStream << "Unknown constant pointer '" << errCtx << "' " + << llvm::isa(stripped) << ", " + << llvm::isa(stripped) << ", " + << llvm::isa(stripped) << ", " + << llvm::isa(stripped) << ", " + << llvm::isa(stripped) << ", " + << llvm::isa(stripped); + pallas::ErrorReporter::addError( + SOURCE_LOC, errorStream.str(), + llvm2col::extractShortPosition(*origin)); + } + break; + } + case llvm::Type::StructTyID: { + llvm::ConstantStruct &llvmStruct = + llvm::cast(llvmConstant); + col::LlvmStructValue *colStruct = colExpr.mutable_llvm_struct_value(); + + for (auto &operand : llvmStruct.operands()) { + llvm2col::transformAndSetConstExpr( + FAM, llvm2col::deepenOperandOrigin(*origin, *operand.get()), + llvm::cast(*operand.get()), + *colStruct->add_value()); + } + colStruct->set_allocated_origin(origin); + llvm2col::transformAndSetType(*llvmStruct.getType(), + *colStruct->mutable_struct_type()); + + break; + } + case llvm::Type::ArrayTyID: { + if (llvm::isa(llvmConstant)) { + llvm::ConstantArray &llvmArray = + llvm::cast(llvmConstant); + col::LlvmArrayValue *colArray = colExpr.mutable_llvm_array_value(); + + for (auto &operand : llvmArray.operands()) { + llvm2col::transformAndSetConstExpr( + FAM, llvm2col::deepenOperandOrigin(*origin, *operand.get()), + llvm::cast(*operand.get()), + *colArray->add_value()); + } + colArray->set_allocated_origin(origin); + llvm2col::transformAndSetType(*llvmArray.getType(), + *colArray->mutable_array_type()); + } else { + llvm::ConstantDataArray &llvmArray = + llvm::cast(llvmConstant); + col::LlvmRawArrayValue *colArray = + colExpr.mutable_llvm_raw_array_value(); + + // TODO: This is not a very useful format. Ideally we detect the + // type and get elements individually as integers or floats or + // something + colArray->set_value(llvmArray.getRawDataValues().str()); + colArray->set_allocated_origin(origin); + llvm::errs() << "Array constant " << llvmArray << " has type " + << *llvmArray.getType() << "\n"; + llvm2col::transformAndSetType(*llvmArray.getType(), + *colArray->mutable_array_type()); + } + + break; + } + case llvm::Type::FixedVectorTyID: { + if (llvm::isa(llvmConstant)) { + llvm::ConstantVector &llvmVector = + llvm::cast(llvmConstant); + col::LlvmVectorValue *colVector = + colExpr.mutable_llvm_vector_value(); + + for (auto &operand : llvmVector.operands()) { + llvm2col::transformAndSetConstExpr( + FAM, llvm2col::deepenOperandOrigin(*origin, *operand.get()), + llvm::cast(*operand.get()), + *colVector->add_value()); + } + colVector->set_allocated_origin(origin); + llvm2col::transformAndSetType(*llvmVector.getType(), + *colVector->mutable_vector_type()); + } else { + llvm::ConstantDataVector &llvmVector = + llvm::cast(llvmConstant); + col::LlvmRawVectorValue *colVector = + colExpr.mutable_llvm_raw_vector_value(); + + // TODO: This is not a very useful format. Ideally we detect the + // type and get elements individually as integers or floats or + // something + colVector->set_value(llvmVector.getRawDataValues().str()); + colVector->set_allocated_origin(origin); + llvm2col::transformAndSetType(*llvmVector.getType(), + *colVector->mutable_vector_type()); + } + + break; + } + default: + std::string errCtx; + llvm::raw_string_ostream(errCtx) << llvmConstant; + std::stringstream errorStream; + errorStream << "Unknown constant '" << errCtx << "' of type '" + << constType->getTypeID() << "'"; + pallas::ErrorReporter::addError( + SOURCE_LOC, errorStream.str(), + llvm2col::extractShortPosition(*origin)); + } +} + +void llvm2col::transformAndSetBigInt(llvm::APInt &apInt, col::BigInt &bigInt) { + // TODO works for "small" signed and unsigned numbers, may break for values + // >=2^64 + llvm::APInt byteSwapped = apInt.byteSwap(); + std::vector byteVector; + for (uint32_t i = 0; i < byteSwapped.getNumWords(); i++) { + byteVector.push_back(byteSwapped.getRawData()[i]); + } + bigInt.set_data(byteVector.data(), apInt.getBitWidth() / 8); +} + +std::string llvm2col::getValueName(llvm::Value &llvmValue) { + std::string name; + llvm::raw_string_ostream contextStream = llvm::raw_string_ostream(name); + llvmValue.printAsOperand(contextStream, false); + return name; +} diff --git a/src/llvm/lib/Util/Exceptions.cpp b/src/llvm/lib/Util/Exceptions.cpp new file mode 100644 index 0000000000..9b14d2e251 --- /dev/null +++ b/src/llvm/lib/Util/Exceptions.cpp @@ -0,0 +1,109 @@ +#include "Util/Exceptions.h" +#include "Origin/ShortPositionDeriver.h" + +#include +#include +#include +#include + +namespace pallas { +UnsupportedTypeException::UnsupportedTypeException(const llvm::Type &type) { + llvm::raw_string_ostream output(str); + output << "Type '" << type << "' not supported"; +} + +[[nodiscard]] const char *UnsupportedTypeException::what() const noexcept { + return str.c_str(); +} + +u_int32_t ErrorReporter::errorCount; +u_int32_t ErrorReporter::warningCount; + +void ErrorReporter::addError(const std::string &source, + const std::string &message) { + llvm::errs() << "[ERROR] [pallas] [" << source << "] " << message << "\n\n"; + ErrorReporter::errorCount++; +} + +void ErrorReporter::addError(const std::string &source, + const std::string &message, + const std::string &origin) { + llvm::errs() << "[ERROR] [pallas] [" << source << "] " << message + << " @\n " << origin << "\n\n"; + ErrorReporter::errorCount++; +} + +void ErrorReporter::addError(const std::string &source, + const std::string &message, + llvm::Module &llvmModule) { + addError(source, message, llvm2col::deriveModuleShortPosition(llvmModule)); +} + +void ErrorReporter::addError(const std::string &source, + const std::string &message, + llvm::Function &llvmFunction) { + addError(source, message, + llvm2col::deriveFunctionShortPosition(llvmFunction)); +} + +void ErrorReporter::addError(const std::string &source, + const std::string &message, + llvm::BasicBlock &llvmBlock) { + addError(source, message, llvm2col::deriveBlockShortPosition(llvmBlock)); +} + +void ErrorReporter::addError(const std::string &source, + const std::string &message, + llvm::Instruction &llvmInstruction) { + addError(source, message, + llvm2col::deriveInstructionShortPosition(llvmInstruction)); +} + +void ErrorReporter::addWarning(const std::string &source, + const std::string &message) { + llvm::errs() << "[WARN] [pallas] [" << source << "] " << message << "\n\n"; + ErrorReporter::warningCount++; +} + +void ErrorReporter::addWarning(const std::string &source, + const std::string &message, + const std::string &origin) { + llvm::errs() << "[WARN] [pallas] [" << source << "] " << message << " @\n " + << origin << "\n\n"; + ErrorReporter::warningCount++; +} + +void ErrorReporter::addWarning(const std::string &source, + const std::string &message, + llvm::Module &llvmModule) { + addWarning(source, message, + llvm2col::deriveModuleShortPosition(llvmModule)); +} + +void ErrorReporter::addWarning(const std::string &source, + const std::string &message, + llvm::Function &llvmFunction) { + addWarning(source, message, + llvm2col::deriveFunctionShortPosition(llvmFunction)); +} + +void ErrorReporter::addWarning(const std::string &source, + const std::string &message, + llvm::BasicBlock &llvmBlock) { + addWarning(source, message, llvm2col::deriveBlockShortPosition(llvmBlock)); +} + +void ErrorReporter::addWarning(const std::string &source, + const std::string &message, + llvm::Instruction &llvmInstruction) { + addWarning(source, message, + llvm2col::deriveInstructionShortPosition(llvmInstruction)); +} + +bool ErrorReporter::hasErrors() { return ErrorReporter::errorCount > 0; } + +u_int32_t ErrorReporter::getErrorCount() { return ErrorReporter::errorCount; } +u_int32_t ErrorReporter::getWarningCount() { + return ErrorReporter::warningCount; +} +} // namespace pallas diff --git a/src/llvm/lib/origin/ContextDeriver.cpp b/src/llvm/lib/origin/ContextDeriver.cpp deleted file mode 100644 index c731ba9348..0000000000 --- a/src/llvm/lib/origin/ContextDeriver.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include "Origin/ContextDeriver.h" - -#include -#include - -namespace llvm2Col { - // module derivers - std::string deriveModuleContext(llvm::Module &llvmModule) { - std::string context; - llvm::raw_string_ostream(context) << llvmModule; - return context; - } - - // function derivers - std::string deriveFunctionContext(llvm::Function &llvmFunction) { - std::string context; - llvm::raw_string_ostream(context) << llvmFunction; - return context; - } - - // block derivers - - std::string deriveLabelContext(llvm::BasicBlock &llvmBlock) { - if (llvmBlock.isEntryBlock()) { - return ""; - } - std::string fullContext; - llvm::raw_string_ostream(fullContext) << llvmBlock; - return fullContext.substr(0, fullContext.find(':') + 1); - } - - std::string deriveBlockContext(llvm::BasicBlock &llvmBlock) { - std::string context; - llvm::raw_string_ostream(context) << llvmBlock; - return context; - } - - // instruction derivers - - std::string deriveSurroundingInstructionContext(llvm::Instruction &llvmInstruction) { - std::string context; - if (llvmInstruction.getPrevNode() != nullptr) { - llvm::raw_string_ostream(context) << *llvmInstruction.getPrevNode() << '\n'; - } - llvm::raw_string_ostream(context) << llvmInstruction; - if (llvmInstruction.getNextNode() != nullptr) { - llvm::raw_string_ostream(context) << '\n' << *llvmInstruction.getNextNode(); - } - return context; - } - - std::string deriveInstructionContext(llvm::Instruction &llvmInstruction) { - std::string context; - llvm::raw_string_ostream(context) << llvmInstruction; - return context; - } - - std::string deriveInstructionLhs(llvm::Instruction &llvmInstruction) { - std::string fullContext = deriveInstructionContext(llvmInstruction); - return fullContext.substr(0, fullContext.find('=')); - } - - std::string deriveInstructionRhs(llvm::Instruction &llvmInstruction) { - std::string fullContext = deriveInstructionContext(llvmInstruction); - return fullContext.substr(fullContext.find('=') + 1); - } - - std::string deriveOperandContext(llvm::Value &llvmOperand) { - std::string context; - llvm::raw_string_ostream contextStream = llvm::raw_string_ostream(context); - llvmOperand.printAsOperand(contextStream, false); - return context; - } -} \ No newline at end of file diff --git a/src/llvm/lib/origin/OriginProvider.cpp b/src/llvm/lib/origin/OriginProvider.cpp deleted file mode 100644 index 248967860a..0000000000 --- a/src/llvm/lib/origin/OriginProvider.cpp +++ /dev/null @@ -1,216 +0,0 @@ -#include "Origin/OriginProvider.h" - -#include - -#include "Origin/PreferredNameDeriver.h" -#include "Origin/ContextDeriver.h" -#include "Origin/ShortPositionDeriver.h" - -namespace llvm2Col { - col::Origin *generateProgramOrigin(llvm::Module &llvmModule) { - col::Origin *origin = new col::Origin(); - col::OriginContent *preferredNameContent = origin->add_content(); - col::PreferredName *preferredName = new col::PreferredName(); - preferredName->add_preferred_name("program:" + llvmModule.getName().str()); - preferredNameContent->set_allocated_preferred_name(preferredName); - - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveModuleContext(llvmModule)); - context->set_inline_context(deriveModuleContext(llvmModule)); - context->set_short_position(deriveModuleShortPosition(llvmModule)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateFuncDefOrigin(llvm::Function &llvmFunction) { - col::Origin *origin = new col::Origin(); - col::OriginContent *preferredNameContent = origin->add_content(); - col::PreferredName *preferredName = new col::PreferredName(); - preferredName->add_preferred_name(llvmFunction.getName().str()); - preferredNameContent->set_allocated_preferred_name(preferredName); - - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveFunctionContext(llvmFunction)); - context->set_inline_context(deriveFunctionContext(llvmFunction)); - context->set_short_position(deriveFunctionShortPosition(llvmFunction)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateFunctionContractOrigin(llvm::Function &llvmFunction, const std::string& contract) { - col::Origin *origin = new col::Origin(); - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(contract); - context->set_inline_context(contract); - context->set_short_position(deriveFunctionShortPosition(llvmFunction)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateGlobalValOrigin(llvm::Module &llvmModule, const std::string &globVal) { - col::Origin *origin = new col::Origin(); - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(globVal); - context->set_inline_context(globVal); - context->set_short_position(deriveModuleShortPosition(llvmModule)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateArgumentOrigin(llvm::Argument &llvmArgument) { - col::Origin *origin = new col::Origin(); - col::OriginContent *preferredNameContent = origin->add_content(); - col::PreferredName *preferredName = new col::PreferredName(); - preferredName->add_preferred_name(deriveArgumentPreferredName(llvmArgument)); - preferredNameContent->set_allocated_preferred_name(preferredName); - - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveFunctionContext(*llvmArgument.getParent())); - context->set_inline_context(deriveFunctionContext(*llvmArgument.getParent())); - context->set_short_position(deriveFunctionShortPosition(*llvmArgument.getParent())); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateBlockOrigin(llvm::BasicBlock &llvmBlock) { - col::Origin *origin = new col::Origin(); - col::OriginContent *preferredNameContent = origin->add_content(); - col::PreferredName *preferredName = new col::PreferredName(); - preferredName->add_preferred_name("block"); - preferredNameContent->set_allocated_preferred_name(preferredName); - - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveBlockContext(llvmBlock)); - context->set_inline_context(deriveBlockContext(llvmBlock)); - context->set_short_position(deriveBlockShortPosition(llvmBlock)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateLabelOrigin(llvm::BasicBlock &llvmBlock) { - col::Origin *origin = new col::Origin(); - col::OriginContent *preferredNameContent = origin->add_content(); - col::PreferredName *preferredName = new col::PreferredName(); - preferredName->add_preferred_name("label"); - preferredNameContent->set_allocated_preferred_name(preferredName); - - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveLabelContext(llvmBlock)); - context->set_inline_context(deriveLabelContext(llvmBlock)); - context->set_short_position(deriveBlockShortPosition(llvmBlock)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateSingleStatementOrigin(llvm::Instruction &llvmInstruction) { - col::Origin *origin = new col::Origin(); - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveSurroundingInstructionContext(llvmInstruction)); - context->set_inline_context(deriveInstructionContext(llvmInstruction)); - context->set_short_position(deriveInstructionShortPosition(llvmInstruction)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateAssignTargetOrigin(llvm::Instruction &llvmInstruction) { - col::Origin *origin = new col::Origin(); - col::OriginContent *preferredNameContent = origin->add_content(); - col::PreferredName *preferredName = new col::PreferredName(); - preferredName->add_preferred_name("var"); - preferredNameContent->set_allocated_preferred_name(preferredName); - - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveInstructionContext(llvmInstruction)); - context->set_inline_context(deriveInstructionLhs(llvmInstruction)); - context->set_short_position(deriveInstructionShortPosition(llvmInstruction)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateBinExprOrigin(llvm::Instruction &llvmInstruction) { - col::Origin *origin = new col::Origin(); - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveSurroundingInstructionContext(llvmInstruction)); - context->set_inline_context(deriveInstructionContext(llvmInstruction)); - context->set_short_position(deriveInstructionShortPosition(llvmInstruction)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateFunctionCallOrigin(llvm::CallInst &callInstruction) { - col::Origin *origin = new col::Origin(); - col::OriginContent *preferredNameContent = origin->add_content(); - col::PreferredName *preferredName = new col::PreferredName(); - preferredName->add_preferred_name(callInstruction.getCalledFunction()->getName().str()); - preferredNameContent->set_allocated_preferred_name(preferredName); - - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveSurroundingInstructionContext(callInstruction)); - context->set_inline_context(deriveInstructionRhs(callInstruction)); - context->set_short_position(deriveInstructionShortPosition(callInstruction)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateOperandOrigin(llvm::Instruction &llvmInstruction, llvm::Value &llvmOperand) { - col::Origin *origin = new col::Origin(); - col::OriginContent *preferredNameContent = origin->add_content(); - col::PreferredName *preferredName = new col::PreferredName(); - preferredName->add_preferred_name(deriveOperandPreferredName(llvmOperand)); - preferredNameContent->set_allocated_preferred_name(preferredName); - - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveInstructionContext(llvmInstruction)); - context->set_inline_context(deriveOperandContext(llvmOperand)); - context->set_short_position(deriveInstructionShortPosition(llvmInstruction)); - contextContent->set_allocated_context(context); - - origin->CheckInitialized(); - return origin; - } - - col::Origin *generateTypeOrigin(llvm::Type &llvmType) { - col::Origin *origin = new col::Origin(); - col::OriginContent *preferredNameContent = origin->add_content(); - col::PreferredName *preferredName = new col::PreferredName(); - preferredName->add_preferred_name(deriveTypePreferredName(llvmType)); - preferredNameContent->set_allocated_preferred_name(preferredName); - - origin->CheckInitialized(); - return origin; - } -} diff --git a/src/llvm/lib/origin/PreferredNameDeriver.cpp b/src/llvm/lib/origin/PreferredNameDeriver.cpp deleted file mode 100644 index 94717208f5..0000000000 --- a/src/llvm/lib/origin/PreferredNameDeriver.cpp +++ /dev/null @@ -1,69 +0,0 @@ -#include "Origin/PreferredNameDeriver.h" - -#include -#include -#include - -namespace llvm2Col { - std::string deriveOperandPreferredName(llvm::Value &llvmOperand) { - std::string preferredName; - llvm::raw_string_ostream preferredNameStream = llvm::raw_string_ostream(preferredName); - preferredNameStream << (llvm::isa(llvmOperand) ? "const_" : "var_"); - llvmOperand.printAsOperand(preferredNameStream, false); - return preferredName; - } - - std::string deriveTypePreferredName(llvm::Type &llvmType) { - std::string prefix = "t_"; - switch(llvmType.getTypeID()) { - case llvm::Type::HalfTyID: - return prefix + "half"; - case llvm::Type::BFloatTyID: - return prefix + "bfloat"; - case llvm::Type::FloatTyID: - return prefix + "float"; - case llvm::Type::DoubleTyID: - return prefix + "double"; - case llvm::Type::X86_FP80TyID: - return prefix + "x86fp80"; - case llvm::Type::FP128TyID: - return prefix + "fp128"; - case llvm::Type::PPC_FP128TyID: - return prefix + "ppcfp128"; - case llvm::Type::VoidTyID: - return prefix + "void"; - case llvm::Type::LabelTyID: - return prefix + "label"; - case llvm::Type::MetadataTyID: - return prefix + "metadata"; - case llvm::Type::X86_MMXTyID: - return prefix + "x86mmx"; - case llvm::Type::X86_AMXTyID: - return prefix + "x86amx"; - case llvm::Type::TokenTyID: - return prefix + "token"; - case llvm::Type::IntegerTyID: - return prefix + (llvmType.getIntegerBitWidth() == 1 ? "boolean" : "integer"); - case llvm::Type::FunctionTyID: - return prefix + "function"; - case llvm::Type::PointerTyID: - return prefix + "ptr"; - case llvm::Type::StructTyID: - return prefix + "struct"; - case llvm::Type::ArrayTyID: - return prefix + "array"; - case llvm::Type::FixedVectorTyID: - return prefix + "fixedvector"; - case llvm::Type::ScalableVectorTyID: - return prefix + "scalevector"; - } - } - std::string deriveArgumentPreferredName(llvm::Argument &llvmArgument) { - std::string preferredName; - llvm::raw_string_ostream preferredNameStream = llvm::raw_string_ostream(preferredName); - preferredNameStream << "arg_"; - llvmArgument.printAsOperand(preferredNameStream, false); - return preferredName; - } - -} \ No newline at end of file diff --git a/src/llvm/lib/origin/ShortPositionDeriver.cpp b/src/llvm/lib/origin/ShortPositionDeriver.cpp deleted file mode 100644 index 5438d15886..0000000000 --- a/src/llvm/lib/origin/ShortPositionDeriver.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "Origin/ShortPositionDeriver.h" -#include "Origin/ContextDeriver.h" - -namespace llvm2Col { - const std::string POSITION_POINTER = "\n\t -> "; - - std::string deriveModuleShortPosition(llvm::Module &llvmModule) { - return "file " + llvmModule.getSourceFileName(); - } - - std::string deriveFunctionShortPosition(llvm::Function &llvmFunction) { - std::string functionPosition = deriveModuleShortPosition(*llvmFunction.getParent()); - llvm::raw_string_ostream functionPosStream = llvm::raw_string_ostream(functionPosition); - functionPosStream << POSITION_POINTER << "function "; - llvmFunction.printAsOperand(functionPosStream, false); - return functionPosition; - } - - std::string deriveBlockShortPosition(llvm::BasicBlock &llvmBlock) { - std::string blockPosition = deriveFunctionShortPosition(*llvmBlock.getParent()); - llvm::raw_string_ostream blockPosStream = llvm::raw_string_ostream(blockPosition); - blockPosStream << POSITION_POINTER << "block "; - llvmBlock.printAsOperand(blockPosStream, false); - blockPosStream << (llvmBlock.isEntryBlock() ? " (entryblock)" : ""); - return blockPosition; - } - - std::string deriveInstructionShortPosition(llvm::Instruction &llvmInstruction) { - std::string instructionPosition = deriveBlockShortPosition(*llvmInstruction.getParent()); - llvm::raw_string_ostream instructionPosStream = llvm::raw_string_ostream(instructionPosition); - int pos = 0; - llvm::BasicBlock *bb = llvmInstruction.getParent(); - for (auto &I: *bb) { - pos++; - if (&I == &llvmInstruction) { - break; - } - } - instructionPosStream << POSITION_POINTER << "instruction #" << pos << " (" - << deriveInstructionContext(llvmInstruction) << ')'; - return instructionPosition; - } -} \ No newline at end of file diff --git a/src/llvm/lib/passes/Function/FunctionBodyTransformer.cpp b/src/llvm/lib/passes/Function/FunctionBodyTransformer.cpp deleted file mode 100644 index 3b27a2394c..0000000000 --- a/src/llvm/lib/passes/Function/FunctionBodyTransformer.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "Passes/Function/FunctionBodyTransformer.h" -#include "Passes/Function/FunctionContractDeclarer.h" -#include - -#include "Passes/Function/FunctionDeclarer.h" -#include "Transform/BlockTransform.h" -#include "Transform/Transform.h" -#include "Origin/OriginProvider.h" -#include "Util/Exceptions.h" - - -namespace vcllvm { - const std::string SOURCE_LOC = "Passes::Function::FunctionBodyTransformer"; - - FunctionCursor::FunctionCursor(col::Scope &functionScope, - col::Block &functionBody, - llvm::Function &llvmFunction, - llvm::FunctionAnalysisManager &FAM) : - functionScope(functionScope), functionBody(functionBody), llvmFunction(llvmFunction), FAM(FAM) {} - - const col::Scope &FunctionCursor::getFunctionScope() { - return functionScope; - } - - void FunctionCursor::addVariableMapEntry(Value &llvmValue, col::Variable &colVar) { - variableMap.insert({&llvmValue, &colVar}); - // add reference to reference lut of function contract - col::Tuple2_String_Ref_VctColAstVariable *ref = FAM.getResult(llvmFunction).getAssociatedColFuncContract().add_variable_refs(); - ref->set_v1(llvm2Col::getValueName(llvmValue)); - ref->mutable_v2()->set_id(colVar.id()); - } - - col::Variable &FunctionCursor::getVariableMapEntry(Value &llvmValue) { - return *variableMap.at(&llvmValue); - } - - bool FunctionCursor::isVisited(BasicBlock &llvmBlock) { - return llvmBlock2LabeledColBlock.contains(&llvmBlock); - } - - void FunctionCursor::complete(col::Block &colBlock) { - completedColBlocks.insert(&colBlock); - } - bool FunctionCursor::isComplete(col::Block &colBlock) { - return completedColBlocks.contains(&colBlock); - } - - LabeledColBlock &FunctionCursor::getOrSetLlvmBlock2LabeledColBlockEntry(BasicBlock &llvmBlock) { - if (!llvmBlock2LabeledColBlock.contains(&llvmBlock)) { - // create label in buffer - col::Label *label = functionBody.add_statements()->mutable_label(); - // set label origin - label->set_allocated_origin(llvm2Col::generateLabelOrigin(llvmBlock)); - // create label declaration in buffer - col::LabelDecl *labelDecl = label->mutable_decl(); - // set label decl origin - labelDecl->set_allocated_origin(llvm2Col::generateLabelOrigin(llvmBlock)); - // set label decl id - llvm2Col::setColNodeId(labelDecl); - // create block inside label statement - col::Block *block = label->mutable_stat()->mutable_block(); - // set block origin - block->set_allocated_origin(llvm2Col::generateBlockOrigin(llvmBlock)); - // add labeled block to the block2block lut - LabeledColBlock labeledColBlock = {*label, *block}; - llvmBlock2LabeledColBlock.insert({&llvmBlock, labeledColBlock}); - } - return llvmBlock2LabeledColBlock.at(&llvmBlock); - } - - LoopInfo &FunctionCursor::getLoopInfo() { - return FAM.getResult(llvmFunction); - } - - LoopInfo &FunctionCursor::getLoopInfo(Function &otherLlvmFunction) { - return FAM.getResult(otherLlvmFunction); - } - - FDResult &FunctionCursor::getFDResult() { - return FAM.getResult(llvmFunction); - } - - FDResult &FunctionCursor::getFDResult(Function &otherLlvmFunction) { - return FAM.getResult(otherLlvmFunction); - } - - col::Variable &FunctionCursor::declareVariable(Instruction &llvmInstruction) { - // create declaration in buffer - col::Variable *varDecl = functionScope.add_locals(); - // set type of declaration - try { - llvm2Col::transformAndSetType(*llvmInstruction.getType(), *varDecl->mutable_t()); - } catch (vcllvm::UnsupportedTypeException &e) { - std::stringstream errorStream; - errorStream << e.what() << " in variable declaration."; - ErrorReporter::addError(SOURCE_LOC, errorStream.str(), llvmInstruction); - } - // set id - llvm2Col::setColNodeId(varDecl); - // set origin - varDecl->set_allocated_origin(llvm2Col::generateSingleStatementOrigin(llvmInstruction)); - // add to the variable lut - this->addVariableMapEntry(llvmInstruction, *varDecl); - return *varDecl; - } - - col::Assign &FunctionCursor::createAssignmentAndDeclaration(Instruction &llvmInstruction, col::Block &colBlock) { - col::Variable &varDecl = declareVariable(llvmInstruction); - return createAssignment(llvmInstruction, colBlock, varDecl); - } - - col::Assign &FunctionCursor::createAssignment(Instruction &llvmInstruction, - col::Block &colBlock, - col::Variable &varDecl) { - col::Assign *assignment = colBlock.add_statements()->mutable_assign(); - assignment->set_allocated_origin(llvm2Col::generateSingleStatementOrigin(llvmInstruction)); - assignment->set_allocated_blame(new col::Blame()); - // create local target in buffer and set origin - col::Local *colLocal = assignment->mutable_target()->mutable_local(); - colLocal->set_allocated_origin(llvm2Col::generateAssignTargetOrigin(llvmInstruction)); - // set target to refer to var decl - colLocal->mutable_ref()->set_id(varDecl.id()); - if(isComplete(colBlock)) { - // if the colBlock is completed, the assignment will be inserted after the goto/branch statement - // this can occur due to e.g. phi nodes back tracking assignments in their origin blocks. - // therefore we need to swap the last two elements of the block - // (i.e. the goto statement and the newest assignment) - int lastIndex = colBlock.statements_size() - 1; - colBlock.mutable_statements()->SwapElements(lastIndex, lastIndex - 1); - } - return *assignment; - } - - FunctionBodyTransformerPass::FunctionBodyTransformerPass(std::shared_ptr pProgram) : - pProgram(std::move(pProgram)) {} - - PreservedAnalyses FunctionBodyTransformerPass::run(Function &F, FunctionAnalysisManager &FAM) { - ColScopedFuncBody scopedFuncBody = FAM.getResult(F).getAssociatedScopedColFuncBody(); - FunctionCursor funcCursor = FunctionCursor(*scopedFuncBody.scope, *scopedFuncBody.block, F, FAM); - // add function arguments to the variableMap - for (auto &A: F.args()) { - funcCursor.addVariableMapEntry(A, FAM.getResult(F).getFuncArgMapEntry(A)); - } - // start recursive block code gen with basic block - llvm::BasicBlock &entryBlock = F.getEntryBlock(); - llvm2Col::transformLlvmBlock(entryBlock, funcCursor); - return PreservedAnalyses::all(); - } -} diff --git a/src/llvm/lib/passes/Function/FunctionContractDeclarer.cpp b/src/llvm/lib/passes/Function/FunctionContractDeclarer.cpp deleted file mode 100644 index 1e03aa54ce..0000000000 --- a/src/llvm/lib/passes/Function/FunctionContractDeclarer.cpp +++ /dev/null @@ -1,90 +0,0 @@ -#include "Passes/Function/FunctionContractDeclarer.h" -#include - -#include "Passes/Function/FunctionDeclarer.h" -#include "Util/Constants.h" -#include "Util/Exceptions.h" -#include "Origin/OriginProvider.h" - - -namespace vcllvm { - const std::string SOURCE_LOC = "Passes::Function::FunctionContractDeclarer"; - - using namespace llvm; - - /* - * Function Contract Declarer Result - */ - - FDCResult::FDCResult(vct::col::ast::LlvmFunctionContract &colFuncContract) : - associatedColFuncContract(colFuncContract) {} - - col::LlvmFunctionContract &FDCResult::getAssociatedColFuncContract() { - return associatedColFuncContract; - } - - /* - * Function Contract Declarer (Analysis) - */ - - AnalysisKey FunctionContractDeclarer::Key; - - - FunctionContractDeclarer::FunctionContractDeclarer(std::shared_ptr pProgram) : - pProgram(std::move(pProgram)) {} - - FunctionContractDeclarer::Result FunctionContractDeclarer::run(Function &F, FunctionAnalysisManager &FAM) { - // fetch relevant function from the Function Declarer - FDResult fdResult = FAM.getResult(F); - col::LlvmFunctionDefinition &colFunction = fdResult.getAssociatedColFuncDef(); - // set a contract in the buffer as well as make and return a result object - return FDCResult(*colFunction.mutable_contract()); - } - - /* - * Function Contract Declarer Pass - */ - - FunctionContractDeclarerPass::FunctionContractDeclarerPass(std::shared_ptr pProgram) : - pProgram(std::move(pProgram)) {} - - PreservedAnalyses FunctionContractDeclarerPass::run(Function &F, FunctionAnalysisManager &FAM) { - // get col contract - FDCResult result = FAM.getResult(F); - col::LlvmFunctionContract &colContract = result.getAssociatedColFuncContract(); - colContract.set_allocated_blame(new col::Blame()); - // check if contract keyword is present - if (!F.hasMetadata(vcllvm::constants::METADATA_CONTRACT_KEYWORD)) { - // set contract to a tautology - colContract.set_value("requires true;"); - colContract.set_allocated_origin(new col::Origin()); - return PreservedAnalyses::all(); - } - // concatenate all contract lines with new lines - MDNode *contractMDNode = F.getMetadata(vcllvm::constants::METADATA_CONTRACT_KEYWORD); - std::stringstream contractStream; - for (u_int32_t i = 0; i < contractMDNode->getNumOperands(); i++) { - auto contractLine = dyn_cast(contractMDNode->getOperand(i)); - if (contractLine == nullptr) { - std::stringstream errorStream; - errorStream << "Unable to cast contract metadata node #" << i + 1 << "to string type"; - vcllvm::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F); - break; - } - contractStream << contractLine->getString().str() << '\n'; - } - colContract.set_value(contractStream.str()); - colContract.set_allocated_origin(llvm2Col::generateFunctionContractOrigin(F, contractStream.str())); - // add all callable functions to the contracts invokables - for(auto &moduleF : F.getParent()->functions()) { - std::string fName = '@' + moduleF.getName().str(); - int64_t fId = FAM.getResult(moduleF).getFunctionId(); - col::Tuple2_String_Ref_VctColAstLlvmCallable *invokeRef = colContract.add_invokable_refs(); - invokeRef->set_v1(fName); - invokeRef->mutable_v2()->set_id(fId); - invokeRef->CheckInitialized(); - } - colContract.CheckInitialized(); - return PreservedAnalyses::all(); - } -} diff --git a/src/llvm/lib/passes/Function/FunctionDeclarer.cpp b/src/llvm/lib/passes/Function/FunctionDeclarer.cpp deleted file mode 100644 index 82173404a6..0000000000 --- a/src/llvm/lib/passes/Function/FunctionDeclarer.cpp +++ /dev/null @@ -1,120 +0,0 @@ -#include -#include "Passes/Function/FunctionDeclarer.h" - -#include "Transform/Transform.h" -#include "Origin/OriginProvider.h" -#include "Util/Exceptions.h" - - -namespace vcllvm { - const std::string SOURCE_LOC = "Passes::Function::FunctionDeclarer"; - using namespace llvm; - - /** - * Checks function definition for unsupported features that might change semantics and - * adds warning if this is the case. - * @param llvmFunction: the function to be checked - */ - void checkFunctionSupport(llvm::Function &llvmFunction) { - // TODO add syntax support checks that change the semantics of the program to function definitions - // TODO see: https://releases.llvm.org/15.0.0/docs/LangRef.html#functions - } - - /* - * Function Declarer Result - */ - - FDResult::FDResult(col::LlvmFunctionDefinition &colFuncDef, - ColScopedFuncBody associatedScopedColFuncBody, - int64_t functionId) : - associatedColFuncDef(colFuncDef), - associatedScopedColFuncBody(associatedScopedColFuncBody), - functionId(functionId) {} - - col::LlvmFunctionDefinition &FDResult::getAssociatedColFuncDef() { - return associatedColFuncDef; - } - - ColScopedFuncBody FDResult::getAssociatedScopedColFuncBody() { - return associatedScopedColFuncBody; - } - - void FDResult::addFuncArgMapEntry(Argument &llvmArg, col::Variable &colArg) { - funcArgMap.insert({&llvmArg, &colArg}); - } - - col::Variable &FDResult::getFuncArgMapEntry(Argument &arg) { - return *funcArgMap.at(&arg); - } - - int64_t &FDResult::getFunctionId() { - return functionId; - } - - - /* - * Function Declarer (Analysis) - */ - AnalysisKey FunctionDeclarer::Key; - - FunctionDeclarer::FunctionDeclarer(std::shared_ptr pProgram) : - pProgram(std::move(pProgram)) {} - - FDResult FunctionDeclarer::run(Function &F, FunctionAnalysisManager &FAM) { - checkFunctionSupport(F); - // create llvmFuncDef declaration in buffer - col::GlobalDeclaration *llvmFuncDefDecl = pProgram->add_declarations(); - // generate id - col::LlvmFunctionDefinition *llvmFuncDef = llvmFuncDefDecl->mutable_llvm_function_definition(); - int64_t functionId = llvm2Col::setColNodeId(llvmFuncDef); - // add body block + scope + origin - // set origin - llvmFuncDef->set_allocated_origin(llvm2Col::generateFuncDefOrigin(F)); - llvmFuncDef->set_allocated_blame(new col::Blame()); - ColScopedFuncBody funcScopedBody{}; - funcScopedBody.scope = llvmFuncDef->mutable_function_body()->mutable_scope(); - funcScopedBody.scope->set_allocated_origin(llvm2Col::generateFuncDefOrigin(F)); - funcScopedBody.block = funcScopedBody.scope->mutable_body()->mutable_block(); - funcScopedBody.block->set_allocated_origin(llvm2Col::generateFuncDefOrigin(F)); - FDResult result = FDResult(*llvmFuncDef, funcScopedBody, functionId); - // set args (if present) - for (llvm::Argument &llvmArg: F.args()) { - // set in buffer - col::Variable *colArg = llvmFuncDef->add_args(); - // set origin - colArg->set_allocated_origin(llvm2Col::generateArgumentOrigin(llvmArg)); - llvm2Col::setColNodeId(colArg); - try { - llvm2Col::transformAndSetType(*llvmArg.getType(), *colArg->mutable_t()); - } catch (vcllvm::UnsupportedTypeException &e) { - std::stringstream errorStream; - errorStream << e.what() << " in argument #" << llvmArg.getArgNo(); - vcllvm::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F); - } - // add args mapping to result - result.addFuncArgMapEntry(llvmArg, *colArg); - } - return result; - } - - /* - * Function Declarer Pass - */ - FunctionDeclarerPass::FunctionDeclarerPass(std::shared_ptr pProgram) : - pProgram(std::move(pProgram)) {} - - PreservedAnalyses FunctionDeclarerPass::run(Function &F, FunctionAnalysisManager &FAM) { - FDResult result = FAM.getResult(F); - col::LlvmFunctionDefinition &colFunction = result.getAssociatedColFuncDef(); - // complete the function declaration in proto buffer - // set return type in protobuf of function - try { - llvm2Col::transformAndSetType(*F.getReturnType(), *colFunction.mutable_return_type()); - } catch (vcllvm::UnsupportedTypeException &e) { - std::stringstream errorStream; - errorStream << e.what() << " in return signature"; - vcllvm::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F); - } - return PreservedAnalyses::all(); - } -} diff --git a/src/llvm/lib/passes/Function/PureAssigner.cpp b/src/llvm/lib/passes/Function/PureAssigner.cpp deleted file mode 100644 index a4378dfc56..0000000000 --- a/src/llvm/lib/passes/Function/PureAssigner.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "Passes/Function/PureAssigner.h" - -#include "Passes/Function/FunctionDeclarer.h" -#include "Util/Constants.h" -#include "Util/Exceptions.h" - -namespace vcllvm { - const std::string SOURCE_LOC = "Passes::Function::PureAssigner"; - - using namespace llvm; - - PureAssignerPass::PureAssignerPass(std::shared_ptr pProgram) : - pProgram(std::move(pProgram)) {} - - PreservedAnalyses PureAssignerPass::run(Function &F, FunctionAnalysisManager &FAM) { - std::ostringstream errorStream; - FDResult result = FAM.getResult(F); - col::LlvmFunctionDefinition &colFunction = result.getAssociatedColFuncDef(); - // check if pure keyword is present, else assume unpure function - if (!F.hasMetadata(vcllvm::constants::METADATA_PURE_KEYWORD)) { - colFunction.set_pure(false); - return PreservedAnalyses::all(); - } - // check if the 'pure' metadata has only 1 operand, else exit with error - MDNode *pureMDNode = F.getMetadata(vcllvm::constants::METADATA_PURE_KEYWORD); - if (pureMDNode->getNumOperands() != 1) { - errorStream << "Expected 1 argument but got " << pureMDNode->getNumOperands(); - reportError(F, errorStream.str()); - return PreservedAnalyses::all(); - } - // check if the only operand is of type 'i1', else exit with error - auto *pureMDValue = cast(pureMDNode->getOperand(0)); - if (!pureMDValue->getType()->isIntegerTy(1)) { - errorStream << "MD node type must be of type \"i1\""; - reportError(F, errorStream.str()); - return PreservedAnalyses::all(); - } - // attempt down cast to ConstantInt (which shouldn't fail given previous checks) - bool purity = cast(pureMDValue)->getValue()->isOneValue(); - colFunction.set_pure(purity); - return PreservedAnalyses::all(); - } - - void reportError(Function &F, const std::string &explanation) { - std::stringstream errorStream; - errorStream << "Malformed Metadata node of type \"" << vcllvm::constants::METADATA_PURE_KEYWORD - << "\":" << explanation; - vcllvm::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F); - } -} \ No newline at end of file diff --git a/src/llvm/lib/passes/Module/ModuleSpecCollector.cpp b/src/llvm/lib/passes/Module/ModuleSpecCollector.cpp deleted file mode 100644 index b523972560..0000000000 --- a/src/llvm/lib/passes/Module/ModuleSpecCollector.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "Passes/Module/ModuleSpecCollector.h" -#include -#include "Util/Constants.h" -#include "Util/Exceptions.h" -#include "Origin/OriginProvider.h" -#include "Transform/Transform.h" - -namespace vcllvm { - const std::string SOURCE_LOC = "Passes::Module::GlobalSpecCollector"; - - using namespace llvm; - - ModuleSpecCollectorPass::ModuleSpecCollectorPass(std::shared_ptr pProgram) : - pProgram(std::move(pProgram)) {} - - PreservedAnalyses ModuleSpecCollectorPass::run(Module &M, ModuleAnalysisManager &MAM) { - NamedMDNode *globalMDNode = M.getNamedMetadata(vcllvm::constants::METADATA_GLOBAL_KEYWORD); - if(globalMDNode == nullptr) { - return PreservedAnalyses::all(); - } - for (u_int32_t i = 0; i < globalMDNode->getNumOperands(); i++) { - for (u_int32_t j = 0; j < globalMDNode->getOperand(i)->getNumOperands(); j++) { - auto globVal = dyn_cast(globalMDNode->getOperand(i)->getOperand(j)); - if (globVal == nullptr) { - std::stringstream errorStream; - errorStream << "Unable to cast global metadata node #" << i + 1 << "to string type"; - vcllvm::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), M); - break; - } - col::GlobalDeclaration *globDecl = pProgram->add_declarations(); - col::LlvmGlobal *colGlobal = globDecl->mutable_llvm_global(); - llvm2Col::setColNodeId(colGlobal); - colGlobal->set_value(globVal->getString().str()); - colGlobal->set_allocated_origin(llvm2Col::generateGlobalValOrigin(M, globVal->getString().str())); - } - } - return PreservedAnalyses::all(); - } - -} \ No newline at end of file diff --git a/src/llvm/lib/transform/BlockTransform.cpp b/src/llvm/lib/transform/BlockTransform.cpp deleted file mode 100644 index d0cf7564da..0000000000 --- a/src/llvm/lib/transform/BlockTransform.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "Transform/BlockTransform.h" - -#include "Transform/Instruction/TermOpTransform.h" -#include "Transform/Instruction/BinaryOpTransform.h" -#include "Transform/Instruction/UnaryOpTransform.h" -#include "Transform/Instruction/MemoryOpTransform.h" -#include "Transform/Instruction/FuncletPadOpTransform.h" -#include "Transform/Instruction/OtherOpTransform.h" -#include "Util/Exceptions.h" - -namespace llvm2Col { - const std::string SOURCE_LOC = "Transform::BlockTransform"; - - void transformLlvmBlock(llvm::BasicBlock &llvmBlock, vcllvm::FunctionCursor &functionCursor) { - col::Block &colBlock = functionCursor.getOrSetLlvmBlock2LabeledColBlockEntry(llvmBlock).block; - for (auto *B: llvm::predecessors(&llvmBlock)) { - if (!functionCursor.isVisited(*B)) return; - } - if (functionCursor.getLoopInfo().isLoopHeader(&llvmBlock)) { - transformLoop(llvmBlock, functionCursor); - return; - } - for (auto &I: llvmBlock) { - transformInstruction(functionCursor, I, colBlock); - } - functionCursor.complete(colBlock); - } - - void transformInstruction(vcllvm::FunctionCursor &funcCursor, - llvm::Instruction &llvmInstruction, - col::Block &colBodyBlock) { - u_int32_t opCode = llvmInstruction.getOpcode(); - if (llvm::Instruction::TermOpsBegin <= opCode && opCode < llvm::Instruction::TermOpsEnd) { - llvm2Col::transformTermOp(llvmInstruction, colBodyBlock, funcCursor); - } else if (llvm::Instruction::BinaryOpsBegin <= opCode && opCode < llvm::Instruction::BinaryOpsEnd) { - llvm2Col::transformBinaryOp(llvmInstruction, colBodyBlock, funcCursor); - } else if (llvm::Instruction::UnaryOpsBegin <= opCode && opCode < llvm::Instruction::UnaryOpsEnd) { - llvm2Col::transformUnaryOp(llvmInstruction, colBodyBlock, funcCursor); - } else if (llvm::Instruction::MemoryOpsBegin <= opCode && opCode < llvm::Instruction::MemoryOpsEnd) { - llvm2Col::transformMemoryOp(llvmInstruction, colBodyBlock, funcCursor); - } else if (llvm::Instruction::FuncletPadOpsBegin <= opCode && opCode < llvm::Instruction::FuncletPadOpsEnd) { - llvm2Col::transformFuncletPadOp(llvmInstruction, colBodyBlock, funcCursor); - } else if (llvm::Instruction::OtherOpsBegin <= opCode && opCode < llvm::Instruction::OtherOpsEnd) { - llvm2Col::transformOtherOp(llvmInstruction, colBodyBlock, funcCursor); - } else { - reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); - } - } - - void transformLoop(llvm::BasicBlock &llvmBlock, vcllvm::FunctionCursor &functionCursor) { - vcllvm::ErrorReporter::addError(SOURCE_LOC, "Unsupported loop detected", llvmBlock); - } - - void reportUnsupportedOperatorError(const std::string &source, llvm::Instruction &llvmInstruction) { - std::stringstream errorStream; - errorStream << "Unsupported operator \"" << llvmInstruction.getOpcodeName() << '"'; - vcllvm::ErrorReporter::addError(source, errorStream.str(), llvmInstruction); - } -} diff --git a/src/llvm/lib/transform/Instruction/BinaryOpTransform.cpp b/src/llvm/lib/transform/Instruction/BinaryOpTransform.cpp deleted file mode 100644 index ac9684dc57..0000000000 --- a/src/llvm/lib/transform/Instruction/BinaryOpTransform.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "Transform/Instruction/BinaryOpTransform.h" - - -#include "Transform/Transform.h" -#include "Transform/BlockTransform.h" -#include "Origin/OriginProvider.h" -#include "Util/Exceptions.h" - -namespace llvm2Col { - const std::string SOURCE_LOC = "Transform::Instruction::BinaryOp"; - - void transformBinaryOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - col::Assign &assignment = funcCursor.createAssignmentAndDeclaration(llvmInstruction, colBlock); - switch (llvm::Instruction::BinaryOps(llvmInstruction.getOpcode())) { - case llvm::Instruction::Add: { - col::Plus &expr = *assignment.mutable_value()->mutable_plus(); - transformBinExpr(llvmInstruction, expr, funcCursor); - break; - } - case llvm::Instruction::Sub: { - col::Minus &expr = *assignment.mutable_value()->mutable_minus(); - transformBinExpr(llvmInstruction, expr, funcCursor); - break; - } - case llvm::Instruction::Mul: { - col::Mult &expr = *assignment.mutable_value()->mutable_mult(); - transformBinExpr(llvmInstruction, expr, funcCursor); - break; - } - case llvm::Instruction::SDiv: - case llvm::Instruction::UDiv: { - if(llvmInstruction.isExact()) { - vcllvm::ErrorReporter::addError(SOURCE_LOC, "Exact division not supported", llvmInstruction); - } - col::FloorDiv &expr = *assignment.mutable_value()->mutable_floor_div(); - expr.set_allocated_blame(new col::Blame()); - transformBinExpr(llvmInstruction, expr, funcCursor); - break; - } - default: - reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); - } - } -} diff --git a/src/llvm/lib/transform/Instruction/CastOpTransform.cpp b/src/llvm/lib/transform/Instruction/CastOpTransform.cpp deleted file mode 100644 index 60d4cae7c9..0000000000 --- a/src/llvm/lib/transform/Instruction/CastOpTransform.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include "Transform/Instruction/CastOpTransform.h" - -#include "Transform/BlockTransform.h" -#include "Util/Exceptions.h" - -namespace llvm2Col { - const std::string SOURCE_LOC = "Transform::Instruction::CastOp"; - void convertCastOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - //TODO stub - reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); - } -} \ No newline at end of file diff --git a/src/llvm/lib/transform/Instruction/FuncletPadOpTransform.cpp b/src/llvm/lib/transform/Instruction/FuncletPadOpTransform.cpp deleted file mode 100644 index bc0ebf77ef..0000000000 --- a/src/llvm/lib/transform/Instruction/FuncletPadOpTransform.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "Transform/Instruction/FuncletPadOpTransform.h" - -#include "Transform/BlockTransform.h" -#include "Util/Exceptions.h" - -namespace llvm2Col { - const std::string SOURCE_LOC = "Transform::Instruction::FuncletPadOp"; - - void transformFuncletPadOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - //TODO stub - reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); - } -} \ No newline at end of file diff --git a/src/llvm/lib/transform/Instruction/MemoryOpTransform.cpp b/src/llvm/lib/transform/Instruction/MemoryOpTransform.cpp deleted file mode 100644 index d4d9f64a68..0000000000 --- a/src/llvm/lib/transform/Instruction/MemoryOpTransform.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "Transform/Instruction/MemoryOpTransform.h" - -#include "Transform/BlockTransform.h" -#include "Util/Exceptions.h" - -namespace llvm2Col { - const std::string SOURCE_LOC = "Transform::Instruction::MemoryOp"; - - void transformMemoryOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - //TODO stub - reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); - } -} \ No newline at end of file diff --git a/src/llvm/lib/transform/Instruction/OtherOpTransform.cpp b/src/llvm/lib/transform/Instruction/OtherOpTransform.cpp deleted file mode 100644 index a07112ff48..0000000000 --- a/src/llvm/lib/transform/Instruction/OtherOpTransform.cpp +++ /dev/null @@ -1,151 +0,0 @@ -#include -#include "Transform/Instruction/OtherOpTransform.h" - -#include "Transform/BlockTransform.h" -#include "Transform/Transform.h" -#include "Util/Exceptions.h" - -namespace llvm2Col { - const std::string SOURCE_LOC = "Transform::Instruction::OtherOp"; - - void transformOtherOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - switch (llvm::Instruction::OtherOps(llvmInstruction.getOpcode())) { - case llvm::Instruction::PHI: - transformPhi(llvm::cast(llvmInstruction), funcCursor); - break; - case llvm::Instruction::ICmp: - transformICmp(llvm::cast(llvmInstruction), colBlock, funcCursor); - break; - case llvm::Instruction::Call: - transformCallExpr(llvm::cast(llvmInstruction), colBlock, funcCursor); - break; - default: - reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); - } - } - - void transformPhi(llvm::PHINode &phiInstruction, - vcllvm::FunctionCursor &funcCursor) { - col::Variable &varDecl = funcCursor.declareVariable(phiInstruction); - for (auto &B: phiInstruction.blocks()) { - // add assignment of the variable to target block - col::Block &targetBlock = funcCursor.getOrSetLlvmBlock2LabeledColBlockEntry(*B).block; - col::Assign &assignment = funcCursor.createAssignment(phiInstruction, targetBlock, varDecl); - // assign correct value by looking at the value-block pair of phi instruction. - col::Expr *value = assignment.mutable_value(); - llvm2Col::transformAndSetExpr(funcCursor, phiInstruction, - *phiInstruction.getIncomingValueForBlock(B), *value); - } - } - - void transformICmp(llvm::ICmpInst &icmpInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - // we only support integer comparison - if (not icmpInstruction.getOperand(0)->getType()->isIntegerTy()) { - vcllvm::ErrorReporter::addError(SOURCE_LOC, "Unsupported compare type", icmpInstruction); - return; - } - col::Assign &assignment = funcCursor.createAssignmentAndDeclaration(icmpInstruction, colBlock); - switch (llvm::ICmpInst::Predicate(icmpInstruction.getPredicate())) { - case llvm::CmpInst::ICMP_EQ: { - col::Eq &eq = *assignment.mutable_value()->mutable_eq(); - transformCmpExpr(icmpInstruction, eq, funcCursor); - break; - } - case llvm::CmpInst::ICMP_NE: { - col::Neq &neq = *assignment.mutable_value()->mutable_neq(); - transformCmpExpr(icmpInstruction, neq, funcCursor); - break; - } - case llvm::CmpInst::ICMP_SGT: - case llvm::CmpInst::ICMP_UGT: { - col::Greater > = *assignment.mutable_value()->mutable_greater(); - transformCmpExpr(icmpInstruction, gt, funcCursor); - break; - } - case llvm::CmpInst::ICMP_SGE: - case llvm::CmpInst::ICMP_UGE: { - col::GreaterEq &geq = *assignment.mutable_value()->mutable_greater_eq(); - transformCmpExpr(icmpInstruction, geq, funcCursor); - break; - } - case llvm::CmpInst::ICMP_SLT: - case llvm::CmpInst::ICMP_ULT: { - col::Less < = *assignment.mutable_value()->mutable_less(); - transformCmpExpr(icmpInstruction, lt, funcCursor); - break; - } - case llvm::CmpInst::ICMP_SLE: - case llvm::CmpInst::ICMP_ULE: { - col::LessEq &leq = *assignment.mutable_value()->mutable_less_eq(); - transformCmpExpr(icmpInstruction, leq, funcCursor); - break; - } - default: - vcllvm::ErrorReporter::addError(SOURCE_LOC, "Unknown ICMP predicate", icmpInstruction); - } - } - - void transformCmpExpr(llvm::CmpInst &cmpInstruction, - auto &colCompareExpr, - vcllvm::FunctionCursor &funcCursor) { - transformBinExpr(cmpInstruction, colCompareExpr, funcCursor); - } - - void checkCallSupport(llvm::CallInst &callInstruction) { - // tail recursion - if (callInstruction.isMustTailCall() || callInstruction.isNoTailCall()) { - vcllvm::ErrorReporter::addError(SOURCE_LOC, "Tail call optimization not supported", callInstruction); - } - // fast math - if (callInstruction.getFastMathFlags().any()) { - vcllvm::ErrorReporter::addError(SOURCE_LOC, "Fast math not supported", callInstruction); - } - // return attributes - for (auto &A: callInstruction.getAttributes().getRetAttrs()) { - std::stringstream errorStream; - errorStream << "Return attribute \"" << A.getAsString() << "\" not supported"; - vcllvm::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), callInstruction); - } - // address space is platform dependent (unlikely to change semantics) - // function attributes are just extra compiler information (no semanatic changes) - - // operand bundles - if (callInstruction.hasOperandBundles()) { - vcllvm::ErrorReporter::addError(SOURCE_LOC, "Operand bundles not supported", callInstruction); - } - } - - void transformCallExpr(llvm::CallInst &callInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - checkCallSupport(callInstruction); - // allocate expression to host the function call in advance - col::Expr *functionCallExpr; - // if void function add an eval expression - if (callInstruction.getType()->isVoidTy()) { - col::Eval *eval = colBlock.add_statements()->mutable_eval(); - eval->set_allocated_origin(llvm2Col::generateSingleStatementOrigin(callInstruction)); - functionCallExpr = eval->mutable_expr(); - } else { // else create an assignment - col::Assign &assignment = funcCursor.createAssignmentAndDeclaration(callInstruction, colBlock); - functionCallExpr = assignment.mutable_value(); - } - // create actual invocation - col::LlvmFunctionInvocation *invocation = functionCallExpr->mutable_llvm_function_invocation(); - // set origin - invocation->set_allocated_origin(llvm2Col::generateFunctionCallOrigin(callInstruction)); - invocation->set_allocated_blame(new col::Blame()); - // set function reference - invocation->mutable_ref()->set_id( - funcCursor.getFDResult(*callInstruction.getCalledFunction()).getFunctionId() - ); - // process function arguments - for (auto &A: callInstruction.args()) { - llvm2Col::transformAndSetExpr(funcCursor, callInstruction, *A, *invocation->add_args()); - } - } -} diff --git a/src/llvm/lib/transform/Instruction/TermOpTransform.cpp b/src/llvm/lib/transform/Instruction/TermOpTransform.cpp deleted file mode 100644 index 752885052f..0000000000 --- a/src/llvm/lib/transform/Instruction/TermOpTransform.cpp +++ /dev/null @@ -1,120 +0,0 @@ -#include "Transform/Instruction/TermOpTransform.h" - -#include "Transform/Transform.h" -#include "Transform/BlockTransform.h" -#include "Util/Exceptions.h" -#include "Origin/OriginProvider.h" - -namespace llvm2Col { - const std::string SOURCE_LOC = "Transform::Instruction::TermOp"; - - - void transformTermOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - switch (llvm::Instruction::TermOps(llvmInstruction.getOpcode())) { - case llvm::Instruction::Ret: - transformRet(cast(llvmInstruction), colBlock, funcCursor); - break; - case llvm::Instruction::Br: { - auto &llvmBranchInst = cast(llvmInstruction); - llvmBranchInst.isConditional() ? transformConditionalBranch(llvmBranchInst, colBlock, funcCursor) - : transformUnConditionalBranch(llvmBranchInst, colBlock, funcCursor); - break; - } - default: - reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); - break; - } - } - - void transformRet(llvm::ReturnInst &llvmRetInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - col::Return *returnStatement = colBlock.add_statements()->mutable_return_(); - returnStatement->set_allocated_origin(generateSingleStatementOrigin(llvmRetInstruction)); - col::Expr *returnExpr = returnStatement->mutable_result(); - llvm2Col::transformAndSetExpr( - funcCursor, - llvmRetInstruction, - *llvmRetInstruction.getReturnValue(), - *returnExpr); - } - - void transformConditionalBranch(llvm::BranchInst &llvmBrInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - col::Branch *colBranch = colBlock.add_statements()->mutable_branch(); - colBranch->set_allocated_origin(generateSingleStatementOrigin(llvmBrInstruction)); - // pre-declare completion because the final branch statement is already present - funcCursor.complete(colBlock); - // true branch - col::Tuple2_VctColAstExpr_VctColAstStatement *colTrueBranch = colBranch->add_branches(); - // set conditional - transformAndSetExpr(funcCursor, - llvmBrInstruction, - *llvmBrInstruction.getCondition(), - *colTrueBranch->mutable_v1()); - // get or pre-generate target labeled block - /* - * I hear you think, why query the 2nd operand? wouldn't that be the false branch i.e the else branch? - * While any logical implementation of getting operands would give the operands in order, the branch instruction - * is no ordinary instruction. For you see to get the branch argument we use the 0th index (so far so good), for the true evaluation - * of the branch instruction we use the 2nd index (uhhh okay, we might be skipping an index?) and the false evaluation of the - * branch instruction we use the 1st index (WHAT!?!?) - * - * Visualized: - * br i1 %var, label %yay, label %nay - * 0 2 1 - * - * Just smile and wave, don't question LLVM. - */ - auto *llvmTrueBlock = cast(llvmBrInstruction.getOperand(2)); - vcllvm::LabeledColBlock labeledTrueColBlock = funcCursor.getOrSetLlvmBlock2LabeledColBlockEntry(*llvmTrueBlock); - // goto statement to true block - col::Goto *trueGoto = colTrueBranch->mutable_v2()->mutable_goto_(); - trueGoto->mutable_lbl()->set_id(labeledTrueColBlock.label.decl().id()); - // set origin for goto to true block - trueGoto->set_allocated_origin(generateSingleStatementOrigin(llvmBrInstruction)); - // transform llvm true block - transformLlvmBlock(*llvmTrueBlock, funcCursor); - - // false branch - col::Tuple2_VctColAstExpr_VctColAstStatement *colFalseBranch = colBranch->add_branches(); - // set conditional (which is a true constant as else == else if(true))) - col::BooleanValue *elseCondition = colFalseBranch->mutable_v1()->mutable_boolean_value(); - elseCondition->set_value(true); - // set origin of else condition - elseCondition->set_allocated_origin(generateOperandOrigin(llvmBrInstruction, *llvmBrInstruction.getCondition())); - // get llvm block targeted by the llvm branch - auto *llvmFalseBlock = cast(llvmBrInstruction.getOperand(1)); - // get or pre-generate target labeled block - vcllvm::LabeledColBlock labeledFalseColBlock = funcCursor.getOrSetLlvmBlock2LabeledColBlockEntry( - *llvmFalseBlock); - // goto statement to false block - col::Goto *falseGoto = colFalseBranch->mutable_v2()->mutable_goto_(); - falseGoto->mutable_lbl()->set_id(labeledFalseColBlock.label.decl().id()); - // set origin for goto to false block - falseGoto->set_allocated_origin(llvm2Col::generateSingleStatementOrigin(llvmBrInstruction)); - // transform llvm falseBlock - transformLlvmBlock(*llvmFalseBlock, funcCursor); - } - - void transformUnConditionalBranch(llvm::BranchInst &llvmBrInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - // get llvm target block - auto *llvmTargetBlock = cast(llvmBrInstruction.getOperand(0)); - // get or pre generate target labeled block - vcllvm::LabeledColBlock labeledColBlock = funcCursor.getOrSetLlvmBlock2LabeledColBlockEntry(*llvmTargetBlock); - // create goto to target labeled block - col::Goto *colGoto = colBlock.add_statements()->mutable_goto_(); - colGoto->mutable_lbl()->set_id(labeledColBlock.label.decl().id()); - // set origin of goto statement - colGoto->set_allocated_origin(llvm2Col::generateSingleStatementOrigin(llvmBrInstruction)); - // pre-declare completion because the final goto is already present - funcCursor.complete(colBlock); - // transform llvm targetBlock - transformLlvmBlock(*llvmTargetBlock, funcCursor); - } -} \ No newline at end of file diff --git a/src/llvm/lib/transform/Instruction/UnaryOpTransform.cpp b/src/llvm/lib/transform/Instruction/UnaryOpTransform.cpp deleted file mode 100644 index 321819bfd3..0000000000 --- a/src/llvm/lib/transform/Instruction/UnaryOpTransform.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Transform/Instruction/UnaryOpTransform.h" -#include "Transform/BlockTransform.h" - -namespace llvm2Col { - const std::string SOURCE_LOC = "Transform::Instruction::UnaryOp"; - void transformUnaryOp(llvm::Instruction &llvmInstruction, - col::Block &colBlock, - vcllvm::FunctionCursor &funcCursor) { - //TODO stub - reportUnsupportedOperatorError(SOURCE_LOC, llvmInstruction); - } -} \ No newline at end of file diff --git a/src/llvm/lib/transform/Transform.cpp b/src/llvm/lib/transform/Transform.cpp deleted file mode 100644 index e1ceb0817d..0000000000 --- a/src/llvm/lib/transform/Transform.cpp +++ /dev/null @@ -1,101 +0,0 @@ -#include "Passes/Function/FunctionBodyTransformer.h" -#include "Transform/Transform.h" - -#include -#include - -#include "Util/Exceptions.h" -#include "Origin/OriginProvider.h" - - - - -/** - * Utility function that converts LLVM types to col types - * @param type - */ -namespace llvm2Col { - const std::string SOURCE_LOC = "Transform::Transform"; - - namespace col = vct::col::ast; - - void transformAndSetType(llvm::Type &llvmType, - col::Type &colType) { - switch (llvmType.getTypeID()) { - case llvm::Type::IntegerTyID: - if (llvmType.getIntegerBitWidth() == 1) { - colType.mutable_t_bool()->set_allocated_origin(generateTypeOrigin(llvmType)); - } else { - colType.mutable_t_int()->set_allocated_origin(generateTypeOrigin(llvmType)); - } - break; - default: - throw vcllvm::UnsupportedTypeException(); - } - } - - - void transformAndSetExpr(vcllvm::FunctionCursor &functionCursor, - llvm::Instruction &llvmInstruction, - llvm::Value &llvmOperand, - col::Expr &colExpr) { - if (llvm::isa(llvmOperand)) { - transformAndSetConstExpr(llvmInstruction, llvm::cast(llvmOperand), colExpr); - } else { - transformAndSetVarExpr(functionCursor, llvmInstruction, llvmOperand, colExpr); - } - } - - void transformAndSetVarExpr(vcllvm::FunctionCursor &functionCursor, - llvm::Instruction &llvmInstruction, - llvm::Value &llvmOperand, - col::Expr &colExpr) { - col::Variable colVar = functionCursor.getVariableMapEntry(llvmOperand); - col::Local *colLocal = colExpr.mutable_local(); - colLocal->set_allocated_origin(generateOperandOrigin(llvmInstruction, llvmOperand)); - colLocal->mutable_ref()->set_id(colVar.id()); - } - - void transformAndSetConstExpr(llvm::Instruction &llvmInstruction, - llvm::Constant &llvmConstant, - col::Expr &colExpr) { - llvm::Type *constType = llvmConstant.getType(); - switch (llvmConstant.getType()->getTypeID()) { - case llvm::Type::IntegerTyID: - if (constType->getIntegerBitWidth() == 1) { - col::BooleanValue *boolValue = colExpr.mutable_boolean_value(); - boolValue->set_allocated_origin(generateOperandOrigin(llvmInstruction, llvmConstant)); - boolValue->set_value(llvmConstant.isOneValue()); - } else { - col::IntegerValue *integerValue = colExpr.mutable_integer_value(); - integerValue->set_allocated_origin(generateOperandOrigin(llvmInstruction, llvmConstant)); - llvm::APInt apInt = llvmConstant.getUniqueInteger(); - transformAndSetIntegerValue(apInt, *integerValue); - } - break; - default: - std::string errCtx; - llvm::raw_string_ostream(errCtx) << llvmConstant; - std::stringstream errorStream; - errorStream << "Unknown constant \"" << errCtx << '\"'; - vcllvm::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), llvmInstruction); - } - } - - void transformAndSetIntegerValue(llvm::APInt &apInt, col::IntegerValue &colIntegerValue) { - // TODO works for "small" signed and unsigned numbers, may break for values >=2^64 - llvm::APInt byteSwapped = apInt.byteSwap(); - std::vector byteVector; - for (uint32_t i = 0; i < byteSwapped.getNumWords(); i++) { - byteVector.push_back(byteSwapped.getRawData()[i]); - } - colIntegerValue.mutable_value()->set_data(byteVector.data(), apInt.getBitWidth() / 8); - } - - std::string getValueName(llvm::Value &llvmValue) { - std::string name; - llvm::raw_string_ostream contextStream = llvm::raw_string_ostream(name); - llvmValue.printAsOperand(contextStream, false); - return name; - } -} \ No newline at end of file diff --git a/src/llvm/lib/util/Exceptions.cpp b/src/llvm/lib/util/Exceptions.cpp deleted file mode 100644 index 32ff59d7ac..0000000000 --- a/src/llvm/lib/util/Exceptions.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include "Util/Exceptions.h" -#include "Origin/ShortPositionDeriver.h" - - -#include -#include -#include - -namespace vcllvm { - [[nodiscard]] const char *UnsupportedTypeException::what() const noexcept { - return "Type not supported"; - } - - u_int32_t ErrorReporter::errorCount; - - void ErrorReporter::addError(const std::string &source, const std::string &message, const std::string &origin) { - llvm::errs() << "[VCLLVM] [" << source << "] " << message << " @\n " << origin << "\n\n"; - ErrorReporter::errorCount++; - } - - void ErrorReporter::addError(const std::string &source, - const std::string &message, - llvm::Module &llvmModule) { - addError(source, message, llvm2Col::deriveModuleShortPosition(llvmModule)); - } - - void ErrorReporter::addError(const std::string &source, - const std::string &message, - llvm::Function &llvmFunction) { - addError(source, message, llvm2Col::deriveFunctionShortPosition(llvmFunction)); - } - - void ErrorReporter::addError(const std::string &source, - const std::string &message, - llvm::BasicBlock &llvmBlock) { - addError(source, message, llvm2Col::deriveBlockShortPosition(llvmBlock)); - } - - void ErrorReporter::addError(const std::string &source, - const std::string &message, - llvm::Instruction &llvmInstruction) { - addError(source, message, llvm2Col::deriveInstructionShortPosition(llvmInstruction)); - } - - bool ErrorReporter::hasErrors() { - return ErrorReporter::errorCount > 0; - } - - u_int32_t ErrorReporter::getErrorCount() { - return ErrorReporter::errorCount; - } -} \ No newline at end of file diff --git a/src/llvm/tools/vcllvm/VCLLVM.cpp b/src/llvm/tools/vcllvm/VCLLVM.cpp deleted file mode 100644 index af755ba71f..0000000000 --- a/src/llvm/tools/vcllvm/VCLLVM.cpp +++ /dev/null @@ -1,145 +0,0 @@ -#include "vct/col/ast/col.pb.h" -#include "Passes/Function/FunctionBodyTransformer.h" -#include "Passes/Function/FunctionContractDeclarer.h" -#include "Passes/Function/PureAssigner.h" -#include "Passes/Module/ModuleSpecCollector.h" - -#include -#include -#include -#include - -#include "Passes/Function/FunctionDeclarer.h" - -#include "Transform/Transform.h" -#include "Origin/OriginProvider.h" - -#include "Util/Exceptions.h" - -#include -#include - -namespace col = vct::col::ast; - -col::Program sampleCol(bool returnBool) { - col::Program program = col::Program(); - - // class - col::GlobalDeclaration *classDeclaration = program.add_declarations(); - col::VctClass *vctClass = classDeclaration->mutable_vct_class(); - llvm2Col::setColNodeId(vctClass); - col::BooleanValue *lockInvariant = vctClass->mutable_intrinsic_lock_invariant()->mutable_boolean_value(); - lockInvariant->set_value(true); - // class>method - col::ClassDeclaration *methodDeclaration = vctClass->add_decls(); - col::InstanceMethod *method = methodDeclaration->mutable_instance_method(); - llvm2Col::setColNodeId(method); - // class>method>return_type - method->mutable_return_type()->mutable_t_bool(); - // class>method>body - col::Block *body = method->mutable_body()->mutable_scope()->mutable_body()->mutable_block(); - col::Return *returnStatement = body->add_statements()->mutable_return_(); - col::BooleanValue *returnValue = returnStatement->mutable_result()->mutable_boolean_value(); - returnValue->set_value(returnBool); - // class>method>inline - method->set_inline_(false); - // class>method>pure - method->set_pure(false); - // class>method>contract - col::ApplicableContract *contract = method->mutable_contract(); - // class>method>contract>precondition - col::UnitAccountedPredicate *precondition = contract->mutable_requires_()->mutable_unit_accounted_predicate(); - col::BooleanValue *prePred = precondition->mutable_pred()->mutable_boolean_value(); - prePred->set_value(true); - // class>method>contract>postcondition - col::UnitAccountedPredicate *postcondition = contract->mutable_ensures()->mutable_unit_accounted_predicate(); - col::Ref *postRefResult = postcondition->mutable_pred()->mutable_result()->mutable_applicable(); - postRefResult->set_id(method->id()); - // class>method>contract>context_everywhere - col::BooleanValue *contextEverywhere = contract->mutable_context_everywhere()->mutable_boolean_value(); - contextEverywhere->set_value(true); - return program; -} - -static vcllvm::cl::opt inputFileName{"", - vcllvm::cl::desc{"Module to analyze"}, - vcllvm::cl::value_desc{"IR filename"}, - vcllvm::cl::Positional}; -static vcllvm::cl::opt testCol{"sample-col", - vcllvm::cl::desc{"Output a sample col buffer with verdict PASS"}}; - -static vcllvm::cl::opt incorrectTestCol{"sample-col-wrong", - vcllvm::cl::desc{"Output a sample col buffer with verdict FAIL"}}; - -static vcllvm::cl::opt humanReadableOutput{"human-readable", - vcllvm::cl::desc{"Output COL buffer in human readable format"}}; - -int main(int argc, char **argv) { - vcllvm::cl::ParseCommandLineOptions(argc, argv); - // sample mode - if (testCol.getValue() || incorrectTestCol.getValue()) { - std::cout << sampleCol(testCol.getValue()).SerializeAsString(); - std::cout.flush(); - return EXIT_SUCCESS; - } - // parse mode - if (inputFileName.empty()) { - vcllvm::errs() << "no input file given\n"; - return EXIT_FAILURE; - } - vcllvm::LLVMContext context; - vcllvm::SMDiagnostic smDiag; - auto pModule = parseIRFile(inputFileName, smDiag, context); - if (!pModule) { - smDiag.print(inputFileName.c_str(), vcllvm::errs()); - return EXIT_FAILURE; - } - pModule->setSourceFileName(inputFileName); - vcllvm::Module *module = pModule.release(); - auto pProgram = std::make_shared(); - // set program origin - pProgram->set_allocated_origin(llvm2Col::generateProgramOrigin(*module)); - pProgram->set_allocated_blame(new col::Blame()); - // Create the analysis managers. - vcllvm::LoopAnalysisManager LAM; - vcllvm::FunctionAnalysisManager FAM; - vcllvm::CGSCCAnalysisManager CGAM; - vcllvm::ModuleAnalysisManager MAM; - FAM.registerPass([&] { return vcllvm::FunctionDeclarer(pProgram); }); - FAM.registerPass([&] { return vcllvm::FunctionContractDeclarer(pProgram); }); - // Create the new pass manager builder. - // Take a look at the PassBuilder constructor parameters for more - // customization, e.g. specifying a TargetMachine or various debugging - // options. - vcllvm::PassBuilder PB; - // Register all the basic analyses with the managers. - PB.registerModuleAnalyses(MAM); - PB.registerCGSCCAnalyses(CGAM); - PB.registerFunctionAnalyses(FAM); - PB.registerLoopAnalyses(LAM); - PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); - - vcllvm::LoopPassManager LPM; - - vcllvm::FunctionPassManager FPM; - FPM.addPass(vcllvm::FunctionDeclarerPass(pProgram)); - FPM.addPass(vcllvm::PureAssignerPass(pProgram)); - FPM.addPass(vcllvm::FunctionContractDeclarerPass(pProgram)); - FPM.addPass(vcllvm::FunctionBodyTransformerPass(pProgram)); - vcllvm::ModulePassManager MPM; - MPM.addPass(vcllvm::ModuleSpecCollectorPass(pProgram)); - MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); - MPM.run(*module, MAM); - if (vcllvm::ErrorReporter::hasErrors()) { - vcllvm::errs() << "While processing \"" << inputFileName << "\" VCLLVM has encountered " - << vcllvm::ErrorReporter::getErrorCount() << " error(s).\n" - << "Exiting with failure code...\n"; - return EXIT_FAILURE; - } - if (humanReadableOutput.getValue()) { - llvm::errs() << pProgram->DebugString(); - } else { - std::cout << pProgram->SerializeAsString(); - } - return EXIT_SUCCESS; -} diff --git a/src/main/vct/main/stages/Parsing.scala b/src/main/vct/main/stages/Parsing.scala index 3e8317ce40..77e01baaad 100644 --- a/src/main/vct/main/stages/Parsing.scala +++ b/src/main/vct/main/stages/Parsing.scala @@ -133,7 +133,7 @@ case class Parsing[G <: Generation]( case Language.SystemC => new ColSystemCParser(Resources.getSystemCConfig) case Language.LLVM => - ColLLVMParser(debugOptions, blameProvider, Resources.getVCLLVM) + ColLLVMParser(debugOptions, blameProvider, Resources.getPallas) } parser.parse[G](readable) diff --git a/src/main/vct/main/stages/Resolution.scala b/src/main/vct/main/stages/Resolution.scala index 64aa310b13..afa2ff5c63 100644 --- a/src/main/vct/main/stages/Resolution.scala +++ b/src/main/vct/main/stages/Resolution.scala @@ -9,8 +9,8 @@ import vct.col.ast.{ CGlobalDeclaration, Expr, GlobalDeclaration, - LlvmFunctionContract, - LlvmGlobal, + LLVMFunctionContract, + LLVMGlobalSpecification, Program, Refute, Verification, @@ -27,7 +27,7 @@ import vct.col.origin.{ ReadableOrigin, } import vct.col.resolve.{Resolve, ResolveReferences, ResolveTypes} -import vct.col.rewrite.Generation +import vct.col.rewrite.{Generation, Rewritten} import vct.col.rewrite.bip.IsolateBipGlue import vct.rewrite.lang.{LangSpecificToCol, LangTypesToCol} import vct.importer.JavaLibraryLoader @@ -36,7 +36,7 @@ import vct.options.Options import vct.options.types.ClassPathEntry import vct.parsers.debug.DebugOptions import vct.parsers.err.FileNotFound -import vct.parsers.parser.{ColJavaParser, ColLLVMParser} +import vct.parsers.parser.{ColJavaParser, ColLLVMContractParser, ColPVLParser} import vct.parsers.transform.BlameProvider import vct.parsers.{ParseResult, parser} import vct.resources.Resources @@ -67,6 +67,14 @@ case object Resolution { case ClassPathEntry.SourcePath(root) => ResolveTypes.JavaClassPathEntry.Path(root) }, + if (options.contractImportFile.isDefined) { + val res = ColPVLParser(options.getParserDebugOptions, blameProvider) + .parse[G]( + options.contractImportFile.get, + Origin(Seq(ReadableOrigin(options.contractImportFile.get))), + ) + res.decls + } else { Seq() }, options.veymontGeneratePermissions, options.devVeymontAllowAssign, ) @@ -84,7 +92,7 @@ case class MyLocalJavaParser( ) extends Resolve.SpecExprParser { override def parse[G](input: String, o: Origin): Expr[G] = { val sr = LiteralReadable("", input) - val cjp = parser.ColJavaParser(debugOptions, blameProvider) + val cjp = ColJavaParser(debugOptions, blameProvider) val x = cjp.parseExpr[G](sr) if (x._2.nonEmpty) { throw SpecExprParseError("...") } x._1 @@ -96,17 +104,17 @@ case class MyLocalLLVMSpecParser( debugOptions: DebugOptions, ) extends Resolve.SpecContractParser { override def parse[G]( - input: LlvmFunctionContract[G], + input: LLVMFunctionContract[G], o: Origin, ): ApplicableContract[G] = - parser.ColLLVMContractParser(debugOptions, blameProvider) + ColLLVMContractParser(debugOptions, blameProvider) .parseFunctionContract[G](new StringReader(input.value), o)._1 override def parse[G]( - input: LlvmGlobal[G], + input: LLVMGlobalSpecification[G], o: Origin, ): Seq[GlobalDeclaration[G]] = - parser.ColLLVMContractParser(debugOptions, blameProvider) + ColLLVMContractParser(debugOptions, blameProvider) .parseReader[G](new StringReader(input.value), o).decls } @@ -117,6 +125,7 @@ case class Resolution[G <: Generation]( ResolveTypes.JavaClassPathEntry.Path(Resources.getJrePath), ResolveTypes.JavaClassPathEntry.SourcePackageRoot, ), + importedDeclarations: Seq[GlobalDeclaration[G]] = Seq(), veymontGeneratePermissions: Boolean = false, veymontAllowAssign: Boolean = false, ) extends Stage[ParseResult[G], Verification[_ <: Generation]] @@ -138,17 +147,26 @@ case class Resolution[G <: Generation]( val joinedProgram = Program(isolatedBipProgram.declarations ++ extraDecls)(blameProvider()) val typedProgram = LangTypesToCol().dispatch(joinedProgram) - ResolveReferences.resolve( - typedProgram, - MyLocalJavaParser(blameProvider, parserDebugOptions), - MyLocalLLVMSpecParser(blameProvider, parserDebugOptions), - ) match { + val javaParser = MyLocalJavaParser(blameProvider, parserDebugOptions) + val llvmParser = MyLocalLLVMSpecParser(blameProvider, parserDebugOptions) + val typedImports = + if (importedDeclarations.isEmpty) { Seq() } + else { + val ast = LangTypesToCol() + .dispatch(Program(importedDeclarations)(blameProvider())) + ResolveReferences.resolve(ast, javaParser, llvmParser, Seq()) + LangSpecificToCol(veymontGeneratePermissions, veymontAllowAssign, Seq()) + .dispatch(ast).asInstanceOf[Program[Rewritten[G]]].declarations + } + ResolveReferences + .resolve(typedProgram, javaParser, llvmParser, typedImports) match { case Nil => // ok case some => throw InputResolutionError(some) } val resolvedProgram = LangSpecificToCol( veymontGeneratePermissions, veymontAllowAssign, + typedImports, ).dispatch(typedProgram) resolvedProgram.check match { case Nil => // ok diff --git a/src/main/vct/main/stages/Transformation.scala b/src/main/vct/main/stages/Transformation.scala index 36fab123a0..f49cf1cd7b 100644 --- a/src/main/vct/main/stages/Transformation.scala +++ b/src/main/vct/main/stages/Transformation.scala @@ -35,6 +35,7 @@ import vct.rewrite.{ HeapVariableToRef, MonomorphizeClass, SmtlibToProverTypes, + VariableToPointer, } import vct.rewrite.lang.ReplaceSYCLTypes import vct.rewrite.veymont.{ @@ -325,6 +326,7 @@ case class SilverTransformation( EncodeString, // Encode spec string as seq EncodeChar, CollectLocalDeclarations, // all decls in Scope + VariableToPointer, // should happen before ParBlockEncoder so it can distinguish between variables which can and can't altered in a parallel block DesugarPermissionOperators, // no PointsTo, \pointer, etc. ReadToValue, // resolve wildcard into fractional permission TrivialAddrOf, diff --git a/src/main/vct/options/Options.scala b/src/main/vct/options/Options.scala index 79e522f5e8..d5cd459ee0 100644 --- a/src/main/vct/options/Options.scala +++ b/src/main/vct/options/Options.scala @@ -295,6 +295,9 @@ case object Options { opt[Path]("path-c-preprocessor").valueName("") .action((path, c) => c.copy(cPreprocessorPath = path)) .text("Set the location of the C preprocessor binary"), + opt[PathOrStd]("contract-import-file").valueName("") + .action((path, c) => c.copy(contractImportFile = Some(path))) + .text("Load function contracts from the specified file"), note(""), note("VeyMont Mode"), opt[Unit]("veymont").action((_, c) => c.copy(mode = Mode.VeyMont)).text( @@ -448,6 +451,9 @@ case class Options( // Control flow graph options cfgOutput: Path = null, + + // Pallas options + contractImportFile: Option[PathOrStd] = None, ) { def getParserDebugOptions: vct.parsers.debug.DebugOptions = vct.parsers.debug.DebugOptions( diff --git a/src/main/vct/resources/Resources.scala b/src/main/vct/resources/Resources.scala index 86fbb5e3a3..43508dce57 100644 --- a/src/main/vct/resources/Resources.scala +++ b/src/main/vct/resources/Resources.scala @@ -16,5 +16,5 @@ case object Resources { def getCPPcPath: Path = Paths.get("clang++") def getSystemCConfig: Path = getResource("/systemc/config") def getVeymontPath: Path = getResource("/veymont") - def getVCLLVM: Path = getResource("/vcllvm") + def getPallas: Path = getResource("/pallas") } diff --git a/src/parsers/antlr4/LangPVLLexer.g4 b/src/parsers/antlr4/LangPVLLexer.g4 index 4bd41d3d30..f6ddf942ee 100644 --- a/src/parsers/antlr4/LangPVLLexer.g4 +++ b/src/parsers/antlr4/LangPVLLexer.g4 @@ -45,6 +45,7 @@ PERCENT: '%'; INC: '++'; DEC: '--'; CONS: '::'; +AMPERSAND: '&'; ENUM: 'enum'; CLASS: 'class'; diff --git a/src/parsers/antlr4/LangPVLParser.g4 b/src/parsers/antlr4/LangPVLParser.g4 index cbf9a06ae2..b72c0e47b3 100644 --- a/src/parsers/antlr4/LangPVLParser.g4 +++ b/src/parsers/antlr4/LangPVLParser.g4 @@ -131,6 +131,8 @@ seqAddExpr unaryExpr : '!' unaryExpr | '-' unaryExpr + | '*' unaryExpr + | '&' unaryExpr | valPrefix unaryExpr | newExpr ; diff --git a/src/parsers/vct/parsers/parser/ColLLVMParser.scala b/src/parsers/vct/parsers/parser/ColLLVMParser.scala index 98d65c25ef..510c2d015a 100644 --- a/src/parsers/vct/parsers/parser/ColLLVMParser.scala +++ b/src/parsers/vct/parsers/parser/ColLLVMParser.scala @@ -1,18 +1,21 @@ package vct.parsers.parser +import com.google.protobuf.InvalidProtocolBufferException import com.typesafe.scalalogging.LazyLogging import hre.io.Readable import org.antlr.v4.runtime.{CharStream, CommonTokenStream} import vct.antlr4.generated.{LLVMSpecParser, LangLLVMSpecLexer} -import vct.col.ast.Deserialize import vct.col.ast.serialize.Program +import vct.col.ast.{Declaration, Deserialize, LLVMFunctionDefinition} import vct.col.origin.{ExpectedError, Origin} +import vct.col.ref.Ref +import vct.parsers.transform.{BlameProvider, LLVMContractToCol, OriginProvider} +import vct.result.VerificationError.{SystemError, Unreachable, UserError} +import vct.parsers.{Parser, ParseResult} import vct.parsers.debug.DebugOptions -import vct.parsers.transform.{BlameProvider, LLVMContractToCol} -import vct.parsers.{ParseResult, Parser} -import vct.result.VerificationError.{Unreachable, UserError} import java.io.{IOException, Reader} +import java.nio.file.Path import java.nio.charset.StandardCharsets import java.nio.file.Path import scala.util.{Failure, Using} @@ -20,26 +23,38 @@ import scala.util.{Failure, Using} case class ColLLVMParser( debugOptions: DebugOptions, blameProvider: BlameProvider, - vcllvm: Path, + pallas: Path, ) extends Parser with LazyLogging { - case class LLVMParseError(fileName: String, errorCode: Int, error: String) - extends UserError { + private case class LLVMParseError( + fileName: String, + errorCode: Int, + error: String, + ) extends UserError { override def code: String = "LLVMParseError" override def text: String = - s"[ERROR] Parsing file $fileName failed with exit code $errorCode:\n$error" + messageContext( + s"[ERROR] Parsing file $fileName failed with exit code $errorCode:\n$error" + ) } override def parse[G]( readable: Readable, baseOrigin: Origin = Origin(Nil), ): ParseResult[G] = { - if (vcllvm == null) { + if (pallas == null) { throw Unreachable( - "The COLLVMParser needs to be provided with the path to vcllvm to parse LLVM-IR files" + "The ColLLVMParser needs to be provided with the path to pallas to parse LLVM-IR files" ) } - val command = Seq(vcllvm.toString, readable.fileName) + val command = Seq( + "opt-17", + s"--load-pass-plugin=$pallas", + "--passes=module(pallas-declare-variables,pallas-collect-module-spec),function(pallas-declare-function,pallas-assign-pure,pallas-declare-function-contract,pallas-transform-function-body),module(pallas-print-protobuf)", + readable.fileName, + "--disable-output", + ) + val process = new ProcessBuilder(command: _*).start() val protoProgram = @@ -51,7 +66,7 @@ case class ColLLVMParser( new String( process.getErrorStream.readAllBytes(), StandardCharsets.UTF_8, - ), + ).indent(8), )) }.get @@ -63,7 +78,7 @@ case class ColLLVMParser( new String( process.getErrorStream.readAllBytes(), StandardCharsets.UTF_8, - ), + ).indent(8), ) } diff --git a/src/parsers/vct/parsers/transform/LLVMContractToCol.scala b/src/parsers/vct/parsers/transform/LLVMContractToCol.scala index dfd50a7640..b714a0ed9e 100644 --- a/src/parsers/vct/parsers/transform/LLVMContractToCol.scala +++ b/src/parsers/vct/parsers/transform/LLVMContractToCol.scala @@ -23,7 +23,7 @@ case class LLVMContractToCol[G]( ) extends ToCol(baseOrigin, blameProvider, errors) { def local(ctx: ParserRuleContext, name: String): Expr[G] = - LlvmLocal(name)(blame(ctx))(origin(ctx)) + LLVMLocal(name)(blame(ctx))(origin(ctx)) def createVariable( ctx: ParserRuleContext, @@ -145,7 +145,7 @@ case class LLVMContractToCol[G]( callOp match { case CallInstruction0(_, id, _, exprList, _) => val args: Seq[Expr[G]] = convert(exprList) - LlvmAmbiguousFunctionInvocation(id, args, Nil, Nil)(blame(callOp)) + LLVMAmbiguousFunctionInvocation(id, args, Nil, Nil)(blame(callOp)) } def convert(implicit binOp: BinOpInstructionContext): Expr[G] = @@ -547,7 +547,7 @@ case class LLVMContractToCol[G]( modifiers.foreach(convert(_, modifierCollector)) val namedOrigin = origin(decl).sourceName(convert(name)) - new LlvmSpecFunction( + new LLVMSpecFunction( convert(name), convert(t), args.map(convert(_)).getOrElse(Nil), diff --git a/src/parsers/vct/parsers/transform/PVLToCol.scala b/src/parsers/vct/parsers/transform/PVLToCol.scala index 5e9d424fa6..8bffdc7610 100644 --- a/src/parsers/vct/parsers/transform/PVLToCol.scala +++ b/src/parsers/vct/parsers/transform/PVLToCol.scala @@ -408,8 +408,10 @@ case class PVLToCol[G]( expr match { case UnaryExpr0(_, inner) => Not(convert(inner)) case UnaryExpr1(_, inner) => UMinus(convert(inner)) - case UnaryExpr2(op, inner) => convert(expr, op, convert(inner)) - case UnaryExpr3(inner) => convert(inner) + case UnaryExpr2(_, inner) => DerefPointer(convert(inner))(blame(expr)) + case UnaryExpr3(_, inner) => AddrOf(convert(inner)) + case UnaryExpr4(op, inner) => convert(expr, op, convert(inner)) + case UnaryExpr5(inner) => convert(inner) } def convert(implicit expr: NewExprContext): Expr[G] = diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 32cfe554f9..3b089691cb 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -460,9 +460,9 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { "Instance operator methods are already compiled away", Some(method), ) - case function: LlvmSpecFunction[Pre] => + case function: LLVMSpecFunction[Pre] => throw ExcludedByPassOrder( - "Llvm spec functions are already compiled away", + "LLVM spec functions are already compiled away", Some(function), ) } diff --git a/src/rewrite/vct/rewrite/DesugarPermissionOperators.scala b/src/rewrite/vct/rewrite/DesugarPermissionOperators.scala index d0957ff8da..e00bc3ae21 100644 --- a/src/rewrite/vct/rewrite/DesugarPermissionOperators.scala +++ b/src/rewrite/vct/rewrite/DesugarPermissionOperators.scala @@ -171,7 +171,7 @@ case class DesugarPermissionOperators[Pre <: Generation]() )(FramedPtrOffset), dispatch(perm), ) - case other => rewriteDefault(other) + case other => other.rewriteDefault() } } } diff --git a/src/rewrite/vct/rewrite/EncodeCurrentThread.scala b/src/rewrite/vct/rewrite/EncodeCurrentThread.scala index ade6107c58..32f3e7b8f7 100644 --- a/src/rewrite/vct/rewrite/EncodeCurrentThread.scala +++ b/src/rewrite/vct/rewrite/EncodeCurrentThread.scala @@ -51,7 +51,7 @@ case class EncodeCurrentThread[Pre <: Generation]() extends Rewriter[Pre] { // PB: although a pure method will become a function, it should really be possible to mark a pure method as thread // local. case m: AbstractMethod[Pre] => !m.pure - case m: LlvmFunctionDefinition[Pre] => !m.pure + case m: LLVMFunctionDefinition[Pre] => !m.pure case _: ADTFunction[Pre] => false case _: ProverFunction[Pre] => false diff --git a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala index 0b344f2a41..128f34291c 100644 --- a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala +++ b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala @@ -433,15 +433,14 @@ case class ResolveExpressionSideEffects[Pre <: Generation]() case proof: FramedProof[Pre] => rewriteDefault(proof) case extract: Extract[Pre] => rewriteDefault(extract) case branch: IndetBranch[Pre] => rewriteDefault(branch) - case LlvmLoop(cond, contract, body) => + case LLVMLoop(cond, contract, body) => evaluateOne(cond) match { case (Nil, Nil, cond) => - LlvmLoop(cond, dispatch(contract), dispatch(body)) + LLVMLoop(cond, dispatch(contract), dispatch(body)) case (variables, sideEffects, cond) => val break = new LabelDecl[Post]()(BreakOrigin) - Block(Seq( - LlvmLoop( + LLVMLoop( tt, dispatch(contract), Block(Seq( @@ -471,6 +470,7 @@ case class ResolveExpressionSideEffects[Pre <: Generation]() case _: CStatement[Pre] => throw ExtraNode case _: CPPStatement[Pre] => throw ExtraNode case _: JavaStatement[Pre] => throw ExtraNode + case _: LLVMStatement[Pre] => throw ExtraNode } } diff --git a/src/rewrite/vct/rewrite/TrivialAddrOf.scala b/src/rewrite/vct/rewrite/TrivialAddrOf.scala index edc400f193..1162ff3d92 100644 --- a/src/rewrite/vct/rewrite/TrivialAddrOf.scala +++ b/src/rewrite/vct/rewrite/TrivialAddrOf.scala @@ -32,11 +32,16 @@ case object TrivialAddrOf extends RewriterBuilder { case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(e: Expr[Pre]): Expr[Post] = e match { + case DerefPointer(PointerAdd(AddrOf(pointer), offset)) + if offset.isInstanceOf[ConstantInt[Pre]] && + offset.asInstanceOf[ConstantInt[Pre]].value.signum == 0 => + dispatch(pointer) case AddrOf(DerefPointer(p)) => dispatch(p) case AddrOf(sub @ PointerSubscript(p, i)) => PointerAdd(dispatch(p), dispatch(i))(SubscriptErrorAddError(sub))(e.o) + case AddrOf(other) if other.t.isInstanceOf[TClass[Pre]] => dispatch(other) case AddrOf(other) => throw UnsupportedLocation(other) case assign @ PreAssignExpression(target, AddrOf(value)) if value.t.isInstanceOf[TClass[Pre]] => @@ -55,7 +60,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { newValue, )(assign.blame) With(newPointer, newAssign) - case other => rewriteDefault(other) + case other => other.rewriteDefault() } override def dispatch(s: Statement[Pre]): Statement[Post] = @@ -77,7 +82,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { newValue, )(assign.blame) Block(Seq(newPointer, newAssign)) - case other => rewriteDefault(other) + case other => other.rewriteDefault() } // TODO: AddressOff needs a more structured approach. Now you could assign a local structure to a pointer, and that pointer diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala new file mode 100644 index 0000000000..ad46a37086 --- /dev/null +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -0,0 +1,213 @@ +package vct.rewrite + +import vct.col.ast._ +import vct.col.ref._ +import vct.col.origin._ +import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder, Rewritten} +import vct.col.util.AstBuildHelpers._ +import vct.col.util.SuccessionMap +import vct.result.VerificationError.UserError + +import scala.collection.mutable + +case object VariableToPointer extends RewriterBuilder { + override def key: String = "variableToPointer" + + override def desc: String = + "Translate every local and field to a pointer such that it can have its address taken" + + case class UnsupportedAddrOf(loc: Expr[_]) extends UserError { + override def code: String = "unsupportedAddrOf" + + override def text: String = + loc.o.messageInContext( + "Taking an address of this expression is not supported" + ) + } +} + +case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { + + import VariableToPointer._ + + val addressedSet: mutable.Set[Node[Pre]] = new mutable.HashSet[Node[Pre]]() + val heapVariableMap: SuccessionMap[HeapVariable[Pre], HeapVariable[Post]] = + SuccessionMap() + val variableMap: SuccessionMap[Variable[Pre], Variable[Post]] = + SuccessionMap() + val fieldMap: SuccessionMap[InstanceField[Pre], InstanceField[Post]] = + SuccessionMap() + + override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] = { + addressedSet.addAll(program.collect { + case AddrOf(Local(Ref(v))) if !v.t.isInstanceOf[TClass[Pre]] => v + case AddrOf(DerefHeapVariable(Ref(v))) + if !v.t.isInstanceOf[TClass[Pre]] => + v + case AddrOf(Deref(_, Ref(f))) if !f.t.isInstanceOf[TClass[Pre]] => f + }) + super.dispatch(program) + } + + override def dispatch(decl: Declaration[Pre]): Unit = + decl match { + case v: HeapVariable[Pre] if addressedSet.contains(v) => + heapVariableMap(v) = globalDeclarations + .declare(new HeapVariable(TPointer(dispatch(v.t)))(v.o)) + case v: Variable[Pre] if addressedSet.contains(v) => + variableMap(v) = variables + .declare(new Variable(TPointer(dispatch(v.t)))(v.o)) + case f: InstanceField[Pre] if addressedSet.contains(f) => + fieldMap(f) = classDeclarations.declare( + new InstanceField( + TPointer(dispatch(f.t)), + f.flags.map { it => dispatch(it) }, + )(f.o) + ) + case other => allScopes.anySucceed(other, other.rewriteDefault()) + } + + override def dispatch(stat: Statement[Pre]): Statement[Post] = { + implicit val o: Origin = stat.o + stat match { + case s: Scope[Pre] => + s.rewrite( + locals = variables.dispatch(s.locals), + body = Block(s.locals.filter { local => addressedSet.contains(local) } + .map { local => + implicit val o: Origin = local.o + Assign( + Local[Post](variableMap.ref(local)), + NewPointerArray( + variableMap(local).t.asPointer.get.element, + const(1), + )(PanicBlame("Size is > 0")), + )(PanicBlame("Initialisation should always succeed")) + } ++ Seq(dispatch(s.body))), + ) + case i @ Instantiate(cls, out) => + Block(Seq(i.rewriteDefault()) ++ cls.decl.declarations.flatMap { + case f: InstanceField[Pre] => + if (f.t.asClass.isDefined) { + Seq( + Assign( + Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame( + "Initialisation should always succeed" + )), + NewPointerArray( + fieldMap(f).t.asPointer.get.element, + const(1), + )(PanicBlame("Size is > 0")), + )(PanicBlame("Initialisation should always succeed")), + Assign( + PointerSubscript( + Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame( + "Initialisation should always succeed" + )), + const[Post](0), + )(PanicBlame("Size is > 0")), + dispatch(NewObject[Pre](f.t.asClass.get.cls)), + )(PanicBlame("Initialisation should always succeed")), + ) + } else if (addressedSet.contains(f)) { + Seq( + Assign( + Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame( + "Initialisation should always succeed" + )), + NewPointerArray( + fieldMap(f).t.asPointer.get.element, + const(1), + )(PanicBlame("Size is > 0")), + )(PanicBlame("Initialisation should always succeed")) + ) + } else { Seq() } + case _ => Seq() + }) + case other => other.rewriteDefault() + } + } + + override def dispatch(expr: Expr[Pre]): Expr[Post] = { + implicit val o: Origin = expr.o + expr match { + case deref @ DerefHeapVariable(Ref(v)) if addressedSet.contains(v) => + DerefPointer( + DerefHeapVariable[Post](heapVariableMap.ref(v))(deref.blame) + )(PanicBlame("Should always be accessible")) + case Local(Ref(v)) if addressedSet.contains(v) => + DerefPointer(Local[Post](variableMap.ref(v)))(PanicBlame( + "Should always be accessible" + )) + case deref @ Deref(obj, Ref(f)) if addressedSet.contains(f) => + DerefPointer(Deref[Post](dispatch(obj), fieldMap.ref(f))(deref.blame))( + PanicBlame("Should always be accessible") + ) + case newObject @ NewObject(Ref(cls)) => + val obj = new Variable[Post](TClass(succ(cls), Seq())) + ScopedExpr( + Seq(obj), + With( + Block( + Seq(assignLocal(obj.get, newObject.rewriteDefault())) ++ + cls.declarations.flatMap { + case f: InstanceField[Pre] => + if (f.t.asClass.isDefined) { + Seq( + Assign( + Deref[Post](obj.get, anySucc(f))(PanicBlame( + "Initialisation should always succeed" + )), + dispatch(NewObject[Pre](f.t.asClass.get.cls)), + )(PanicBlame("Initialisation should always succeed")) + ) + } else if (addressedSet.contains(f)) { + Seq( + Assign( + Deref[Post](obj.get, fieldMap.ref(f))(PanicBlame( + "Initialisation should always succeed" + )), + NewPointerArray( + fieldMap(f).t.asPointer.get.element, + const(1), + )(PanicBlame("Size is > 0")), + )(PanicBlame("Initialisation should always succeed")) + ) + } else { Seq() } + case _ => Seq() + } + ), + obj.get, + ), + ) + case other => other.rewriteDefault() + } + } + + override def dispatch(loc: Location[Pre]): Location[Post] = { + implicit val o: Origin = loc.o + loc match { + case HeapVariableLocation(Ref(v)) if addressedSet.contains(v) => + PointerLocation( + DerefHeapVariable[Post](heapVariableMap.ref(v))(PanicBlame( + "Should always be accessible" + )) + )(PanicBlame("Should always be accessible")) + case FieldLocation(obj, Ref(f)) if addressedSet.contains(f) => + PointerLocation(Deref[Post](dispatch(obj), fieldMap.ref(f))(PanicBlame( + "Should always be accessible" + )))(PanicBlame("Should always be accessible")) + case PointerLocation( + AddrOf(Deref(obj, Ref(f))) + ) /* if addressedSet.contains(f) always true */ => + FieldLocation[Post](dispatch(obj), fieldMap.ref(f)) + case PointerLocation( + AddrOf(DerefHeapVariable(Ref(v))) + ) /* if addressedSet.contains(v) always true */ => + HeapVariableLocation[Post](heapVariableMap.ref(v)) + case PointerLocation(AddrOf(local @ Local(_))) => + throw UnsupportedAddrOf(local) + case other => other.rewriteDefault() + } + } +} diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index 162774a74c..b776df0767 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -2,71 +2,139 @@ package vct.rewrite.lang import com.typesafe.scalalogging.LazyLogging import vct.col.ast._ -import vct.col.origin.Origin -import vct.col.ref.{LazyRef, Ref} -import vct.col.resolve.ctx.RefLlvmFunctionDefinition +import vct.col.origin.{Origin, PanicBlame, SourceName} +import vct.col.ref.{DirectRef, LazyRef, Ref} +import vct.col.resolve.ctx.RefLLVMFunctionDefinition import vct.col.rewrite.{Generation, Rewritten} +import vct.col.util.AstBuildHelpers.{VarBuildHelpers, assignLocal, const, tt} import vct.col.util.{CurrentProgramContext, SuccessionMap} -import vct.result.VerificationError.SystemError -import vct.rewrite.lang.LangLLVMToCol.UnexpectedLlvmNode +import vct.result.VerificationError.{SystemError, UserError} + +import scala.collection.mutable case object LangLLVMToCol { - case class UnexpectedLlvmNode(node: Node[_]) extends SystemError { + case class UnexpectedLLVMNode(node: Node[_]) extends SystemError { override def text: String = context[CurrentProgramContext].map(_.highlight(node)).getOrElse(node.o) .messageInContext( "VerCors assumes this node does not occur here in llvm input." ) } + + case class NonConstantStructIndex(origin: Origin) extends UserError { + override def code: String = "nonConstantStructIndex" + + override def text: String = + origin.messageInContext( + s"This struct indexing operation (getelementptr) uses a non-constant struct index which we do not support." + ) + } } case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends LazyLogging { + + import LangLLVMToCol._ + type Post = Rewritten[Pre] implicit val implicitRewriter: AbstractRewriter[Pre, Post] = rw private val llvmFunctionMap - : SuccessionMap[LlvmFunctionDefinition[Pre], Procedure[Post]] = + : SuccessionMap[LLVMFunctionDefinition[Pre], Procedure[Post]] = SuccessionMap() private val specFunctionMap - : SuccessionMap[LlvmSpecFunction[Pre], Function[Post]] = SuccessionMap() + : SuccessionMap[LLVMSpecFunction[Pre], Function[Post]] = SuccessionMap() + private val globalVariableMap + : SuccessionMap[LLVMGlobalVariable[Pre], HeapVariable[Post]] = + SuccessionMap() + private val structMap: SuccessionMap[LLVMTStruct[Pre], Class[Post]] = + SuccessionMap() + private val structFieldMap + : SuccessionMap[(LLVMTStruct[Pre], Int), InstanceField[Post]] = + SuccessionMap() - def rewriteLocal(local: LlvmLocal[Pre]): Expr[Post] = { + private val globalVariableTypeGuesses + : mutable.HashMap[LLVMGlobalVariable[Pre], mutable.HashSet[Type[Pre]]] = + mutable.HashMap() + private val structFieldTypeGuesses + : mutable.HashMap[(LLVMTStruct[Pre], Int), mutable.HashSet[Type[Pre]]] = + mutable.HashMap() + private val localTypeGuesses + : mutable.HashMap[Variable[Pre], mutable.HashSet[Type[Pre]]] = mutable + .HashMap() + + def rewriteLocal(local: LLVMLocal[Pre]): Expr[Post] = { implicit val o: Origin = local.o Local(rw.succ(local.ref.get.decl)) } - def rewriteFunctionDef(func: LlvmFunctionDefinition[Pre]): Unit = { + def rewriteFunctionDef(func: LLVMFunctionDefinition[Pre]): Unit = { implicit val o: Origin = func.o + val importedDecl = rw.importedDeclarations.find { + case procedure: Procedure[Pre] => + func.contract.name == procedure.o.get[SourceName].name + } val procedure = rw.labelDecls.scope { - rw.globalDeclarations.declare( + rw.globalDeclarations.declare(if (importedDecl.isDefined) { + val importedProcedure = importedDecl.get.asInstanceOf[Procedure[Pre]] + val newArgs = importedProcedure.args.map { it => it.rewriteDefault() } + new Procedure[Post]( + returnType = rw.dispatch(importedProcedure.returnType), + args = + rw.variables.collect { + func.args.zip(newArgs).foreach { case (a, b) => + rw.variables.succeed(a, b) + } + }._1, + outArgs = Nil, + typeArgs = Nil, + body = + func.functionBody match { + case None => None + case Some(functionBody) => + if (func.pure) + Some(GotoEliminator(functionBody match { + case scope: Scope[Pre] => scope; + case other => throw UnexpectedLLVMNode(other) + }).eliminate()) + else + Some(rw.dispatch(functionBody)) + }, + contract = rw.dispatch(func.contract.data.get), + pure = func.pure, + )(func.blame) + } else { new Procedure[Post]( returnType = rw.dispatch(func.returnType), args = rw.variables.collect { func.args.foreach(rw.dispatch) }._1, outArgs = Nil, typeArgs = Nil, body = - if (func.pure) - Some(GotoEliminator(func.functionBody match { - case scope: Scope[Pre] => scope; - case other => throw UnexpectedLlvmNode(other) - }).eliminate()) - else - Some(rw.dispatch(func.functionBody)), + func.functionBody match { + case None => None + case Some(functionBody) => + if (func.pure) + Some(GotoEliminator(functionBody match { + case scope: Scope[Pre] => scope; + case other => throw UnexpectedLLVMNode(other) + }).eliminate()) + else + Some(rw.dispatch(functionBody)) + }, contract = rw.dispatch(func.contract.data.get), pure = func.pure, )(func.blame) - ) + }) } llvmFunctionMap.update(func, procedure) } def rewriteAmbiguousFunctionInvocation( - inv: LlvmAmbiguousFunctionInvocation[Pre] + inv: LLVMAmbiguousFunctionInvocation[Pre] ): Invocation[Post] = { implicit val o: Origin = inv.o inv.ref.get.decl match { - case func: LlvmFunctionDefinition[Pre] => + case func: LLVMFunctionDefinition[Pre] => new ProcedureInvocation[Post]( ref = new LazyRef[Post, Procedure[Post]](llvmFunctionMap(func)), args = inv.args.map(rw.dispatch), @@ -79,7 +147,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) outArgs = Seq.empty, typeArgs = Seq.empty, )(inv.blame) - case func: LlvmSpecFunction[Pre] => + case func: LLVMSpecFunction[Pre] => new FunctionInvocation[Post]( ref = new LazyRef[Post, Function[Post]](specFunctionMap(func)), args = inv.args.map(rw.dispatch), @@ -96,7 +164,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } def rewriteFunctionInvocation( - inv: LlvmFunctionInvocation[Pre] + inv: LLVMFunctionInvocation[Pre] ): ProcedureInvocation[Post] = { implicit val o: Origin = inv.o new ProcedureInvocation[Post]( @@ -113,11 +181,11 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) )(inv.blame) } - def rewriteGlobal(decl: LlvmGlobal[Pre]): Unit = { + def rewriteGlobal(decl: LLVMGlobalSpecification[Pre]): Unit = { implicit val o: Origin = decl.o decl.data.get.foreach { decl => rw.globalDeclarations.declare(decl match { - case function: LlvmSpecFunction[Pre] => + case function: LLVMSpecFunction[Pre] => val rwFunction = new Function[Post]( rw.dispatch(function.returnType), @@ -134,12 +202,328 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) )(function.blame) specFunctionMap.update(function, rwFunction) rwFunction - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) }) } } - def result(ref: RefLlvmFunctionDefinition[Pre])( + def rewriteFunctionPointer( + pointer: LLVMFunctionPointerValue[Pre] + ): LLVMFunctionPointerValue[Post] = { + implicit val o: Origin = pointer.o + new LLVMFunctionPointerValue[Post](value = + new LazyRef[Post, GlobalDeclaration[Post]](llvmFunctionMap( + pointer.value.decl.asInstanceOf[LLVMFunctionDefinition[Pre]] + )) + ) + } + + def rewriteStruct(t: LLVMTStruct[Pre]): Unit = { + val LLVMTStruct(name, packed, elements) = t + val newStruct = + new Class[Post]( + Seq(), + rw.classDeclarations.collect { + elements.zipWithIndex.foreach { case (fieldType, idx) => + structFieldMap((t, idx)) = + new InstanceField(rw.dispatch(fieldType), flags = Nil)( + fieldType.o + ) + rw.classDeclarations.declare(structFieldMap((t, idx))) + } + }._1, + Seq(), + tt[Post], + )(t.o) + + rw.globalDeclarations.declare(newStruct) + structMap(t) = newStruct + } + + def rewriteGlobalVariable(decl: LLVMGlobalVariable[Pre]): Unit = { + // TODO: Handle the initializer + // TODO: Include array and vector bounds somehow + decl.variableType match { + case struct: LLVMTStruct[Pre] => { + rewriteStruct(struct) + globalVariableMap.update( + decl, + rw.globalDeclarations.declare( + new HeapVariable[Post]( + new TClass[Post]( + new DirectRef[Post, Class[Post]](structMap(struct)), + Seq(), + )(struct.o) + )(decl.o) + ), + ) + } + case array: LLVMTArray[Pre] => { + globalVariableMap.update( + decl, + rw.globalDeclarations.declare( + new HeapVariable[Post]( + new TPointer[Post](rw.dispatch(array.elementType))(array.o) + )(decl.o) + ), + ) + } + case vector: LLVMTVector[Pre] => { + globalVariableMap.update( + decl, + rw.globalDeclarations.declare( + new HeapVariable[Post]( + new TPointer[Post](rw.dispatch(vector.elementType))(vector.o) + )(decl.o) + ), + ) + } + case _ => { ??? } + } + } + + def rewritePointerChain( + pointer: Expr[Post], + t: Type[Pre], + indices: Seq[Expr[Pre]], + )(implicit o: Origin): Expr[Post] = { + if (indices.isEmpty) { return pointer } + t match { + case struct: LLVMTStruct[Pre] => { + if (!structMap.contains(struct)) { rewriteStruct(struct) } + indices.head match { + case value: LLVMIntegerValue[Pre] => + rewritePointerChain( + Deref[Post]( + pointer, + structFieldMap.ref((struct, value.value.intValue)), + )(o), + struct.elements(value.value.intValue), + indices.tail, + ) + case value: IntegerValue[Pre] => + rewritePointerChain( + Deref[Post]( + pointer, + structFieldMap.ref((struct, value.value.intValue)), + )(o), + struct.elements(value.value.intValue), + indices.tail, + ) + case _ => throw NonConstantStructIndex(o) + } + } + case array: LLVMTArray[Pre] => ??? + case vector: LLVMTVector[Pre] => ??? + } + } + + def derefUntil( + pointer: Expr[Post], + currentType: Type[Pre], + untilType: Type[Pre], + ): (Expr[Post], Type[Pre]) = { + implicit val o: Origin = pointer.o + currentType match { + case _ if currentType == untilType => (AddrOf(pointer), currentType) + case LLVMTPointer(None) => (pointer, LLVMTPointer[Pre](Some(untilType))) + case LLVMTPointer(Some(inner)) if inner == untilType => + (pointer, currentType) + case LLVMTPointer(Some(LLVMTArray(numElements, elementType))) => { + val (expr, inner) = derefUntil( + PointerSubscript[Post]( + DerefPointer(pointer)(pointer.o), + IntegerValue(BigInt(0)), + )(pointer.o), + elementType, + untilType, + ) + (expr, LLVMTPointer[Pre](Some(LLVMTArray(numElements, inner)))) + } + case LLVMTArray(numElements, elementType) => { + val (expr, inner) = derefUntil( + PointerSubscript[Post](pointer, IntegerValue(BigInt(0)))(pointer.o), + elementType, + untilType, + ) + (expr, LLVMTArray[Pre](numElements, inner)) + } + case LLVMTPointer(Some(LLVMTVector(numElements, elementType))) => { + val (expr, inner) = derefUntil( + PointerSubscript[Post]( + DerefPointer(pointer)(pointer.o), + IntegerValue(BigInt(0)), + )(pointer.o), + elementType, + untilType, + ) + (expr, LLVMTPointer[Pre](Some(LLVMTVector(numElements, inner)))) + } + case LLVMTVector(numElements, elementType) => { + val (expr, inner) = derefUntil( + PointerSubscript[Post](pointer, IntegerValue(BigInt(0)))(pointer.o), + elementType, + untilType, + ) + (expr, LLVMTVector[Pre](numElements, inner)) + } + case LLVMTPointer(Some(struct @ LLVMTStruct(name, packed, elements))) => { + val (expr, inner) = derefUntil( + Deref[Post]( + DerefPointer(pointer)(pointer.o), + structFieldMap.ref((struct, 0)), + )(pointer.o), + elements.head, + untilType, + ) + ( + expr, + LLVMTPointer[Pre](Some( + LLVMTStruct(name, packed, inner +: elements.tail) + )), + ) + } + case struct @ LLVMTStruct(name, packed, elements) => { + val (expr, inner) = derefUntil( + Deref[Post](pointer, structFieldMap.ref((struct, 0)))(pointer.o), + elements.head, + untilType, + ) + (expr, LLVMTStruct[Pre](name, packed, inner +: elements.tail)) + } + } + } + + def rewriteGetElementPointer(gep: LLVMGetElementPointer[Pre]): Expr[Post] = { + implicit val o: Origin = gep.o + val t = gep.structureType + t match { + case struct: LLVMTStruct[Pre] => { + // TODO: We don't support variables in GEP yet and this just assumes all the indices are integer constants + // TODO: Use an actual Blame + + // Acquire the actual struct through a PointerAdd + // TODO: Can we somehow wrap the rw.dispatch(gep.pointer) to add the known type structureType? + gep.pointer.t match { + case LLVMTPointer(None) => + val structPointer = + DerefPointer( + PointerAdd( + rw.dispatch(gep.pointer), + rw.dispatch(gep.indices.head), + )(o) + )(o) + AddrOf(rewritePointerChain(structPointer, struct, gep.indices.tail)) + case LLVMTPointer(Some(inner)) if inner == t => + val structPointer = + DerefPointer( + PointerAdd( + rw.dispatch(gep.pointer), + rw.dispatch(gep.indices.head), + )(o) + )(o) + AddrOf(rewritePointerChain(structPointer, struct, gep.indices.tail)) + case LLVMTPointer(Some(_)) => + val (pointer, inferredType) = derefUntil( + rw.dispatch(gep.pointer), + gep.pointer.t, + t, + ) + addTypeGuess(gep.pointer, inferredType) + val structPointer = + DerefPointer( + PointerAdd(pointer, rw.dispatch(gep.indices.head))(o) + )(o) + val ret = AddrOf( + rewritePointerChain(structPointer, struct, gep.indices.tail) + ) + ret + } + } + case array: LLVMTArray[Pre] => ??? + case vector: LLVMTVector[Pre] => ??? + } + // Deref might not be the correct thing to use here since technically the pointer is only dereferenced in the load or store instruction + } + + def rewriteStore(store: LLVMStore[Pre]): Statement[Post] = { + implicit val o: Origin = store.o + val (pointer, inferredType) = derefUntil( + rw.dispatch(store.pointer), + store.pointer.t, + store.value.t, + ) + addTypeGuess(store.pointer, inferredType) + Assign(DerefPointer(pointer)(store.o), rw.dispatch(store.value))(store.o) + } + + def rewriteLoad(load: LLVMLoad[Pre]): Expr[Post] = { + val (pointer, inferredType) = derefUntil( + rw.dispatch(load.pointer), + load.pointer.t, + load.loadType, + ) + addTypeGuess(load.pointer, inferredType) + DerefPointer(pointer)(load.o)(load.o) + } + + def rewriteAllocA(alloc: LLVMAllocA[Pre]): Expr[Post] = { + implicit val o: Origin = alloc.o + val t = rw.dispatch(alloc.allocationType) + val v = new Variable[Post](TPointer(t))(alloc.o) + alloc.allocationType match { + case structType: LLVMTStruct[Pre] => + With( + Block(Seq( + LocalDecl(v), + assignLocal( + v.get, + NewPointerArray[Post]( + rw.dispatch(alloc.allocationType), + rw.dispatch(alloc.numElements), + )(PanicBlame("allocation should never fail")), + ), + Assign( + DerefPointer(v.get)(alloc.o), + NewObject[Post](structMap.ref(structType)), + )(PanicBlame("assignment should never fail")), + )), + v.get, + ) + case _ => + NewPointerArray[Post](t, rw.dispatch(alloc.numElements))(PanicBlame( + "allocation should never fail" + )) + } + } + + private def addTypeGuess(pointer: Expr[Pre], inferredType: Type[Pre]): Unit = + pointer match { + case Local(Ref(v)) => + localTypeGuesses.getOrElseUpdate(v, { mutable.HashSet() }) + .add(LLVMTPointer[Pre](Some(inferredType))) + case LLVMPointerValue(Ref(g)) => + globalVariableTypeGuesses.getOrElseUpdate( + g.asInstanceOf[LLVMGlobalVariable[Pre]], + { mutable.HashSet() }, + ).add(inferredType) + case it => { + println(it) + ??? + } + } + + def rewritePointerValue(pointer: LLVMPointerValue[Pre]): Expr[Post] = { + implicit val o: Origin = pointer.o + // Will be transformed by VariableToPointer pass + new AddrOf[Post]( + DerefHeapVariable[Post](globalVariableMap.ref( + pointer.value.decl.asInstanceOf[LLVMGlobalVariable[Pre]] + ))(pointer.o) + ) + } + + def result(ref: RefLLVMFunctionDefinition[Pre])( implicit o: Origin ): Expr[Post] = Result[Post](llvmFunctionMap.ref(ref.decl)) @@ -148,7 +532,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) effectively transforming the CFG into a tree. More efficient restructuring algorithms but this works for now. This of course only works for acyclic CFGs as otherwise replacement would be infinitely recursive. - Loop restructuring should be handled by VCLLVM as it has much more analytical and contextual information about + Loop restructuring should be handled by pallas as it has much more analytical and contextual information about the program. */ case class GotoEliminator(bodyScope: Scope[Pre]) extends LazyLogging { @@ -157,9 +541,9 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case block: Block[Pre] => block.statements.map { case label: Label[Pre] => (label.decl, label) - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) }.toMap - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) } def eliminate(): Scope[Post] = { @@ -171,12 +555,12 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case bodyBlock: Block[Pre] => Block[Post](bodyBlock.statements.head match { case label: Label[Pre] => Seq(eliminate(label)) - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) })(scope.body.o) - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) }, )(scope.o) - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) } } @@ -193,16 +577,16 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case _: Return[Pre] => rw.dispatch(block) match { case block: Block[Post] => block - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) } case branch: Branch[Pre] => Block[Post]( block.statements.dropRight(1).map(rw.dispatch) :+ eliminate(branch) ) - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) } - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) } } @@ -213,10 +597,27 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) rw.dispatch(bs._1), bs._2 match { case goto: Goto[Pre] => eliminate(labelDeclMap(goto.lbl.decl)) - case other => throw UnexpectedLlvmNode(other) + case other => throw UnexpectedLLVMNode(other) }, ) )) } } + + def structType(t: LLVMTStruct[Pre]): Type[Post] = { + val targetClass = new LazyRef[Post, Class[Post]](structMap(t)) + TClass[Post](targetClass, Seq())(t.o) + } + + def pointerType(t: LLVMTPointer[Pre]): Type[Post] = + t.innerType match { + case Some(innerType) => TPointer[Post](rw.dispatch(innerType))(t.o) + case None => TPointer[Post](TAny())(t.o) + } + + def arrayType(t: LLVMTArray[Pre]): Type[Post] = + TPointer(rw.dispatch(t.elementType))(t.o) + + def vectorType(t: LLVMTVector[Pre]): Type[Post] = + TPointer(rw.dispatch(t.elementType))(t.o) } diff --git a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala index 213997930e..b1b3fd9802 100644 --- a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala @@ -13,6 +13,7 @@ import vct.col.rewrite.{ Rewriter, RewriterBuilderArg, RewriterBuilderArg2, + Rewritten, } import vct.result.VerificationError.UserError import vct.rewrite.lang.LangSpecificToCol.NotAValue @@ -22,6 +23,12 @@ case object LangSpecificToCol extends RewriterBuilderArg2[Boolean, Boolean] { override def desc: String = "Translate language-specific constructs to a common subset of nodes." + override def apply[Pre <: Generation]( + veymontGeneratePermissions: Boolean, + veymontAllowAssign: Boolean, + ): AbstractRewriter[Pre, _ <: Generation] = + LangSpecificToCol(veymontGeneratePermissions, veymontAllowAssign, Seq()) + def ThisVar(): Origin = Origin(Seq(PreferredName(Seq("this")), LabelContext("constructor this"))) @@ -35,6 +42,7 @@ case object LangSpecificToCol extends RewriterBuilderArg2[Boolean, Boolean] { case class LangSpecificToCol[Pre <: Generation]( veymontGeneratePermissions: Boolean = false, veymontAllowAssign: Boolean = false, + importedDeclarations: Seq[GlobalDeclaration[Pre]] = Seq(), ) extends Rewriter[Pre] with LazyLogging { val java: LangJavaToCol[Pre] = LangJavaToCol(this) val bip: LangBipToCol[Pre] = LangBipToCol(this) @@ -194,8 +202,9 @@ case class LangSpecificToCol[Pre <: Generation]( cpp.storeIfSYCLFunction(func) } - case func: LlvmFunctionDefinition[Pre] => llvm.rewriteFunctionDef(func) - case global: LlvmGlobal[Pre] => llvm.rewriteGlobal(global) + case func: LLVMFunctionDefinition[Pre] => llvm.rewriteFunctionDef(func) + case global: LLVMGlobalSpecification[Pre] => llvm.rewriteGlobal(global) + case global: LLVMGlobalVariable[Pre] => llvm.rewriteGlobalVariable(global) case cls: Class[Pre] => currentClass.having(cls) { @@ -265,6 +274,7 @@ case class LangSpecificToCol[Pre <: Generation]( rewriteDefault(unfold) } + case store: LLVMStore[Pre] => llvm.rewriteStore(store) case other => rewriteDefault(other) } @@ -281,7 +291,7 @@ case class LangSpecificToCol[Pre <: Generation]( case ref: RefCGlobalDeclaration[Pre] => c.result(ref) case ref: RefCPPFunctionDefinition[Pre] => cpp.result(ref) case ref: RefCPPGlobalDeclaration[Pre] => cpp.result(ref) - case ref: RefLlvmFunctionDefinition[Pre] => llvm.result(ref) + case ref: RefLLVMFunctionDefinition[Pre] => llvm.result(ref) case RefFunction(decl) => Result[Post](anySucc(decl)) case RefProcedure(decl) => Result[Post](anySucc(decl)) case RefJavaMethod(decl) => Result[Post](java.javaMethod.ref(decl)) @@ -290,7 +300,7 @@ case class LangSpecificToCol[Pre <: Generation]( case RefInstanceMethod(decl) => Result[Post](anySucc(decl)) case RefInstanceOperatorFunction(decl) => Result[Post](anySucc(decl)) case RefInstanceOperatorMethod(decl) => Result[Post](anySucc(decl)) - case RefLlvmSpecFunction(decl) => Result[Post](anySucc(decl)) + case RefLLVMSpecFunction(decl) => Result[Post](anySucc(decl)) } case diz @ AmbiguousThis() => currentThis.top @@ -381,11 +391,18 @@ case class LangSpecificToCol[Pre <: Generation]( silver.adtInvocation(inv) case map: SilverUntypedNonemptyLiteralMap[Pre] => silver.nonemptyMap(map) - case inv: LlvmFunctionInvocation[Pre] => + case inv: LLVMFunctionInvocation[Pre] => llvm.rewriteFunctionInvocation(inv) - case inv: LlvmAmbiguousFunctionInvocation[Pre] => + case inv: LLVMAmbiguousFunctionInvocation[Pre] => llvm.rewriteAmbiguousFunctionInvocation(inv) - case local: LlvmLocal[Pre] => llvm.rewriteLocal(local) + case local: LLVMLocal[Pre] => llvm.rewriteLocal(local) + case pointer: LLVMFunctionPointerValue[Pre] => + llvm.rewriteFunctionPointer(pointer) + case pointer: LLVMPointerValue[Pre] => llvm.rewritePointerValue(pointer) + case gep: LLVMGetElementPointer[Pre] => llvm.rewriteGetElementPointer(gep) + case load: LLVMLoad[Pre] => llvm.rewriteLoad(load) + case alloc: LLVMAllocA[Pre] => llvm.rewriteAllocA(alloc) + case int: LLVMIntegerValue[Pre] => IntegerValue(int.value)(int.o) case other => rewriteDefault(other) } @@ -398,6 +415,15 @@ case class LangSpecificToCol[Pre <: Generation]( case t: TOpenCLVector[Pre] => c.vectorType(t) case t: CTArray[Pre] => c.arrayType(t) case t: CTStruct[Pre] => c.structType(t) + case t: LLVMTInt[Pre] => TInt()(t.o) + case t: LLVMTStruct[Pre] => llvm.structType(t) + case t: LLVMTPointer[Pre] => llvm.pointerType(t) + case t: LLVMTArray[Pre] => llvm.arrayType(t) + case t: LLVMTVector[Pre] => llvm.vectorType(t) + case t: LLVMTMetadata[Pre] => + TInt()( + t.o + ) // TODO: Ignore these by just assuming they're integers... or could we do TVoid? case t: CPPTArray[Pre] => cpp.arrayType(t) case other => rewriteDefault(other) } From 5cb4a2e0265b7385f9450d24b7427a8fe23f50e6 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 30 May 2024 17:10:11 +0200 Subject: [PATCH 02/47] Fixed the LLVM tests, add a blame for LLVM generated nodes, add formatting hook with clang-format --- src/col/vct/col/ast/Deserialize.scala | 7 +- .../vct/col/serialize/SerializeBlame.scala | 3 +- src/col/vct/col/typerules/CoercionUtils.scala | 1 + .../ast/helpers/generator/Deserialize.scala | 12 ++-- .../helpers/generator/DeserializeFamily.scala | 6 +- .../Passes/Function/FunctionBodyTransformer.h | 13 ++++ .../Function/FunctionBodyTransformer.cpp | 71 +++++++++++++++++-- src/llvm/lib/Transform/BlockTransform.cpp | 6 +- .../Instruction/BinaryOpTransform.cpp | 1 + .../Instruction/OtherOpTransform.cpp | 4 +- src/main/vct/importer/Util.scala | 6 +- .../vct/parsers/parser/ColLLVMParser.scala | 3 +- .../{VcllvmSpec.scala => LLVMSpec.scala} | 4 +- .../integration/meta/ExampleCoverage.scala | 2 +- util/githooks/pre-commit | 27 ++++++- 15 files changed, 135 insertions(+), 31 deletions(-) rename test/main/vct/test/integration/examples/{VcllvmSpec.scala => LLVMSpec.scala} (76%) diff --git a/src/col/vct/col/ast/Deserialize.scala b/src/col/vct/col/ast/Deserialize.scala index f9dcb4f35c..99b00a587c 100644 --- a/src/col/vct/col/ast/Deserialize.scala +++ b/src/col/vct/col/ast/Deserialize.scala @@ -1,12 +1,15 @@ package vct.col.ast import vct.col.ast.ops.deserialize.DeserializeProgram +import vct.col.origin.{Blame, Origin, VerificationFailure} import scala.collection.mutable object Deserialize { def deserializeProgram[G]( program: vct.col.ast.serialize.Program, - y: scala.Any, - ): Program[G] = { DeserializeProgram.deserialize(program, mutable.HashMap()) } + blameProvider: Origin => Blame[VerificationFailure], + ): Program[G] = { + DeserializeProgram.deserialize(program, mutable.HashMap(), blameProvider) + } } diff --git a/src/col/vct/col/serialize/SerializeBlame.scala b/src/col/vct/col/serialize/SerializeBlame.scala index cd19ca94e1..e8dd5259ef 100644 --- a/src/col/vct/col/serialize/SerializeBlame.scala +++ b/src/col/vct/col/serialize/SerializeBlame.scala @@ -10,7 +10,8 @@ object SerializeBlame { @unused blame: ser.Blame, origin: Origin, - ): Blame[T] = origin + blameProvider: Origin => Blame[VerificationFailure], + ): Blame[T] = blameProvider(origin) def serialize(blame: Blame[_]): ser.Blame = ser.Blame(ser.Blame.Blame.BlameInput(ser.BlameInput())) diff --git a/src/col/vct/col/typerules/CoercionUtils.scala b/src/col/vct/col/typerules/CoercionUtils.scala index dcffe8c925..fbe1bea2fa 100644 --- a/src/col/vct/col/typerules/CoercionUtils.scala +++ b/src/col/vct/col/typerules/CoercionUtils.scala @@ -198,6 +198,7 @@ case object CoercionUtils { )) case (TCInt(), TInt()) => CoerceCIntInt() case (LLVMTInt(_), TInt()) => CoerceLLVMIntInt() + case (TInt(), LLVMTInt(_)) => CoerceIdentity(target) case (TBoundedInt(gte, lt), TFraction()) if gte >= 1 && lt <= 2 => CoerceBoundIntFrac() diff --git a/src/helpers/vct/col/ast/helpers/generator/Deserialize.scala b/src/helpers/vct/col/ast/helpers/generator/Deserialize.scala index e74528ce42..896c1e0778 100644 --- a/src/helpers/vct/col/ast/helpers/generator/Deserialize.scala +++ b/src/helpers/vct/col/ast/helpers/generator/Deserialize.scala @@ -26,7 +26,7 @@ class Deserialize extends NodeGenerator { package $DeserializePackage object ${deserializeObjectName(node)} { - def deserialize[G](node: ${scalapbType(node.name)}, decls: $MutMap[$Long, $Declaration[G]]): ${typ(node)}[G] = + def deserialize[G](node: ${scalapbType(node.name)}, decls: $MutMap[$Long, $Declaration[G]], blameProvider: $Origin => $Blame[$VerificationFailure]): ${typ(node)}[G] = ${deserializeNode(q"node", node)} } """ @@ -47,7 +47,7 @@ class Deserialize extends NodeGenerator { q""" { val `~o`: $Origin = $SerializeOrigin.deserialize($term.origin) - new ${t"${typ(node)}[G]"}(..$fields)($SerializeBlame.deserialize($term.blame, `~o`))(`~o`) + new ${t"${typ(node)}[G]"}(..$fields)($SerializeBlame.deserialize($term.blame, `~o`, blameProvider))(`~o`) } """ @@ -76,7 +76,7 @@ class Deserialize extends NodeGenerator { case (Proto.Repeated(pt), ST.Seq(st)) => q"$term.map(`~x` => ${deserializeTerm(q"`~x`", pt, st)})" case (Proto.Repeated(_), ST.DeclarationSeq(name)) => - q"$term.map(`~x` => ${deserializeFamilyObject(name)}.deserialize[G](`~x`.parseAs[${scalapbType(name)}], decls))" + q"$term.map(`~x` => ${deserializeFamilyObject(name)}.deserialize[G](`~x`.parseAs[${scalapbType(name)}], decls, blameProvider))" case (pt, st) => err(pt, st) } @@ -89,9 +89,9 @@ class Deserialize extends NodeGenerator { case (_, ST.Seq(ST.ExpectedError)) => q"$SeqObj.empty" case (Proto.FamilyType(_), ST.Node(name)) => - q"${deserializeFamilyObject(name)}.deserialize[G]($term.parseAs[${scalapbType(name)}], decls)" + q"${deserializeFamilyObject(name)}.deserialize[G]($term.parseAs[${scalapbType(name)}], decls, blameProvider)" case (Proto.FamilyType(_), ST.Declaration(name)) => - q"${deserializeFamilyObject(name)}.deserialize[G]($term.parseAs[${scalapbType(name)}], decls)" + q"${deserializeFamilyObject(name)}.deserialize[G]($term.parseAs[${scalapbType(name)}], decls, blameProvider)" case (Proto.StandardType(_), ST.Ref(node)) => q"new ${Init(t"$LazyRef[G, ${typ(node.name)}[G]]", Name.Anonymous(), List(List(q"decls($term.id)")))}" case (Proto.StandardType(_), ST.MultiRef(node)) => @@ -103,7 +103,7 @@ class Deserialize extends NodeGenerator { case (Proto.AuxType(_), ST.Seq(structuralType)) => q"$term.value.map(`~x` => ${deserializeTerm(q"`~x`", ProtoNaming.getType(structuralType).t, structuralType)})" case (Proto.AuxType(_), ST.DeclarationSeq(name)) => - q"$term.value.map(`~x` => ${deserializeFamilyObject(name)}.deserialize[G](`~x`.parseAs[${scalapbType(name)}], decls))" + q"$term.value.map(`~x` => ${deserializeFamilyObject(name)}.deserialize[G](`~x`.parseAs[${scalapbType(name)}], decls, blameProvider))" case (Proto.AuxType(_), ST.Option(structuralType)) => q"$term.value.map(`~x` => ${deserializeTerm(q"`~x`", ProtoNaming.getPrimitiveType(structuralType).t, structuralType)})" case (Proto.AuxType(_), ST.Either(left, right)) => diff --git a/src/helpers/vct/col/ast/helpers/generator/DeserializeFamily.scala b/src/helpers/vct/col/ast/helpers/generator/DeserializeFamily.scala index d074374663..0ec76a15ed 100644 --- a/src/helpers/vct/col/ast/helpers/generator/DeserializeFamily.scala +++ b/src/helpers/vct/col/ast/helpers/generator/DeserializeFamily.scala @@ -24,9 +24,9 @@ class DeserializeFamily extends FamilyGenerator { package $DeserializePackage object ${Term.Name(deserializeFamilyName(name))} { - def deserialize[G](node: ${scalapbType(name)}, decls: $MutMap[$Long, $Declaration[G]]): ${typ(name)}[G] = + def deserialize[G](node: ${scalapbType(name)}, decls: $MutMap[$Long, $Declaration[G]], blameProvider: $Origin => $Blame[$VerificationFailure]): ${typ(name)}[G] = ${if (nodes == Seq(name)) - q"${Term.Name(deserializeName(name))}.deserialize[G](node, decls)" + q"${Term.Name(deserializeName(name))}.deserialize[G](node, decls, blameProvider)" else deserializeOneof(q"node.v", name, nodes)} } @@ -44,6 +44,6 @@ class DeserializeFamily extends FamilyGenerator { Case( Lit.Int(number), None, - q"${deserializeObject(node)}.deserialize[G]($term.value.asInstanceOf[${scalapbType(node)}], decls)", + q"${deserializeObject(node)}.deserialize[G]($term.value.asInstanceOf[${scalapbType(node)}], decls, blameProvider)", ) } diff --git a/src/llvm/include/Passes/Function/FunctionBodyTransformer.h b/src/llvm/include/Passes/Function/FunctionBodyTransformer.h index 3cc962b394..ce34ab2688 100644 --- a/src/llvm/include/Passes/Function/FunctionBodyTransformer.h +++ b/src/llvm/include/Passes/Function/FunctionBodyTransformer.h @@ -55,6 +55,13 @@ class FunctionCursor { /// excludes possible future phi node back transformations. std::set completedColBlocks; + /// set of all COL blocks that we have started transforming. + std::set visitedColBlocks; + + /// map of assignments which should be added to the basic block when it is + /// completed. + std::unordered_multimap phiAssignBuffer; + /// Almost always when adding a variable to the variableMap, some extra /// processing is required which is why this method is private as to not /// accidentally use it outside the functionCursor @@ -102,6 +109,10 @@ class FunctionCursor { col::Assign &createAssignment(Instruction &llvmInstruction, col::Block &colBlock, col::Variable &varDecl); + col::Assign &createPhiAssignment(Instruction &llvmInstruction, + col::Block &colBlock, + col::Variable &varDecl); + col::Variable &getVariableMapEntry(llvm::Value &llvmValue, bool inPhiNode); /** @@ -120,6 +131,8 @@ class FunctionCursor { LabeledColBlock & getOrSetLLVMBlock2LabeledColBlockEntry(BasicBlock &llvmBlock); + LabeledColBlock &visitLLVMBlock(BasicBlock &llvmBlock); + llvm::FunctionAnalysisManager &getFunctionAnalysisManager(); /** diff --git a/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp b/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp index 215d9ae74e..928f1f201a 100644 --- a/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp +++ b/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp @@ -46,18 +46,32 @@ col::Variable &FunctionCursor::getVariableMapEntry(Value &llvmValue, } col::Variable *colVar = new col::Variable(); + llvm2col::setColNodeId(colVar); addVariableMapEntry(llvmValue, *colVar); return *colVar; } } bool FunctionCursor::isVisited(BasicBlock &llvmBlock) { - return llvmBlock2LabeledColBlock.contains(&llvmBlock); + return visitedColBlocks.contains( + &this->getOrSetLLVMBlock2LabeledColBlockEntry(llvmBlock).block); } void FunctionCursor::complete(col::Block &colBlock) { + int lastIndex = colBlock.statements_size() - 1; + bool found = false; + auto range = phiAssignBuffer.equal_range(&colBlock); + for (auto it = range.first; it != range.second; ++it) { + found = true; + colBlock.add_statements()->set_allocated_assign(it->second); + } + if (found) { + colBlock.mutable_statements()->SwapElements( + colBlock.statements_size() - 1, lastIndex); + } completedColBlocks.insert(&colBlock); } + bool FunctionCursor::isComplete(col::Block &colBlock) { return completedColBlocks.contains(&colBlock); } @@ -87,6 +101,13 @@ FunctionCursor::getOrSetLLVMBlock2LabeledColBlockEntry(BasicBlock &llvmBlock) { return llvmBlock2LabeledColBlock.at(&llvmBlock); } +LabeledColBlock &FunctionCursor::visitLLVMBlock(BasicBlock &llvmBlock) { + LabeledColBlock &labeledBlock = + this->getOrSetLLVMBlock2LabeledColBlockEntry(llvmBlock); + visitedColBlocks.insert(&labeledBlock.block); + return labeledBlock; +} + LoopInfo &FunctionCursor::getLoopInfo() { return FAM.getResult(llvmFunction); } @@ -105,8 +126,19 @@ FDResult &FunctionCursor::getFDResult(Function &otherLLVMFunction) { col::Variable &FunctionCursor::declareVariable(Instruction &llvmInstruction, Type *llvmPointerType) { - // create declaration in buffer - col::Variable *varDecl = functionScope.add_locals(); + col::Variable *varDecl; + if (auto variablePair = variableMap.find(&llvmInstruction); + variablePair != variableMap.end()) { + varDecl = functionScope.add_locals(); + *varDecl = *variablePair->second; + } else { + // create declaration in buffer + varDecl = functionScope.add_locals(); + // set id + llvm2col::setColNodeId(varDecl); + // add to the variable lut + this->addVariableMapEntry(llvmInstruction, *varDecl); + } // set type of declaration try { if (llvmPointerType == nullptr) { @@ -121,13 +153,9 @@ col::Variable &FunctionCursor::declareVariable(Instruction &llvmInstruction, errorStream << e.what() << " in variable declaration."; ErrorReporter::addError(SOURCE_LOC, errorStream.str(), llvmInstruction); } - // set id - llvm2col::setColNodeId(varDecl); // set origin varDecl->set_allocated_origin( llvm2col::generateSingleStatementOrigin(llvmInstruction)); - // add to the variable lut - this->addVariableMapEntry(llvmInstruction, *varDecl); return *varDecl; } @@ -162,6 +190,35 @@ col::Assign &FunctionCursor::createAssignment(Instruction &llvmInstruction, return *assignment; } +col::Assign &FunctionCursor::createPhiAssignment(Instruction &llvmInstruction, + col::Block &colBlock, + col::Variable &varDecl) { + col::Assign *assignment = new col::Assign(); + assignment->set_allocated_blame(new col::Blame()); + assignment->set_allocated_origin( + llvm2col::generateSingleStatementOrigin(llvmInstruction)); + // create local target in buffer and set origin + col::Local *colLocal = assignment->mutable_target()->mutable_local(); + colLocal->set_allocated_origin( + llvm2col::generateAssignTargetOrigin(llvmInstruction)); + // set target to refer to var decl + colLocal->mutable_ref()->set_id(varDecl.id()); + if (isComplete(colBlock)) { + // if the colBlock is completed, the assignment will be inserted after + // the goto/branch statement this can occur due to e.g. phi nodes back + // tracking assignments in their origin blocks. therefore we need to + // swap the last two elements of the block (i.e. the goto statement and + // the newest assignment) + int lastIndex = colBlock.statements_size() - 1; + colBlock.add_statements()->set_allocated_assign(assignment); + colBlock.mutable_statements()->SwapElements(lastIndex, lastIndex - 1); + } else { + // Buffer the phi assignments so they appear at the end + phiAssignBuffer.insert({&colBlock, assignment}); + } + return *assignment; +} + llvm::FunctionAnalysisManager &FunctionCursor::getFunctionAnalysisManager() { return FAM; } diff --git a/src/llvm/lib/Transform/BlockTransform.cpp b/src/llvm/lib/Transform/BlockTransform.cpp index c32a1c05c7..9ce1085009 100644 --- a/src/llvm/lib/Transform/BlockTransform.cpp +++ b/src/llvm/lib/Transform/BlockTransform.cpp @@ -13,10 +13,10 @@ const std::string SOURCE_LOC = "Transform::BlockTransform"; void llvm2col::transformLLVMBlock(llvm::BasicBlock &llvmBlock, pallas::FunctionCursor &functionCursor) { - if (functionCursor.isVisited(llvmBlock)) + if (functionCursor.isVisited(llvmBlock)) { return; - col::Block &colBlock = - functionCursor.getOrSetLLVMBlock2LabeledColBlockEntry(llvmBlock).block; + } + col::Block &colBlock = functionCursor.visitLLVMBlock(llvmBlock).block; /* for (auto *B : llvm::predecessors(&llvmBlock)) { */ /* if (!functionCursor.isVisited(*B)) */ /* return; */ diff --git a/src/llvm/lib/Transform/Instruction/BinaryOpTransform.cpp b/src/llvm/lib/Transform/Instruction/BinaryOpTransform.cpp index de070a21e2..1794174b7b 100644 --- a/src/llvm/lib/Transform/Instruction/BinaryOpTransform.cpp +++ b/src/llvm/lib/Transform/Instruction/BinaryOpTransform.cpp @@ -42,6 +42,7 @@ void llvm2col::transformBinaryOp(llvm::Instruction &llvmInstruction, } col::FloorDiv &expr = *assignment.mutable_value()->mutable_floor_div(); transformBinExpr(llvmInstruction, expr, funcCursor); + expr.set_allocated_blame(new col::Blame()); break; } // TODO: All of these are currently bitwise operators, verify that works diff --git a/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp b/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp index 81453faa62..43e2e4d2e0 100644 --- a/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp +++ b/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp @@ -34,8 +34,8 @@ void llvm2col::transformPhi(llvm::PHINode &phiInstruction, // add assignment of the variable to target block col::Block &targetBlock = funcCursor.getOrSetLLVMBlock2LabeledColBlockEntry(*B).block; - col::Assign &assignment = - funcCursor.createAssignment(phiInstruction, targetBlock, varDecl); + col::Assign &assignment = funcCursor.createPhiAssignment( + phiInstruction, targetBlock, varDecl); // assign correct value by looking at the value-block pair of phi // instruction. col::Expr *value = assignment.mutable_value(); diff --git a/src/main/vct/importer/Util.scala b/src/main/vct/importer/Util.scala index c007cf334b..b0bf5c374d 100644 --- a/src/main/vct/importer/Util.scala +++ b/src/main/vct/importer/Util.scala @@ -65,8 +65,10 @@ case object Util extends LazyLogging { } Using(Files.newInputStream(result)) { in => - Deserialize - .deserializeProgram[G](vct.col.ast.serialize.Program.parseFrom(in), 0) + Deserialize.deserializeProgram[G]( + vct.col.ast.serialize.Program.parseFrom(in), + o => o, + ) }.get } diff --git a/src/parsers/vct/parsers/parser/ColLLVMParser.scala b/src/parsers/vct/parsers/parser/ColLLVMParser.scala index 510c2d015a..7cbc8fe295 100644 --- a/src/parsers/vct/parsers/parser/ColLLVMParser.scala +++ b/src/parsers/vct/parsers/parser/ColLLVMParser.scala @@ -82,8 +82,9 @@ case class ColLLVMParser( ) } + // Use the origin in the blame provider val COLProgram = Deserialize - .deserializeProgram[G](protoProgram, readable.fileName) + .deserializeProgram[G](protoProgram, _ => blameProvider.apply()) ParseResult(COLProgram.declarations, Seq.empty) } diff --git a/test/main/vct/test/integration/examples/VcllvmSpec.scala b/test/main/vct/test/integration/examples/LLVMSpec.scala similarity index 76% rename from test/main/vct/test/integration/examples/VcllvmSpec.scala rename to test/main/vct/test/integration/examples/LLVMSpec.scala index d6f5752d5e..508bd05ef4 100644 --- a/test/main/vct/test/integration/examples/VcllvmSpec.scala +++ b/test/main/vct/test/integration/examples/LLVMSpec.scala @@ -2,11 +2,11 @@ package vct.test.integration.examples import vct.test.integration.helper.VercorsSpec -class VcllvmSpec extends VercorsSpec { +class LLVMSpec extends VercorsSpec { vercors should verify using silicon example "concepts/llvm/cantor.c" vercors should verify using silicon example "concepts/llvm/cantor.ll" vercors should verify using silicon example "concepts/llvm/date.c" - vercors should verify using silicon example "concepts/llvm/date.ll" + vercors should fail withCode "preFailed:false" using silicon example "concepts/llvm/date.ll" vercors should verify using silicon example "concepts/llvm/fib.c" vercors should verify using silicon example "concepts/llvm/fib.ll" } diff --git a/test/main/vct/test/integration/meta/ExampleCoverage.scala b/test/main/vct/test/integration/meta/ExampleCoverage.scala index a443e18ee9..0de6dbc107 100644 --- a/test/main/vct/test/integration/meta/ExampleCoverage.scala +++ b/test/main/vct/test/integration/meta/ExampleCoverage.scala @@ -56,7 +56,7 @@ class ExampleCoverage extends AnyFlatSpec { new TechnicalVeyMontExamplesSpec(), new TerminationSpec(), new TypeValuesSpec(), - new VcllvmSpec(), + new LLVMSpec(), new VerifyThisSpec(), new VeyMontToolPaperSpec(), new VeyMontExamplesSpec(), diff --git a/util/githooks/pre-commit b/util/githooks/pre-commit index 3f22973560..63f897331d 100755 --- a/util/githooks/pre-commit +++ b/util/githooks/pre-commit @@ -2,4 +2,29 @@ echo "=== pre-commit tasks ===" echo "these may be skipped with git commit --no-verify" ./mill __.reformatStaged -echo "=== end of pre-commit tasks ===" \ No newline at end of file + +if [ -z "$CLANG_FMT" ] +then + if command -v clang-format > /dev/null 2>&1 + then + CLANG_FMT=clang-format + elif command -v clang-format-18 > /dev/null 2>&1 + then + CLANG_FMT=clang-format-17 + elif command -v clang-format-16 > /dev/null 2>&1 + then + CLANG_FMT=clang-format-16 + + elif command -v clang-format-15 > /dev/null 2>&1 + then + CLANG_FMT=clang-format-15 + fi +fi + +if [ -z "$CLANG_FMT" ] +then + echo "[WARN] clang-format was not found so the cpp sources could not be formatted" + echo "[NOTE] you can set the CLANG_FMT environment variable to specify the clang-format executable that should be used if the one you want to use isn't automatically detected" +fi +find src -regex '.*\.\(h\|cpp\)' -print0 | xargs -0 $CLANG_FMT -i +echo "=== end of pre-commit tasks ===" From 0088121535ce16dbca1ba078a9d9f6ddb3ad49d3 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 31 May 2024 09:54:16 +0200 Subject: [PATCH 03/47] Use absolute path for finding Pallas shared library --- src/parsers/vct/parsers/parser/ColLLVMParser.scala | 2 +- util/githooks/pre-commit | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parsers/vct/parsers/parser/ColLLVMParser.scala b/src/parsers/vct/parsers/parser/ColLLVMParser.scala index 7cbc8fe295..6d1b836abb 100644 --- a/src/parsers/vct/parsers/parser/ColLLVMParser.scala +++ b/src/parsers/vct/parsers/parser/ColLLVMParser.scala @@ -49,7 +49,7 @@ case class ColLLVMParser( } val command = Seq( "opt-17", - s"--load-pass-plugin=$pallas", + s"--load-pass-plugin=${pallas.toAbsolutePath}", "--passes=module(pallas-declare-variables,pallas-collect-module-spec),function(pallas-declare-function,pallas-assign-pure,pallas-declare-function-contract,pallas-transform-function-body),module(pallas-print-protobuf)", readable.fileName, "--disable-output", diff --git a/util/githooks/pre-commit b/util/githooks/pre-commit index 63f897331d..73bb81c003 100755 --- a/util/githooks/pre-commit +++ b/util/githooks/pre-commit @@ -8,7 +8,7 @@ then if command -v clang-format > /dev/null 2>&1 then CLANG_FMT=clang-format - elif command -v clang-format-18 > /dev/null 2>&1 + elif command -v clang-format-17 > /dev/null 2>&1 then CLANG_FMT=clang-format-17 elif command -v clang-format-16 > /dev/null 2>&1 From 7b884e63002294f8a4e6226e88ecf28824718b2c Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 31 May 2024 11:25:57 +0200 Subject: [PATCH 04/47] Remove pallas binary from output jar --- build.sc | 1 - src/parsers/vct/parsers/parser/ColLLVMParser.scala | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/build.sc b/build.sc index 4304306658..30e13dc5ac 100644 --- a/build.sc +++ b/build.sc @@ -433,7 +433,6 @@ object vercors extends Module { ivy"org.apache.logging.log4j:log4j-to-slf4j:2.23.1", ) override def moduleDeps = Seq(hre, col, serialize) - override def unmanagedClasspath = super.unmanagedClasspath() ++ Agg(PathRef(pallas.compile().path / os.up)) val includePallasCross = interp.watchValue { if(os.exists(settings.root / ".include-pallas")) { diff --git a/src/parsers/vct/parsers/parser/ColLLVMParser.scala b/src/parsers/vct/parsers/parser/ColLLVMParser.scala index 6d1b836abb..7cbc8fe295 100644 --- a/src/parsers/vct/parsers/parser/ColLLVMParser.scala +++ b/src/parsers/vct/parsers/parser/ColLLVMParser.scala @@ -49,7 +49,7 @@ case class ColLLVMParser( } val command = Seq( "opt-17", - s"--load-pass-plugin=${pallas.toAbsolutePath}", + s"--load-pass-plugin=$pallas", "--passes=module(pallas-declare-variables,pallas-collect-module-spec),function(pallas-declare-function,pallas-assign-pure,pallas-declare-function-contract,pallas-transform-function-body),module(pallas-print-protobuf)", readable.fileName, "--disable-output", From d52394f1c782cb5d9ab44c9624ab9207f6ededb0 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Tue, 14 May 2024 16:30:42 +0200 Subject: [PATCH 05/47] Create separate classes for by reference and by value classes --- src/col/vct/col/ast/Node.scala | 59 ++++- .../global/ByReferenceClassImpl.scala | 10 + .../declaration/global/ByValueClassImpl.scala | 12 + .../ast/declaration/global/ClassImpl.scala | 28 +- .../declaration/singular/EndpointImpl.scala | 4 +- .../singular/LocalHeapVariableImpl.scala | 18 ++ .../expr/ambiguous/AmbiguousPlusImpl.scala | 14 +- .../ast/expr/context/AmbiguousThisImpl.scala | 5 +- .../col/ast/expr/context/ThisObjectImpl.scala | 4 +- .../ast/expr/heap/alloc/NewObjectImpl.scala | 4 +- .../vct/col/ast/expr/misc/HeapLocalImpl.scala | 17 ++ .../family/coercion/CoerceNullClassImpl.scala | 2 +- .../family/coercion/CoerceSupportsImpl.scala | 4 +- .../nonexecutable/HeapLocalDeclImpl.scala | 14 + .../col/ast/type/TByReferenceClassImpl.scala | 8 + .../vct/col/ast/type/TByValueClassImpl.scala | 8 + src/col/vct/col/ast/type/TClassImpl.scala | 24 +- .../col/ast/unsorted/ConstructorImpl.scala | 3 +- .../col/ast/unsorted/PVLEndpointImpl.scala | 2 +- src/col/vct/col/origin/Blame.scala | 21 +- src/col/vct/col/origin/Origin.scala | 3 + src/col/vct/col/resolve/Resolve.scala | 9 +- src/col/vct/col/resolve/lang/Java.scala | 3 +- src/col/vct/col/resolve/lang/PVL.scala | 12 +- src/col/vct/col/resolve/lang/Spec.scala | 24 +- .../vct/col/typerules/CoercingRewriter.scala | 13 +- src/col/vct/col/typerules/CoercionUtils.scala | 19 +- src/col/vct/col/typerules/Types.scala | 4 +- src/col/vct/col/util/AstBuildHelpers.scala | 8 + src/llvm/tools/vcllvm/VCLLVM.cpp | 8 +- src/main/vct/main/stages/Transformation.scala | 2 + .../vct/parsers/transform/PVLToCol.scala | 2 +- .../systemctocol/engine/ClassTransformer.java | 8 +- .../engine/KnownTypeTransformer.java | 10 +- .../systemctocol/engine/MainTransformer.java | 6 +- .../rewrite/CheckContractSatisfiability.scala | 36 +-- .../vct/rewrite/CheckProcessAlgebra.scala | 2 +- src/rewrite/vct/rewrite/ClassToRef.scala | 32 ++- .../vct/rewrite/ConstantifyFinalFields.scala | 22 +- .../vct/rewrite/EncodeArrayValues.scala | 4 +- src/rewrite/vct/rewrite/EncodeAutoValue.scala | 52 ++-- .../vct/rewrite/EncodeByValueClass.scala | 249 ++++++++++++++++++ src/rewrite/vct/rewrite/EncodeForkJoin.scala | 8 +- .../vct/rewrite/EncodeIntrinsicLock.scala | 4 +- .../vct/rewrite/EncodeResourceValues.scala | 22 +- .../ExtractInlineQuantifierPatterns.scala | 72 ++--- .../vct/rewrite/MonomorphizeClass.scala | 64 +++-- .../MonomorphizeContractApplicables.scala | 8 +- src/rewrite/vct/rewrite/ParBlockEncoder.scala | 48 ++-- .../ResolveExpressionSideEffects.scala | 33 ++- src/rewrite/vct/rewrite/adt/ImportADT.scala | 2 +- src/rewrite/vct/rewrite/bip/EncodeBip.scala | 4 +- src/rewrite/vct/rewrite/cfg/Utils.scala | 6 +- .../vct/rewrite/exc/EncodeBreakReturn.scala | 22 +- .../vct/rewrite/lang/LangCPPToCol.scala | 45 ++-- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 114 +------- .../vct/rewrite/lang/LangJavaToCol.scala | 10 +- .../vct/rewrite/lang/LangPVLToCol.scala | 2 +- .../vct/rewrite/lang/LangSpecificToCol.scala | 12 +- .../vct/rewrite/lang/LangTypesToCol.scala | 5 +- .../vct/rewrite/lang/NoSupportSelfLoop.scala | 10 +- src/rewrite/vct/rewrite/util/Extract.scala | 4 +- .../vct/rewrite/veymont/EncodeChannels.scala | 16 +- .../rewrite/veymont/EncodeChoreography.scala | 8 +- .../EncodeChoreographyParameters.scala | 10 +- .../GenerateChoreographyPermissions.scala | 13 +- .../veymont/GenerateImplementation.scala | 18 +- .../veymont/SpecializeEndpointClasses.scala | 3 +- .../vct/helper/SimpleProgramGenerator.scala | 2 +- .../vct/test/integration/examples/CSpec.scala | 4 +- 70 files changed, 937 insertions(+), 421 deletions(-) create mode 100644 src/col/vct/col/ast/declaration/global/ByReferenceClassImpl.scala create mode 100644 src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala create mode 100644 src/col/vct/col/ast/declaration/singular/LocalHeapVariableImpl.scala create mode 100644 src/col/vct/col/ast/expr/misc/HeapLocalImpl.scala create mode 100644 src/col/vct/col/ast/statement/nonexecutable/HeapLocalDeclImpl.scala create mode 100644 src/col/vct/col/ast/type/TByReferenceClassImpl.scala create mode 100644 src/col/vct/col/ast/type/TByValueClassImpl.scala create mode 100644 src/rewrite/vct/rewrite/EncodeByValueClass.scala diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 838740745f..67e3e06e5a 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -107,6 +107,7 @@ final case class VerificationContext[G](program: Program[G])( @scopes[ModelDeclaration] @scopes[EnumConstant] @scopes[Variable] +@scopes[LocalHeapVariable] final case class Program[G](declarations: Seq[GlobalDeclaration[G]])( val blame: Blame[UnsafeCoercion] )(implicit val o: Origin) @@ -218,9 +219,17 @@ sealed trait DeclaredType[G] extends Type[G] with DeclaredTypeImpl[G] final case class TModel[G](model: Ref[G, Model[G]])( implicit val o: Origin = DiagnosticOrigin ) extends DeclaredType[G] with TModelImpl[G] -final case class TClass[G](cls: Ref[G, Class[G]], typeArgs: Seq[Type[G]])( - implicit val o: Origin = DiagnosticOrigin -) extends DeclaredType[G] with TClassImpl[G] +sealed trait TClass[G] extends DeclaredType[G] with TClassImpl[G] +final case class TByReferenceClass[G]( + cls: Ref[G, Class[G]], + typeArgs: Seq[Type[G]], +)(implicit val o: Origin = DiagnosticOrigin) + extends TClass[G] with TByReferenceClassImpl[G] +final case class TByValueClass[G]( + cls: Ref[G, Class[G]], + typeArgs: Seq[Type[G]], +)(implicit val o: Origin = DiagnosticOrigin) + extends TClass[G] with TByValueClassImpl[G] final case class TAnyClass[G]()(implicit val o: Origin = DiagnosticOrigin) extends DeclaredType[G] with TAnyClassImpl[G] final case class TAxiomatic[G]( @@ -249,6 +258,7 @@ final case class ParSequential[G](regions: Seq[ParRegion[G]])( )(implicit val o: Origin) extends ParRegion[G] with ParSequentialImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] @scopes[SendDecl] @scopes[ParBlockDecl] final case class ParBlock[G]( @@ -276,6 +286,7 @@ final case class IterationContract[G]( extends LoopContract[G] with IterationContractImpl[G] @family @scopes[Variable] +@scopes[LocalHeapVariable] final case class CatchClause[G](decl: Variable[G], body: Statement[G])( implicit val o: Origin ) extends NodeFamily[G] with CatchClauseImpl[G] @@ -314,6 +325,10 @@ final case class LocalDecl[G](local: Variable[G])(implicit val o: Origin) extends NonExecutableStatement[G] with PurelySequentialStatement[G] with LocalDeclImpl[G] +final case class HeapLocalDecl[G](local: LocalHeapVariable[G])(implicit val o: Origin) + extends NonExecutableStatement[G] + with PurelySequentialStatement[G] + with HeapLocalDeclImpl[G] final case class SpecIgnoreStart[G]()(implicit val o: Origin) extends NonExecutableStatement[G] with PurelySequentialStatement[G] @@ -520,6 +535,7 @@ final case class Block[G](statements: Seq[Statement[G]])(implicit val o: Origin) with ControlContainerStatement[G] with BlockImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] @scopes[CLocalDeclaration] @scopes[CPPLocalDeclaration] @scopes[JavaLocalDeclaration] @@ -555,6 +571,7 @@ final case class Loop[G]( with ControlContainerStatement[G] with LoopImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class RangedFor[G]( iter: IterVariable[G], contract: LoopContract[G], @@ -609,6 +626,7 @@ final case class ParStatement[G](impl: ParRegion[G])(implicit val o: Origin) with PurelySequentialStatement[G] with ParStatementImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class VecBlock[G]( iters: Seq[IterVariable[G]], requires: Expr[G], @@ -643,18 +661,26 @@ final class HeapVariable[G](val t: Type[G])(implicit val o: Origin) final class SimplificationRule[G](val axiom: Expr[G])(implicit val o: Origin) extends GlobalDeclaration[G] with SimplificationRuleImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final class AxiomaticDataType[G]( val decls: Seq[ADTDeclaration[G]], val typeArgs: Seq[Variable[G]], )(implicit val o: Origin) extends GlobalDeclaration[G] with AxiomaticDataTypeImpl[G] -final class Class[G]( +sealed trait Class[G] extends GlobalDeclaration[G] with ClassImpl[G] +final class ByReferenceClass[G]( val typeArgs: Seq[Variable[G]], val decls: Seq[ClassDeclaration[G]], val supports: Seq[Type[G]], val intrinsicLockInvariant: Expr[G], )(implicit val o: Origin) - extends GlobalDeclaration[G] with ClassImpl[G] + extends Class[G] with ByReferenceClassImpl[G] +final class ByValueClass[G]( + val typeArgs: Seq[Variable[G]], + val decls: Seq[ClassDeclaration[G]], + val supports: Seq[Type[G]], +)(implicit val o: Origin) + extends Class[G] with ByValueClassImpl[G] final class Model[G](val declarations: Seq[ModelDeclaration[G]])( implicit val o: Origin ) extends GlobalDeclaration[G] with Declarator[G] with ModelImpl[G] @@ -687,6 +713,7 @@ final class VeSUVMainMethod[G](val body: Option[Statement[G]])( )(implicit val o: Origin) extends GlobalDeclaration[G] with VeSUVMainMethodImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final class Predicate[G]( val args: Seq[Variable[G]], val body: Option[Expr[G]], @@ -759,6 +786,7 @@ final class InstanceMethod[G]( with AbstractMethod[G] with InstanceMethodImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final class InstancePredicate[G]( val args: Seq[Variable[G]], val body: Option[Expr[G]], @@ -814,6 +842,7 @@ sealed trait ModelDeclaration[G] final class ModelField[G](val t: Type[G])(implicit val o: Origin) extends ModelDeclaration[G] with Field[G] with ModelFieldImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final class ModelProcess[G]( val args: Seq[Variable[G]], val impl: Expr[G], @@ -824,6 +853,7 @@ final class ModelProcess[G]( )(val blame: Blame[PostconditionFailed])(implicit val o: Origin) extends ModelDeclaration[G] with Applicable[G] with ModelProcessImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final class ModelAction[G]( val args: Seq[Variable[G]], val requires: Expr[G], @@ -838,6 +868,7 @@ sealed trait ADTDeclaration[G] extends Declaration[G] with ADTDeclarationImpl[G] final class ADTAxiom[G](val axiom: Expr[G])(implicit val o: Origin) extends ADTDeclaration[G] with ADTAxiomImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final class ADTFunction[G](val args: Seq[Variable[G]], val returnType: Type[G])( implicit val o: Origin ) extends Applicable[G] with ADTDeclaration[G] with ADTFunctionImpl[G] @@ -846,6 +877,9 @@ final class ADTFunction[G](val args: Seq[Variable[G]], val returnType: Type[G])( final class Variable[G](val t: Type[G])(implicit val o: Origin) extends Declaration[G] with VariableImpl[G] @family +final class LocalHeapVariable[G](val t: Type[G])(implicit val o: Origin) + extends Declaration[G] with LocalHeapVariableImpl[G] +@family final class LabelDecl[G]()(implicit val o: Origin) extends Declaration[G] with LabelDeclImpl[G] @family @@ -872,6 +906,7 @@ sealed trait AbstractMethod[G] sealed trait Field[G] extends FieldImpl[G] @family @scopes[Variable] +@scopes[LocalHeapVariable] final case class SignalsClause[G](binding: Variable[G], assn: Expr[G])( implicit val o: Origin ) extends NodeFamily[G] with SignalsClauseImpl[G] @@ -1260,6 +1295,7 @@ final case class MapRemove[G](map: Expr[G], k: Expr[G])(implicit val o: Origin) sealed trait Binder[G] extends Expr[G] with BinderImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class Forall[G]( bindings: Seq[Variable[G]], triggers: Seq[Seq[Expr[G]]], @@ -1267,6 +1303,7 @@ final case class Forall[G]( )(implicit val o: Origin) extends Binder[G] with ForallImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class Starall[G]( bindings: Seq[Variable[G]], triggers: Seq[Seq[Expr[G]]], @@ -1274,6 +1311,7 @@ final case class Starall[G]( )(val blame: Blame[ReceiverNotInjective])(implicit val o: Origin) extends Binder[G] with StarallImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class Exists[G]( bindings: Seq[Variable[G]], triggers: Seq[Seq[Expr[G]]], @@ -1281,6 +1319,7 @@ final case class Exists[G]( )(implicit val o: Origin) extends Binder[G] with ExistsImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class Sum[G]( bindings: Seq[Variable[G]], condition: Expr[G], @@ -1288,6 +1327,7 @@ final case class Sum[G]( )(implicit val o: Origin) extends Binder[G] with SumImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class Product[G]( bindings: Seq[Variable[G]], condition: Expr[G], @@ -1295,6 +1335,7 @@ final case class Product[G]( )(implicit val o: Origin) extends Binder[G] with ProductImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class ForPerm[G]( bindings: Seq[Variable[G]], loc: Location[G], @@ -1302,10 +1343,12 @@ final case class ForPerm[G]( )(implicit val o: Origin) extends Binder[G] with ForPermImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class ForPermWithValue[G](binding: Variable[G], body: Expr[G])( implicit val o: Origin ) extends Binder[G] with ForPermWithValueImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class Let[G](binding: Variable[G], value: Expr[G], main: Expr[G])( implicit val o: Origin ) extends Binder[G] with LetImpl[G] @@ -1317,12 +1360,16 @@ final case class InlinePattern[G]( extends Expr[G] with InlinePatternImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final case class ScopedExpr[G](declarations: Seq[Variable[G]], body: Expr[G])( implicit val o: Origin ) extends Declarator[G] with Expr[G] with ScopedExprImpl[G] final case class Local[G](ref: Ref[G, Variable[G]])(implicit val o: Origin) extends Expr[G] with LocalImpl[G] +final case class HeapLocal[G](ref: Ref[G, LocalHeapVariable[G]])(implicit val o: Origin) + extends Expr[G] with HeapLocalImpl[G] + final case class EnumUse[G]( enum: Ref[G, Enum[G]], const: Ref[G, EnumConstant[G]], @@ -3106,6 +3153,7 @@ sealed trait JavaClassOrInterface[G] with Declarator[G] with JavaClassOrInterfaceImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final class JavaClass[G]( val name: String, val modifiers: Seq[JavaModifier[G]], @@ -3117,6 +3165,7 @@ final class JavaClass[G]( )(val blame: Blame[JavaImplicitConstructorFailure])(implicit val o: Origin) extends JavaClassOrInterface[G] with JavaClassImpl[G] @scopes[Variable] +@scopes[LocalHeapVariable] final class JavaInterface[G]( val name: String, val modifiers: Seq[JavaModifier[G]], diff --git a/src/col/vct/col/ast/declaration/global/ByReferenceClassImpl.scala b/src/col/vct/col/ast/declaration/global/ByReferenceClassImpl.scala new file mode 100644 index 0000000000..ec95adbe7c --- /dev/null +++ b/src/col/vct/col/ast/declaration/global/ByReferenceClassImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.declaration.global + +import vct.col.ast.{ByReferenceClass, TByReferenceClass, Type} +import vct.col.ast.ops.ByReferenceClassOps + +trait ByReferenceClassImpl[G] extends ByReferenceClassOps[G] { + this: ByReferenceClass[G] => + override def classType(typeArgs: Seq[Type[G]]): TByReferenceClass[G] = + TByReferenceClass[G](this.ref, typeArgs) +} diff --git a/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala b/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala new file mode 100644 index 0000000000..e776781309 --- /dev/null +++ b/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala @@ -0,0 +1,12 @@ +package vct.col.ast.declaration.global + +import vct.col.ast.{ByValueClass, Expr, InstanceField, TByValueClass, Type} +import vct.col.ast.ops.ByValueClassOps +import vct.col.util.AstBuildHelpers._ + +trait ByValueClassImpl[G] extends ByValueClassOps[G] { + this: ByValueClass[G] => + override def intrinsicLockInvariant: Expr[G] = tt + override def classType(typeArgs: Seq[Type[G]]): TByValueClass[G] = + TByValueClass[G](this.ref, typeArgs) +} diff --git a/src/col/vct/col/ast/declaration/global/ClassImpl.scala b/src/col/vct/col/ast/declaration/global/ClassImpl.scala index 6b26f4e42b..7320c8dadf 100644 --- a/src/col/vct/col/ast/declaration/global/ClassImpl.scala +++ b/src/col/vct/col/ast/declaration/global/ClassImpl.scala @@ -1,18 +1,36 @@ package vct.col.ast.declaration.global -import vct.col.ast.{Class, Declaration, TClass, TVar} +import vct.col.ast.{ + Class, + ClassDeclaration, + Declaration, + Expr, + TByReferenceClass, + TClass, + TVar, + Type, + Variable, +} import vct.col.ast.util.Declarator import vct.col.print._ import vct.col.util.AstBuildHelpers.tt -import vct.result.VerificationError.Unreachable -import vct.col.ast.ops.ClassOps -trait ClassImpl[G] extends Declarator[G] with ClassOps[G] { +trait ClassImpl[G] extends Declarator[G] { this: Class[G] => + def typeArgs: Seq[Variable[G]] + def decls: Seq[ClassDeclaration[G]] + def supports: Seq[Type[G]] + def intrinsicLockInvariant: Expr[G] + + def classType(typeArgs: Seq[Type[G]]): TClass[G] + def transSupportArrowsHelper( seen: Set[TClass[G]] ): Seq[(TClass[G], TClass[G])] = { - val t: TClass[G] = TClass(this.ref, typeArgs.map(v => TVar(v.ref))) + // TODO: Does this break things if we have a ByValueClass with supers? + val t: TClass[G] = classType( + typeArgs.map((v: Variable[G]) => TVar(v.ref[Variable[G]])) + ) if (seen.contains(t)) Nil else diff --git a/src/col/vct/col/ast/declaration/singular/EndpointImpl.scala b/src/col/vct/col/ast/declaration/singular/EndpointImpl.scala index 1d2c5b607f..1d236b23f1 100644 --- a/src/col/vct/col/ast/declaration/singular/EndpointImpl.scala +++ b/src/col/vct/col/ast/declaration/singular/EndpointImpl.scala @@ -1,7 +1,7 @@ package vct.col.ast.declaration.singular import vct.col.ast.declaration.DeclarationImpl -import vct.col.ast.{Endpoint, TClass, Type} +import vct.col.ast.{Endpoint, TByReferenceClass, TClass, Type} import vct.col.print._ import vct.col.ast.ops.{EndpointFamilyOps, EndpointOps} import vct.col.check.{CheckContext, CheckError} @@ -14,7 +14,7 @@ trait EndpointImpl[G] Group(t.show <> "(" <> Doc.args(args) <> ");") }) - def t: TClass[G] = TClass(cls, typeArgs) + def t: TClass[G] = TByReferenceClass(cls, typeArgs) override def check(ctx: CheckContext[G]): Seq[CheckError] = super.check(ctx) ++ ctx.checkInScope(this, cls) diff --git a/src/col/vct/col/ast/declaration/singular/LocalHeapVariableImpl.scala b/src/col/vct/col/ast/declaration/singular/LocalHeapVariableImpl.scala new file mode 100644 index 0000000000..2814eed660 --- /dev/null +++ b/src/col/vct/col/ast/declaration/singular/LocalHeapVariableImpl.scala @@ -0,0 +1,18 @@ +package vct.col.ast.declaration.singular + +import vct.col.ast.LocalHeapVariable +import vct.col.print._ +import vct.col.ast.ops.{LocalHeapVariableOps, VariableFamilyOps} +import vct.col.ast.ops.{LocalHeapVariableOps, LocalHeapVariableFamilyOps} + +trait LocalHeapVariableImpl[G] extends LocalHeapVariableOps[G] with LocalHeapVariableFamilyOps[G] { + this: LocalHeapVariable[G] => + override def layout(implicit ctx: Ctx): Doc = + Text("@heap") <+> (ctx.syntax match { + case Ctx.C | Ctx.Cuda | Ctx.OpenCL | Ctx.CPP => + val (spec, decl) = t.layoutSplitDeclarator + spec <+> decl <> ctx.name(this) + case Ctx.PVL | Ctx.Java => t.show <+> ctx.name(this) + case Ctx.Silver => Text(ctx.name(this)) <> ":" <+> t + }) +} diff --git a/src/col/vct/col/ast/expr/ambiguous/AmbiguousPlusImpl.scala b/src/col/vct/col/ast/expr/ambiguous/AmbiguousPlusImpl.scala index 8a227e7898..73233d9105 100644 --- a/src/col/vct/col/ast/expr/ambiguous/AmbiguousPlusImpl.scala +++ b/src/col/vct/col/ast/expr/ambiguous/AmbiguousPlusImpl.scala @@ -20,7 +20,7 @@ trait AmbiguousPlusImpl[G] extends AmbiguousPlusOps[G] { right val decls = subject.t match { - case TClass(Ref(cls), _) => cls.decls + case t: TClass[G] => t.cls.decl.decls case JavaTClass(Ref(cls), _) => cls.decls case _ => return None } @@ -62,6 +62,18 @@ trait AmbiguousPlusImpl[G] extends AmbiguousPlusOps[G] { getCustomPlusOpType().get else getNumericType + if (isProcessOp) + TProcess() + else if (isSeqOp || isBagOp || isSetOp || isVectorOp) + Types.leastCommonSuperType(left.t, right.t) + else if (isPointerOp) + left.t + else if (isStringOp) + TString() + else if (getCustomPlusOpType().isDefined) + getCustomPlusOpType().get + else + getNumericType } override def precedence: Int = Precedence.ADDITIVE diff --git a/src/col/vct/col/ast/expr/context/AmbiguousThisImpl.scala b/src/col/vct/col/ast/expr/context/AmbiguousThisImpl.scala index 01b0ab843c..18cad16207 100644 --- a/src/col/vct/col/ast/expr/context/AmbiguousThisImpl.scala +++ b/src/col/vct/col/ast/expr/context/AmbiguousThisImpl.scala @@ -17,9 +17,8 @@ trait AmbiguousThisImpl[G] extends AmbiguousThisOps[G] { ) match { case RefJavaClass(decl) => JavaTClass(decl.ref, Nil) case RefClass(decl) => - TClass( - decl.ref, - decl.typeArgs.map((v: Variable[G]) => TVar(v.ref[Variable[G]])), + decl.classType( + decl.typeArgs.map((v: Variable[G]) => TVar(v.ref[Variable[G]])) ) case RefModel(decl) => TModel(decl.ref) case RefPVLChoreography(decl) => TPVLChoreography(decl.ref) diff --git a/src/col/vct/col/ast/expr/context/ThisObjectImpl.scala b/src/col/vct/col/ast/expr/context/ThisObjectImpl.scala index 25e90414f7..74a515cf8d 100644 --- a/src/col/vct/col/ast/expr/context/ThisObjectImpl.scala +++ b/src/col/vct/col/ast/expr/context/ThisObjectImpl.scala @@ -8,7 +8,9 @@ import vct.col.check.{CheckContext, CheckError, ThisInConstructorPre} trait ThisObjectImpl[G] extends ThisDeclarationImpl[G] with ThisObjectOps[G] { this: ThisObject[G] => override def t: Type[G] = - TClass(cls, cls.decl.typeArgs.map(v => TVar(v.ref[Variable[G]]))) + cls.decl.classType(cls.decl.typeArgs.map((v: Variable[G]) => + TVar(v.ref[Variable[G]]) + )) override def check(context: CheckContext[G]): Seq[CheckError] = { val inConstructor = diff --git a/src/col/vct/col/ast/expr/heap/alloc/NewObjectImpl.scala b/src/col/vct/col/ast/expr/heap/alloc/NewObjectImpl.scala index dbc0fa2953..f5e8ead476 100644 --- a/src/col/vct/col/ast/expr/heap/alloc/NewObjectImpl.scala +++ b/src/col/vct/col/ast/expr/heap/alloc/NewObjectImpl.scala @@ -1,12 +1,12 @@ package vct.col.ast.expr.heap.alloc -import vct.col.ast.{NewObject, TClass, Type} +import vct.col.ast.{NewObject, Type} import vct.col.print.{Ctx, Doc, Precedence, Text} import vct.col.ast.ops.NewObjectOps trait NewObjectImpl[G] extends NewObjectOps[G] { this: NewObject[G] => - override def t: Type[G] = TClass(cls, Seq()) + override def t: Type[G] = cls.decl.classType(Seq()) override def precedence: Int = Precedence.POSTFIX override def layout(implicit ctx: Ctx): Doc = diff --git a/src/col/vct/col/ast/expr/misc/HeapLocalImpl.scala b/src/col/vct/col/ast/expr/misc/HeapLocalImpl.scala new file mode 100644 index 0000000000..d397452dc6 --- /dev/null +++ b/src/col/vct/col/ast/expr/misc/HeapLocalImpl.scala @@ -0,0 +1,17 @@ +package vct.col.ast.expr.misc + +import vct.col.ast.expr.ExprImpl +import vct.col.ast.{HeapLocal, Type} +import vct.col.check.{CheckContext, CheckError} +import vct.col.print.{Ctx, Doc, Precedence, Text} +import vct.col.ast.ops.HeapLocalOps + +trait HeapLocalImpl[G] extends ExprImpl[G] with HeapLocalOps[G] { + this: HeapLocal[G] => + override def t: Type[G] = ref.decl.t + override def check(context: CheckContext[G]): Seq[CheckError] = + context.checkInScope(this, ref) + + override def precedence: Int = Precedence.ATOMIC + override def layout(implicit ctx: Ctx): Doc = Text(ctx.name(ref)) +} diff --git a/src/col/vct/col/ast/family/coercion/CoerceNullClassImpl.scala b/src/col/vct/col/ast/family/coercion/CoerceNullClassImpl.scala index 4dc7f0662b..01b76c7319 100644 --- a/src/col/vct/col/ast/family/coercion/CoerceNullClassImpl.scala +++ b/src/col/vct/col/ast/family/coercion/CoerceNullClassImpl.scala @@ -5,5 +5,5 @@ import vct.col.ast.ops.CoerceNullClassOps trait CoerceNullClassImpl[G] extends CoerceNullClassOps[G] { this: CoerceNullClass[G] => - override def target: TClass[G] = TClass(targetClass, typeArgs) + override def target: TClass[G] = targetClass.decl.classType(typeArgs) } diff --git a/src/col/vct/col/ast/family/coercion/CoerceSupportsImpl.scala b/src/col/vct/col/ast/family/coercion/CoerceSupportsImpl.scala index e0ce3a4353..0a04b3f6ce 100644 --- a/src/col/vct/col/ast/family/coercion/CoerceSupportsImpl.scala +++ b/src/col/vct/col/ast/family/coercion/CoerceSupportsImpl.scala @@ -7,5 +7,7 @@ trait CoerceSupportsImpl[G] extends CoerceSupportsOps[G] { this: CoerceSupports[G] => // TODO (RR): Integrate coercions with generics? override def target: TClass[G] = - TClass(targetClass, { assert(sourceClass.decl.typeArgs.isEmpty); Seq() }) + targetClass.decl.classType({ + assert(sourceClass.decl.typeArgs.isEmpty); Seq() + }) } diff --git a/src/col/vct/col/ast/statement/nonexecutable/HeapLocalDeclImpl.scala b/src/col/vct/col/ast/statement/nonexecutable/HeapLocalDeclImpl.scala new file mode 100644 index 0000000000..2b7e8650fd --- /dev/null +++ b/src/col/vct/col/ast/statement/nonexecutable/HeapLocalDeclImpl.scala @@ -0,0 +1,14 @@ +package vct.col.ast.statement.nonexecutable + +import vct.col.ast.HeapLocalDecl +import vct.col.print.{Ctx, Doc, Text} +import vct.col.ast.ops.HeapLocalDeclOps + +trait HeapLocalDeclImpl[G] extends HeapLocalDeclOps[G] { + this: HeapLocalDecl[G] => + override def layout(implicit ctx: Ctx): Doc = + ctx.syntax match { + case Ctx.Silver => Text("var") <+> local.show + case _ => local.show <> ";" + } +} diff --git a/src/col/vct/col/ast/type/TByReferenceClassImpl.scala b/src/col/vct/col/ast/type/TByReferenceClassImpl.scala new file mode 100644 index 0000000000..8a800321e9 --- /dev/null +++ b/src/col/vct/col/ast/type/TByReferenceClassImpl.scala @@ -0,0 +1,8 @@ +package vct.col.ast.`type` + +import vct.col.ast.TByReferenceClass +import vct.col.ast.ops.TByReferenceClassOps + +trait TByReferenceClassImpl[G] extends TByReferenceClassOps[G] { + this: TByReferenceClass[G] => +} diff --git a/src/col/vct/col/ast/type/TByValueClassImpl.scala b/src/col/vct/col/ast/type/TByValueClassImpl.scala new file mode 100644 index 0000000000..0f8e0f6bdc --- /dev/null +++ b/src/col/vct/col/ast/type/TByValueClassImpl.scala @@ -0,0 +1,8 @@ +package vct.col.ast.`type` + +import vct.col.ast.TByValueClass +import vct.col.ast.ops.TByValueClassOps + +trait TByValueClassImpl[G] extends TByValueClassOps[G] { + this: TByValueClass[G] => +} diff --git a/src/col/vct/col/ast/type/TClassImpl.scala b/src/col/vct/col/ast/type/TClassImpl.scala index 42a03d7983..1fd098e9c2 100644 --- a/src/col/vct/col/ast/type/TClassImpl.scala +++ b/src/col/vct/col/ast/type/TClassImpl.scala @@ -1,27 +1,23 @@ package vct.col.ast.`type` import vct.col.ast.{ - Applicable, Class, - ClassDeclaration, - Constructor, - ContractApplicable, InstanceField, - InstanceFunction, - InstanceMethod, - InstanceOperatorFunction, - InstanceOperatorMethod, + TByReferenceClass, + TByValueClass, TClass, Type, Variable, } -import vct.col.print.{Ctx, Doc, Empty, Group, Text} -import vct.col.ast.ops.TClassOps +import vct.col.print._ import vct.col.ref.Ref -import vct.result.VerificationError.Unreachable -trait TClassImpl[G] extends TClassOps[G] { +trait TClassImpl[G] { this: TClass[G] => + def cls: Ref[G, Class[G]] + + def typeArgs: Seq[Type[G]] + def transSupportArrowsHelper( seen: Set[TClass[G]] ): Seq[(TClass[G], TClass[G])] = @@ -45,7 +41,9 @@ trait TClassImpl[G] extends TClassOps[G] { def instantiate(t: Type[G]): Type[G] = this match { - case TClass(Ref(cls), typeArgs) if typeArgs.nonEmpty => + case TByReferenceClass(Ref(cls), typeArgs) if typeArgs.nonEmpty => + t.particularize(cls.typeArgs.zip(typeArgs).toMap) + case TByValueClass(Ref(cls), typeArgs) if typeArgs.nonEmpty => t.particularize(cls.typeArgs.zip(typeArgs).toMap) case _ => t } diff --git a/src/col/vct/col/ast/unsorted/ConstructorImpl.scala b/src/col/vct/col/ast/unsorted/ConstructorImpl.scala index 23f656fc03..cec2a2bdfc 100644 --- a/src/col/vct/col/ast/unsorted/ConstructorImpl.scala +++ b/src/col/vct/col/ast/unsorted/ConstructorImpl.scala @@ -8,7 +8,7 @@ trait ConstructorImpl[G] extends ConstructorOps[G] { this: Constructor[G] => override def pure: Boolean = false override def returnType: TClass[G] = - TClass(cls, cls.decl.typeArgs.map((v: Variable[G]) => TVar(v.ref))) + cls.decl.classType(cls.decl.typeArgs.map((v: Variable[G]) => TVar(v.ref))) override def layout(implicit ctx: Ctx): Doc = { Doc.stack(Seq( @@ -19,5 +19,4 @@ trait ConstructorImpl[G] extends ConstructorOps[G] { ) <> body.map(Text(" ") <> _).getOrElse(Text(";")), )) } - } diff --git a/src/col/vct/col/ast/unsorted/PVLEndpointImpl.scala b/src/col/vct/col/ast/unsorted/PVLEndpointImpl.scala index 9aea9ecdf0..04d804ee0f 100644 --- a/src/col/vct/col/ast/unsorted/PVLEndpointImpl.scala +++ b/src/col/vct/col/ast/unsorted/PVLEndpointImpl.scala @@ -8,5 +8,5 @@ trait PVLEndpointImpl[G] extends PVLEndpointOps[G] { this: PVLEndpoint[G] => // override def layout(implicit ctx: Ctx): Doc = ??? - def t: TClass[G] = TClass(cls, typeArgs) + def t: TClass[G] = cls.decl.classType(typeArgs) } diff --git a/src/col/vct/col/origin/Blame.scala b/src/col/vct/col/origin/Blame.scala index 068aa2f893..174b70f334 100644 --- a/src/col/vct/col/origin/Blame.scala +++ b/src/col/vct/col/origin/Blame.scala @@ -158,22 +158,25 @@ case class AssignFieldFailed(node: SilverFieldAssign[_]) s"Insufficient permission for assignment `$source`." } -case class CopyStructFailed(node: Expr[_], field: String) +case class CopyClassFailed(node: Node[_], clazz: ByValueClass[_], field: String) extends AssignFailed with NodeVerificationFailure { - override def code: String = "copyStructFailed" + override def code: String = "copyClassFailed" override def descInContext: String = - s"Insufficient read permission for field '$field' to copy struct." + s"Insufficient read permission for field '$field' to copy ${clazz.o + .find[TypeName].map(_.name).getOrElse("class")}." override def inlineDescWithSource(source: String): String = s"Insufficient permission for assignment `$source`." } -case class CopyStructFailedBeforeCall(node: Expr[_], field: String) - extends AssignFailed - with FrontendInvocationError - with NodeVerificationFailure { - override def code: String = "copyStructFailedBeforeCall" +case class CopyClassFailedBeforeCall( + node: Node[_], + clazz: ByValueClass[_], + field: String, +) extends AssignFailed with InvocationFailure with NodeVerificationFailure { + override def code: String = "copyClassFailedBeforeCall" override def descInContext: String = - s"Insufficient read permission for field '$field' to copy struct before call." + s"Insufficient read permission for field '$field' to copy ${clazz.o + .find[TypeName].map(_.name).getOrElse("class")} before call." override def inlineDescWithSource(source: String): String = s"Insufficient permission for call `$source`." } diff --git a/src/col/vct/col/origin/Origin.scala b/src/col/vct/col/origin/Origin.scala index 987b16e1f4..fd9b9d993d 100644 --- a/src/col/vct/col/origin/Origin.scala +++ b/src/col/vct/col/origin/Origin.scala @@ -109,6 +109,9 @@ case class SourceName(name: String) extends NameStrategy { Some(SourceName.stringToName(name)) } +// Used to disambiguate whether to show a ByValueClass as a class or a struct +case class TypeName(name: String) extends OriginContent + /** Content that provides a bit of context here. By default, this assembles * further context from the remaining origin. contextHere and inlineContextHere * may optionally consume some more contents, otherwise they can just return diff --git a/src/col/vct/col/resolve/Resolve.scala b/src/col/vct/col/resolve/Resolve.scala index 1bd8310820..f88099f286 100644 --- a/src/col/vct/col/resolve/Resolve.scala +++ b/src/col/vct/col/resolve/Resolve.scala @@ -220,7 +220,12 @@ case object ResolveTypes { Spec.findModel(name, ctx) .getOrElse(throw NoSuchNameError("model", name, t)) ) - case t @ TClass(ref, _) => + case t @ TByReferenceClass(ref, _) => + ref.tryResolve(name => + Spec.findClass(name, ctx) + .getOrElse(throw NoSuchNameError("class", name, t)) + ) + case t @ TByValueClass(ref, _) => ref.tryResolve(name => Spec.findClass(name, ctx) .getOrElse(throw NoSuchNameError("class", name, t)) @@ -643,7 +648,7 @@ case object ResolveReferences extends LazyLogging { case endpoint: PVLEndpoint[G] => endpoint.ref = Some( PVL.findConstructor( - TClass(endpoint.cls.decl.ref[Class[G]], Seq()), + TByReferenceClass(endpoint.cls.decl.ref[Class[G]], Seq()), Seq(), endpoint.args, ).getOrElse(throw ConstructorNotFound(endpoint)) diff --git a/src/col/vct/col/resolve/lang/Java.scala b/src/col/vct/col/resolve/lang/Java.scala index b65ced5463..1d734628a7 100644 --- a/src/col/vct/col/resolve/lang/Java.scala +++ b/src/col/vct/col/resolve/lang/Java.scala @@ -42,6 +42,7 @@ import vct.col.ast.{ TArray, TBag, TBool, + TByReferenceClass, TChar, TClass, TEnum, @@ -877,7 +878,7 @@ case object Java extends LazyLogging { case t: TFloat[G] => const(0) case TRational() => const(0) case TZFraction() => const(0) - case TClass(_, _) => Null() + case TByReferenceClass(_, _) => Null() case JavaTClass(_, _) => Null() case TEnum(_) => Null() case TAnyClass() => Null() diff --git a/src/col/vct/col/resolve/lang/PVL.scala b/src/col/vct/col/resolve/lang/PVL.scala index 9a8593e5a4..5de6ae8de2 100644 --- a/src/col/vct/col/resolve/lang/PVL.scala +++ b/src/col/vct/col/resolve/lang/PVL.scala @@ -13,8 +13,8 @@ case object PVL { args: Seq[Expr[G]], ): Option[PVLConstructorTarget[G]] = { t match { - case t @ TClass(Ref(cls), _) => - val resolvedCons = cls.decls.collectFirst { + case t: TClass[G] => + val resolvedCons = t.cls.decl.decls.collectFirst { case cons: PVLConstructor[G] if Util .compat(t.typeEnv, args, typeArgs, cons.args, cons.typeArgs) => @@ -23,7 +23,7 @@ case object PVL { args match { case Nil => - resolvedCons.orElse(Some(ImplicitDefaultPVLConstructor(cls))) + resolvedCons.orElse(Some(ImplicitDefaultPVLConstructor(t.cls.decl))) case _ => resolvedCons } case TModel(Ref(model)) if args.isEmpty => Some(RefModel(model)) @@ -69,7 +69,7 @@ case object PVL { ref.decl.declarations.flatMap(Referrable.from).collectFirst { case ref: RefModelField[G] if ref.name == name => ref } - case TClass(ref, _) => findDerefOfClass(ref.decl, name) + case t: TClass[G] => findDerefOfClass(t.cls.decl, name) case _ => Spec.builtinField(obj, name, blame) } @@ -94,8 +94,8 @@ case object PVL { case ref: RefModelAction[G] if ref.name == method => ref case ref: RefModelProcess[G] if ref.name == method => ref }.orElse(Spec.builtinInstanceMethod(obj, method, blame)) - case t @ TClass(ref, _) => - ref.decl.declarations.flatMap(Referrable.from).collectFirst { + case t: TClass[G] => + t.cls.decl.declarations.flatMap(Referrable.from).collectFirst { case ref: RefInstanceFunction[G] if ref.name == method && Util.compat(t.typeEnv, args, typeArgs, ref.decl) => diff --git a/src/col/vct/col/resolve/lang/Spec.scala b/src/col/vct/col/resolve/lang/Spec.scala index df62509925..52b6cd9a32 100644 --- a/src/col/vct/col/resolve/lang/Spec.scala +++ b/src/col/vct/col/resolve/lang/Spec.scala @@ -348,7 +348,11 @@ case object Spec { def findMethod[G](obj: Expr[G], name: String): Option[InstanceMethod[G]] = obj.t match { - case TClass(Ref(cls), _) => + case TByReferenceClass(Ref(cls), _) => + cls.decls.flatMap(Referrable.from).collectFirst { + case ref @ RefInstanceMethod(decl) if ref.name == name => decl + } + case TByValueClass(Ref(cls), _) => cls.decls.flatMap(Referrable.from).collectFirst { case ref @ RefInstanceMethod(decl) if ref.name == name => decl } @@ -360,7 +364,11 @@ case object Spec { name: String, ): Option[InstanceFunction[G]] = obj.t match { - case TClass(Ref(cls), _) => + case TByReferenceClass(Ref(cls), _) => + cls.decls.flatMap(Referrable.from).collectFirst { + case ref @ RefInstanceFunction(decl) if ref.name == name => decl + } + case TByValueClass(Ref(cls), _) => cls.decls.flatMap(Referrable.from).collectFirst { case ref @ RefInstanceFunction(decl) if ref.name == name => decl } @@ -372,7 +380,11 @@ case object Spec { name: String, ): Option[InstancePredicate[G]] = obj.t match { - case TClass(Ref(cls), _) => + case TByReferenceClass(Ref(cls), _) => + cls.decls.flatMap(Referrable.from).collectFirst { + case ref @ RefInstancePredicate(decl) if ref.name == name => decl + } + case TByValueClass(Ref(cls), _) => cls.decls.flatMap(Referrable.from).collectFirst { case ref @ RefInstancePredicate(decl) if ref.name == name => decl } @@ -385,7 +397,11 @@ case object Spec { def findField[G](obj: Expr[G], name: String): Option[InstanceField[G]] = obj.t match { - case TClass(Ref(cls), _) => + case TByReferenceClass(Ref(cls), _) => + cls.decls.flatMap(Referrable.from).collectFirst { + case ref @ RefField(decl) if ref.name == name => decl + } + case TByValueClass(Ref(cls), _) => cls.decls.flatMap(Referrable.from).collectFirst { case ref @ RefField(decl) if ref.name == name => decl } diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index baf88034cf..26a81d7d8a 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -1417,6 +1417,7 @@ abstract class CoercingRewriter[Pre <: Generation]() case LiteralTuple(ts, values) => LiteralTuple(ts, values.zip(ts).map { case (v, t) => coerce(v, t) }) case Local(ref) => Local(ref) + case HeapLocal(ref) => HeapLocal(ref) case LocalThreadId() => LocalThreadId() case MapCons(m, k, v) => val (coercedMap, mapType) = map(m) @@ -2187,7 +2188,6 @@ abstract class CoercingRewriter[Pre <: Generation]() givenMap, yields, ) => - val cls = TClass(ref.decl.cls, classTypeArgs) InvokeConstructor( ref, classTypeArgs, @@ -2237,6 +2237,7 @@ abstract class CoercingRewriter[Pre <: Generation]() case j @ Join(obj) => Join(cls(obj))(j.blame) case Label(decl, stat) => Label(decl, stat) case LocalDecl(local) => LocalDecl(local) + case HeapLocalDecl(local) => HeapLocalDecl(local) case l @ Lock(obj) => Lock(cls(obj))(l.blame) case Loop(init, cond, update, contract, body) => Loop(init, bool(cond), update, contract, body) @@ -2322,13 +2323,14 @@ abstract class CoercingRewriter[Pre <: Generation]() case rule: SimplificationRule[Pre] => new SimplificationRule[Pre](bool(rule.axiom)) case dataType: AxiomaticDataType[Pre] => dataType - case clazz: Class[Pre] => - new Class[Pre]( + case clazz: ByReferenceClass[Pre] => + new ByReferenceClass[Pre]( clazz.typeArgs, clazz.decls, clazz.supports, res(clazz.intrinsicLockInvariant), ) + case clazz: ByValueClass[Pre] => clazz case enum: Enum[Pre] => enum case enumConstant: EnumConstant[Pre] => enumConstant case model: Model[Pre] => model @@ -2447,6 +2449,7 @@ abstract class CoercingRewriter[Pre <: Generation]() case axiom: ADTAxiom[Pre] => new ADTAxiom[Pre](bool(axiom.axiom)) case function: ADTFunction[Pre] => function case variable: Variable[Pre] => variable + case variable: LocalHeapVariable[Pre] => variable case decl: LabelDecl[Pre] => decl case decl: SendDecl[Pre] => decl case decl: ParBlockDecl[Pre] => decl @@ -2587,7 +2590,9 @@ abstract class CoercingRewriter[Pre <: Generation]() // PB: types may very well contain expressions eventually, but for now they don't. def coerce(node: Type[Pre]): Type[Pre] = node match { - case t @ TClass(cls, args) => arity(TClass(cls, args)) + case t @ TByReferenceClass(cls, args) => + arity(TByReferenceClass(cls, args)) + case t @ TByValueClass(cls, args) => arity(TByValueClass(cls, args)) case _ => node } diff --git a/src/col/vct/col/typerules/CoercionUtils.scala b/src/col/vct/col/typerules/CoercionUtils.scala index 47295af0af..deef9b73ca 100644 --- a/src/col/vct/col/typerules/CoercionUtils.scala +++ b/src/col/vct/col/typerules/CoercionUtils.scala @@ -117,7 +117,7 @@ case object CoercionUtils { case (TNull(), TRef()) => CoerceNullRef() case (TNull(), TArray(target)) => CoerceNullArray(target) - case (TNull(), TClass(target, typeArgs)) => + case (TNull(), TByReferenceClass(target, typeArgs)) => CoerceNullClass(target, typeArgs) case (TNull(), JavaTClass(target, _)) => CoerceNullJavaClass(target) case (TNull(), TAnyClass()) => CoerceNullAnyClass() @@ -211,16 +211,15 @@ case object CoercionUtils { CoercionSequence(Seq(CoerceUnboundInt(source, TInt()), CoerceIntRat())) case (_: IntType[G], TRational()) => CoerceIntRat() - case ( - source @ TClass(sourceClass, Seq()), - target @ TClass(targetClass, Seq()), - ) if source.transSupportArrows.exists { case (_, supp) => - supp.cls.decl == targetClass.decl - } => - CoerceSupports(sourceClass, targetClass) + case (source: TClass[G], target: TClass[G]) + if source.typeArgs.isEmpty && target.typeArgs.isEmpty && + source.transSupportArrows().exists { case (_, supp) => + supp.cls.decl == target.cls.decl + } => + CoerceSupports(source.cls, target.cls) - case (source @ TClass(sourceClass, typeArgs), TAnyClass()) => - CoerceClassAnyClass(sourceClass, typeArgs) + case (source: TClass[G], TAnyClass()) => + CoerceClassAnyClass(source.cls, source.typeArgs) case ( source @ JavaTClass(sourceClass, Nil), diff --git a/src/col/vct/col/typerules/Types.scala b/src/col/vct/col/typerules/Types.scala index c06bdb2e45..2aa301264d 100644 --- a/src/col/vct/col/typerules/Types.scala +++ b/src/col/vct/col/typerules/Types.scala @@ -56,7 +56,7 @@ object Types { case (TType(left), TType(right)) => TType(leastCommonSuperType(left, right)) - case (left @ TClass(_, _), right @ TClass(_, _)) => + case (left: TClass[G], right: TClass[G]) => val leftArrows = left.transSupportArrows val rightArrows = right.transSupportArrows // Shared support are classes where there is an incoming left-arrow and right-arrow @@ -79,7 +79,7 @@ object Types { case other => TUnion(other) } - case (TClass(_, _), TAnyClass()) | (TAnyClass(), TClass(_, _)) => + case (_: TClass[G], TAnyClass()) | (TAnyClass(), _: TClass[G]) => TAnyClass() // TODO similar stuff for JavaClass diff --git a/src/col/vct/col/util/AstBuildHelpers.scala b/src/col/vct/col/util/AstBuildHelpers.scala index 0618c8caaa..8cbf2ba839 100644 --- a/src/col/vct/col/util/AstBuildHelpers.scala +++ b/src/col/vct/col/util/AstBuildHelpers.scala @@ -105,6 +105,10 @@ object AstBuildHelpers { SilverLocalAssign(new DirectRef(left), right) } + implicit class LocalHeapVarBuildHelpers[G](left: LocalHeapVariable[G]) { + def get(implicit origin: Origin): HeapLocal[G] = HeapLocal(new DirectRef(left)) + } + implicit class FieldBuildHelpers[G](left: SilverDeref[G]) { def <~(right: Expr[G])( implicit blame: Blame[AssignFailed], @@ -764,6 +768,10 @@ object AstBuildHelpers { implicit o: Origin ): Assign[G] = Assign(local, value)(AssignLocalOk) + def assignHeapLocal[G](local: HeapLocal[G], value: Expr[G])( + implicit o: Origin + ): Assign[G] = Assign(local, value)(AssignLocalOk) + def assignField[G]( obj: Expr[G], field: Ref[G, InstanceField[G]], diff --git a/src/llvm/tools/vcllvm/VCLLVM.cpp b/src/llvm/tools/vcllvm/VCLLVM.cpp index af755ba71f..5cb289dc5b 100644 --- a/src/llvm/tools/vcllvm/VCLLVM.cpp +++ b/src/llvm/tools/vcllvm/VCLLVM.cpp @@ -26,12 +26,12 @@ col::Program sampleCol(bool returnBool) { // class col::GlobalDeclaration *classDeclaration = program.add_declarations(); - col::VctClass *vctClass = classDeclaration->mutable_vct_class(); - llvm2Col::setColNodeId(vctClass); - col::BooleanValue *lockInvariant = vctClass->mutable_intrinsic_lock_invariant()->mutable_boolean_value(); + col::ByReferenceClass *clazz = classDeclaration->mutable_by_reference_class(); + llvm2Col::setColNodeId(clazz); + col::BooleanValue *lockInvariant = clazz->mutable_intrinsic_lock_invariant()->mutable_boolean_value(); lockInvariant->set_value(true); // class>method - col::ClassDeclaration *methodDeclaration = vctClass->add_decls(); + col::ClassDeclaration *methodDeclaration = clazz->add_decls(); col::InstanceMethod *method = methodDeclaration->mutable_instance_method(); llvm2Col::setColNodeId(method); // class>method>return_type diff --git a/src/main/vct/main/stages/Transformation.scala b/src/main/vct/main/stages/Transformation.scala index 36fab123a0..15c61d3b8d 100644 --- a/src/main/vct/main/stages/Transformation.scala +++ b/src/main/vct/main/stages/Transformation.scala @@ -35,6 +35,7 @@ import vct.rewrite.{ HeapVariableToRef, MonomorphizeClass, SmtlibToProverTypes, + EncodeByValueClass, } import vct.rewrite.lang.ReplaceSYCLTypes import vct.rewrite.veymont.{ @@ -325,6 +326,7 @@ case class SilverTransformation( EncodeString, // Encode spec string as seq EncodeChar, CollectLocalDeclarations, // all decls in Scope + EncodeByValueClass, DesugarPermissionOperators, // no PointsTo, \pointer, etc. ReadToValue, // resolve wildcard into fractional permission TrivialAddrOf, diff --git a/src/parsers/vct/parsers/transform/PVLToCol.scala b/src/parsers/vct/parsers/transform/PVLToCol.scala index 5e9d424fa6..9274d0260b 100644 --- a/src/parsers/vct/parsers/transform/PVLToCol.scala +++ b/src/parsers/vct/parsers/transform/PVLToCol.scala @@ -159,7 +159,7 @@ case class PVLToCol[G]( withContract( contract, contract => { - new Class( + new ByReferenceClass( decls = decls.flatMap(convert(_)), supports = Nil, intrinsicLockInvariant = AstBuildHelpers diff --git a/src/parsers/vct/parsers/transform/systemctocol/engine/ClassTransformer.java b/src/parsers/vct/parsers/transform/systemctocol/engine/ClassTransformer.java index b23134a1df..9b44c43783 100644 --- a/src/parsers/vct/parsers/transform/systemctocol/engine/ClassTransformer.java +++ b/src/parsers/vct/parsers/transform/systemctocol/engine/ClassTransformer.java @@ -50,7 +50,7 @@ public Class create_process_class(ProcessClass process) { // Transform class attributes Ref> main_cls_ref = new LazyRef<>(col_system::get_main, Option.empty(), ClassTag$.MODULE$.apply(Class.class)); - InstanceField m = new InstanceField<>(new TClass<>(main_cls_ref, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, OriGen.create("m")); + InstanceField m = new InstanceField<>(new TByReferenceClass<>(main_cls_ref, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, OriGen.create("m")); declarations.add(m); col_system.add_class_main_ref(process, m); java.util.Map> fields = create_fields(process.get_generating_function(), process.get_methods(), @@ -75,7 +75,7 @@ public Class create_process_class(ProcessClass process) { // Add all newly generated methods to the declarations as well declarations.addAll(generated_instance_methods); - return new Class<>(Seqs.empty(), + return new ByReferenceClass<>(Seqs.empty(), List.from(CollectionConverters.asScala(declarations)), Seqs.empty(), col_system.TRUE, OriGen.create(create_name(process.get_generating_instance(), process.get_generating_function()))); } @@ -91,7 +91,7 @@ public Class create_state_class(StateClass state_class) { // Transform class attributes Ref> main_cls_ref = new LazyRef<>(col_system::get_main, Option.empty(), ClassTag$.MODULE$.apply(Class.class)); - InstanceField m = new InstanceField<>(new TClass<>(main_cls_ref, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, OriGen.create("m")); + InstanceField m = new InstanceField<>(new TByReferenceClass<>(main_cls_ref, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, OriGen.create("m")); declarations.add(m); col_system.add_class_main_ref(state_class, m); java.util.Map> fields = create_fields(null, state_class.get_methods(), @@ -126,7 +126,7 @@ public Class create_state_class(StateClass state_class) { // Add newly generated methods to declaration list declarations.addAll(generated_instance_methods); - return new Class<>(Seqs.empty(), + return new ByReferenceClass<>(Seqs.empty(), List.from(CollectionConverters.asScala(declarations)), Seqs.empty(), col_system.TRUE, OriGen.create(create_name(state_class.get_generating_instance()))); } diff --git a/src/parsers/vct/parsers/transform/systemctocol/engine/KnownTypeTransformer.java b/src/parsers/vct/parsers/transform/systemctocol/engine/KnownTypeTransformer.java index 3165bba55d..5cc5087f12 100644 --- a/src/parsers/vct/parsers/transform/systemctocol/engine/KnownTypeTransformer.java +++ b/src/parsers/vct/parsers/transform/systemctocol/engine/KnownTypeTransformer.java @@ -92,7 +92,7 @@ public void transform() { // Add channel field to COL system Ref> ref_to_cls = new DirectRef<>(cls, ClassTag$.MODULE$.apply(Class.class)); - col_system.add_primitive_channel(sc_inst, new InstanceField<>(new TClass<>(ref_to_cls, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, + col_system.add_primitive_channel(sc_inst, new InstanceField<>(new TByReferenceClass<>(ref_to_cls, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, OriGen.create(name.toLowerCase()))); } @@ -119,7 +119,7 @@ private String generate_class_name() { private Class transform_fifo(Origin o, Type t) { // Class fields Ref> main_cls_ref = new LazyRef<>(col_system::get_main, Option.empty(), ClassTag$.MODULE$.apply(Class.class)); - InstanceField m = new InstanceField<>(new TClass<>(main_cls_ref, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, OriGen.create("m")); + InstanceField m = new InstanceField<>(new TByReferenceClass<>(main_cls_ref, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, OriGen.create("m")); InstanceField buf = new InstanceField<>(new TSeq<>(t, OriGen.create()), col_system.NO_FLAGS, OriGen.create("buffer")); InstanceField nr_read = new InstanceField<>(col_system.T_INT, col_system.NO_FLAGS, OriGen.create("num_read")); InstanceField written = new InstanceField<>(new TSeq<>(t, OriGen.create()), col_system.NO_FLAGS, OriGen.create("written")); @@ -144,7 +144,7 @@ private Class transform_fifo(Origin o, Type t) { // Create the class java.util.List> declarations = java.util.List.of(m, buf, nr_read, written, constructor, fifo_read, fifo_write, fifo_update); - return new Class<>(Seqs.empty(), List.from(CollectionConverters.asScala(declarations)), Seqs.empty(), col_system.TRUE, o); + return new ByReferenceClass<>(Seqs.empty(), List.from(CollectionConverters.asScala(declarations)), Seqs.empty(), col_system.TRUE, o); } /** @@ -524,7 +524,7 @@ private InstanceMethod create_fifo_update_method(InstanceField m, Instance private Class transform_signal(Origin o, Type t) { // Class fields Ref> main_cls_ref = new LazyRef<>(col_system::get_main, Option.empty(), ClassTag$.MODULE$.apply(Class.class)); - InstanceField m = new InstanceField<>(new TClass<>(main_cls_ref, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, OriGen.create("m")); + InstanceField m = new InstanceField<>(new TByReferenceClass<>(main_cls_ref, Seqs.empty(), OriGen.create()), col_system.NO_FLAGS, OriGen.create("m")); InstanceField val = new InstanceField<>(t, col_system.NO_FLAGS, OriGen.create("val")); InstanceField _val = new InstanceField<>(t, col_system.NO_FLAGS, OriGen.create("_val")); @@ -545,7 +545,7 @@ private Class transform_signal(Origin o, Type t) { // Create the class java.util.List> class_content = java.util.List.of(m, val, _val, constructor, signal_read, signal_write, signal_update); - return new Class<>(Seqs.empty(), + return new ByReferenceClass<>(Seqs.empty(), List.from(CollectionConverters.asScala(class_content)), Seqs.empty(), col_system.TRUE, o); } diff --git a/src/parsers/vct/parsers/transform/systemctocol/engine/MainTransformer.java b/src/parsers/vct/parsers/transform/systemctocol/engine/MainTransformer.java index b9ff6d8305..01a1850e43 100644 --- a/src/parsers/vct/parsers/transform/systemctocol/engine/MainTransformer.java +++ b/src/parsers/vct/parsers/transform/systemctocol/engine/MainTransformer.java @@ -189,7 +189,7 @@ private void create_instances() { // Get field type Class transformed_class = col_system.get_col_class_translation(process_class); Ref> ref_to_class = new DirectRef<>(transformed_class, ClassTag$.MODULE$.apply(Class.class)); - Type t = new TClass<>(ref_to_class, Seqs.empty(), OriGen.create()); + Type t = new TByReferenceClass<>(ref_to_class, Seqs.empty(), OriGen.create()); // Generate instance field InstanceField inst = new InstanceField<>(t, col_system.NO_FLAGS, OriGen.create(create_instance_name(process_class))); @@ -204,7 +204,7 @@ private void create_instances() { // Get field type Class transformed_class = col_system.get_col_class_translation(state_class); Ref> ref_to_class = new DirectRef<>(transformed_class, ClassTag$.MODULE$.apply(Class.class)); - Type t = new TClass<>(ref_to_class, Seqs.empty(), OriGen.create()); + Type t = new TByReferenceClass<>(ref_to_class, Seqs.empty(), OriGen.create()); // Generate instance field InstanceField inst = new InstanceField<>(t, col_system.NO_FLAGS, OriGen.create(create_instance_name(state_class))); @@ -1237,7 +1237,7 @@ private void assemble_main() { new WritePerm<>(OriGen.create()), OriGen.create()); // Assemble class - Class main_class = new Class<>(Seqs.empty(), List.from(CollectionConverters.asScala(declarations)), + Class main_class = new ByReferenceClass<>(Seqs.empty(), List.from(CollectionConverters.asScala(declarations)), Seqs.empty(), lock_invariant, OriGen.create("Main")); // Register Main class in COL system context diff --git a/src/rewrite/vct/rewrite/CheckContractSatisfiability.scala b/src/rewrite/vct/rewrite/CheckContractSatisfiability.scala index 04f5e18c17..0f8b9ffe4a 100644 --- a/src/rewrite/vct/rewrite/CheckContractSatisfiability.scala +++ b/src/rewrite/vct/rewrite/CheckContractSatisfiability.scala @@ -97,23 +97,25 @@ case class CheckContractSatisfiability[Pre <: Generation]( val result = extractObj.extract(pred) val extractObj.Data(ts, in, _, _, _) = extractObj.finish() variables.scope { - globalDeclarations.declare(procedure( - blame = PanicBlame( - "The postcondition of a method checking satisfiability is empty" - ), - contractBlame = UnsafeDontCare.Satisfiability( - "the precondition of a check-sat method is only there to check it." - ), - requires = - UnitAccountedPredicate( - wellFormednessBlame.having(NotWellFormedIgnoreCheckSat(err)) { - dispatch(result) - } - )(result.o), - typeArgs = variables.dispatch(ts.keys), - args = variables.dispatch(in.keys), - body = Some(Scope[Post](Nil, Assert(ff)(onlyAssertBlame))), - )) + localHeapVariables.scope { + globalDeclarations.declare(procedure( + blame = PanicBlame( + "The postcondition of a method checking satisfiability is empty" + ), + contractBlame = UnsafeDontCare.Satisfiability( + "the precondition of a check-sat method is only there to check it." + ), + requires = + UnitAccountedPredicate( + wellFormednessBlame.having(NotWellFormedIgnoreCheckSat(err)) { + dispatch(result) + } + )(result.o), + typeArgs = variables.dispatch(ts.keys), + args = variables.dispatch(in.keys), + body = Some(Scope[Post](Nil, Assert(ff)(onlyAssertBlame))), + )) + } } } } diff --git a/src/rewrite/vct/rewrite/CheckProcessAlgebra.scala b/src/rewrite/vct/rewrite/CheckProcessAlgebra.scala index 87f378717b..a38730d7be 100644 --- a/src/rewrite/vct/rewrite/CheckProcessAlgebra.scala +++ b/src/rewrite/vct/rewrite/CheckProcessAlgebra.scala @@ -95,7 +95,7 @@ case class CheckProcessAlgebra[Pre <: Generation]() val newClass = currentModel.having(model) { - new Class( + new ByReferenceClass( Seq(), classDeclarations.collect { model.declarations.foreach(dispatch(_)) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 32cfe554f9..b73c52606a 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -2,13 +2,12 @@ package vct.col.rewrite import vct.col.ast._ import vct.col.origin._ +import vct.result.VerificationError import vct.col.util.AstBuildHelpers._ import hre.util.ScopedStack import vct.col.rewrite.error.{ExcludedByPassOrder, ExtraNode} import vct.col.ref.Ref -import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder, Rewritten} import vct.col.util.SuccessionMap -import RewriteHelpers._ import scala.collection.mutable @@ -30,6 +29,7 @@ case object ClassToRef extends RewriterBuilder { override def blame(error: PreconditionFailed): Unit = inner.blame(InstanceNull(inv)) } + } case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { @@ -75,6 +75,27 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ) } + def transitiveByValuePermissions( + obj: Expr[Pre], + t: TByValueClass[Pre], + amount: Expr[Pre], + )(implicit o: Origin): Expr[Pre] = { + t.cls.decl.decls.collect[Expr[Pre]] { case field: InstanceField[Pre] => + field.t match { + case field_t: TByValueClass[Pre] => + fieldPerm[Pre](obj, field.ref, amount) &* + transitiveByValuePermissions( + Deref[Pre](obj, field.ref)(PanicBlame( + "Permission should already be ensured" + )), + field_t, + amount, + ) + case _ => fieldPerm(obj, field.ref, amount) + } + }.reduce[Expr[Pre]] { (a, b) => a &* b } + } + def makeInstanceOf: Function[Post] = { implicit val o: Origin = InstanceOfOrigin val sub = new Variable[Post](TInt()) @@ -117,7 +138,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { decl match { case cls: Class[Pre] => if (cls.typeArgs.nonEmpty) - throw vct.result.VerificationError.Unreachable( + throw VerificationError.Unreachable( "Class type parameters should be encoded using monomorphization earlier" ) @@ -407,7 +428,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ) case TypeValue(t) => t match { - case TClass(Ref(cls), Seq()) => const(typeNumber(cls))(e.o) + case t: TClass[Pre] if t.typeArgs.isEmpty => + const(typeNumber(t.cls.decl))(e.o) case other => ??? } case TypeOf(value) => @@ -471,7 +493,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(t: Type[Pre]): Type[Post] = t match { - case TClass(_, _) => TRef() + case _: TClass[Pre] => TRef() case TAnyClass() => TRef() case t => rewriteDefault(t) } diff --git a/src/rewrite/vct/rewrite/ConstantifyFinalFields.scala b/src/rewrite/vct/rewrite/ConstantifyFinalFields.scala index 67ab5cceb6..3594585c57 100644 --- a/src/rewrite/vct/rewrite/ConstantifyFinalFields.scala +++ b/src/rewrite/vct/rewrite/ConstantifyFinalFields.scala @@ -80,12 +80,22 @@ case class ConstantifyFinalFields[Pre <: Generation]() extends Rewriter[Pre] { implicit val o: Origin = field.o if (isFinal(field)) { val `this` = - new Variable[Post](TClass( - succ(currentClass.top), - currentClass.top.typeArgs.map { v: Variable[Pre] => - TVar(succ(v)) - }, - )) + currentClass.top match { + case _: ByReferenceClass[Pre] => + new Variable[Post](TByReferenceClass( + succ(currentClass.top), + currentClass.top.typeArgs.map { v: Variable[Pre] => + TVar(succ(v)) + }, + )) + case _: ByValueClass[Pre] => + new Variable[Post](TByValueClass( + succ(currentClass.top), + currentClass.top.typeArgs.map { v: Variable[Pre] => + TVar(succ(v)) + }, + )) + } fieldFunction(field) = globalDeclarations .declare(withResult((result: Result[Post]) => function[Post]( diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index dcb8e47ab8..7c1732761f 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -423,8 +423,8 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { val fields = structType match { - case TClass(ref, _) => - ref.decl.declarations.collect { case field: InstanceField[Post] => + case t: TClass[Post] => + t.cls.decl.declarations.collect { case field: InstanceField[Post] => field } case _ => Seq() diff --git a/src/rewrite/vct/rewrite/EncodeAutoValue.scala b/src/rewrite/vct/rewrite/EncodeAutoValue.scala index 0e2beaf27f..c9c71db2e5 100644 --- a/src/rewrite/vct/rewrite/EncodeAutoValue.scala +++ b/src/rewrite/vct/rewrite/EncodeAutoValue.scala @@ -179,31 +179,33 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { } case Let(binding, value, main) => variables.scope { - val top = conditionContext.pop() - val (b, v) = - try { (variables.dispatch(binding), dispatch(value)) } - finally { conditionContext.push(top) } - val mMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() - val m = - conditionContext.having((conditionContext.top._1, mMap)) { - dispatch(main) - } - if (mMap.isEmpty) { Let(b, v, m) } - else { - mMap.foreach(postM => - conditionContext.top._2 - .append((Let(binding, value, postM._1), Let(b, v, postM._2))) - ) - conditionContext.top._1 match { - case InPrecondition() => Let(b, v, m) - case InPostcondition() => - Let( - b, - Old(v, None)(PanicBlame( - "Old should always be valid in a postcondition" - )), - m, - ) + localHeapVariables.scope { + val top = conditionContext.pop() + val (b, v) = + try { (variables.dispatch(binding), dispatch(value)) } + finally { conditionContext.push(top) } + val mMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() + val m = + conditionContext.having((conditionContext.top._1, mMap)) { + dispatch(main) + } + if (mMap.isEmpty) { Let(b, v, m) } + else { + mMap.foreach(postM => + conditionContext.top._2 + .append((Let(binding, value, postM._1), Let(b, v, postM._2))) + ) + conditionContext.top._1 match { + case InPrecondition() => Let(b, v, m) + case InPostcondition() => + Let( + b, + Old(v, None)(PanicBlame( + "Old should always be valid in a postcondition" + )), + m, + ) + } } } } diff --git a/src/rewrite/vct/rewrite/EncodeByValueClass.scala b/src/rewrite/vct/rewrite/EncodeByValueClass.scala new file mode 100644 index 0000000000..3054e2572c --- /dev/null +++ b/src/rewrite/vct/rewrite/EncodeByValueClass.scala @@ -0,0 +1,249 @@ +package vct.rewrite + +import hre.util.ScopedStack +import vct.col.ast._ +import vct.col.origin._ +import vct.col.ref.Ref +import vct.col.resolve.ctx.Referrable +import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} +import vct.col.util.AstBuildHelpers._ +import vct.result.VerificationError.UserError + +case object EncodeByValueClass extends RewriterBuilder { + override def key: String = "encodeByValueClass" + + override def desc: String = + "Initialise ByValueClasses when they are declared and copy them whenever they're read" + + private case class ClassCopyInAssignmentFailed( + blame: Blame[AssignFailed], + assign: Node[_], + clazz: ByValueClass[_], + field: InstanceField[_], + ) extends Blame[InsufficientPermission] { + override def blame(error: InsufficientPermission): Unit = { + if (blame.isInstanceOf[PanicBlame]) { + assign.o + .blame(CopyClassFailed(assign, clazz, Referrable.originName(field))) + } else { + blame + .blame(CopyClassFailed(assign, clazz, Referrable.originName(field))) + } + } + } + + private case class ClassCopyInCallFailed( + blame: Blame[InvocationFailure], + inv: Invocation[_], + clazz: ByValueClass[_], + field: InstanceField[_], + ) extends Blame[InsufficientPermission] { + override def blame(error: InsufficientPermission): Unit = { + blame.blame( + CopyClassFailedBeforeCall(inv, clazz, Referrable.originName(field)) + ) + } + } + + case class UnsupportedStructPerm(o: Origin) extends UserError { + override def code: String = "unsupportedStructPerm" + override def text: String = + o.messageInContext( + "Shorthand for Permissions for structs not possible, since the struct has a cyclic reference" + ) + } + + private sealed class CopyContext + + private case class InCall(invocation: Invocation[_]) extends CopyContext + + private case class InAssignmentExpression(assignment: AssignExpression[_]) + extends CopyContext + + private case class InAssignmentStatement(assignment: Assign[_]) + extends CopyContext +} + +case class EncodeByValueClass[Pre <: Generation]() extends Rewriter[Pre] { + + import EncodeByValueClass._ + + private val inAssignment: ScopedStack[Unit] = ScopedStack() + private val copyContext: ScopedStack[CopyContext] = ScopedStack() + + override def dispatch(node: Statement[Pre]): Statement[Post] = + node match { + case s: Scope[Pre] => + cPPLocalDeclarations.scope { + cLocalDeclarations.scope { + variables.scope { + localHeapVariables.scope { + val locals = variables.dispatch(s.locals) + Scope( + locals, + Block(locals.collect { + case v: Variable[Post] + if v.t.isInstanceOf[TByValueClass[Post]] => + Assign( + v.get(v.o), + NewObject(v.t.asInstanceOf[TByValueClass[Post]].cls)(v.o), + )(PanicBlame( + "Instantiating a ByValueClass should always succeed" + ))(v.o) + } ++ Seq(s.body.rewriteDefault()))(node.o), + )(node.o) + } + } + } + } + case assign: Assign[Pre] => { + val target = inAssignment.having(()) { assign.target.rewriteDefault() } + copyContext.having(InAssignmentStatement(assign)) { + assign.rewrite(target = target) + } + } + case _ => node.rewriteDefault() + } + + private def copyClassValue( + obj: Expr[Post], + t: TByValueClass[Pre], + blame: InstanceField[Pre] => Blame[InsufficientPermission], + ): Expr[Post] = { + implicit val o: Origin = obj.o + val v = new Variable[Post](dispatch(t)) + val children = t.cls.decl.decls.collect { case f: InstanceField[Pre] => + f.t match { + case inner: TByValueClass[Pre] => + Assign[Post]( + Deref[Post](v.get, succ(f))(DerefAssignTarget), + copyClassValue(Deref[Post](obj, succ(f))(blame(f)), inner, blame), + )(AssignLocalOk) + case _ => + Assign[Post]( + Deref[Post](v.get, succ(f))(DerefAssignTarget), + Deref[Post](obj, succ(f))(blame(f)), + )(AssignLocalOk) + + } + } + ScopedExpr( + Seq(v), + Then( + PreAssignExpression(v.get, NewObject[Post](succ(t.cls.decl)))( + AssignLocalOk + ), + Block(children), + ), + ) + } + + // def unwrapClassPerm( + // struct: Expr[Post], + // perm: Expr[Pre], + // structType: TByValueClass[Pre], + // origin: Origin, + // visited: Seq[TByValueClass[Pre]] = Seq(), + // ): Expr[Post] = { + // if (visited.contains(structType)) + // throw UnsupportedStructPerm( + // origin + // ) // We do not allow this notation for recursive structs + // implicit val o: Origin = origin + // val blame = PanicBlame("Field permission is framed") + // val Seq(CStructDeclaration(_, fields)) = structType.ref.decl.decl.specs + // val newPerm = dispatch(perm) + // val AmbiguousLocation(newExpr) = struct + // val newFieldPerms = fields.map(member => { + // val loc = + // AmbiguousLocation( + // Deref[Post]( + // newExpr, + // cStructFieldsSuccessor.ref((structType.ref.decl, member)), + // )(blame) + // )(struct.blame) + // member.specs.collectFirst { + // case CSpecificationType(newStruct: CTStruct[Pre]) => + // // We recurse, since a field is another struct + // Perm(loc, newPerm) &* unwrapStructPerm( + // loc, + // perm, + // newStruct, + // origin, + // structType +: visited, + // ) + // }.getOrElse(Perm(loc, newPerm)) + // }) + + // foldStar(newFieldPerms) + // } + override def dispatch(node: Expr[Pre]): Expr[Post] = + if (inAssignment.nonEmpty) + node.rewriteDefault() + else + node match { + case Perm(loc, p) => node.rewriteDefault() + case assign: PreAssignExpression[Pre] => + val target = + inAssignment.having(()) { assign.target.rewriteDefault() } + copyContext.having(InAssignmentExpression(assign)) { + assign.rewrite(target = target) + } + case invocation: Invocation[Pre] => { + copyContext.having(InCall(invocation)) { invocation.rewriteDefault() } + } + case Local(Ref(v)) if v.t.isInstanceOf[TByValueClass[Pre]] => + if (copyContext.isEmpty) { + return node.rewriteDefault() + } // If we are in other kinds of expressions like if statements + val t = v.t.asInstanceOf[TByValueClass[Pre]] + val clazz = t.cls.decl.asInstanceOf[ByValueClass[Pre]] + + copyContext.top match { + case InCall(invocation) => + copyClassValue( + node.rewriteDefault(), + t, + f => + ClassCopyInCallFailed(invocation.blame, invocation, clazz, f), + ) + case InAssignmentExpression(assignment: PreAssignExpression[_]) => + copyClassValue( + node.rewriteDefault(), + t, + f => + ClassCopyInAssignmentFailed( + assignment.blame, + assignment, + clazz, + f, + ), + ) + case InAssignmentExpression(assignment: PostAssignExpression[_]) => + copyClassValue( + node.rewriteDefault(), + t, + f => + ClassCopyInAssignmentFailed( + assignment.blame, + assignment, + clazz, + f, + ), + ) + case InAssignmentStatement(assignment) => + copyClassValue( + node.rewriteDefault(), + t, + f => + ClassCopyInAssignmentFailed( + assignment.blame, + assignment, + clazz, + f, + ), + ) + } + case _ => node.rewriteDefault() + } +} diff --git a/src/rewrite/vct/rewrite/EncodeForkJoin.scala b/src/rewrite/vct/rewrite/EncodeForkJoin.scala index 961b4e8b2e..49a762d685 100644 --- a/src/rewrite/vct/rewrite/EncodeForkJoin.scala +++ b/src/rewrite/vct/rewrite/EncodeForkJoin.scala @@ -129,7 +129,13 @@ case class EncodeForkJoin[Pre <: Generation]() extends Rewriter[Pre] { implicit val o: Origin = e.o cls.decls.collectFirst { case run: RunMethod[Pre] => run } match { case Some(_) => - val obj = new Variable[Post](TClass(succ(cls), Seq())) + val obj = + cls match { + case _: ByReferenceClass[Pre] => + new Variable[Post](TByReferenceClass(succ(cls), Seq())) + case _: ByValueClass[Pre] => + new Variable[Post](TByValueClass(succ(cls), Seq())) + } ScopedExpr( Seq(obj), With( diff --git a/src/rewrite/vct/rewrite/EncodeIntrinsicLock.scala b/src/rewrite/vct/rewrite/EncodeIntrinsicLock.scala index c5308b7b26..def6ba5911 100644 --- a/src/rewrite/vct/rewrite/EncodeIntrinsicLock.scala +++ b/src/rewrite/vct/rewrite/EncodeIntrinsicLock.scala @@ -85,7 +85,7 @@ case class EncodeIntrinsicLock[Pre <: Generation]() extends Rewriter[Pre] { def getClass(obj: Expr[Pre]): Class[Pre] = obj.t match { - case TClass(Ref(cls), _) => cls + case t: TClass[Pre] => t.cls.decl case _ => throw UnreachableAfterTypeCheck( "This argument is not a class type.", @@ -153,7 +153,7 @@ case class EncodeIntrinsicLock[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(decl: Declaration[Pre]): Unit = decl match { - case cls: Class[Pre] => + case cls: ByReferenceClass[Pre] => globalDeclarations.succeed( cls, cls.rewrite( diff --git a/src/rewrite/vct/rewrite/EncodeResourceValues.scala b/src/rewrite/vct/rewrite/EncodeResourceValues.scala index 531858084e..0c99f1a948 100644 --- a/src/rewrite/vct/rewrite/EncodeResourceValues.scala +++ b/src/rewrite/vct/rewrite/EncodeResourceValues.scala @@ -188,7 +188,11 @@ case class EncodeResourceValues[Pre <: Generation]() case ResourcePattern.HeapVariableLocation(_) => Nil case ResourcePattern.FieldLocation(f) => nonGeneric(fieldOwner(f)) - Seq(TClass(succ(fieldOwner(f)), Seq())) + Seq(fieldOwner(f) match { + case cls: ByReferenceClass[Pre] => + TByReferenceClass(succ(cls), Seq()) + case cls: ByValueClass[Pre] => TByValueClass(succ(cls), Seq()) + }) case ResourcePattern.ModelLocation(f) => Seq(TModel(succ(modelFieldOwner(f)))) case ResourcePattern.SilverFieldLocation(_) => Seq(TRef()) @@ -200,8 +204,12 @@ case class EncodeResourceValues[Pre <: Generation]() ref.args.map(_.t).map(dispatch) case ResourcePattern.InstancePredicateLocation(ref) => nonGeneric(predicateOwner(ref)) - TClass[Post](succ(predicateOwner(ref)), Seq()) +: - ref.args.map(_.t).map(dispatch) + (predicateOwner(ref) match { + case cls: ByReferenceClass[Pre] => + TByReferenceClass(succ[Class[Post]](cls), Seq()) + case cls: ByValueClass[Pre] => + TByValueClass(succ[Class[Post]](cls), Seq()) + }) +: ref.args.map(_.t).map(dispatch) } def freeTypes(pattern: ResourcePattern): Seq[Type[Post]] = @@ -212,8 +220,12 @@ case class EncodeResourceValues[Pre <: Generation]() case ResourcePattern.Predicate(p) => p.args.map(_.t).map(dispatch) case ResourcePattern.InstancePredicate(p) => nonGeneric(predicateOwner(p)) - TClass[Post](succ(predicateOwner(p)), Seq()) +: p.args.map(_.t) - .map(dispatch) + (predicateOwner(p) match { + case cls: ByReferenceClass[Pre] => + TByReferenceClass(succ[Class[Post]](cls), Seq()) + case cls: ByValueClass[Pre] => + TByValueClass(succ[Class[Post]](cls), Seq()) + }) +: p.args.map(_.t).map(dispatch) case ResourcePattern.Star(left, right) => freeTypes(left) ++ freeTypes(right) case ResourcePattern.Implies(res) => freeTypes(res) diff --git a/src/rewrite/vct/rewrite/ExtractInlineQuantifierPatterns.scala b/src/rewrite/vct/rewrite/ExtractInlineQuantifierPatterns.scala index 825d188c08..2768359d6e 100644 --- a/src/rewrite/vct/rewrite/ExtractInlineQuantifierPatterns.scala +++ b/src/rewrite/vct/rewrite/ExtractInlineQuantifierPatterns.scala @@ -111,50 +111,56 @@ case class ExtractInlineQuantifierPatterns[Pre <: Generation]() case f: Forall[Pre] => variables.scope { - val (patternsHere, body) = patterns.collect { - // We only want to inline lets that are defined inside the quantifier - letBindings.having(ScopedStack()) { dispatch(f.body) } + localHeapVariables.scope { + val (patternsHere, body) = patterns.collect { + // We only want to inline lets that are defined inside the quantifier + letBindings.having(ScopedStack()) { dispatch(f.body) } + } + val unsortedGroups = patternsHere.groupBy(_.group) + val sortedGroups = unsortedGroups.toSeq.sortBy(_._1).map(_._2) + val triggers = sortedGroups.map(_.map(_.make())) + Forall( + bindings = variables.collect { f.bindings.foreach(dispatch) }._1, + triggers = f.triggers.map(_.map(dispatch)) ++ triggers, + body = body, + )(f.o) } - val unsortedGroups = patternsHere.groupBy(_.group) - val sortedGroups = unsortedGroups.toSeq.sortBy(_._1).map(_._2) - val triggers = sortedGroups.map(_.map(_.make())) - Forall( - bindings = variables.collect { f.bindings.foreach(dispatch) }._1, - triggers = f.triggers.map(_.map(dispatch)) ++ triggers, - body = body, - )(f.o) } case f: Starall[Pre] => variables.scope { - val (patternsHere, body) = patterns.collect { - // We only want to inline lets that are defined inside the quantifier - letBindings.having(ScopedStack()) { dispatch(f.body) } + localHeapVariables.scope { + val (patternsHere, body) = patterns.collect { + // We only want to inline lets that are defined inside the quantifier + letBindings.having(ScopedStack()) { dispatch(f.body) } + } + val unsortedGroups = patternsHere.groupBy(_.group) + val sortedGroups = unsortedGroups.toSeq.sortBy(_._1).map(_._2) + val triggers = sortedGroups.map(_.map(_.make())) + Starall( + bindings = variables.collect { f.bindings.foreach(dispatch) }._1, + triggers = f.triggers.map(_.map(dispatch)) ++ triggers, + body = body, + )(f.blame)(f.o) } - val unsortedGroups = patternsHere.groupBy(_.group) - val sortedGroups = unsortedGroups.toSeq.sortBy(_._1).map(_._2) - val triggers = sortedGroups.map(_.map(_.make())) - Starall( - bindings = variables.collect { f.bindings.foreach(dispatch) }._1, - triggers = f.triggers.map(_.map(dispatch)) ++ triggers, - body = body, - )(f.blame)(f.o) } case f: Exists[Pre] => variables.scope { - val (patternsHere, body) = patterns.collect { - // We only want to inline lets that are defined inside the quantifier - letBindings.having(ScopedStack()) { dispatch(f.body) } + localHeapVariables.scope { + val (patternsHere, body) = patterns.collect { + // We only want to inline lets that are defined inside the quantifier + letBindings.having(ScopedStack()) { dispatch(f.body) } + } + val unsortedGroups = patternsHere.groupBy(_.group) + val sortedGroups = unsortedGroups.toSeq.sortBy(_._1).map(_._2) + val triggers = sortedGroups.map(_.map(_.make())) + Exists( + bindings = variables.collect { f.bindings.foreach(dispatch) }._1, + triggers = f.triggers.map(_.map(dispatch)) ++ triggers, + body = body, + )(f.o) } - val unsortedGroups = patternsHere.groupBy(_.group) - val sortedGroups = unsortedGroups.toSeq.sortBy(_._1).map(_._2) - val triggers = sortedGroups.map(_.map(_.make())) - Exists( - bindings = variables.collect { f.bindings.foreach(dispatch) }._1, - triggers = f.triggers.map(_.map(dispatch)) ++ triggers, - body = body, - )(f.o) } case other => rewriteDefault(other) diff --git a/src/rewrite/vct/rewrite/MonomorphizeClass.scala b/src/rewrite/vct/rewrite/MonomorphizeClass.scala index 779fa700e4..dd91b27be4 100644 --- a/src/rewrite/vct/rewrite/MonomorphizeClass.scala +++ b/src/rewrite/vct/rewrite/MonomorphizeClass.scala @@ -82,9 +82,16 @@ case class MonomorphizeClass[Pre <: Generation]() globalDeclarations.scope { classDeclarations.scope { variables.scope { - allScopes.anyDeclare( - allScopes.anySucceedOnly(cls, cls.rewrite(typeArgs = Seq())) - ) + localHeapVariables.scope { + allScopes.anyDeclare(allScopes.anySucceedOnly( + cls, + cls match { + case cls: ByReferenceClass[Pre] => + cls.rewrite(typeArgs = Seq()) + case cls: ByValueClass[Pre] => cls.rewrite(typeArgs = Seq()) + }, + )) + } } } } @@ -130,14 +137,25 @@ case class MonomorphizeClass[Pre <: Generation]() override def dispatch(t: Type[Pre]): Type[Post] = (t, ctx.topOption) match { - case (TClass(Ref(cls), typeArgs), ctx) if typeArgs.nonEmpty => + case (TByReferenceClass(Ref(cls), typeArgs), ctx) if typeArgs.nonEmpty => + val typeValues = + ctx match { + case Some(ctx) => typeArgs.map(ctx.substitute.dispatch) + case None => typeArgs + } + instantiate(cls, typeValues, false) + TByReferenceClass[Post]( + genericSucc.ref[Post, Class[Post]](((cls, typeValues), cls)), + Seq(), + ) + case (TByValueClass(Ref(cls), typeArgs), ctx) if typeArgs.nonEmpty => val typeValues = ctx match { case Some(ctx) => typeArgs.map(ctx.substitute.dispatch) case None => typeArgs } instantiate(cls, typeValues, false) - TClass[Post]( + TByValueClass[Post]( genericSucc.ref[Post, Class[Post]](((cls, typeValues), cls)), Seq(), ) @@ -158,13 +176,13 @@ case class MonomorphizeClass[Pre <: Generation]() ) case inv: InvokeMethod[Pre] => inv.obj.t match { - case TClass(Ref(cls), typeArgs) if typeArgs.nonEmpty => - val typeValues = ctx.topOption.map(_.evalTypes(typeArgs)) - .getOrElse(typeArgs) - instantiate(cls, typeValues, false) + case t: TClass[Pre] if t.typeArgs.nonEmpty => + val typeValues = ctx.topOption.map(_.evalTypes(t.typeArgs)) + .getOrElse(t.typeArgs) + instantiate(t.cls.decl, typeValues, false) inv.rewrite(ref = genericSucc.ref[Post, InstanceMethod[Post]]( - ((cls, typeValues), inv.ref.decl) + ((t.cls.decl, typeValues), inv.ref.decl) ) ) case _ => inv.rewriteDefault() @@ -176,13 +194,14 @@ case class MonomorphizeClass[Pre <: Generation]() loc match { case loc @ FieldLocation(obj, Ref(field)) => obj.t match { - case TClass(Ref(cls), typeArgs) if typeArgs.nonEmpty => - val typeArgs1 = ctx.topOption.map(_.evalTypes(typeArgs)) - .getOrElse(typeArgs) - instantiate(cls, typeArgs1, false) + case t: TClass[Pre] if t.typeArgs.nonEmpty => + val typeArgs1 = ctx.topOption.map(_.evalTypes(t.typeArgs)) + .getOrElse(t.typeArgs) + instantiate(t.cls.decl, typeArgs1, false) loc.rewrite(field = - genericSucc - .ref[Post, InstanceField[Post]](((cls, typeArgs1), field)) + genericSucc.ref[Post, InstanceField[Post]]( + ((t.cls.decl, typeArgs1), field) + ) ) case _ => loc.rewriteDefault() } @@ -193,13 +212,14 @@ case class MonomorphizeClass[Pre <: Generation]() expr match { case deref @ Deref(obj, Ref(field)) => obj.t match { - case TClass(Ref(cls), typeArgs) if typeArgs.nonEmpty => - val typeArgs1 = ctx.topOption.map(_.evalTypes(typeArgs)) - .getOrElse(typeArgs) - instantiate(cls, typeArgs1, false) + case t: TClass[Pre] if t.typeArgs.nonEmpty => + val typeArgs1 = ctx.topOption.map(_.evalTypes(t.typeArgs)) + .getOrElse(t.typeArgs) + instantiate(t.cls.decl, typeArgs1, false) deref.rewrite(ref = - genericSucc - .ref[Post, InstanceField[Post]](((cls, typeArgs1), field)) + genericSucc.ref[Post, InstanceField[Post]]( + ((t.cls.decl, typeArgs1), field) + ) ) case _ => deref.rewriteDefault() } diff --git a/src/rewrite/vct/rewrite/MonomorphizeContractApplicables.scala b/src/rewrite/vct/rewrite/MonomorphizeContractApplicables.scala index 23012710f3..a7175b4b49 100644 --- a/src/rewrite/vct/rewrite/MonomorphizeContractApplicables.scala +++ b/src/rewrite/vct/rewrite/MonomorphizeContractApplicables.scala @@ -46,9 +46,11 @@ case class MonomorphizeContractApplicables[Pre <: Generation]() globalDeclarations.scope { classDeclarations.scope { variables.scope { - allScopes.anyDeclare( - allScopes.anySucceedOnly(app, app.rewrite(typeArgs = Nil)) - ) + localHeapVariables.scope { + allScopes.anyDeclare( + allScopes.anySucceedOnly(app, app.rewrite(typeArgs = Nil)) + ) + } } } } diff --git a/src/rewrite/vct/rewrite/ParBlockEncoder.scala b/src/rewrite/vct/rewrite/ParBlockEncoder.scala index e477579d47..54a74079e9 100644 --- a/src/rewrite/vct/rewrite/ParBlockEncoder.scala +++ b/src/rewrite/vct/rewrite/ParBlockEncoder.scala @@ -151,29 +151,31 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { scale(dispatch(e)) else variables.scope { - val range = quantVars.map(v => - from(v) <= Local[Post](succ(v)) && Local[Post](succ(v)) < to(v) - ).reduceOption[Expr[Post]](And(_, _)).getOrElse(tt) - - e match { - case Forall(bindings, Nil, body) => - Forall( - variables.dispatch(bindings ++ quantVars), - Nil, - range ==> scale(dispatch(body)), - )(body.o) - case s @ Starall(bindings, Nil, body) => - Starall( - variables.dispatch(bindings ++ quantVars), - Nil, - range ==> scale(dispatch(body)), - )(s.blame)(body.o) - case other => - Starall( - variables.dispatch(quantVars), - Nil, - range ==> scale(dispatch(other)), - )(ParBlockNotInjective(block, other))(other.o) + localHeapVariables.scope { + val range = quantVars.map(v => + from(v) <= Local[Post](succ(v)) && Local[Post](succ(v)) < to(v) + ).reduceOption[Expr[Post]](And(_, _)).getOrElse(tt) + + e match { + case Forall(bindings, Nil, body) => + Forall( + variables.dispatch(bindings ++ quantVars), + Nil, + range ==> scale(dispatch(body)), + )(body.o) + case s @ Starall(bindings, Nil, body) => + Starall( + variables.dispatch(bindings ++ quantVars), + Nil, + range ==> scale(dispatch(body)), + )(s.blame)(body.o) + case other => + Starall( + variables.dispatch(quantVars), + Nil, + range ==> scale(dispatch(other)), + )(ParBlockNotInjective(block, other))(other.o) + } } } }) diff --git a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala index 0b344f2a41..ba090ae05e 100644 --- a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala +++ b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala @@ -684,7 +684,19 @@ case class ResolveExpressionSideEffects[Pre <: Generation]() givenMap, yields, ) => - val typ = TClass[Post](succ(cons.cls.decl), classTypeArgs.map(dispatch)) + val typ = + cons.cls.decl match { + case cls: ByReferenceClass[Pre] => + TByReferenceClass[Post]( + succ[Class[Post]](cls), + classTypeArgs.map(dispatch), + ) + case cls: ByValueClass[Pre] => + TByValueClass[Post]( + succ[Class[Post]](cls), + classTypeArgs.map(dispatch), + ) + } val res = new Variable[Post](typ)(ResultVar) variables.succeed(res.asInstanceOf[Variable[Pre]], res) effect( @@ -699,12 +711,25 @@ case class ResolveExpressionSideEffects[Pre <: Generation]() yields.map { case (e, Ref(v)) => (inlined(e), succ(v)) }, )(inv.blame)(e.o) ) - stored(res.get(SideEffectOrigin), TClass(cons.cls, classTypeArgs)) + stored( + res.get(SideEffectOrigin), + cons.cls.decl.classType(classTypeArgs), + ) case NewObject(Ref(cls)) => - val res = new Variable[Post](TClass(succ(cls), Seq()))(ResultVar) + val res = + cls match { + case cls: ByReferenceClass[Pre] => + new Variable[Post]( + TByReferenceClass(succ[Class[Post]](cls), Seq()) + )(ResultVar) + case cls: ByValueClass[Pre] => + new Variable[Post](TByValueClass(succ[Class[Post]](cls), Seq()))( + ResultVar + ) + } variables.succeed(res.asInstanceOf[Variable[Pre]], res) effect(Instantiate[Post](succ(cls), res.get(ResultVar))(e.o)) - stored(res.get(SideEffectOrigin), TClass(cls.ref, Seq())) + stored(res.get(SideEffectOrigin), cls.ref.decl.classType(Seq())) case other => stored(ReInliner().dispatch(rewriteDefault(other)), other.t) } } diff --git a/src/rewrite/vct/rewrite/adt/ImportADT.scala b/src/rewrite/vct/rewrite/adt/ImportADT.scala index 08b35817d7..bc78866cd7 100644 --- a/src/rewrite/vct/rewrite/adt/ImportADT.scala +++ b/src/rewrite/vct/rewrite/adt/ImportADT.scala @@ -71,7 +71,7 @@ case object ImportADT { case TZFraction() => "zfract" case TMap(key, value) => "map$" + typeText(key) + "__" + typeText(value) + "$" - case TClass(Ref(cls), _) => cls.o.getPreferredNameOrElse().camel + case t: TClass[_] => t.cls.decl.o.getPreferredNameOrElse().camel case TVar(Ref(v)) => v.o.getPreferredNameOrElse().camel case TUnion(ts) => "union" + ts.map(typeText).mkString("$", "__", "$") case SilverPartialTAxiomatic(Ref(adt), _) => diff --git a/src/rewrite/vct/rewrite/bip/EncodeBip.scala b/src/rewrite/vct/rewrite/bip/EncodeBip.scala index 52652b5f84..3e08112449 100644 --- a/src/rewrite/vct/rewrite/bip/EncodeBip.scala +++ b/src/rewrite/vct/rewrite/bip/EncodeBip.scala @@ -408,7 +408,7 @@ case class EncodeBip[Pre <: Generation](results: VerificationResults) results.declare(component) implicit val o = DiagnosticOrigin val ref = succ[Class[Post]](classOf(constructor)) - val t = TClass[Post](ref, Seq()) + val t = TByReferenceClass[Post](ref, Seq()) rewritingBipConstructorBody.having(component) { constructorSucc(constructor) = globalDeclarations.declare( new Procedure[Post]( @@ -526,7 +526,7 @@ case class EncodeBip[Pre <: Generation](results: VerificationResults) transitions.flatMap { transition => val v = new Variable[Post]( - TClass(succ[Class[Post]](classOf(transition)), Seq()) + TByReferenceClass(succ[Class[Post]](classOf(transition)), Seq()) )(SynchronizationComponentVariableOrigin( synchronization, componentOf(transition), diff --git a/src/rewrite/vct/rewrite/cfg/Utils.scala b/src/rewrite/vct/rewrite/cfg/Utils.scala index c5a85fc946..73bffad58b 100644 --- a/src/rewrite/vct/rewrite/cfg/Utils.scala +++ b/src/rewrite/vct/rewrite/cfg/Utils.scala @@ -115,7 +115,11 @@ object Utils { } private def get_out_variable[G](cls: Ref[G, Class[G]], o: Origin): Local[G] = - Local(new DirectRef[G, Variable[G]](new Variable(TClass(cls, Seq()))(o)))(o) + Local( + new DirectRef[G, Variable[G]](new Variable(TByReferenceClass(cls, Seq()))( + o + )) + )(o) def find_all_cases[G]( body: Statement[G], diff --git a/src/rewrite/vct/rewrite/exc/EncodeBreakReturn.scala b/src/rewrite/vct/rewrite/exc/EncodeBreakReturn.scala index 66e1de14d1..28d2a8994c 100644 --- a/src/rewrite/vct/rewrite/exc/EncodeBreakReturn.scala +++ b/src/rewrite/vct/rewrite/exc/EncodeBreakReturn.scala @@ -130,7 +130,9 @@ case class EncodeBreakReturn[Pre <: Generation]() extends Rewriter[Pre] { after = Block(Nil), catches = Seq(CatchClause( decl = - new Variable(TClass(breakLabelException.ref(decl), Seq())), + new Variable( + TByReferenceClass(breakLabelException.ref(decl), Seq()) + ), body = Block(Nil), )), ) @@ -147,8 +149,9 @@ case class EncodeBreakReturn[Pre <: Generation]() extends Rewriter[Pre] { case Break(Some(Ref(label))) => val cls = breakLabelException.getOrElseUpdate( label, - globalDeclarations - .declare(new Class[Post](Nil, Nil, Nil, tt)(BreakException)), + globalDeclarations.declare( + new ByReferenceClass[Post](Nil, Nil, Nil, tt)(BreakException) + ), ) Throw(NewObject[Post](cls.ref))(PanicBlame( @@ -156,7 +159,7 @@ case class EncodeBreakReturn[Pre <: Generation]() extends Rewriter[Pre] { )) case Return(result) => - val exc = new Variable[Post](TClass(returnClass.get.ref, Seq())) + val exc = new Variable[Post](returnClass.get.classType(Seq())) Scope( Seq(exc), Block(Seq( @@ -196,13 +199,16 @@ case class EncodeBreakReturn[Pre <: Generation]() extends Rewriter[Pre] { ReturnField ) val returnClass = - new Class[Post](Nil, Seq(returnField), Nil, tt)( - ReturnClass - ) + new ByReferenceClass[Post]( + Nil, + Seq(returnField), + Nil, + tt, + )(ReturnClass) globalDeclarations.declare(returnClass) val caughtReturn = - new Variable[Post](TClass(returnClass.ref, Seq())) + new Variable[Post](returnClass.classType(Seq())) TryCatchFinally( body = BreakReturnToException( diff --git a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala index e66592de4b..debd341db5 100644 --- a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala @@ -1348,7 +1348,8 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) // Create a class that can be used to create a 'this' object // It will be linked to the class made near the end of this method. - val preEventClass: Class[Pre] = new Class(Nil, Nil, Nil, tt)(commandGroup.o) + val preEventClass: Class[Pre] = + new ByValueClass(Nil, Nil, Nil)(commandGroup.o) this.currentThis = Some( rw.dispatch(ThisObject[Pre](preEventClass.ref)(preEventClass.o)) ) @@ -1475,8 +1476,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) )(KernelLambdaRunMethodBlame(kernelDeclaration))(commandGroup.o) // Create the surrounding class + // cl::sycl::event has a default copy constructor hence a ByValueClass val postEventClass = - new Class[Post]( + new ByValueClass[Post]( typeArgs = Seq(), decls = currentKernelType.get.getRangeFields ++ @@ -1484,13 +1486,12 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) .flatMap(acc => acc.instanceField +: acc.rangeIndexFields) ++ Seq(kernelRunner), supports = Seq(), - intrinsicLockInvariant = tt, )(commandGroup.o.where(name = "SYCL_EVENT_CLASS")) rw.globalDeclarations.succeed(preEventClass, postEventClass) // Create a variable to refer to the class instance val eventClassRef = - new Variable[Post](TClass(postEventClass.ref, Seq()))( + new Variable[Post](TByValueClass(postEventClass.ref, Seq()))( commandGroup.o.where(name = "sycl_event_ref") ) // Store the class ref and read-write accessors to be used when the kernel is done running @@ -1976,7 +1977,7 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) preClass: Class[Pre], commandGroupO: Origin, ): Procedure[Post] = { - val t = rw.dispatch(TClass[Pre](preClass.ref, Seq())) + val t = rw.dispatch(TByValueClass[Pre](preClass.ref, Seq())) rw.globalDeclarations.declare( withResult((result: Result[Post]) => { val constructorPostConditions: mutable.Buffer[Expr[Post]] = @@ -2142,22 +2143,24 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) scale(cond) else rw.variables.scope { - val range = quantVars.map(v => - rangesMap(v)._1 <= Local[Post](v.ref) && - Local[Post](v.ref) < rangesMap(v)._2 - ).reduceOption[Expr[Post]](And(_, _)).getOrElse(tt) - - cond match { - case Forall(bindings, Nil, body) => - Forall(bindings ++ quantVars, Nil, range ==> scale(body)) - case s @ Starall(bindings, Nil, body) => - Starall(bindings ++ quantVars, Nil, range ==> scale(body))( - s.blame - ) - case other => - Starall(quantVars.toSeq, Nil, range ==> scale(other))( - ParBlockNotInjective(block, other) - ) + rw.localHeapVariables.scope { + val range = quantVars.map(v => + rangesMap(v)._1 <= Local[Post](v.ref) && + Local[Post](v.ref) < rangesMap(v)._2 + ).reduceOption[Expr[Post]](And(_, _)).getOrElse(tt) + + cond match { + case Forall(bindings, Nil, body) => + Forall(bindings ++ quantVars, Nil, range ==> scale(body)) + case s @ Starall(bindings, Nil, body) => + Starall(bindings ++ quantVars, Nil, range ==> scale(body))( + s.blame + ) + case other => + Starall(quantVars.toSeq, Nil, range ==> scale(other))( + ParBlockNotInjective(block, other) + ) + } } } }) diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index bcd6d89252..f4ca25352c 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -192,25 +192,6 @@ case object LangCToCol { } } - case class StructCopyFailed( - assign: PreAssignExpression[_], - field: InstanceField[_], - ) extends Blame[InsufficientPermission] { - override def blame(error: InsufficientPermission): Unit = { - assign.blame.blame(CopyStructFailed(assign, Referrable.originName(field))) - } - } - - case class StructCopyBeforeCallFailed( - inv: CInvocation[_], - field: InstanceField[_], - ) extends Blame[InsufficientPermission] { - override def blame(error: InsufficientPermission): Unit = { - inv.blame - .blame(CopyStructFailedBeforeCall(inv, Referrable.originName(field))) - } - } - case class VectorBoundFailed(subscript: AmbiguousSubscript[_]) extends Blame[InvocationFailure] { override def blame(error: InvocationFailure): Unit = @@ -273,6 +254,8 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) SuccessionMap() val cNameSuccessor: SuccessionMap[CNameTarget[Pre], Variable[Post]] = SuccessionMap() + val cLocalHeapNameSuccessor: SuccessionMap[CNameTarget[Pre], LocalHeapVariable[Post]] = + SuccessionMap() val cGlobalNameSuccessor : SuccessionMap[CNameTarget[Pre], HeapVariable[Post]] = SuccessionMap() val cStructSuccessor: SuccessionMap[CGlobalDeclaration[Pre], Class[Post]] = @@ -303,7 +286,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) private var kernelSpecifier: Option[CGpgpuKernelSpecifier[Pre]] = None private def CStructOrigin(sdecl: CStructDeclaration[_]): Origin = - sdecl.o.sourceName(sdecl.name.get) + sdecl.o.sourceName(sdecl.name.get).withContent(TypeName("struct")) private def CStructFieldOrigin(cdecl: CDeclarator[_]): Origin = cdecl.o.sourceName(nameFromDeclarator(cdecl)) @@ -999,7 +982,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case _ => throw WrongStructType(decl) } val newStruct = - new Class[Post]( + new ByValueClass[Post]( Seq(), rw.classDeclarations.collect { decls.foreach { fieldDecl => @@ -1019,7 +1002,6 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } }._1, Seq(), - tt[Post], )(CStructOrigin(sdecl)) rw.globalDeclarations.declare(newStruct) @@ -1163,21 +1145,14 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) implicit val o: Origin = init.o val targetClass: Class[Post] = cStructSuccessor(ref.decl) - val t = TClass[Post](targetClass.ref, Seq()) + val t = TByValueClass[Post](targetClass.ref, Seq()) - val v = new Variable[Post](t)(o.sourceName(info.name)) - cNameSuccessor(RefCLocalDeclaration(decl, 0)) = v - - val initialVal = init.init.map(i => - createStructCopy( - rw.dispatch(i), - ref.decl, - (f: InstanceField[_]) => - PanicBlame("Cannot fail due to insufficient perm"), - ) - ).getOrElse(NewObject[Post](targetClass.ref)) + val v = new LocalHeapVariable[Post](t)(o.sourceName(info.name)) + cLocalHeapNameSuccessor(RefCLocalDeclaration(decl, 0)) = v - Block(Seq(LocalDecl(v), assignLocal(v.get, initialVal))) + if (init.init.isDefined) { + Block(Seq(HeapLocalDecl(v), assignHeapLocal(v.get, rw.dispatch(init.init.get)))) + } else { HeapLocalDecl(v) } } def rewriteLocal(decl: CLocalDeclaration[Pre]): Statement[Post] = { @@ -1340,6 +1315,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))(local.blame) case Some(_) => throw NotAValue(local) } + case ref: RefCLocalDeclaration[Pre] if cLocalHeapNameSuccessor.contains(ref) => HeapLocal(cLocalHeapNameSuccessor.ref(ref)) case ref: RefCLocalDeclaration[Pre] => Local(cNameSuccessor.ref(ref)) case _: RefCudaVec[Pre] => throw NotAValue(local) } @@ -1460,59 +1436,6 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) foldStar(newFieldPerms) } - def createStructCopy( - value: Expr[Post], - struct: CGlobalDeclaration[Pre], - blame: InstanceField[_] => Blame[InsufficientPermission], - )(implicit o: Origin): Expr[Post] = { - val targetClass: Class[Post] = cStructSuccessor(struct) - val t = TClass[Post](targetClass.ref, Seq()) - - // Assign a new variable towards the value, such that methods do not get executed multiple times. - val vValue = new Variable[Post](t) - // The copy of the value - val vCopy = new Variable[Post](t) - - val fieldAssigns = targetClass.declarations.collect { - case field: InstanceField[Post] => - val ref: Ref[Post, InstanceField[Post]] = field.ref - assignField( - vCopy.get, - ref, - Deref[Post](vValue.get, field.ref)(blame(field)), - PanicBlame("Assignment should work"), - ) - } - - With( - Block( - Seq( - LocalDecl(vCopy), - LocalDecl(vValue), - assignLocal(vValue.get, value), - assignLocal(vCopy.get, NewObject[Post](targetClass.ref)), - ) ++ fieldAssigns - ), - vCopy.get, - ) - } - - def assignStruct(assign: PreAssignExpression[Pre]): Expr[Post] = { - getBaseType(assign.target.t) match { - case CTStruct(ref) => - val copy = - createStructCopy( - rw.dispatch(assign.value), - ref.decl, - (f: InstanceField[_]) => StructCopyFailed(assign, f), - )(assign.o) - PreAssignExpression(rw.dispatch(assign.target), copy)(AssignLocalOk)( - assign.o - ) - case _ => throw WrongStructType(assign.target) - } - } - def createUpdateVectorFunction(size: Int): Function[Post] = { implicit val o: Origin = Origin(Seq(LabelContext("vector update method"))) /* for instance for size 4: @@ -1748,18 +1671,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case _ => } - // Create copy for any direct structure arguments - val newArgs = args.map(a => - getBaseType(a.t) match { - case CTStruct(ref) => - createStructCopy( - rw.dispatch(a), - ref.decl, - (f: InstanceField[_]) => StructCopyBeforeCallFailed(inv, f), - )(a.o) - case _ => rw.dispatch(a) - } - ) + val newArgs = args.map(a => rw.dispatch(a)) implicit val o: Origin = inv.o inv.ref.get match { @@ -1998,6 +1910,6 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) def structType(t: CTStruct[Pre]): Type[Post] = { val targetClass = new LazyRef[Post, Class[Post]](cStructSuccessor(t.ref.decl)) - TClass[Post](targetClass, Seq())(t.o) + TByValueClass[Post](targetClass, Seq())(t.o) } } diff --git a/src/rewrite/vct/rewrite/lang/LangJavaToCol.scala b/src/rewrite/vct/rewrite/lang/LangJavaToCol.scala index 7a65f46ca6..8b73dce94d 100644 --- a/src/rewrite/vct/rewrite/lang/LangJavaToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangJavaToCol.scala @@ -275,7 +275,7 @@ case class LangJavaToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case cons: JavaConstructor[Pre] => logger.debug(s"Constructor for ${cons.o.inlineContextText}") implicit val o: Origin = cons.o - val t = TClass(ref, Seq()) + val t = TByReferenceClass(ref, Seq()) val `this` = ThisObject(ref) val results = currentJavaClass.top.modifiers.collect { @@ -429,7 +429,7 @@ case class LangJavaToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val instanceClass = rw.currentThis.having(ThisObject(javaInstanceClassSuccessor.ref(cls))) { - new Class[Post]( + new ByReferenceClass[Post]( rw.variables.dispatch(cls.typeParams)(rw), rw.classDeclarations.collect { makeJavaClass( @@ -454,7 +454,7 @@ case class LangJavaToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) if (staticDecls.nonEmpty) { val staticsClass = - new Class[Post]( + new ByReferenceClass[Post]( Seq(), rw.classDeclarations.collect { rw.currentThis @@ -472,7 +472,7 @@ case class LangJavaToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) )(JavaStaticsClassOrigin(cls)) rw.globalDeclarations.declare(staticsClass) - val t = TClass[Post](staticsClass.ref, Seq()) + val t = TByReferenceClass[Post](staticsClass.ref, Seq()) val singleton = withResult((res: Result[Post]) => function( AbstractApplicable, @@ -754,7 +754,7 @@ case class LangJavaToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) def classType(t: JavaTClass[Pre]): Type[Post] = t.ref.decl match { case classOrInterface: JavaClassOrInterface[Pre] => - TClass( + TByReferenceClass( javaInstanceClassSuccessor.ref(classOrInterface), t.typeArgs.map(rw.dispatch), ) diff --git a/src/rewrite/vct/rewrite/lang/LangPVLToCol.scala b/src/rewrite/vct/rewrite/lang/LangPVLToCol.scala index 4952f2eae4..6db99e77dc 100644 --- a/src/rewrite/vct/rewrite/lang/LangPVLToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangPVLToCol.scala @@ -179,7 +179,7 @@ case class LangPVLToCol[Pre <: Generation]( val PVLNew(t, typeArgs, args, givenMap, yields) = inv val classTypeArgs = t match { - case TClass(_, typeArgs) => typeArgs + case t: TClass[Pre] => t.typeArgs case _ => Seq() } implicit val o: Origin = inv.o diff --git a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala index 213997930e..488782e83b 100644 --- a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala @@ -206,7 +206,13 @@ case class LangSpecificToCol[Pre <: Generation]( pvl.maybeDeclareDefaultConstructor(cls) }._1 - globalDeclarations.succeed(cls, cls.rewrite(decls = decls)) + globalDeclarations.succeed( + cls, + cls match { + case cls: ByReferenceClass[Pre] => cls.rewrite(decls = decls) + case cls: ByValueClass[Pre] => cls.rewrite(decls = decls) + }, + ) } } @@ -369,10 +375,6 @@ case class LangSpecificToCol[Pre <: Generation]( case _ => } assign.target.t match { - case CPrimitiveType(specs) if specs.collectFirst { - case CSpecificationType(_: CTStruct[Pre]) => () - }.isDefined => - c.assignStruct(assign) case CPPPrimitiveType(_) => cpp.preAssignExpr(assign) case _ => rewriteDefault(assign) } diff --git a/src/rewrite/vct/rewrite/lang/LangTypesToCol.scala b/src/rewrite/vct/rewrite/lang/LangTypesToCol.scala index 6b7f9fe8eb..adfbeb4be1 100644 --- a/src/rewrite/vct/rewrite/lang/LangTypesToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangTypesToCol.scala @@ -88,7 +88,10 @@ case class LangTypesToCol[Pre <: Generation]() extends Rewriter[Pre] { case t @ PVLNamedType(_, typeArgs) => t.ref.get match { case spec: SpecTypeNameTarget[Pre] => specType(spec, typeArgs) - case RefClass(decl) => TClass(succ(decl), typeArgs.map(dispatch)) + case RefClass(decl: ByReferenceClass[Pre]) => + TByReferenceClass(succ[Class[Post]](decl), typeArgs.map(dispatch)) + case RefClass(decl: ByValueClass[Pre]) => + TByValueClass(succ[Class[Post]](decl), typeArgs.map(dispatch)) } case t @ CPrimitiveType(specs) => dispatch(C.getPrimitiveType(specs, context = Some(t))) diff --git a/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala b/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala index de0cfe57c6..f59b38fcb1 100644 --- a/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala +++ b/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala @@ -13,7 +13,15 @@ case object NoSupportSelfLoop extends RewriterBuilder { case class NoSupportSelfLoop[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(decl: Declaration[Pre]): Unit = decl match { - case cls: Class[Pre] => + case cls: ByReferenceClass[Pre] => + globalDeclarations.succeed( + cls, + cls.rewrite(supports = + cls.supports.filter(_.asClass.get.cls.decl != cls) + .map(_.rewriteDefault()) + ), + ) + case cls: ByValueClass[Pre] => globalDeclarations.succeed( cls, cls.rewrite(supports = diff --git a/src/rewrite/vct/rewrite/util/Extract.scala b/src/rewrite/vct/rewrite/util/Extract.scala index 95ef8f02d9..35ba7fe755 100644 --- a/src/rewrite/vct/rewrite/util/Extract.scala +++ b/src/rewrite/vct/rewrite/util/Extract.scala @@ -84,7 +84,9 @@ case class Extract[G]() { t -> Local( getOrElseUpdate( free, - new Variable(extract(TClass(t.cls, Seq())))(ExtractOrigin("this")), + new Variable(extract(t.cls.decl.classType(Seq())))(ExtractOrigin( + "this" + )), ).ref[Variable[G]] )(ExtractOrigin("")) case free @ FreeThisModel(t) => diff --git a/src/rewrite/vct/rewrite/veymont/EncodeChannels.scala b/src/rewrite/vct/rewrite/veymont/EncodeChannels.scala index 5a0d2dc111..6a229562c3 100644 --- a/src/rewrite/vct/rewrite/veymont/EncodeChannels.scala +++ b/src/rewrite/vct/rewrite/veymont/EncodeChannels.scala @@ -5,6 +5,7 @@ import hre.util.ScopedStack import vct.col.ast.{ Assign, Block, + ByValueClass, ChorStatement, Choreography, Class, @@ -23,6 +24,7 @@ import vct.col.ast.{ Program, Scope, Statement, + TByValueClass, TClass, TVar, Type, @@ -81,7 +83,7 @@ case class EncodeChannels[Pre <: Generation](importer: ImportADTImporter) }.get def channelType(comm: Communicate[Pre]): Type[Post] = - TClass[Post](channelClassSucc.ref(comm), Seq()) + TByValueClass[Post](channelClassSucc.ref(comm), Seq()) val currentCommunicate = ScopedStack[Communicate[Pre]]() val currentMsgTVar = ScopedStack[Variable[Pre]]() @@ -159,7 +161,7 @@ case class EncodeChannels[Pre <: Generation](importer: ImportADTImporter) }).succeed(chor) } - case cls: Class[Pre] if isEndpointClass(cls) => + case cls: ByValueClass[Pre] if isEndpointClass(cls) => cls.rewrite(decls = classDeclarations.collect { cls.decls.foreach(dispatch) @@ -174,13 +176,15 @@ case class EncodeChannels[Pre <: Generation](importer: ImportADTImporter) }._1 ).succeed(cls) - case cls: Class[Pre] if cls == genericChannelClass => + case cls: ByValueClass[Pre] if cls == genericChannelClass => globalDeclarations.scope { classDeclarations.scope { variables.scope { - currentMsgTVar.having(cls.typeArgs.head) { - channelClassSucc(currentCommunicate.top) = cls - .rewrite(typeArgs = Seq()).succeed(cls) + localHeapVariables.scope { + currentMsgTVar.having(cls.typeArgs.head) { + channelClassSucc(currentCommunicate.top) = cls + .rewrite(typeArgs = Seq()).succeed(cls) + } } } } diff --git a/src/rewrite/vct/rewrite/veymont/EncodeChoreography.scala b/src/rewrite/vct/rewrite/veymont/EncodeChoreography.scala index 76f4f9513c..a22aff6af8 100644 --- a/src/rewrite/vct/rewrite/veymont/EncodeChoreography.scala +++ b/src/rewrite/vct/rewrite/veymont/EncodeChoreography.scala @@ -31,7 +31,7 @@ import vct.col.ast.{ Scope, Sender, Statement, - TClass, + TByReferenceClass, TVoid, ThisChoreography, Variable, @@ -257,9 +257,9 @@ case class EncodeChoreography[Pre <: Generation]() currentInstanceMethod.having(method) { for (endpoint <- prog.endpoints) { endpointSucc((mode, endpoint)) = - new Variable(TClass(succ[Class[Post]](endpoint.cls.decl), Seq()))( - endpoint.o - ) + new Variable( + TByReferenceClass(succ[Class[Post]](endpoint.cls.decl), Seq()) + )(endpoint.o) } prog.params.foreach(_.drop()) diff --git a/src/rewrite/vct/rewrite/veymont/EncodeChoreographyParameters.scala b/src/rewrite/vct/rewrite/veymont/EncodeChoreographyParameters.scala index e86486d36c..38f0f9c535 100644 --- a/src/rewrite/vct/rewrite/veymont/EncodeChoreographyParameters.scala +++ b/src/rewrite/vct/rewrite/veymont/EncodeChoreographyParameters.scala @@ -4,8 +4,8 @@ import com.typesafe.scalalogging.LazyLogging import hre.util.ScopedStack import vct.col.ast.{ Block, + ByReferenceClass, Choreography, - Class, Declaration, Endpoint, EndpointName, @@ -44,8 +44,10 @@ case class EncodeChoreographyParameters[Pre <: Generation]() case p: Choreography[Pre] => p } lazy val allEndpoints = choreographies.flatMap { _.endpoints } - lazy val endpointOfClass: Map[Class[Pre], Endpoint[Pre]] = - allEndpoints.map { endpoint => (endpoint.cls.decl, endpoint) }.toMap + lazy val endpointOfClass: Map[ByReferenceClass[Pre], Endpoint[Pre]] = + allEndpoints.map { endpoint => + (endpoint.cls.decl.asInstanceOf[ByReferenceClass[Pre]], endpoint) + }.toMap lazy val choreographyOfEndpoint: Map[Endpoint[Pre], Choreography[Pre]] = choreographies.flatMap { chor => chor.endpoints.map { ep => (ep, chor) } } .toMap @@ -83,7 +85,7 @@ case class EncodeChoreographyParameters[Pre <: Generation]() }), ) } - case cls: Class[Pre] if endpointOfClass.contains(cls) => + case cls: ByReferenceClass[Pre] if endpointOfClass.contains(cls) => val endpoint = endpointOfClass(cls) val chor = choreographyOfEndpoint(endpoint) implicit val o = chor.o diff --git a/src/rewrite/vct/rewrite/veymont/GenerateChoreographyPermissions.scala b/src/rewrite/vct/rewrite/veymont/GenerateChoreographyPermissions.scala index b66fa3c90b..c074286515 100644 --- a/src/rewrite/vct/rewrite/veymont/GenerateChoreographyPermissions.scala +++ b/src/rewrite/vct/rewrite/veymont/GenerateChoreographyPermissions.scala @@ -230,7 +230,7 @@ case class GenerateChoreographyPermissions[Pre <: Generation]( transitivePerm(Result[Post](anySucc(app)), app.returnType) def classPerm(cls: Class[Pre]): Expr[Post] = - transitivePerm(ThisObject[Post](succ(cls))(cls.o), TClass(cls.ref, Seq()))( + transitivePerm(ThisObject[Post](succ(cls))(cls.o), cls.classType(Seq()))( cls.o ) @@ -276,17 +276,18 @@ case class GenerateChoreographyPermissions[Pre <: Generation]( u, )), ) - case TClass(Ref(cls), _) if !generatingClasses.contains(cls) => - generatingClasses.having(cls) { - foldStar(cls.collect { case f: InstanceField[Pre] => + case t: TClass[Pre] if !generatingClasses.contains(t.cls.decl) => + generatingClasses.having(t.cls.decl) { + foldStar(t.cls.decl.collect { case f: InstanceField[Pre] => fieldTransitivePerm(e, f)(f.o) }) } - case TClass(Ref(cls), _) => + case t: TByReferenceClass[Pre] => // The class we are generating permission for has already been encountered when going through the chain // of fields. So we cut off the computation logger.warn( - s"Not generating permissions for recursive occurrence of ${cls.o.getPreferredNameOrElse().ucamel}. Circular datastructures are not supported by permission generation" + s"Not generating permissions for recursive occurrence of ${t.cls.decl + .o.getPreferredNameOrElse().ucamel}. Circular datastructures are not supported by permission generation" ) tt case _ => tt diff --git a/src/rewrite/vct/rewrite/veymont/GenerateImplementation.scala b/src/rewrite/vct/rewrite/veymont/GenerateImplementation.scala index 06e100d262..dbff2b2b99 100644 --- a/src/rewrite/vct/rewrite/veymont/GenerateImplementation.scala +++ b/src/rewrite/vct/rewrite/veymont/GenerateImplementation.scala @@ -10,6 +10,7 @@ import vct.col.ast.{ Block, BooleanValue, Branch, + ByReferenceClass, ChorBranch, ChorGuard, ChorLoop, @@ -55,6 +56,7 @@ import vct.col.ast.{ Star, Statement, TClass, + TByReferenceClass, TVeyMontChannel, TVoid, ThisChoreography, @@ -211,7 +213,7 @@ case class GenerateImplementation[Pre <: Generation]() override def dispatch(decl: Declaration[Pre]): Unit = { decl match { case p: Procedure[Pre] => super.dispatch(p) - case cls: Class[Pre] if isEndpointClass(cls) => + case cls: ByReferenceClass[Pre] if isEndpointClass(cls) => val chor = choreographyOf(cls) val endpoint = endpointOf(cls) currentThis.having(ThisObject[Post](succ(cls))(cls.o)) { @@ -457,7 +459,7 @@ case class GenerateImplementation[Pre <: Generation]() seqProg.endpoints.foreach(thread => { val threadField = new InstanceField[Post]( - TClass(givenClassSucc.ref(thread.t), Seq()), + TByReferenceClass(givenClassSucc.ref(thread.t), Seq()), Nil, )(thread.o) val channelFields = getChannelFields( @@ -476,16 +478,16 @@ case class GenerateImplementation[Pre <: Generation]() }) } - private def dispatchGivenClass(c: Class[Pre]): Class[Post] = { + private def dispatchGivenClass(c: ByReferenceClass[Pre]): Class[Post] = { val rw = GivenClassRewriter() val gc = c.rewrite(decls = classDeclarations.collect { - (givenClassConstrSucc.get(TClass(c.ref, Seq())).get +: c.declarations) - .foreach(d => rw.dispatch(d)) + (givenClassConstrSucc.get(TByReferenceClass(c.ref, Seq())).get +: + c.declarations).foreach(d => rw.dispatch(d)) }._1 )(rw) - givenClassSucc.update(TClass(c.ref, Seq()), gc) + givenClassSucc.update(TByReferenceClass(c.ref, Seq()), gc) gc } @@ -550,7 +552,7 @@ case class GenerateImplementation[Pre <: Generation]() else rewriteDefault(l) case t: ThisObject[Pre] => - val thisClassType = TClass(t.cls, Seq()) + val thisClassType = TByReferenceClass(t.cls, Seq()) if ( rewritingConstr.nonEmpty && rewritingConstr.top._2 == thisClassType ) @@ -623,7 +625,7 @@ case class GenerateImplementation[Pre <: Generation]() val threadRun = getThreadRunMethod(threadRes.runMethod) classDeclarations.scope { val threadClass = - new Class[Post]( + new ByReferenceClass[Post]( Seq(), (threadRes.threadField +: threadRes.channelFields.values.toSeq) ++ (threadConstr +: threadRun +: threadMethods), diff --git a/src/rewrite/vct/rewrite/veymont/SpecializeEndpointClasses.scala b/src/rewrite/vct/rewrite/veymont/SpecializeEndpointClasses.scala index 177c4064b5..f75250f037 100644 --- a/src/rewrite/vct/rewrite/veymont/SpecializeEndpointClasses.scala +++ b/src/rewrite/vct/rewrite/veymont/SpecializeEndpointClasses.scala @@ -11,6 +11,7 @@ import vct.col.ast.{ Block, BooleanValue, Branch, + ByReferenceClass, ChorGuard, ChorRun, ChorStatement, @@ -140,7 +141,7 @@ case class SpecializeEndpointClasses[Pre <: Generation]() } val wrapperClass = - new Class[Post]( + new ByReferenceClass[Post]( typeArgs = Seq(), supports = Seq(), intrinsicLockInvariant = tt, diff --git a/test/main/vct/helper/SimpleProgramGenerator.scala b/test/main/vct/helper/SimpleProgramGenerator.scala index f0587c3abd..3aee073fe9 100644 --- a/test/main/vct/helper/SimpleProgramGenerator.scala +++ b/test/main/vct/helper/SimpleProgramGenerator.scala @@ -20,7 +20,7 @@ object SimpleProgramGenerator { val contract1 = generateSimpleApplicableContract[G]() val blame1 = origin val method1 = new InstanceMethod(TVoid(), Nil, Nil, Nil, Option(body), contract1)(blame1) - val classNode1 = new Class(Nil, Seq(method1), Nil, tt) + val classNode1 = new ByReferenceClass(Nil, Seq(method1), Nil, tt) Program(Seq(classNode1))(DiagnosticOrigin) } diff --git a/test/main/vct/test/integration/examples/CSpec.scala b/test/main/vct/test/integration/examples/CSpec.scala index d4cfd0de4e..b5af488aab 100644 --- a/test/main/vct/test/integration/examples/CSpec.scala +++ b/test/main/vct/test/integration/examples/CSpec.scala @@ -310,7 +310,7 @@ class CSpec extends VercorsSpec { } """ - vercors should fail withCode "copyStructFailedBeforeCall" using silicon in "Insufficient permission for field x to copy struct before call" c + vercors should fail withCode "copyClassFailedBeforeCall" using silicon in "Insufficient permission for field x to copy struct before call" c """ struct d { int x; @@ -328,7 +328,7 @@ class CSpec extends VercorsSpec { } """ - vercors should fail withCode "copyStructFailed" using silicon in "Insufficient permission for field x to copy struct" c + vercors should fail withCode "copyClassFailed" using silicon in "Insufficient permission for field x to copy struct" c """ struct d { int x; From b01df4d72ab5b35b9a9848e90d4dc86e1c109502 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Tue, 11 Jun 2024 15:43:46 +0200 Subject: [PATCH 06/47] Use ADT encoding for ByValueClasses --- examples/concepts/c/structs.c | 19 +- src/col/vct/col/ast/Node.scala | 31 +- .../alloc/NewNonNullPointerArrayImpl.scala | 14 + .../coercion/CoerceNonNullPointerImpl.scala | 9 + .../location/ByValueClassLocationImpl.scala | 10 + .../col/ast/type/TNonNullPointerImpl.scala | 16 + src/col/vct/col/origin/Blame.scala | 9 +- src/col/vct/col/resolve/Resolve.scala | 1 + .../vct/col/typerules/CoercingRewriter.scala | 4 + src/col/vct/col/typerules/CoercionUtils.scala | 5 + src/col/vct/col/util/AstBuildHelpers.scala | 15 +- src/main/vct/main/stages/Transformation.scala | 9 +- src/rewrite/vct/rewrite/ClassToRef.scala | 262 ++++++++++++++-- .../vct/rewrite/DisambiguateLocation.scala | 3 +- .../vct/rewrite/EncodeArrayValues.scala | 60 +++- .../vct/rewrite/EncodeByValueClass.scala | 249 ---------------- .../vct/rewrite/LowerLocalHeapVariables.scala | 111 +++++++ .../vct/rewrite/PrepareByValueClass.scala | 279 ++++++++++++++++++ .../ResolveExpressionSideEffects.scala | 2 + src/rewrite/vct/rewrite/TrivialAddrOf.scala | 4 +- .../vct/rewrite/VariableToPointer.scala | 221 ++++++++++++++ .../vct/rewrite/adt/ImportPointer.scala | 105 +++---- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 57 ++-- .../vct/rewrite/lang/LangSpecificToCol.scala | 22 +- 24 files changed, 1117 insertions(+), 400 deletions(-) create mode 100644 src/col/vct/col/ast/expr/heap/alloc/NewNonNullPointerArrayImpl.scala create mode 100644 src/col/vct/col/ast/family/coercion/CoerceNonNullPointerImpl.scala create mode 100644 src/col/vct/col/ast/family/location/ByValueClassLocationImpl.scala create mode 100644 src/col/vct/col/ast/type/TNonNullPointerImpl.scala delete mode 100644 src/rewrite/vct/rewrite/EncodeByValueClass.scala create mode 100644 src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala create mode 100644 src/rewrite/vct/rewrite/PrepareByValueClass.scala create mode 100644 src/rewrite/vct/rewrite/VariableToPointer.scala diff --git a/examples/concepts/c/structs.c b/examples/concepts/c/structs.c index 44882c9ac6..886ed073f5 100644 --- a/examples/concepts/c/structs.c +++ b/examples/concepts/c/structs.c @@ -21,8 +21,8 @@ struct linked_list{ /*@ context p != NULL ** Perm(p, write); - context Perm(p->x, write); - context Perm(p->y, write); + context Perm(&p->x, write); + context Perm(&p->y, write); ensures p->x == 0; ensures p->y == 0; ensures \old(*p) == *p; @@ -44,14 +44,15 @@ void alter_struct_1(struct point *p){ } /*@ - context Perm(p.x, 1\1); - context Perm(p.y, 1\1); + context Perm(&p.x, 1\1); + context Perm(&p.y, 1\1); @*/ void alter_copy_struct(struct point p){ p.x = 0; p.y = 0; } +// TODO: Should be auto-generated /*@ context Perm(p, 1\1); @*/ @@ -75,7 +76,7 @@ int avr_x(struct triangle *r){ requires inp != NULL && \pointer_length(inp) >= n; requires (\forall* int i; 0 <= i && i < n; Perm(&inp[i], 1\10)); requires (\forall int i, int j; 0<=i && i {:inp[i]:} != {:inp[j]:}); - requires (\forall* int i; 0 <= i && i < n; Perm(inp[i].x, 1\10)); + requires (\forall* int i; 0 <= i && i < n; Perm(&inp[i].x, 1\10)); ensures |\result| == n; ensures (\forall int i; 0 <= i && i < n; \result[i] == inp[i].x); //ensures n>0 ==> \result == inp_to_seq(inp, n-1) + [inp[n-1].x]; @@ -132,7 +133,7 @@ int main(){ struct point *pp; pp = &p; - //@ assert (pp[0] != NULL ); + /* //@ assert (pp[0] != NULL ); */ assert (pp != NULL ); p.x = 1; @@ -163,9 +164,13 @@ int main(){ struct polygon pol, *ppols; ppols = &pol; pol.ps = ps; + //@ assert Perm(&ppols->ps[0], write); + //@ assert Perm(&ppols->ps[1], write); + //@ assert Perm(&ppols->ps[2], write); + //@ assert (\forall* int i; 0<=i && i<3; Perm(&ppols->ps[i], write)); int avr_pol = avr_x_pol(ppols, 3); // assert sum_seq(inp_to_seq(ppols->ps, 3)) == 6; assert(avr_pol == 2); return 0; -} \ No newline at end of file +} diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 67e3e06e5a..c46d938b5b 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -133,6 +133,9 @@ final case class TArray[G](element: Type[G])( final case class TPointer[G](element: Type[G])( implicit val o: Origin = DiagnosticOrigin ) extends Type[G] with TPointerImpl[G] +final case class TNonNullPointer[G](element: Type[G])( + implicit val o: Origin = DiagnosticOrigin +) extends Type[G] with TNonNullPointerImpl[G] final case class TType[G](t: Type[G])(implicit val o: Origin = DiagnosticOrigin) extends Type[G] with TTypeImpl[G] final case class TVar[G](ref: Ref[G, Variable[G]])( @@ -325,8 +328,9 @@ final case class LocalDecl[G](local: Variable[G])(implicit val o: Origin) extends NonExecutableStatement[G] with PurelySequentialStatement[G] with LocalDeclImpl[G] -final case class HeapLocalDecl[G](local: LocalHeapVariable[G])(implicit val o: Origin) - extends NonExecutableStatement[G] +final case class HeapLocalDecl[G](local: LocalHeapVariable[G])( + implicit val o: Origin +) extends NonExecutableStatement[G] with PurelySequentialStatement[G] with HeapLocalDeclImpl[G] final case class SpecIgnoreStart[G]()(implicit val o: Origin) @@ -1008,6 +1012,9 @@ final case class CoerceNullAnyClass[G]()(implicit val o: Origin) final case class CoerceNullPointer[G](pointerElementType: Type[G])( implicit val o: Origin ) extends Coercion[G] with CoerceNullPointerImpl[G] +final case class CoerceNonNullPointer[G](elementType: Type[G])( + implicit val o: Origin +) extends Coercion[G] with CoerceNonNullPointerImpl[G] final case class CoerceNullEnum[G](targetEnum: Ref[G, Enum[G]])( implicit val o: Origin ) extends Coercion[G] with CoerceNullEnumImpl[G] @@ -1367,8 +1374,9 @@ final case class ScopedExpr[G](declarations: Seq[Variable[G]], body: Expr[G])( final case class Local[G](ref: Ref[G, Variable[G]])(implicit val o: Origin) extends Expr[G] with LocalImpl[G] -final case class HeapLocal[G](ref: Ref[G, LocalHeapVariable[G]])(implicit val o: Origin) - extends Expr[G] with HeapLocalImpl[G] +final case class HeapLocal[G](ref: Ref[G, LocalHeapVariable[G]])( + implicit val o: Origin +) extends Expr[G] with HeapLocalImpl[G] final case class EnumUse[G]( enum: Ref[G, Enum[G]], @@ -1735,6 +1743,10 @@ final case class PointerLocation[G](pointer: Expr[G])( val blame: Blame[PointerLocationError] )(implicit val o: Origin) extends Location[G] with PointerLocationImpl[G] +final case class ByValueClassLocation[G](expr: Expr[G])( + val blame: Blame[PointerLocationError] +)(implicit val o: Origin) + extends Location[G] with ByValueClassLocationImpl[G] final case class PredicateLocation[G]( predicate: Ref[G, Predicate[G]], args: Seq[Expr[G]], @@ -1869,6 +1881,10 @@ final case class NewPointerArray[G](element: Type[G], size: Expr[G])( val blame: Blame[ArraySizeError] )(implicit val o: Origin) extends Expr[G] with NewPointerArrayImpl[G] +final case class NewNonNullPointerArray[G](element: Type[G], size: Expr[G])( + val blame: Blame[ArraySizeError] +)(implicit val o: Origin) + extends Expr[G] with NewNonNullPointerArrayImpl[G] final case class FreePointer[G](pointer: Expr[G])( val blame: Blame[PointerFreeError] )(implicit val o: Origin) @@ -2761,10 +2777,9 @@ final case class GpgpuAtomic[G]( extends CStatement[G] with GpgpuAtomicImpl[G] sealed trait CExpr[G] extends Expr[G] with CExprImpl[G] -final case class CLocal[G](name: String)( - val blame: Blame[DerefInsufficientPermission] -)(implicit val o: Origin) - extends CExpr[G] with CLocalImpl[G] { +final case class CLocal[G](name: String)(val blame: Blame[FrontendDerefError])( + implicit val o: Origin +) extends CExpr[G] with CLocalImpl[G] { var ref: Option[CNameTarget[G]] = None } final case class CInvocation[G]( diff --git a/src/col/vct/col/ast/expr/heap/alloc/NewNonNullPointerArrayImpl.scala b/src/col/vct/col/ast/expr/heap/alloc/NewNonNullPointerArrayImpl.scala new file mode 100644 index 0000000000..abc8896755 --- /dev/null +++ b/src/col/vct/col/ast/expr/heap/alloc/NewNonNullPointerArrayImpl.scala @@ -0,0 +1,14 @@ +package vct.col.ast.expr.heap.alloc + +import vct.col.ast.ops.NewNonNullPointerArrayOps +import vct.col.ast.{NewNonNullPointerArray, TNonNullPointer, Type} +import vct.col.print._ + +trait NewNonNullPointerArrayImpl[G] extends NewNonNullPointerArrayOps[G] { + this: NewNonNullPointerArray[G] => + override lazy val t: Type[G] = TNonNullPointer(element) + + override def precedence: Int = Precedence.POSTFIX + override def layout(implicit ctx: Ctx): Doc = + Text("new") <+> element <> "[" <> size <> "]" +} diff --git a/src/col/vct/col/ast/family/coercion/CoerceNonNullPointerImpl.scala b/src/col/vct/col/ast/family/coercion/CoerceNonNullPointerImpl.scala new file mode 100644 index 0000000000..2dd5617753 --- /dev/null +++ b/src/col/vct/col/ast/family/coercion/CoerceNonNullPointerImpl.scala @@ -0,0 +1,9 @@ +package vct.col.ast.family.coercion + +import vct.col.ast.ops.CoerceNonNullPointerOps +import vct.col.ast.{CoerceNonNullPointer, TPointer} + +trait CoerceNonNullPointerImpl[G] extends CoerceNonNullPointerOps[G] { + this: CoerceNonNullPointer[G] => + override def target: TPointer[G] = TPointer(elementType) +} diff --git a/src/col/vct/col/ast/family/location/ByValueClassLocationImpl.scala b/src/col/vct/col/ast/family/location/ByValueClassLocationImpl.scala new file mode 100644 index 0000000000..225d574c96 --- /dev/null +++ b/src/col/vct/col/ast/family/location/ByValueClassLocationImpl.scala @@ -0,0 +1,10 @@ +package vct.col.ast.family.location + +import vct.col.ast.ByValueClassLocation +import vct.col.ast.ops.ByValueClassLocationOps +import vct.col.print.{Ctx, Doc} + +trait ByValueClassLocationImpl[G] extends ByValueClassLocationOps[G] { + this: ByValueClassLocation[G] => + override def layout(implicit ctx: Ctx): Doc = expr.show +} diff --git a/src/col/vct/col/ast/type/TNonNullPointerImpl.scala b/src/col/vct/col/ast/type/TNonNullPointerImpl.scala new file mode 100644 index 0000000000..cd769efa94 --- /dev/null +++ b/src/col/vct/col/ast/type/TNonNullPointerImpl.scala @@ -0,0 +1,16 @@ +package vct.col.ast.`type` + +import vct.col.ast.TNonNullPointer +import vct.col.ast.ops.TNonNullPointerOps +import vct.col.print._ + +trait TNonNullPointerImpl[G] extends TNonNullPointerOps[G] { + this: TNonNullPointer[G] => + override def layoutSplitDeclarator(implicit ctx: Ctx): (Doc, Doc) = { + val (spec, decl) = element.layoutSplitDeclarator + (spec, decl <> "*") + } + + override def layout(implicit ctx: Ctx): Doc = + Group(Text("NonNull") <> open <> element <> close) +} diff --git a/src/col/vct/col/origin/Blame.scala b/src/col/vct/col/origin/Blame.scala index 174b70f334..50a1b2c6dc 100644 --- a/src/col/vct/col/origin/Blame.scala +++ b/src/col/vct/col/origin/Blame.scala @@ -159,7 +159,7 @@ case class AssignFieldFailed(node: SilverFieldAssign[_]) } case class CopyClassFailed(node: Node[_], clazz: ByValueClass[_], field: String) - extends AssignFailed with NodeVerificationFailure { + extends PointerDerefError with NodeVerificationFailure { override def code: String = "copyClassFailed" override def descInContext: String = s"Insufficient read permission for field '$field' to copy ${clazz.o @@ -172,7 +172,9 @@ case class CopyClassFailedBeforeCall( node: Node[_], clazz: ByValueClass[_], field: String, -) extends AssignFailed with InvocationFailure with NodeVerificationFailure { +) extends PointerDerefError + with InvocationFailure + with NodeVerificationFailure { override def code: String = "copyClassFailedBeforeCall" override def descInContext: String = s"Insufficient read permission for field '$field' to copy ${clazz.o @@ -1517,6 +1519,9 @@ object JavaArrayInitializerBlame "The explicit initialization of an array in Java should never generate an assignment that exceeds the bounds of the array" ) +object NonNullPointerNull + extends PanicBlame("A non-null pointer can never be null") + object UnsafeDontCare { case class Satisfiability(reason: String) extends UnsafeDontCare[NontrivialUnsatisfiable] diff --git a/src/col/vct/col/resolve/Resolve.scala b/src/col/vct/col/resolve/Resolve.scala index f88099f286..daa25bb41e 100644 --- a/src/col/vct/col/resolve/Resolve.scala +++ b/src/col/vct/col/resolve/Resolve.scala @@ -366,6 +366,7 @@ case object ResolveReferences extends LazyLogging { case CPPDeclarationStatement(decl) => Seq(decl) case JavaLocalDeclarationStatement(decl) => Seq(decl) case LocalDecl(v) => Seq(v) + case HeapLocalDecl(v) => Seq(v) case other => other.subnodes.flatMap(scanScope(ctx)) } diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index 26a81d7d8a..4c709d31cd 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -281,6 +281,7 @@ abstract class CoercingRewriter[Pre <: Generation]() case CoerceNullJavaClass(_) => e case CoerceNullAnyClass() => e case CoerceNullPointer(_) => e + case CoerceNonNullPointer(_) => e case CoerceFracZFrac() => e case CoerceZFracRat() => e case CoerceFloatRat(_) => e @@ -1566,6 +1567,8 @@ abstract class CoercingRewriter[Pre <: Generation]() NewArray(element, dims.map(int), moreDims, initialize)(na.blame) case na @ NewPointerArray(element, size) => NewPointerArray(element, size)(na.blame) + case na @ NewNonNullPointerArray(element, size) => + NewNonNullPointerArray(element, size)(na.blame) case NewObject(cls) => NewObject(cls) case NoPerm() => NoPerm() case Not(arg) => Not(bool(arg)) @@ -2686,6 +2689,7 @@ abstract class CoercingRewriter[Pre <: Generation]() ArrayLocation(array(arrayObj)._1, int(subscript))(a.blame) case p @ PointerLocation(pointerExp) => PointerLocation(pointer(pointerExp)._1)(p.blame) + case ByValueClassLocation(expr) => node case PredicateLocation(predicate, args) => PredicateLocation(predicate, coerceArgs(args, predicate.decl)) case InstancePredicateLocation(predicate, obj, args) => diff --git a/src/col/vct/col/typerules/CoercionUtils.scala b/src/col/vct/col/typerules/CoercionUtils.scala index deef9b73ca..a71e109e01 100644 --- a/src/col/vct/col/typerules/CoercionUtils.scala +++ b/src/col/vct/col/typerules/CoercionUtils.scala @@ -140,6 +140,9 @@ case object CoercionUtils { TPointer(element), ) => // if element == innerType => getAnyCoercion(element, innerType).getOrElse(return None) + case (TNonNullPointer(innerType), TPointer(element)) + if innerType == element => + CoerceNonNullPointer(innerType) case ( TPointer(element), CTPointer(innerType), @@ -430,6 +433,8 @@ case object CoercionUtils { case t: TPointer[G] => Some((CoerceIdentity(source), t)) case t: CTPointer[G] => Some((CoerceIdentity(source), TPointer(t.innerType))) + case t: TNonNullPointer[G] => + Some((CoerceIdentity(source), TPointer(t.element))) case t: CTArray[G] => Some((CoerceCArrayPointer(t.innerType), TPointer(t.innerType))) case t: CPPPrimitiveType[G] => chainCPPCoercion(t, getAnyPointerCoercion) diff --git a/src/col/vct/col/util/AstBuildHelpers.scala b/src/col/vct/col/util/AstBuildHelpers.scala index 8cbf2ba839..575d920219 100644 --- a/src/col/vct/col/util/AstBuildHelpers.scala +++ b/src/col/vct/col/util/AstBuildHelpers.scala @@ -106,7 +106,9 @@ object AstBuildHelpers { } implicit class LocalHeapVarBuildHelpers[G](left: LocalHeapVariable[G]) { - def get(implicit origin: Origin): HeapLocal[G] = HeapLocal(new DirectRef(left)) + def get(blame: Blame[PointerDerefError])( + implicit origin: Origin + ): DerefPointer[G] = DerefPointer(HeapLocal[G](new DirectRef(left)))(blame) } implicit class FieldBuildHelpers[G](left: SilverDeref[G]) { @@ -672,6 +674,13 @@ object AstBuildHelpers { )(implicit o: Origin): FunctionInvocation[G] = FunctionInvocation(ref, args, typeArgs, givenMap, yields)(blame) + def adtFunctionInvocation[G]( + ref: Ref[G, ADTFunction[G]], + typeArgs: Option[(Ref[G, AxiomaticDataType[G]], Seq[Type[G]])] = None, + args: Seq[Expr[G]] = Nil, + )(implicit o: Origin): ADTFunctionInvocation[G] = + ADTFunctionInvocation(typeArgs, ref, args) + def methodInvocation[G]( blame: Blame[InstanceInvocationFailure], obj: Expr[G], @@ -768,10 +777,6 @@ object AstBuildHelpers { implicit o: Origin ): Assign[G] = Assign(local, value)(AssignLocalOk) - def assignHeapLocal[G](local: HeapLocal[G], value: Expr[G])( - implicit o: Origin - ): Assign[G] = Assign(local, value)(AssignLocalOk) - def assignField[G]( obj: Expr[G], field: Ref[G, InstanceField[G]], diff --git a/src/main/vct/main/stages/Transformation.scala b/src/main/vct/main/stages/Transformation.scala index 15c61d3b8d..f75ed061ee 100644 --- a/src/main/vct/main/stages/Transformation.scala +++ b/src/main/vct/main/stages/Transformation.scala @@ -29,13 +29,15 @@ import vct.result.VerificationError.SystemError import vct.rewrite.adt.ImportSetCompat import vct.rewrite.{ EncodeAutoValue, + PrepareByValueClass, EncodeRange, EncodeResourceValues, ExplicitResourceValues, HeapVariableToRef, + LowerLocalHeapVariables, MonomorphizeClass, SmtlibToProverTypes, - EncodeByValueClass, + VariableToPointer, } import vct.rewrite.lang.ReplaceSYCLTypes import vct.rewrite.veymont.{ @@ -326,7 +328,8 @@ case class SilverTransformation( EncodeString, // Encode spec string as seq EncodeChar, CollectLocalDeclarations, // all decls in Scope - EncodeByValueClass, +// EncodeByValueClass, + VariableToPointer, // should happen before ParBlockEncoder so it can distinguish between variables which can and can't altered in a parallel block DesugarPermissionOperators, // no PointsTo, \pointer, etc. ReadToValue, // resolve wildcard into fractional permission TrivialAddrOf, @@ -335,6 +338,7 @@ case class SilverTransformation( QuantifySubscriptAny, // no arr[*] IterationContractToParBlock, PropagateContextEverywhere, // inline context_everywhere into loop invariants + PrepareByValueClass, EncodeArrayValues, // maybe don't target shift lemmas on generated function for \values GivenYieldsToArgs, CheckProcessAlgebra, @@ -384,6 +388,7 @@ case class SilverTransformation( // No more classes ClassToRef, HeapVariableToRef, + LowerLocalHeapVariables, CheckContractSatisfiability.withArg(checkSat), DesugarCollectionOperators, EncodeNdIndex, diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index b73c52606a..a3c29fbe64 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -22,6 +22,10 @@ case object ClassToRef extends RewriterBuilder { private def InstanceOfOrigin: Origin = Origin(Seq(PreferredName(Seq("subtype")), LabelContext("classToRef"))) + private val PointerCreationOrigin: Origin = Origin( + Seq(LabelContext("classToRef, pointer creation method")) + ) + case class InstanceNullPreconditionFailed( inner: Blame[InstanceNull], inv: InvokingNode[_], @@ -38,11 +42,19 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { private def This: Origin = Origin(Seq(PreferredName(Seq("this")), LabelContext("classToRef"))) - val fieldSucc: SuccessionMap[Field[Pre], SilverField[Post]] = SuccessionMap() + val byRefFieldSucc: SuccessionMap[Field[Pre], SilverField[Post]] = + SuccessionMap() + val byValFieldSucc: SuccessionMap[Field[Pre], ADTFunction[Post]] = + SuccessionMap() + val byValClassSucc + : SuccessionMap[ByValueClass[Pre], AxiomaticDataType[Post]] = + SuccessionMap() val methodSucc: SuccessionMap[InstanceMethod[Pre], Procedure[Post]] = SuccessionMap() val consSucc: SuccessionMap[Constructor[Pre], Procedure[Post]] = SuccessionMap() + val byValConsSucc: SuccessionMap[ByValueClass[Pre], ADTFunction[Post]] = + SuccessionMap() val functionSucc: SuccessionMap[InstanceFunction[Pre], Function[Post]] = SuccessionMap() val predicateSucc: SuccessionMap[InstancePredicate[Pre], Predicate[Post]] = @@ -53,6 +65,26 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { var typeNumberStore: mutable.Map[Class[Pre], Int] = mutable.Map() val typeOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap() val instanceOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap() + private val pointerCreationMethods + : SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap() + + def makePointerCreationMethod(t: Type[Post]): Procedure[Post] = { + implicit val o: Origin = PointerCreationOrigin + + val result = new Variable[Post](TNonNullPointer(t)) + globalDeclarations.declare(procedure[Post]( + blame = AbstractApplicable, + contractBlame = TrueSatisfiable, + returnType = TVoid(), + outArgs = Seq(result), + ensures = UnitAccountedPredicate( + (PointerBlockLength(result.get)(FramedPtrBlockLength) === const(1)) &* + (PointerBlockOffset(result.get)(FramedPtrOffset) === const(0)) &* + Perm(PointerLocation(result.get)(FramedPtrOffset), WritePerm()) + ), + decreases = Some(DecreasesClauseNoRecursion[Post]()), + )) + } def typeNumber(cls: Class[Pre]): Int = typeNumberStore.getOrElseUpdate(cls, typeNumberStore.size + 1) @@ -143,7 +175,6 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ) typeNumber(cls) - cls.drop() cls.decls.foreach { case function: InstanceFunction[Pre] => implicit val o: Origin = function.o @@ -278,33 +309,182 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ) } case field: Field[Pre] => - fieldSucc(field) = new SilverField(dispatch(field.t))(field.o) - globalDeclarations.declare(fieldSucc(field)) + if (cls.isInstanceOf[ByReferenceClass[Pre]]) { + byRefFieldSucc(field) = + new SilverField(dispatch(field.t))(field.o) + globalDeclarations.declare(byRefFieldSucc(field)) + } case _ => throw ExtraNode } + cls match { + case cls: ByValueClass[Pre] => + implicit val o: Origin = cls.o + val axiomType = TAxiomatic[Post](byValClassSucc.ref(cls), Nil) + val (fieldFunctions, fieldTypes) = + cls.decls.collect { case field: Field[Pre] => + val newT = dispatch(field.t) + byValFieldSucc(field) = + new ADTFunction[Post]( + Seq(new Variable(axiomType)(field.o)), + newT, + )(field.o) + (byValFieldSucc(field), newT) + }.unzip + val constructor = + new ADTFunction[Post](fieldTypes.map(new Variable(_)), axiomType)( + cls.o + ) + val destructorAxiom = + new ADTAxiom[Post](foralls( + fieldTypes, + body = + variables => { + foldAnd(variables.zip(fieldFunctions).map { case (v, f) => + adtFunctionInvocation[Post]( + f.ref, + args = Seq(adtFunctionInvocation[Post]( + constructor.ref, + None, + args = variables, + )), + ) === v + }) + }, + triggers = + variables => { + fieldFunctions.map { f => + Seq(adtFunctionInvocation[Post]( + f.ref, + args = Seq(adtFunctionInvocation[Post]( + constructor.ref, + None, + args = variables, + )), + )) + } + }, + )) + val nonNullAxiom = + new ADTAxiom[Post](forall( + axiomType, + body = + v => { + foldAnd(fieldFunctions.map { f => + adtFunctionInvocation[Post]( + f.ref, + None, + args = Seq(v), + ) !== Null() + }) + }, + )) + // TODO: need Non null pointer.... + val injectivityAxiom = + new ADTAxiom[Post](foralls( + Seq(axiomType, axiomType), + body = { case Seq(a0, a1) => + (a0 !== a1) ==> foldAnd(fieldFunctions.map { f => + DerefPointer( + adtFunctionInvocation[Post](f.ref, args = Seq(a0)) + )(NonNullPointerNull)( + o.withContent(TypeName("helloWorld")) + ) !== DerefPointer( + adtFunctionInvocation[Post](f.ref, args = Seq(a1)) + )(NonNullPointerNull)(o.withContent(TypeName("helloWorld"))) + }) + }, + triggers = { case Seq(a0, a1) => + fieldFunctions.map { f => + Seq( + DerefPointer( + adtFunctionInvocation[Post](f.ref, None, args = Seq(a0)) + )(NonNullPointerNull)(o.withContent(TypeName( + "helloWorld" + ))), + DerefPointer( + adtFunctionInvocation[Post](f.ref, None, args = Seq(a1)) + )(NonNullPointerNull)(o.withContent(TypeName( + "helloWorld" + ))), + ) + } + }, + )) + byValConsSucc(cls) = constructor + byValClassSucc(cls) = + new AxiomaticDataType[Post]( + Seq( + constructor, + destructorAxiom, +// nonNullAxiom, + injectivityAxiom, + ) ++ fieldFunctions, + Nil, + ) + globalDeclarations.succeed(cls, byValClassSucc(cls)) + case _ => cls.drop() + } case decl => rewriteDefault(decl) } def instantiate(cls: Class[Pre], target: Ref[Post, Variable[Post]])( implicit o: Origin ): Statement[Post] = { - Block(Seq( - SilverNewRef[Post]( - target, - cls.decls.collect { case field: InstanceField[Pre] => - fieldSucc.ref(field) - }, - ), - Inhale( - FunctionInvocation[Post]( - typeOf.ref(()), - Seq(Local(target)), - Nil, - Nil, - Nil, - )(PanicBlame("typeOf requires nothing.")) === const(typeNumber(cls)) - ), - )) + cls match { + case cls: ByReferenceClass[Pre] => + Block(Seq( + SilverNewRef[Post]( + target, + cls.decls.collect { case field: InstanceField[Pre] => + byRefFieldSucc.ref(field) + }, + ), + Inhale( + FunctionInvocation[Post]( + typeOf.ref(()), + Seq(Local(target)), + Nil, + Nil, + Nil, + )(PanicBlame("typeOf requires nothing.")) === const(typeNumber(cls)) + ), + )) + case cls: ByValueClass[Pre] => + val (assigns, vars) = + cls.decls.collect { case field: InstanceField[Pre] => + val element = field.t.asPointer.get.element + val newE = dispatch(element) + val v = new Variable[Post](TNonNullPointer(newE)) + ( + InvokeProcedure[Post]( + pointerCreationMethods + .getOrElseUpdate(element, makePointerCreationMethod(newE)) + .ref, + Nil, + Seq(v.get), + Nil, + Nil, + Nil, + )(TrueSatisfiable), + v, + ) + }.unzip + Scope( + vars, + Block( + assigns ++ Seq( + Assign( + Local(target), + adtFunctionInvocation[Post]( + byValConsSucc.ref(cls), + args = vars.map(_.get), + ), + )(AssignLocalOk) + // TODO: Add back typeOf here (but use a separate definition for the adt) + ) + ), + ) + } } override def dispatch(stat: Statement[Pre]): Statement[Post] = @@ -423,9 +603,17 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ))(inv.o) case ThisObject(_) => diz.top case deref @ Deref(obj, Ref(field)) => - SilverDeref[Post](dispatch(obj), fieldSucc.ref(field))(deref.blame)( - deref.o - ) + obj.t match { + case _: TByReferenceClass[Pre] => + SilverDeref[Post](dispatch(obj), byRefFieldSucc.ref(field))( + deref.blame + )(deref.o) + case _: TByValueClass[Pre] => + adtFunctionInvocation[Post]( + byValFieldSucc.ref(field), + args = Seq(dispatch(obj)), + )(deref.o) + } case TypeValue(t) => t match { case t: TClass[Pre] if t.typeArgs.isEmpty => @@ -493,7 +681,12 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(t: Type[Pre]): Type[Post] = t match { - case _: TClass[Pre] => TRef() + case _: TByReferenceClass[Pre] => TRef() + case t: TByValueClass[Pre] => + TAxiomatic( + byValClassSucc.ref(t.cls.decl.asInstanceOf[ByValueClass[Pre]]), + Nil, + ) case TAnyClass() => TRef() case t => rewriteDefault(t) } @@ -505,10 +698,21 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { predicateSucc.ref(predicate.decl), dispatch(obj) +: args.map(dispatch), )(loc.o) - case FieldLocation(obj, field) => - SilverFieldLocation[Post](dispatch(obj), fieldSucc.ref(field.decl))( - loc.o - ) + case FieldLocation(obj, Ref(field)) => + obj.t match { + case _: TByReferenceClass[Pre] => + SilverFieldLocation[Post](dispatch(obj), byRefFieldSucc.ref(field))( + loc.o + ) + case _: TByValueClass[Pre] => + PointerLocation[Post]( + adtFunctionInvocation[Post]( + byValFieldSucc.ref(field), + None, + args = Seq(dispatch(obj)), + )(loc.o) + )(NonNullPointerNull)(loc.o) + } case default => rewriteDefault(default) } } diff --git a/src/rewrite/vct/rewrite/DisambiguateLocation.scala b/src/rewrite/vct/rewrite/DisambiguateLocation.scala index 8209136550..e013f80915 100644 --- a/src/rewrite/vct/rewrite/DisambiguateLocation.scala +++ b/src/rewrite/vct/rewrite/DisambiguateLocation.scala @@ -45,6 +45,8 @@ case class DisambiguateLocation[Pre <: Generation]() extends Rewriter[Pre] { ArrayLocation(dispatch(arr), dispatch(index))(expr.blame) case expr if expr.t.asPointer.isDefined => PointerLocation(dispatch(expr))(blame) + case expr if expr.t.isInstanceOf[TByValueClass[Pre]] => + ByValueClassLocation(dispatch(expr))(blame) case PredicateApply(ref, args, WritePerm()) => PredicateLocation(succ(ref.decl), (args.map(dispatch))) case InstancePredicateApply(obj, ref, args, WritePerm()) => @@ -53,7 +55,6 @@ case class DisambiguateLocation[Pre <: Generation]() extends Rewriter[Pre] { dispatch(obj), args.map(dispatch), ) - case InlinePattern(inner, pattern, group) => InLinePatternLocation( exprToLoc(inner, blame), diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index 7c1732761f..a134a63a58 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -60,13 +60,15 @@ case object EncodeArrayValues extends RewriterBuilder { } } - case class PointerArrayCreationFailed(arr: NewPointerArray[_]) - extends Blame[InvocationFailure] { + case class PointerArrayCreationFailed( + arr: Expr[_], + blame: Blame[ArraySizeError], + ) extends Blame[InvocationFailure] { override def blame(error: InvocationFailure): Unit = error match { - case PreconditionFailed(_, _, _) => arr.blame.blame(ArraySize(arr)) + case PreconditionFailed(_, _, _) => blame.blame(ArraySize(arr)) case ContextEverywhereFailedInPre(_, _) => - arr.blame.blame(ArraySize(arr)) // Unnecessary? + blame.blame(ArraySize(arr)) // Unnecessary? case other => throw Unreachable(s"Invalid invocation failure: $other") } } @@ -106,6 +108,8 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { val pointerArrayCreationMethods: mutable.Map[Type[Pre], Procedure[Post]] = mutable.Map() + val nonNullPointerArrayCreationMethods + : mutable.Map[Type[Pre], Procedure[Post]] = mutable.Map() val freeMethods: mutable.Map[Type[ Post @@ -189,7 +193,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { // If structure contains structs, the permission for those fields need to be released as well val permFields = t match { - case t: TClass[Post] => unwrapStructPerm(access, t, o, makeStruct) +// case t: TClass[Post] => unwrapStructPerm(access, t, o, makeStruct) case _ => Seq() } requiresT = @@ -421,6 +425,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { ) // We do not allow this notation for recursive structs implicit val o: Origin = origin + // TODO: Instead of doing complicated stuff here just generate a Perm(struct.field, write) and rely on EncodyByValueClass to deal with it :) val fields = structType match { case t: TClass[Post] => @@ -431,7 +436,10 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { } val newFieldPerms = fields.map(member => { val loc = - (i: Variable[Post]) => Deref[Post](struct(i), member.ref)(DerefPerm) + (i: Variable[Post]) => + DerefPointer(Deref[Post](struct(i), member.ref)(DerefPerm))( + NonNullPointerNull + ) var anns: Seq[(Expr[Post], Expr[Pre] => PointerFreeError)] = Seq(( makeStruct.makePerm( i => FieldLocation[Post](struct(i), member.ref), @@ -444,7 +452,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { ), )) anns = - if (typeIsRef(member.t)) + if (typeIsRef(member.t.asPointer.get.element)) anns :+ ( makeStruct.makeUnique(loc), @@ -452,7 +460,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { ) else anns - member.t match { + member.t.asPointer.get.element match { case newStruct: TClass[Post] => // We recurse, since a field is another struct anns ++ unwrapStructPerm( @@ -504,7 +512,10 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { case _ => false } - def makePointerCreationMethodFor(elementType: Type[Pre]) = { + def makePointerCreationMethodFor( + elementType: Type[Pre], + nullable: Boolean, + ) = { implicit val o: Origin = arrayCreationOrigin // ar != null // ar.length == dim0 @@ -529,9 +540,11 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { Seq(access(i), access(j)), ) - var ensures = (result !== Null()) &* + var ensures = (PointerBlockLength(result)(FramedPtrBlockLength) === sizeArg.get) &* - (PointerBlockOffset(result)(FramedPtrBlockOffset) === zero) + (PointerBlockOffset(result)(FramedPtrBlockOffset) === zero) + + if (nullable) { ensures = (result !== Null()) &* ensures } // Pointer location needs pointer add, not pointer subscript ensures = ensures &* makeStruct.makePerm( @@ -561,7 +574,9 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { procedure( blame = AbstractApplicable, contractBlame = TrueSatisfiable, - returnType = TPointer(dispatch(elementType)), + returnType = + if (nullable) { TPointer(dispatch(elementType)) } + else { TNonNullPointer(dispatch(elementType)) }, args = Seq(sizeArg), requires = UnitAccountedPredicate(requires), ensures = UnitAccountedPredicate(ensures), @@ -597,8 +612,23 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { Nil, )(ArrayCreationFailed(newArr)) case newPointerArr @ NewPointerArray(element, size) => - val method = pointerArrayCreationMethods - .getOrElseUpdate(element, makePointerCreationMethodFor(element)) + val method = pointerArrayCreationMethods.getOrElseUpdate( + element, + makePointerCreationMethodFor(element, nullable = true), + ) + ProcedureInvocation[Post]( + method.ref, + Seq(dispatch(size)), + Nil, + Nil, + Nil, + Nil, + )(PointerArrayCreationFailed(newPointerArr, newPointerArr.blame)) + case newPointerArr @ NewNonNullPointerArray(element, size) => + val method = nonNullPointerArrayCreationMethods.getOrElseUpdate( + element, + makePointerCreationMethodFor(element, nullable = false), + ) ProcedureInvocation[Post]( method.ref, Seq(dispatch(size)), @@ -606,7 +636,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { Nil, Nil, Nil, - )(PointerArrayCreationFailed(newPointerArr)) + )(PointerArrayCreationFailed(newPointerArr, newPointerArr.blame)) case free @ FreePointer(xs) => val newXs = dispatch(xs) val TPointer(t) = newXs.t diff --git a/src/rewrite/vct/rewrite/EncodeByValueClass.scala b/src/rewrite/vct/rewrite/EncodeByValueClass.scala deleted file mode 100644 index 3054e2572c..0000000000 --- a/src/rewrite/vct/rewrite/EncodeByValueClass.scala +++ /dev/null @@ -1,249 +0,0 @@ -package vct.rewrite - -import hre.util.ScopedStack -import vct.col.ast._ -import vct.col.origin._ -import vct.col.ref.Ref -import vct.col.resolve.ctx.Referrable -import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} -import vct.col.util.AstBuildHelpers._ -import vct.result.VerificationError.UserError - -case object EncodeByValueClass extends RewriterBuilder { - override def key: String = "encodeByValueClass" - - override def desc: String = - "Initialise ByValueClasses when they are declared and copy them whenever they're read" - - private case class ClassCopyInAssignmentFailed( - blame: Blame[AssignFailed], - assign: Node[_], - clazz: ByValueClass[_], - field: InstanceField[_], - ) extends Blame[InsufficientPermission] { - override def blame(error: InsufficientPermission): Unit = { - if (blame.isInstanceOf[PanicBlame]) { - assign.o - .blame(CopyClassFailed(assign, clazz, Referrable.originName(field))) - } else { - blame - .blame(CopyClassFailed(assign, clazz, Referrable.originName(field))) - } - } - } - - private case class ClassCopyInCallFailed( - blame: Blame[InvocationFailure], - inv: Invocation[_], - clazz: ByValueClass[_], - field: InstanceField[_], - ) extends Blame[InsufficientPermission] { - override def blame(error: InsufficientPermission): Unit = { - blame.blame( - CopyClassFailedBeforeCall(inv, clazz, Referrable.originName(field)) - ) - } - } - - case class UnsupportedStructPerm(o: Origin) extends UserError { - override def code: String = "unsupportedStructPerm" - override def text: String = - o.messageInContext( - "Shorthand for Permissions for structs not possible, since the struct has a cyclic reference" - ) - } - - private sealed class CopyContext - - private case class InCall(invocation: Invocation[_]) extends CopyContext - - private case class InAssignmentExpression(assignment: AssignExpression[_]) - extends CopyContext - - private case class InAssignmentStatement(assignment: Assign[_]) - extends CopyContext -} - -case class EncodeByValueClass[Pre <: Generation]() extends Rewriter[Pre] { - - import EncodeByValueClass._ - - private val inAssignment: ScopedStack[Unit] = ScopedStack() - private val copyContext: ScopedStack[CopyContext] = ScopedStack() - - override def dispatch(node: Statement[Pre]): Statement[Post] = - node match { - case s: Scope[Pre] => - cPPLocalDeclarations.scope { - cLocalDeclarations.scope { - variables.scope { - localHeapVariables.scope { - val locals = variables.dispatch(s.locals) - Scope( - locals, - Block(locals.collect { - case v: Variable[Post] - if v.t.isInstanceOf[TByValueClass[Post]] => - Assign( - v.get(v.o), - NewObject(v.t.asInstanceOf[TByValueClass[Post]].cls)(v.o), - )(PanicBlame( - "Instantiating a ByValueClass should always succeed" - ))(v.o) - } ++ Seq(s.body.rewriteDefault()))(node.o), - )(node.o) - } - } - } - } - case assign: Assign[Pre] => { - val target = inAssignment.having(()) { assign.target.rewriteDefault() } - copyContext.having(InAssignmentStatement(assign)) { - assign.rewrite(target = target) - } - } - case _ => node.rewriteDefault() - } - - private def copyClassValue( - obj: Expr[Post], - t: TByValueClass[Pre], - blame: InstanceField[Pre] => Blame[InsufficientPermission], - ): Expr[Post] = { - implicit val o: Origin = obj.o - val v = new Variable[Post](dispatch(t)) - val children = t.cls.decl.decls.collect { case f: InstanceField[Pre] => - f.t match { - case inner: TByValueClass[Pre] => - Assign[Post]( - Deref[Post](v.get, succ(f))(DerefAssignTarget), - copyClassValue(Deref[Post](obj, succ(f))(blame(f)), inner, blame), - )(AssignLocalOk) - case _ => - Assign[Post]( - Deref[Post](v.get, succ(f))(DerefAssignTarget), - Deref[Post](obj, succ(f))(blame(f)), - )(AssignLocalOk) - - } - } - ScopedExpr( - Seq(v), - Then( - PreAssignExpression(v.get, NewObject[Post](succ(t.cls.decl)))( - AssignLocalOk - ), - Block(children), - ), - ) - } - - // def unwrapClassPerm( - // struct: Expr[Post], - // perm: Expr[Pre], - // structType: TByValueClass[Pre], - // origin: Origin, - // visited: Seq[TByValueClass[Pre]] = Seq(), - // ): Expr[Post] = { - // if (visited.contains(structType)) - // throw UnsupportedStructPerm( - // origin - // ) // We do not allow this notation for recursive structs - // implicit val o: Origin = origin - // val blame = PanicBlame("Field permission is framed") - // val Seq(CStructDeclaration(_, fields)) = structType.ref.decl.decl.specs - // val newPerm = dispatch(perm) - // val AmbiguousLocation(newExpr) = struct - // val newFieldPerms = fields.map(member => { - // val loc = - // AmbiguousLocation( - // Deref[Post]( - // newExpr, - // cStructFieldsSuccessor.ref((structType.ref.decl, member)), - // )(blame) - // )(struct.blame) - // member.specs.collectFirst { - // case CSpecificationType(newStruct: CTStruct[Pre]) => - // // We recurse, since a field is another struct - // Perm(loc, newPerm) &* unwrapStructPerm( - // loc, - // perm, - // newStruct, - // origin, - // structType +: visited, - // ) - // }.getOrElse(Perm(loc, newPerm)) - // }) - - // foldStar(newFieldPerms) - // } - override def dispatch(node: Expr[Pre]): Expr[Post] = - if (inAssignment.nonEmpty) - node.rewriteDefault() - else - node match { - case Perm(loc, p) => node.rewriteDefault() - case assign: PreAssignExpression[Pre] => - val target = - inAssignment.having(()) { assign.target.rewriteDefault() } - copyContext.having(InAssignmentExpression(assign)) { - assign.rewrite(target = target) - } - case invocation: Invocation[Pre] => { - copyContext.having(InCall(invocation)) { invocation.rewriteDefault() } - } - case Local(Ref(v)) if v.t.isInstanceOf[TByValueClass[Pre]] => - if (copyContext.isEmpty) { - return node.rewriteDefault() - } // If we are in other kinds of expressions like if statements - val t = v.t.asInstanceOf[TByValueClass[Pre]] - val clazz = t.cls.decl.asInstanceOf[ByValueClass[Pre]] - - copyContext.top match { - case InCall(invocation) => - copyClassValue( - node.rewriteDefault(), - t, - f => - ClassCopyInCallFailed(invocation.blame, invocation, clazz, f), - ) - case InAssignmentExpression(assignment: PreAssignExpression[_]) => - copyClassValue( - node.rewriteDefault(), - t, - f => - ClassCopyInAssignmentFailed( - assignment.blame, - assignment, - clazz, - f, - ), - ) - case InAssignmentExpression(assignment: PostAssignExpression[_]) => - copyClassValue( - node.rewriteDefault(), - t, - f => - ClassCopyInAssignmentFailed( - assignment.blame, - assignment, - clazz, - f, - ), - ) - case InAssignmentStatement(assignment) => - copyClassValue( - node.rewriteDefault(), - t, - f => - ClassCopyInAssignmentFailed( - assignment.blame, - assignment, - clazz, - f, - ), - ) - } - case _ => node.rewriteDefault() - } -} diff --git a/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala b/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala new file mode 100644 index 0000000000..9e3fdad2a7 --- /dev/null +++ b/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala @@ -0,0 +1,111 @@ +package vct.rewrite + +import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} +import vct.col.ast.{Variable, _} +import vct.col.origin.{AssignLocalOk, LabelContext, Origin, PanicBlame} +import vct.col.util.AstBuildHelpers._ +import vct.col.ref.Ref +import vct.col.util.{CurrentRewriteProgramContext, SuccessionMap} +import vct.result.VerificationError + +case object LowerLocalHeapVariables extends RewriterBuilder { + override def key: String = "lowerLocalHeapVariables" + + override def desc: String = + "Lower LocalHeapVariables to Variables if their address is never taken" + + private val pointerCreationOrigin: Origin = Origin( + Seq(LabelContext("pointer creation method")) + ) +} + +case class LowerLocalHeapVariables[Pre <: Generation]() extends Rewriter[Pre] { + import LowerLocalHeapVariables._ + + private val stripped: SuccessionMap[LocalHeapVariable[Pre], Variable[Post]] = + SuccessionMap() + private val lowered: SuccessionMap[LocalHeapVariable[Pre], Variable[Post]] = + SuccessionMap() +// private val pointerCreationMethods: SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap() +// +// def makePointerCreationMethod(t: Type[Pre]): Procedure[Post] = { +// implicit val o: Origin = pointerCreationOrigin +// +// val proc = globalDeclarations.declare(withResult((result: Result[Post]) => { +// +// })) +// } + + override def dispatch(program: Program[Pre]): Program[Post] = { + val dereferencedHeapLocals = program.collect { + case DerefPointer(hl @ HeapLocal(_)) => System.identityHashCode(hl) + } + val nakedHeapLocals = program.collect { + case hl @ HeapLocal(Ref(v)) + if !dereferencedHeapLocals.contains(System.identityHashCode(hl)) => + v + } + VerificationError.withContext(CurrentRewriteProgramContext(program)) { + localHeapVariables.scope { + variables.scope { + enumConstants.scope { + modelDeclarations.scope { + aDTDeclarations.scope { + classDeclarations.scope { + globalDeclarations.scope { + program.collect { + case HeapLocal(Ref(v)) if !nakedHeapLocals.contains(v) => + v + }.foreach(v => + stripped(v) = + new Variable[Post](dispatch(v.t.asPointer.get.element))( + v.o + ) + ) + Program(globalDeclarations.dispatch(program.declarations))( + dispatch(program.blame) + )(program.o) + } + } + } + } + } + } + } + } + } + + override def dispatch(node: Statement[Pre]): Statement[Post] = { + implicit val o: Origin = node.o + node match { + // Same logic as CollectLocalDeclarations + case Scope(vars, impl) => + val (newVars, newImpl) = variables.collect { + vars.foreach(dispatch) + dispatch(impl) + } + Scope(newVars, newImpl) + case HeapLocalDecl(v) => + if (stripped.contains(v)) { variables.declare(stripped(v)) } + else { + lowered(v) = new Variable[Post](dispatch(v.t))(v.o) + variables.declare(lowered(v)) + } + Block(Nil) + case _ => node.rewriteDefault() + } + } + + override def dispatch(node: Expr[Pre]): Expr[Post] = { + implicit val o: Origin = node.o + node match { + case DerefPointer(HeapLocal(Ref(v))) if stripped.contains(v) => + stripped(v).get + case HeapLocal(Ref(v)) if lowered.contains(v) => { + // lowered.contains(v) should always be true since all stripped HeapLocals would be caught by DerefPointer(HeapLocal(Ref(v))) + Local(lowered.ref(v)) + } + case _ => node.rewriteDefault() + } + } +} diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala new file mode 100644 index 0000000000..8067a5ec3b --- /dev/null +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -0,0 +1,279 @@ +package vct.rewrite + +import hre.util.ScopedStack +import vct.col.ast._ +import vct.col.origin._ +import vct.col.ref.Ref +import vct.col.resolve.ctx.Referrable +import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} +import vct.col.util.AstBuildHelpers._ +import vct.result.VerificationError.{Unreachable, UserError} + +// TODO: Think of a better name +case object PrepareByValueClass extends RewriterBuilder { + override def key: String = "prepareByValueClass" + + override def desc: String = + "Initialise ByValueClasses when they are declared and copy them whenever they're read" + + private case class ClassCopyInAssignmentFailed( + blame: Blame[PointerDerefError], + assign: Node[_], + clazz: ByValueClass[_], + field: InstanceField[_], + ) extends Blame[InsufficientPermission] { + override def blame(error: InsufficientPermission): Unit = { + if (blame.isInstanceOf[PanicBlame]) { + assign.o + .blame(CopyClassFailed(assign, clazz, Referrable.originName(field))) + } else { + blame + .blame(CopyClassFailed(assign, clazz, Referrable.originName(field))) + } + } + } + + private case class ClassCopyInCallFailed( + blame: Blame[PointerDerefError], + inv: Invocation[_], + clazz: ByValueClass[_], + field: InstanceField[_], + ) extends Blame[InsufficientPermission] { + override def blame(error: InsufficientPermission): Unit = { + blame.blame( + CopyClassFailedBeforeCall(inv, clazz, Referrable.originName(field)) + ) + } + } + + case class UnsupportedStructPerm(o: Origin) extends UserError { + override def code: String = "unsupportedStructPerm" + override def text: String = + o.messageInContext( + "Shorthand for Permissions for structs not possible, since the struct has a cyclic reference" + ) + } + + private sealed class CopyContext + + private case class InCall(invocation: Invocation[_]) extends CopyContext + + private case class InAssignmentExpression(assignment: AssignExpression[_]) + extends CopyContext + + private case class InAssignmentStatement(assignment: Assign[_]) + extends CopyContext + + case class PointerLocationDerefBlame(blame: Blame[PointerLocationError]) + extends Blame[PointerDerefError] { + override def blame(error: PointerDerefError): Unit = { + error match { + case error: PointerLocationError => blame.blame(error) + case _ => + Unreachable( + "Blame of the respective pointer operation should be used not of DerefPointer" + ) + } + } + } +} + +case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { + + import PrepareByValueClass._ + + private val inAssignment: ScopedStack[Unit] = ScopedStack() + private val copyContext: ScopedStack[CopyContext] = ScopedStack() + + override def dispatch(node: Statement[Pre]): Statement[Post] = { + implicit val o: Origin = node.o + node match { + case HeapLocalDecl(local) + if local.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => { + val newLocal = localHeapVariables.dispatch(local) + val t = newLocal.t.asPointer.get.element + Block(Seq( + HeapLocalDecl(newLocal), + Assign( + HeapLocal[Post](newLocal.ref), + NewNonNullPointerArray(t, const(1))(PanicBlame("Size > 0")), + )(AssignLocalOk), + Assign( + newLocal.get(DerefAssignTarget), + NewObject(t.asInstanceOf[TByValueClass[Post]].cls), + )(AssignLocalOk), + )) + } + case assign: Assign[Pre] => { + val target = inAssignment.having(()) { dispatch(assign.target) } + copyContext.having(InAssignmentStatement(assign)) { + assign.rewrite(target = target) + } + } + case _ => node.rewriteDefault() + } + } + + private def copyClassValue( + obj: Expr[Post], + t: TByValueClass[Pre], + blame: InstanceField[Pre] => Blame[InsufficientPermission], + ): Expr[Post] = { + implicit val o: Origin = obj.o + val ov = new Variable[Post](obj.t) + val v = + new Variable[Post](dispatch(t))(o.withContent(TypeName("HelloWorld"))) + val children = t.cls.decl.decls.collect { case f: InstanceField[Pre] => + f.t match { + case inner: TByValueClass[Pre] => + Assign[Post]( + DerefPointer(Deref[Post](v.get, succ(f))(DerefAssignTarget))( + NonNullPointerNull + ), + copyClassValue(Deref[Post](ov.get, succ(f))(blame(f)), inner, blame), + )(AssignLocalOk) + case _ => + Assign[Post]( + DerefPointer(Deref[Post](v.get, succ(f))(DerefAssignTarget))( + NonNullPointerNull + ), + DerefPointer(Deref[Post](ov.get, succ(f))(blame(f)))( + NonNullPointerNull + ), + )(AssignLocalOk) + + } + } + ScopedExpr( + Seq(ov, v), + Then( + With( + assignLocal(ov.get, obj), + PreAssignExpression(v.get, NewObject[Post](succ(t.cls.decl)))( + AssignLocalOk + ), + ), + Block(children), + ), + ) + } + + private def unwrapClassPerm( + obj: Expr[Post], + perm: Expr[Post], + structType: TByValueClass[Pre], + visited: Seq[TByValueClass[Pre]] = Seq(), + ): Expr[Post] = { + if (visited.contains(structType)) + throw UnsupportedStructPerm( + obj.o + ) // We do not allow this notation for recursive structs + implicit val o: Origin = obj.o + val blame = PanicBlame("Field permission is framed") + val fields = structType.cls.decl.decls.collect { + case f: InstanceField[Pre] => f + } + val newFieldPerms = fields.map(member => { + val loc = FieldLocation[Post](obj, succ(member)) + member.t.asPointer.get.element match { + case inner: TByValueClass[Pre] => + Perm[Post](loc, perm) &* unwrapClassPerm( + DerefPointer(Deref[Post](obj, succ(member))(blame))( + NonNullPointerNull + ), + perm, + inner, + structType +: visited, + ) + case _ => Perm(loc, perm) + } + }) + + foldStar(newFieldPerms) + } + + override def dispatch(node: Expr[Pre]): Expr[Post] = { + implicit val o: Origin = node.o + if (inAssignment.nonEmpty) + node.rewriteDefault() + else + node match { + case Perm(ByValueClassLocation(e), p) => + unwrapClassPerm( + dispatch(e), + dispatch(p), + e.t.asInstanceOf[TByValueClass[Pre]], + ) + // What if I get rid of this... +// case Perm(loc@PointerLocation(e), p) if e.t.asPointer.exists(t => t.element.isInstanceOf[TByValueClass[Pre]])=> +// unwrapClassPerm(DerefPointer(dispatch(e))(PointerLocationDerefBlame(loc.blame))(loc.o), dispatch(p), e.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]]) + case assign: PreAssignExpression[Pre] => + val target = inAssignment.having(()) { dispatch(assign.target) } + copyContext.having(InAssignmentExpression(assign)) { + assign.rewrite(target = target) + } + case invocation: Invocation[Pre] => { + copyContext.having(InCall(invocation)) { invocation.rewriteDefault() } + } + // WHOOPSIE WE ALSO MAKE A COPY IF IT WAS A POINTER + case dp @ DerefPointer(HeapLocal(Ref(v))) + if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => + rewriteInCopyContext( + dp, + v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], + ) + case dp @ DerefPointer(DerefHeapVariable(Ref(v))) + if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => + rewriteInCopyContext( + dp, + v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], + ) + case dp @ DerefPointer(Deref(_, Ref(f))) + if f.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => + rewriteInCopyContext( + dp, + f.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], + ) + case dp @ DerefPointer(Local(Ref(v))) + if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => + // This can happen if the user specifies a local of type pointer to TByValueClass + rewriteInCopyContext( + dp, + v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], + ) + case _ => node.rewriteDefault() + } + } + + private def rewriteInCopyContext( + dp: DerefPointer[Pre], + t: TByValueClass[Pre], + ): Expr[Post] = { + if (copyContext.isEmpty) { + // If we are in other kinds of expressions like if statements + return dp.rewriteDefault() + } + val clazz = t.cls.decl.asInstanceOf[ByValueClass[Pre]] + + copyContext.top match { + case InCall(invocation) => + copyClassValue( + dp.rewriteDefault(), + t, + f => ClassCopyInCallFailed(dp.blame, invocation, clazz, f), + ) + case InAssignmentExpression(assignment) => + copyClassValue( + dp.rewriteDefault(), + t, + f => ClassCopyInAssignmentFailed(dp.blame, assignment, clazz, f), + ) + case InAssignmentStatement(assignment) => + copyClassValue( + dp.rewriteDefault(), + t, + f => ClassCopyInAssignmentFailed(dp.blame, assignment, clazz, f), + ) + } + } +} diff --git a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala index ba090ae05e..8fca4ff97a 100644 --- a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala +++ b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala @@ -340,6 +340,7 @@ case class ResolveExpressionSideEffects[Pre <: Generation]() ), ) case decl: LocalDecl[Pre] => rewriteDefault(decl) + case decl: HeapLocalDecl[Pre] => decl.rewriteDefault() case Return(result) => frame( result, @@ -532,6 +533,7 @@ case class ResolveExpressionSideEffects[Pre <: Generation]() val result = target match { case Local(Ref(v)) => Local[Post](succ(v))(target.o) + case HeapLocal(Ref(v)) => HeapLocal[Post](succ(v))(target.o) case deref @ DerefHeapVariable(Ref(v)) => DerefHeapVariable[Post](succ(v))(deref.blame)(target.o) case Deref(obj, Ref(f)) => diff --git a/src/rewrite/vct/rewrite/TrivialAddrOf.scala b/src/rewrite/vct/rewrite/TrivialAddrOf.scala index edc400f193..63b5396c0d 100644 --- a/src/rewrite/vct/rewrite/TrivialAddrOf.scala +++ b/src/rewrite/vct/rewrite/TrivialAddrOf.scala @@ -39,7 +39,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { case AddrOf(other) => throw UnsupportedLocation(other) case assign @ PreAssignExpression(target, AddrOf(value)) - if value.t.isInstanceOf[TClass[Pre]] => + if value.t.isInstanceOf[TByReferenceClass[Pre]] => implicit val o: Origin = assign.o val (newPointer, newTarget, newValue) = rewriteAssign( target, @@ -61,7 +61,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(s: Statement[Pre]): Statement[Post] = s match { case assign @ Assign(target, AddrOf(value)) - if value.t.isInstanceOf[TClass[Pre]] => + if value.t.isInstanceOf[TByReferenceClass[Pre]] => implicit val o: Origin = assign.o val (newPointer, newTarget, newValue) = rewriteAssign( target, diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala new file mode 100644 index 0000000000..ab399c63a5 --- /dev/null +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -0,0 +1,221 @@ +package vct.rewrite + +import vct.col.ast._ +import vct.col.ref._ +import vct.col.origin._ +import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder, Rewritten} +import vct.col.util.AstBuildHelpers._ +import vct.col.util.SuccessionMap +import vct.result.VerificationError.UserError + +import scala.collection.mutable + +case object VariableToPointer extends RewriterBuilder { + override def key: String = "variableToPointer" + + override def desc: String = + "Translate every local and field to a pointer such that it can have its address taken" + + case class UnsupportedAddrOf(loc: Expr[_]) extends UserError { + override def code: String = "unsupportedAddrOf" + + override def text: String = + loc.o.messageInContext( + "Taking an address of this expression is not supported" + ) + } +} + +case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { + + import VariableToPointer._ + + val addressedSet: mutable.Set[Node[Pre]] = new mutable.HashSet[Node[Pre]]() + val heapVariableMap: SuccessionMap[HeapVariable[Pre], HeapVariable[Post]] = + SuccessionMap() + val variableMap: SuccessionMap[Variable[Pre], Variable[Post]] = + SuccessionMap() + val fieldMap: SuccessionMap[InstanceField[Pre], InstanceField[Post]] = + SuccessionMap() + + override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] = { + // TODO: Replace the isInstanceOf[TByReferenceClass] checks with something that more clearly communicates that we want to exclude all reference types + addressedSet.addAll(program.collect { + case AddrOf(Local(Ref(v))) if !v.t.isInstanceOf[TByReferenceClass[Pre]] => + v + case AddrOf(DerefHeapVariable(Ref(v))) + if !v.t.isInstanceOf[TByReferenceClass[Pre]] => + v + case AddrOf(Deref(_, Ref(f))) + if !f.t.isInstanceOf[TByReferenceClass[Pre]] => + f + }) + super.dispatch(program) + } + + override def dispatch(decl: Declaration[Pre]): Unit = + decl match { + // TODO: Use some sort of NonNull pointer type instead + case v: HeapVariable[Pre] if addressedSet.contains(v) => + heapVariableMap(v) = globalDeclarations + .succeed(v, new HeapVariable(TPointer(dispatch(v.t)))(v.o)) + case v: Variable[Pre] if addressedSet.contains(v) => + variableMap(v) = variables + .succeed(v, new Variable(TPointer(dispatch(v.t)))(v.o)) + case f: InstanceField[Pre] if addressedSet.contains(f) => + fieldMap(f) = classDeclarations.succeed( + f, + new InstanceField( + TPointer(dispatch(f.t)), + f.flags.map { it => dispatch(it) }, + )(f.o), + ) + case other => allScopes.anySucceed(other, other.rewriteDefault()) + } + + override def dispatch(stat: Statement[Pre]): Statement[Post] = { + implicit val o: Origin = stat.o + stat match { + case s: Scope[Pre] => + s.rewrite( + locals = variables.dispatch(s.locals), + body = Block(s.locals.filter { local => addressedSet.contains(local) } + .map { local => + implicit val o: Origin = local.o + Assign( + Local[Post](variableMap.ref(local)), + NewPointerArray( + variableMap(local).t.asPointer.get.element, + const(1), + )(PanicBlame("Size is > 0")), + )(PanicBlame("Initialisation should always succeed")) + } ++ Seq(dispatch(s.body))), + ) + case i @ Instantiate(cls, out) => + // TODO: Make sure that we recursively build newobject for byvalueclasses + // maybe get rid this entirely and only have it in encode by value class + Block(Seq(i.rewriteDefault()) ++ cls.decl.declarations.flatMap { + case f: InstanceField[Pre] => + if (f.t.asClass.isDefined) { + Seq( + Assign( + Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame( + "Initialisation should always succeed" + )), + NewPointerArray( + fieldMap(f).t.asPointer.get.element, + const(1), + )(PanicBlame("Size is > 0")), + )(PanicBlame("Initialisation should always succeed")), + Assign( + PointerSubscript( + Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame( + "Initialisation should always succeed" + )), + const[Post](0), + )(PanicBlame("Size is > 0")), + dispatch(NewObject[Pre](f.t.asClass.get.cls)), + )(PanicBlame("Initialisation should always succeed")), + ) + } else if (addressedSet.contains(f)) { + Seq( + Assign( + Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame( + "Initialisation should always succeed" + )), + NewPointerArray( + fieldMap(f).t.asPointer.get.element, + const(1), + )(PanicBlame("Size is > 0")), + )(PanicBlame("Initialisation should always succeed")) + ) + } else { Seq() } + case _ => Seq() + }) + case other => other.rewriteDefault() + } + } + + override def dispatch(expr: Expr[Pre]): Expr[Post] = { + implicit val o: Origin = expr.o + expr match { + case deref @ DerefHeapVariable(Ref(v)) if addressedSet.contains(v) => + DerefPointer( + DerefHeapVariable[Post](heapVariableMap.ref(v))(deref.blame) + )(PanicBlame("Should always be accessible")) + case Local(Ref(v)) if addressedSet.contains(v) => + DerefPointer(Local[Post](variableMap.ref(v)))(PanicBlame( + "Should always be accessible" + )) + case deref @ Deref(obj, Ref(f)) if addressedSet.contains(f) => + DerefPointer(Deref[Post](dispatch(obj), fieldMap.ref(f))(deref.blame))( + PanicBlame("Should always be accessible") + ) + case newObject @ NewObject(Ref(cls)) => + val obj = new Variable[Post](TByReferenceClass(succ(cls), Seq())) + ScopedExpr( + Seq(obj), + With( + Block( + Seq(assignLocal(obj.get, newObject.rewriteDefault())) ++ + cls.declarations.flatMap { + case f: InstanceField[Pre] => + if (f.t.asClass.isDefined) { + Seq( + Assign( + Deref[Post](obj.get, anySucc(f))(PanicBlame( + "Initialisation should always succeed" + )), + dispatch(NewObject[Pre](f.t.asClass.get.cls)), + )(PanicBlame("Initialisation should always succeed")) + ) + } else if (addressedSet.contains(f)) { + Seq( + Assign( + Deref[Post](obj.get, fieldMap.ref(f))(PanicBlame( + "Initialisation should always succeed" + )), + NewPointerArray( + fieldMap(f).t.asPointer.get.element, + const(1), + )(PanicBlame("Size is > 0")), + )(PanicBlame("Initialisation should always succeed")) + ) + } else { Seq() } + case _ => Seq() + } + ), + obj.get, + ), + ) + case other => other.rewriteDefault() + } + } + + override def dispatch(loc: Location[Pre]): Location[Post] = { + implicit val o: Origin = loc.o + loc match { + case HeapVariableLocation(Ref(v)) if addressedSet.contains(v) => + PointerLocation( + DerefHeapVariable[Post](heapVariableMap.ref(v))(PanicBlame( + "Should always be accessible" + )) + )(PanicBlame("Should always be accessible")) + case FieldLocation(obj, Ref(f)) if addressedSet.contains(f) => + PointerLocation(Deref[Post](dispatch(obj), fieldMap.ref(f))(PanicBlame( + "Should always be accessible" + )))(PanicBlame("Should always be accessible")) + case PointerLocation( + AddrOf(Deref(obj, Ref(f))) + ) /* if addressedSet.contains(f) always true */ => + FieldLocation[Post](dispatch(obj), fieldMap.ref(f)) + case PointerLocation( + AddrOf(DerefHeapVariable(Ref(v))) + ) /* if addressedSet.contains(v) always true */ => + HeapVariableLocation[Post](heapVariableMap.ref(v)) + case PointerLocation(AddrOf(local @ Local(_))) => + throw UnsupportedAddrOf(local) + case other => other.rewriteDefault() + } + } +} diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index 2e81faf964..09f85ec8eb 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -87,17 +87,30 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) ).ref } + private def unwrapOption( + ptr: Expr[Pre], + blame: Blame[PointerNull], + ): Expr[Post] = { + ptr.t match { + case TPointer(_) => + OptGet(dispatch(ptr))(PointerNullOptNone(blame, ptr))(ptr.o) + case TNonNullPointer(_) => dispatch(ptr) + } + } + override def applyCoercion(e: => Expr[Post], coercion: Coercion[Pre])( implicit o: Origin ): Expr[Post] = coercion match { case CoerceNullPointer(_) => OptNone() + case CoerceNonNullPointer(_) => OptSome(e) case other => super.applyCoercion(e, other) } override def postCoerce(t: Type[Pre]): Type[Post] = t match { case TPointer(_) => TOption(TAxiomatic(pointerAdt.ref, Nil)) + case TNonNullPointer(_) => TAxiomatic(pointerAdt.ref, Nil) case other => rewriteDefault(other) } @@ -108,11 +121,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) obj = FunctionInvocation[Post]( ref = pointerDeref.ref, - args = Seq( - OptGet(dispatch(pointer))( - PointerNullOptNone(loc.blame, pointer) - )(pointer.o) - ), + args = Seq(unwrapOption(pointer, loc.blame)), typeArgs = Nil, Nil, Nil, @@ -133,12 +142,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) args = Seq( FunctionInvocation[Post]( ref = pointerAdd.ref, - args = Seq( - OptGet(dispatch(pointer))( - PointerNullOptNone(sub.blame, pointer) - ), - dispatch(index), - ), + args = Seq(unwrapOption(pointer, sub.blame), dispatch(index)), typeArgs = Nil, Nil, Nil, @@ -151,46 +155,51 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) field = getPointerField(pointer), )(PointerFieldInsufficientPermission(sub.blame, sub)) case add @ PointerAdd(pointer, offset) => - OptSome( + val inv = FunctionInvocation[Post]( ref = pointerAdd.ref, - args = Seq( - OptGet(dispatch(pointer))(PointerNullOptNone(add.blame, pointer)), - dispatch(offset), - ), + args = Seq(unwrapOption(pointer, add.blame), dispatch(offset)), typeArgs = Nil, Nil, Nil, )(NoContext(PointerBoundsPreconditionFailed(add.blame, pointer))) - ) + pointer.t match { + case TPointer(_) => OptSome(inv) + case TNonNullPointer(_) => inv + } case deref @ DerefPointer(pointer) => - SilverDeref( - obj = - FunctionInvocation[Post]( - ref = pointerDeref.ref, - args = Seq( - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq( - OptGet(dispatch(pointer))( - PointerNullOptNone(deref.blame, pointer) - ), - const(0), - ), - typeArgs = Nil, - Nil, - Nil, - )(NoContext( - DerefPointerBoundsPreconditionFailed(deref.blame, pointer) - )) - ), - typeArgs = Nil, - Nil, - Nil, - )(PanicBlame("ptr_deref requires nothing.")), - field = getPointerField(pointer), - )(PointerFieldInsufficientPermission(deref.blame, deref)) + if (pointer.o.find[TypeName].isDefined) { + FunctionInvocation[Post]( + ref = pointerDeref.ref, + args = Seq(unwrapOption(pointer, deref.blame)), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame("ptr_deref requires nothing.")) + } else { + SilverDeref( + obj = + FunctionInvocation[Post]( + ref = pointerDeref.ref, + args = Seq( + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(unwrapOption(pointer, deref.blame), const(0)), + typeArgs = Nil, + Nil, + Nil, + )(NoContext( + DerefPointerBoundsPreconditionFailed(deref.blame, pointer) + )) + ), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame("ptr_deref requires nothing.")), + field = getPointerField(pointer), + )(PointerFieldInsufficientPermission(deref.blame, deref)) + } case len @ PointerBlockLength(pointer) => ADTFunctionInvocation[Post]( typeArgs = Some((blockAdt.ref, Nil)), @@ -198,18 +207,14 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) args = Seq(ADTFunctionInvocation[Post]( typeArgs = Some((pointerAdt.ref, Nil)), ref = pointerBlock.ref, - args = Seq( - OptGet(dispatch(pointer))(PointerNullOptNone(len.blame, pointer)) - ), + args = Seq(unwrapOption(pointer, len.blame)), )), ) case off @ PointerBlockOffset(pointer) => ADTFunctionInvocation[Post]( typeArgs = Some((pointerAdt.ref, Nil)), ref = pointerOffset.ref, - args = Seq( - OptGet(dispatch(pointer))(PointerNullOptNone(off.blame, pointer)) - ), + args = Seq(unwrapOption(pointer, off.blame)), ) case pointerLen @ PointerLength(pointer) => postCoerce( diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index f4ca25352c..955e57c7bd 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -254,7 +254,8 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) SuccessionMap() val cNameSuccessor: SuccessionMap[CNameTarget[Pre], Variable[Post]] = SuccessionMap() - val cLocalHeapNameSuccessor: SuccessionMap[CNameTarget[Pre], LocalHeapVariable[Post]] = + val cLocalHeapNameSuccessor + : SuccessionMap[CNameTarget[Pre], LocalHeapVariable[Post]] = SuccessionMap() val cGlobalNameSuccessor : SuccessionMap[CNameTarget[Pre], HeapVariable[Post]] = SuccessionMap() @@ -991,10 +992,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) Seq(x), ) = fieldDecl fieldDecl.drop() - val t = - specs.collectFirst { case t: CSpecificationType[Pre] => - rw.dispatch(t.t) - }.get + val t = TNonNullPointer(specs.collectFirst { + case t: CSpecificationType[Pre] => rw.dispatch(t.t) + }.get) cStructFieldsSuccessor((decl, fieldDecl)) = new InstanceField(t = t, flags = Nil)(CStructFieldOrigin(x)) rw.classDeclarations @@ -1147,11 +1147,20 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val targetClass: Class[Post] = cStructSuccessor(ref.decl) val t = TByValueClass[Post](targetClass.ref, Seq()) - val v = new LocalHeapVariable[Post](t)(o.sourceName(info.name)) + val v = + new LocalHeapVariable[Post](TNonNullPointer(t))(o.sourceName(info.name)) cLocalHeapNameSuccessor(RefCLocalDeclaration(decl, 0)) = v if (init.init.isDefined) { - Block(Seq(HeapLocalDecl(v), assignHeapLocal(v.get, rw.dispatch(init.init.get)))) + Block(Seq( + HeapLocalDecl(v), + Assign( + v.get(PanicBlame( + "Dereferencing freshly declared struct should never fail" + )), + rw.dispatch(init.init.get), + )(AssignLocalOk), + )) } else { HeapLocalDecl(v) } } @@ -1315,7 +1324,11 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))(local.blame) case Some(_) => throw NotAValue(local) } - case ref: RefCLocalDeclaration[Pre] if cLocalHeapNameSuccessor.contains(ref) => HeapLocal(cLocalHeapNameSuccessor.ref(ref)) + case ref: RefCLocalDeclaration[Pre] + if cLocalHeapNameSuccessor.contains(ref) => + DerefPointer(HeapLocal[Post](cLocalHeapNameSuccessor.ref(ref)))( + local.blame + ) case ref: RefCLocalDeclaration[Pre] => Local(cNameSuccessor.ref(ref)) case _: RefCudaVec[Pre] => throw NotAValue(local) } @@ -1365,10 +1378,12 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case _: TNotAValue[Pre] => throw TypeUsedAsValue(deref.obj) case _ => ??? } - Deref[Post]( - rw.dispatch(deref.obj), - cStructFieldsSuccessor.ref((struct_ref.decl, struct.decls)), - )(deref.blame) + DerefPointer( + Deref[Post]( + rw.dispatch(deref.obj), + cStructFieldsSuccessor.ref((struct_ref.decl, struct.decls)), + )(deref.blame) + )(NonNullPointerNull) } } @@ -1388,10 +1403,12 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case CTPointer(CTStruct(struct)) => struct case t => throw WrongStructType(t) } - Deref[Post]( - DerefPointer(rw.dispatch(deref.struct))(b), - cStructFieldsSuccessor.ref((structRef.decl, struct.decls)), - )(deref.blame)(deref.o) + DerefPointer( + Deref[Post]( + DerefPointer(rw.dispatch(deref.struct))(b), + cStructFieldsSuccessor.ref((structRef.decl, struct.decls)), + )(deref.blame)(deref.o) + )(NonNullPointerNull)(deref.o) } } @@ -1415,9 +1432,11 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val newFieldPerms = fields.map(member => { val loc = AmbiguousLocation( - Deref[Post]( - newExpr, - cStructFieldsSuccessor.ref((structType.ref.decl, member)), + DerefPointer( + Deref[Post]( + newExpr, + cStructFieldsSuccessor.ref((structType.ref.decl, member)), + )(blame) )(blame) )(struct.blame) member.specs.collectFirst { diff --git a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala index 488782e83b..b15ae0c8b9 100644 --- a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala @@ -332,17 +332,17 @@ case class LangSpecificToCol[Pre <: Generation]( case cast: CCast[Pre] => c.cast(cast) case sizeof: SizeOf[Pre] => throw LangCToCol.UnsupportedSizeof(sizeof) - case Perm(a @ AmbiguousLocation(expr), perm) - if c.getBaseType(expr.t).isInstanceOf[CTStruct[Pre]] => - c.getBaseType(expr.t) match { - case structType: CTStruct[Pre] => - c.unwrapStructPerm( - dispatch(a).asInstanceOf[AmbiguousLocation[Post]], - perm, - structType, - e.o, - ) - } +// case Perm(a @ AmbiguousLocation(expr), perm) +// if c.getBaseType(expr.t).isInstanceOf[CTStruct[Pre]] => +// c.getBaseType(expr.t) match { +// case structType: CTStruct[Pre] => +// c.unwrapStructPerm( +// dispatch(a).asInstanceOf[AmbiguousLocation[Post]], +// perm, +// structType, +// e.o, +// ) +// } case local: CPPLocal[Pre] => cpp.local(local) case deref: CPPClassMethodOrFieldAccess[Pre] => cpp.deref(deref) case inv: CPPInvocation[Pre] => cpp.invocation(inv) From 1666b3577c1eb1969243c89ff940caa28bf87959 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Mon, 17 Jun 2024 17:16:06 +0200 Subject: [PATCH 07/47] Add some axioms that speed up pointer verification --- build.sc | 12 +- res/universal/res/adt/pointer.pvl | 14 +- src/col/vct/col/ast/Node.scala | 4 + .../expr/heap/read/RawDerefPointerImpl.scala | 14 ++ .../vct/col/typerules/CoercingRewriter.scala | 2 + src/main/vct/main/stages/Transformation.scala | 3 +- src/rewrite/vct/rewrite/ClassToRef.scala | 68 ++++--- .../vct/rewrite/DisambiguateLocation.scala | 8 +- .../vct/rewrite/EncodeArrayValues.scala | 3 +- .../vct/rewrite/PrepareByValueClass.scala | 76 +++++++- .../vct/rewrite/adt/ImportPointer.scala | 173 ++++++++++++++---- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 21 ++- 12 files changed, 308 insertions(+), 90 deletions(-) create mode 100644 src/col/vct/col/ast/expr/heap/read/RawDerefPointerImpl.scala diff --git a/build.sc b/build.sc index 23a2c215fa..2e7c847df6 100644 --- a/build.sc +++ b/build.sc @@ -41,7 +41,7 @@ object external extends Module { object viper extends ScalaModule { object silverGit extends GitModule { def url = T { "https://github.com/viperproject/silver.git" } - def commitish = T { "31c94df4f9792046618d9b4db52444ffe9c7c988" } + def commitish = T { "9cd85c01c10f5f846ed2754d64de08fcc59207ee" } def filteredRepo = T { val workspace = repo() os.remove.all(workspace / "src" / "test") @@ -51,7 +51,7 @@ object viper extends ScalaModule { object siliconGit extends GitModule { def url = T { "https://github.com/viperproject/silicon.git" } - def commitish = T { "529d2a49108b954d2b0749356faf985d622f54f0" } + def commitish = T { "c4350bd33043a727a0a4f3008f39a4efc7748033" } def filteredRepo = T { val workspace = repo() os.remove.all(workspace / "src" / "test") @@ -61,7 +61,7 @@ object viper extends ScalaModule { object carbonGit extends GitModule { def url = T { "https://github.com/viperproject/carbon.git" } - def commitish = T { "d7ac8b000e1123a72cbdda0c7679ab88ca8a52d4" } + def commitish = T { "15d74246bb8baef1e3ea88dcc4861c891259a99d" } } object silver extends ScalaModule { @@ -79,6 +79,8 @@ object viper extends ScalaModule { ivy"commons-io:commons-io:2.8.0", ivy"com.google.guava:guava:29.0-jre", ivy"org.jgrapht:jgrapht-core:1.5.0", + ivy"com.lihaoyi::requests:0.3.0", + ivy"com.lihaoyi::upickle:1.0.0", ) } @@ -407,14 +409,14 @@ object vercors extends Module { ) override def moduleDeps = Seq(hre, col, serialize) - val includeVcllvmCross = interp.watchValue { + val includeVcllvmCross = interp.watchValue { if(os.exists(settings.root / ".include-vcllvm")) { Seq("vcllvm") } else { Seq.empty[String] } } - + object vcllvmDep extends Cross[VcllvmDep](includeVcllvmCross) trait VcllvmDep extends Cross.Module[String] { def path = T { diff --git a/res/universal/res/adt/pointer.pvl b/res/universal/res/adt/pointer.pvl index 743584d8b4..6c4da7e5db 100644 --- a/res/universal/res/adt/pointer.pvl +++ b/res/universal/res/adt/pointer.pvl @@ -16,6 +16,7 @@ adt `pointer` { pure `pointer` pointer_of(`block` b, int offset); pure `block` pointer_block(`pointer` p); pure int pointer_offset(`pointer` p); + pure `pointer` pointer_inv(ref r); // the block offset is valid wrt the length of the block axiom (∀ `pointer` p; @@ -26,6 +27,17 @@ adt `pointer` { axiom (∀`block` b, int offset; {:pointer_block(pointer_of(b, offset)):} == b && {:pointer_offset(pointer_of(b, offset)):} == offset); + + axiom (∀ ref r; ptr_deref({:pointer_inv(r):}) == r); + + axiom (∀ `pointer` p; pointer_inv({:ptr_deref(p):}) == p); + + axiom (∀ `pointer` p1, `pointer` p2, int offset; + (0 <= offset && offset < `block`.block_length(pointer_block(p1)) && + pointer_block(p1) == pointer_block(p2) && + {:pointer_of(pointer_block(p1), offset):} == + {:pointer_of(pointer_block(p2), offset):}) ==> p1 == p2 + ); } decreases; @@ -38,4 +50,4 @@ requires `pointer`.pointer_offset(p) + offset < `block`.block_length(`pointer`.p pure `pointer` ptr_add(`pointer` p, int offset) = `pointer`.pointer_of( `pointer`.pointer_block(p), - `pointer`.pointer_offset(p) + offset); \ No newline at end of file + `pointer`.pointer_offset(p) + offset); diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index c46d938b5b..10d52782e1 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -1400,6 +1400,10 @@ final case class DerefPointer[G](pointer: Expr[G])( val blame: Blame[PointerDerefError] )(implicit val o: Origin) extends Expr[G] with DerefPointerImpl[G] +final case class RawDerefPointer[G](pointer: Expr[G])( + val blame: Blame[PointerDerefError] +)(implicit val o: Origin) + extends Expr[G] with RawDerefPointerImpl[G] final case class PointerAdd[G](pointer: Expr[G], offset: Expr[G])( val blame: Blame[PointerAddError] )(implicit val o: Origin) diff --git a/src/col/vct/col/ast/expr/heap/read/RawDerefPointerImpl.scala b/src/col/vct/col/ast/expr/heap/read/RawDerefPointerImpl.scala new file mode 100644 index 0000000000..d270112e07 --- /dev/null +++ b/src/col/vct/col/ast/expr/heap/read/RawDerefPointerImpl.scala @@ -0,0 +1,14 @@ +package vct.col.ast.expr.heap.read + +import vct.col.ast.ops.RawDerefPointerOps +import vct.col.ast.{RawDerefPointer, TRef, Type} +import vct.col.print._ + +trait RawDerefPointerImpl[G] extends RawDerefPointerOps[G] { + this: RawDerefPointer[G] => + override def t: Type[G] = TRef() + + override def precedence: Int = Precedence.POSTFIX + override def layout(implicit ctx: Ctx): Doc = + Group(Text("ptr_deref(") <> pointer <> Text(")")) +} diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index 4c709d31cd..c26e6c9476 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -1245,6 +1245,8 @@ abstract class CoercingRewriter[Pre <: Generation]() case deref @ Deref(obj, ref) => Deref(cls(obj), ref)(deref.blame) case deref @ DerefHeapVariable(ref) => DerefHeapVariable(ref)(deref.blame) case deref @ DerefPointer(p) => DerefPointer(pointer(p)._1)(deref.blame) + case deref @ RawDerefPointer(p) => + RawDerefPointer(pointer(p)._1)(deref.blame) case Drop(xs, count) => Drop(seq(xs)._1, int(count)) case Empty(obj) => Empty(sized(obj)._1) case EmptyProcess() => EmptyProcess() diff --git a/src/main/vct/main/stages/Transformation.scala b/src/main/vct/main/stages/Transformation.scala index f75ed061ee..91644dede1 100644 --- a/src/main/vct/main/stages/Transformation.scala +++ b/src/main/vct/main/stages/Transformation.scala @@ -338,7 +338,6 @@ case class SilverTransformation( QuantifySubscriptAny, // no arr[*] IterationContractToParBlock, PropagateContextEverywhere, // inline context_everywhere into loop invariants - PrepareByValueClass, EncodeArrayValues, // maybe don't target shift lemmas on generated function for \values GivenYieldsToArgs, CheckProcessAlgebra, @@ -378,7 +377,7 @@ case class SilverTransformation( // flatten out functions in the rhs of assignments, making it harder to detect final field assignments where the // value is pure and therefore be put in the contract of the constant function. ConstantifyFinalFields, - + PrepareByValueClass, // Resolve side effects including method invocations, for encodetrythrowsignals. ResolveExpressionSideChecks, ResolveExpressionSideEffects, diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index a3c29fbe64..bebe3d263e 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -364,48 +364,46 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { } }, )) - val nonNullAxiom = - new ADTAxiom[Post](forall( - axiomType, - body = - v => { - foldAnd(fieldFunctions.map { f => - adtFunctionInvocation[Post]( - f.ref, - None, - args = Seq(v), - ) !== Null() - }) - }, + val injectivityAxiom1 = + new ADTAxiom[Post](foralls( + Seq(axiomType, axiomType), + body = { case Seq(a0, a1) => + foldAnd(fieldFunctions.combinations(2).map { + case Seq(f0, f1) => + Neq( + adtFunctionInvocation[Post](f0.ref, args = Seq(a0)), + adtFunctionInvocation[Post](f1.ref, args = Seq(a1)), + ) + }.toSeq) + }, + triggers = { case Seq(a0, a1) => + fieldFunctions.combinations(2).map { case Seq(f0, f1) => + Seq( + adtFunctionInvocation[Post](f0.ref, None, args = Seq(a0)), + adtFunctionInvocation[Post](f1.ref, None, args = Seq(a1)), + ) + }.toSeq + }, )) - // TODO: need Non null pointer.... - val injectivityAxiom = + val injectivityAxiom2 = new ADTAxiom[Post](foralls( Seq(axiomType, axiomType), body = { case Seq(a0, a1) => - (a0 !== a1) ==> foldAnd(fieldFunctions.map { f => - DerefPointer( - adtFunctionInvocation[Post](f.ref, args = Seq(a0)) - )(NonNullPointerNull)( - o.withContent(TypeName("helloWorld")) - ) !== DerefPointer( - adtFunctionInvocation[Post](f.ref, args = Seq(a1)) - )(NonNullPointerNull)(o.withContent(TypeName("helloWorld"))) + foldAnd(fieldFunctions.map { f => + Implies( + Eq( + adtFunctionInvocation[Post](f.ref, args = Seq(a0)), + adtFunctionInvocation[Post](f.ref, args = Seq(a1)), + ), + a0 === a1, + ) }) }, triggers = { case Seq(a0, a1) => fieldFunctions.map { f => Seq( - DerefPointer( - adtFunctionInvocation[Post](f.ref, None, args = Seq(a0)) - )(NonNullPointerNull)(o.withContent(TypeName( - "helloWorld" - ))), - DerefPointer( - adtFunctionInvocation[Post](f.ref, None, args = Seq(a1)) - )(NonNullPointerNull)(o.withContent(TypeName( - "helloWorld" - ))), + adtFunctionInvocation[Post](f.ref, None, args = Seq(a0)), + adtFunctionInvocation[Post](f.ref, None, args = Seq(a1)), ) } }, @@ -416,8 +414,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { Seq( constructor, destructorAxiom, -// nonNullAxiom, - injectivityAxiom, + injectivityAxiom1, + injectivityAxiom2, ) ++ fieldFunctions, Nil, ) diff --git a/src/rewrite/vct/rewrite/DisambiguateLocation.scala b/src/rewrite/vct/rewrite/DisambiguateLocation.scala index e013f80915..17e0368d9f 100644 --- a/src/rewrite/vct/rewrite/DisambiguateLocation.scala +++ b/src/rewrite/vct/rewrite/DisambiguateLocation.scala @@ -36,6 +36,10 @@ case class DisambiguateLocation[Pre <: Generation]() extends Rewriter[Pre] { implicit o: Origin ): Location[Post] = expr match { + case expr if expr.t.asPointer.isDefined => + PointerLocation(dispatch(expr))(blame) + case expr if expr.t.isInstanceOf[TByValueClass[Pre]] => + ByValueClassLocation(dispatch(expr))(blame) case DerefHeapVariable(ref) => HeapVariableLocation(succ(ref.decl)) case Deref(obj, ref) => FieldLocation(dispatch(obj), succ(ref.decl)) case ModelDeref(obj, ref) => ModelLocation(dispatch(obj), succ(ref.decl)) @@ -43,10 +47,6 @@ case class DisambiguateLocation[Pre <: Generation]() extends Rewriter[Pre] { SilverFieldLocation(dispatch(obj), succ(ref.decl)) case expr @ ArraySubscript(arr, index) => ArrayLocation(dispatch(arr), dispatch(index))(expr.blame) - case expr if expr.t.asPointer.isDefined => - PointerLocation(dispatch(expr))(blame) - case expr if expr.t.isInstanceOf[TByValueClass[Pre]] => - ByValueClassLocation(dispatch(expr))(blame) case PredicateApply(ref, args, WritePerm()) => PredicateLocation(succ(ref.decl), (args.map(dispatch))) case InstancePredicateApply(obj, ref, args, WritePerm()) => diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index a134a63a58..a3d7757469 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -500,8 +500,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { val zero = const[Post](0) val pre1 = zero <= i.get && i.get < size val pre2 = zero <= j.get && j.get < size - val body = - (pre1 && pre2 && (i.get !== j.get)) ==> (access(i) !== access(j)) + val body = (pre1 && pre2 && access(i) === access(j)) ==> (i.get === j.get) Forall(Seq(i, j), Seq(triggerUnique), body) } } diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala index 8067a5ec3b..133ad89ec3 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -94,10 +94,11 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { val t = newLocal.t.asPointer.get.element Block(Seq( HeapLocalDecl(newLocal), - Assign( - HeapLocal[Post](newLocal.ref), - NewNonNullPointerArray(t, const(1))(PanicBlame("Size > 0")), - )(AssignLocalOk), +// Assign( +// HeapLocal[Post](newLocal.ref), +// NewNonNullPointerArray(t, const(1))(PanicBlame("Size > 0")), +// )(AssignLocalOk), + // TODO: Only do this if the first use does not overwrite it again (do something similar to what I implemented in ImportPointer).... Assign( newLocal.get(DerefAssignTarget), NewObject(t.asInstanceOf[TByValueClass[Post]].cls), @@ -175,6 +176,7 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { } val newFieldPerms = fields.map(member => { val loc = FieldLocation[Post](obj, succ(member)) + // TODO: Don't go through regular pointers... member.t.asPointer.get.element match { case inner: TByValueClass[Pre] => Perm[Post](loc, perm) &* unwrapClassPerm( @@ -192,8 +194,63 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { foldStar(newFieldPerms) } + private def unwrapClassComp( + comp: (Expr[Post], Expr[Post]) => Expr[Post], + left: Expr[Post], + right: Expr[Post], + structType: TByValueClass[Pre], + visited: Seq[TByValueClass[Pre]] = Nil, + )(implicit o: Origin): Expr[Post] = { + // TODO: Better error + if (visited.contains(structType)) + throw UnsupportedStructPerm(o) + + val blame = PanicBlame("Struct deref can never fail") + val fields = structType.cls.decl.decls.collect { + case f: InstanceField[Pre] => f + } + foldAnd(fields.map(member => { + val l = + RawDerefPointer(Deref[Post](left, succ(member))(blame))( + NonNullPointerNull + ) + val r = + RawDerefPointer(Deref[Post](right, succ(member))(blame))( + NonNullPointerNull + ) + member.t match { +// case p: TNonNullPointer[Pre] if p.element.isInstanceOf[TByValueClass[Pre]] => +// unwrapClassComp(comp, DerefPointer(l)(NonNullPointerNull), r, p.element.asInstanceOf[TByValueClass[Pre]], structType +: visited) + case _ => comp(l, r) + } + })) + } + override def dispatch(node: Expr[Pre]): Expr[Post] = { implicit val o: Origin = node.o + node match { + case Eq(left, right) + if left.t == right.t && left.t.isInstanceOf[TByValueClass[Pre]] => + val newLeft = dispatch(left) + val newRight = dispatch(right) + return Eq(newLeft, newRight) && unwrapClassComp( + (l, r) => Eq(l, r), + newLeft, + newRight, + left.t.asInstanceOf[TByValueClass[Pre]], + ) + case Neq(left, right) + if left.t == right.t && left.t.isInstanceOf[TByValueClass[Pre]] => + val newLeft = dispatch(left) + val newRight = dispatch(right) + return Neq(newLeft, newRight) && unwrapClassComp( + (l, r) => Neq(l, r), + newLeft, + newRight, + left.t.asInstanceOf[TByValueClass[Pre]], + ) + case _ => {} + } if (inAssignment.nonEmpty) node.rewriteDefault() else @@ -204,6 +261,17 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { dispatch(p), e.t.asInstanceOf[TByValueClass[Pre]], ) + case Perm(pl @ PointerLocation(dhv @ DerefHeapVariable(Ref(v))), p) + if v.t.isInstanceOf[TNonNullPointer[Pre]] => + val t = v.t.asInstanceOf[TNonNullPointer[Pre]] + if (t.element.isInstanceOf[TByValueClass[Pre]]) { + val newV: Ref[Post, HeapVariable[Post]] = succ(v) + val newP = dispatch(p) + Perm(HeapVariableLocation(newV), newP) &* Perm( + PointerLocation(DerefHeapVariable(newV)(dhv.blame))(pl.blame), + newP, + ) + } else { node.rewriteDefault() } // What if I get rid of this... // case Perm(loc@PointerLocation(e), p) if e.t.asPointer.exists(t => t.element.isInstanceOf[TByValueClass[Pre]])=> // unwrapClassPerm(DerefPointer(dispatch(e))(PointerLocationDerefBlame(loc.blame))(loc.o), dispatch(p), e.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]]) diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index 09f85ec8eb..dfa103b0f1 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -5,7 +5,8 @@ import ImportADT.typeText import vct.col.origin._ import vct.col.ref.Ref import vct.col.rewrite.Generation -import vct.col.util.AstBuildHelpers.{ExprBuildHelpers, const} +import vct.col.util.AstBuildHelpers._ +import vct.col.util.SuccessionMap import scala.collection.mutable @@ -13,6 +14,10 @@ case object ImportPointer extends ImportADTBuilder("pointer") { private def PointerField(t: Type[_]): Origin = Origin(Seq(PreferredName(Seq(typeText(t))), LabelContext("pointer field"))) + private val PointerCreationOrigin: Origin = Origin( + Seq(LabelContext("classToRef, pointer creation method")) + ) + case class PointerNullOptNone(inner: Blame[PointerNull], expr: Expr[_]) extends Blame[OptionNone] { override def blame(error: OptionNone): Unit = inner.blame(PointerNull(expr)) @@ -77,6 +82,57 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) val pointerField: mutable.Map[Type[Post], SilverField[Post]] = mutable.Map() + private val pointerCreationMethods + : SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap() + + private def makePointerCreationMethod(t: Type[Post]): Procedure[Post] = { + implicit val o: Origin = PointerCreationOrigin + + val result = new Variable[Post](TAxiomatic(pointerAdt.ref, Nil)) + globalDeclarations.declare(procedure[Post]( + blame = AbstractApplicable, + contractBlame = TrueSatisfiable, + returnType = TVoid(), + outArgs = Seq(result), + ensures = UnitAccountedPredicate( + (ADTFunctionInvocation[Post]( + typeArgs = Some((blockAdt.ref, Nil)), + ref = blockLength.ref, + args = Seq(ADTFunctionInvocation[Post]( + typeArgs = Some((pointerAdt.ref, Nil)), + ref = pointerBlock.ref, + args = Seq(result.get), + )), + ) === const(1)) &* + (ADTFunctionInvocation[Post]( + typeArgs = Some((pointerAdt.ref, Nil)), + ref = pointerOffset.ref, + args = Seq(result.get), + ) === const(0)) &* Perm( + SilverFieldLocation( + obj = + FunctionInvocation[Post]( + ref = pointerDeref.ref, + args = Seq(result.get), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame("ptr_deref requires nothing.")), + field = + pointerField.getOrElseUpdate( + t, { + globalDeclarations + .declare(new SilverField(t)(PointerField(t))) + }, + ).ref, + ), + WritePerm(), + ) + ), + decreases = Some(DecreasesClauseNoRecursion[Post]()), + )) + } + private def getPointerField(ptr: Expr[Pre]): Ref[Post, SilverField[Post]] = { val tElement = dispatch(ptr.t.asPointer.get.element) pointerField.getOrElseUpdate( @@ -114,7 +170,8 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) case other => rewriteDefault(other) } - override def postCoerce(location: Location[Pre]): Location[Post] = + override def postCoerce(location: Location[Pre]): Location[Post] = { + implicit val o: Origin = location.o location match { case loc @ PointerLocation(pointer) => SilverFieldLocation( @@ -127,9 +184,46 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) Nil, )(PanicBlame("ptr_deref requires nothing."))(pointer.o), field = getPointerField(pointer), - )(loc.o) + ) case other => rewriteDefault(other) } + } + + override def postCoerce(s: Statement[Pre]): Statement[Post] = { + implicit val o: Origin = s.o + s match { + case scope: Scope[Pre] => + scope.rewrite(body = Block(scope.locals.collect { + case v if v.t.isInstanceOf[TNonNullPointer[Pre]] => { + val firstUse = scope.body.collectFirst { + case l @ Local(Ref(variable)) if variable == v => l + } + if ( + firstUse.isDefined && scope.body.collectFirst { + case Assign(l @ Local(Ref(variable)), _) if variable == v => + System.identityHashCode(l) != + System.identityHashCode(firstUse.get) + }.getOrElse(true) + ) { + val oldT = v.t.asInstanceOf[TNonNullPointer[Pre]].element + val newT = dispatch(oldT) + Seq( + InvokeProcedure[Post]( + pointerCreationMethods + .getOrElseUpdate(oldT, makePointerCreationMethod(newT)).ref, + Nil, + Seq(Local(succ(v))), + Nil, + Nil, + Nil, + )(TrueSatisfiable) + ) + } else { Nil } + } + }.flatten :+ dispatch(scope.body))) + case _ => s.rewriteDefault() + } + } override def postCoerce(e: Expr[Pre]): Expr[Post] = { implicit val o: Origin = e.o @@ -168,38 +262,47 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) case TNonNullPointer(_) => inv } case deref @ DerefPointer(pointer) => - if (pointer.o.find[TypeName].isDefined) { - FunctionInvocation[Post]( - ref = pointerDeref.ref, - args = Seq(unwrapOption(pointer, deref.blame)), - typeArgs = Nil, - Nil, - Nil, - )(PanicBlame("ptr_deref requires nothing.")) - } else { - SilverDeref( - obj = - FunctionInvocation[Post]( - ref = pointerDeref.ref, - args = Seq( - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq(unwrapOption(pointer, deref.blame), const(0)), - typeArgs = Nil, - Nil, - Nil, - )(NoContext( - DerefPointerBoundsPreconditionFailed(deref.blame, pointer) - )) - ), - typeArgs = Nil, - Nil, - Nil, - )(PanicBlame("ptr_deref requires nothing.")), - field = getPointerField(pointer), - )(PointerFieldInsufficientPermission(deref.blame, deref)) - } + SilverDeref( + obj = + FunctionInvocation[Post]( + ref = pointerDeref.ref, + args = Seq( + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(unwrapOption(pointer, deref.blame), const(0)), + typeArgs = Nil, + Nil, + Nil, + )(NoContext( + DerefPointerBoundsPreconditionFailed(deref.blame, pointer) + )) + ), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame("ptr_deref requires nothing.")), + field = getPointerField(pointer), + )(PointerFieldInsufficientPermission(deref.blame, deref)) + case deref @ RawDerefPointer(pointer) => + FunctionInvocation[Post]( + ref = pointerDeref.ref, + args = Seq( + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(unwrapOption(pointer, deref.blame), const(0)), + typeArgs = Nil, + Nil, + Nil, + )(NoContext( + DerefPointerBoundsPreconditionFailed(deref.blame, pointer) + )) + ), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame("ptr_deref requires nothing.")) case len @ PointerBlockLength(pointer) => ADTFunctionInvocation[Post]( typeArgs = Some((blockAdt.ref, Nil)), diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index 955e57c7bd..54abadbe31 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -1042,8 +1042,11 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) ) ) case None => + val newT = + if (t.isInstanceOf[TByValueClass[Post]]) { TNonNullPointer(t) } + else { t } cGlobalNameSuccessor(RefCGlobalDeclaration(decl, idx)) = rw - .globalDeclarations.declare(new HeapVariable(t)(init.o)) + .globalDeclarations.declare(new HeapVariable(newT)(init.o)) } } } @@ -1321,7 +1324,21 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case ref @ RefCGlobalDeclaration(decl, initIdx) => C.getDeclaratorInfo(decl.decl.inits(initIdx).decl).params match { case None => - DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))(local.blame) + val t = + decl.decl.specs.collectFirst { case t: CSpecificationType[Pre] => + t.t + }.get + if (t.isInstanceOf[CTStruct[Pre]]) { + DerefPointer[Post]( + DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))( + local.blame + ) + )(NonNullPointerNull) + } else { + DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))( + local.blame + ) + } case Some(_) => throw NotAValue(local) } case ref: RefCLocalDeclaration[Pre] From 0dfde7cdd3eda0101a56c98c2dfc90a8fd935830 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Tue, 18 Jun 2024 15:31:39 +0200 Subject: [PATCH 08/47] Do no copy in expressions which do not yield TByValueClass --- examples/concepts/c/structs.c | 8 +----- .../vct/rewrite/EncodeArrayValues.scala | 2 +- .../vct/rewrite/PrepareByValueClass.scala | 26 ++++++++++++++----- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/examples/concepts/c/structs.c b/examples/concepts/c/structs.c index 886ed073f5..960a6fa1ca 100644 --- a/examples/concepts/c/structs.c +++ b/examples/concepts/c/structs.c @@ -52,7 +52,6 @@ void alter_copy_struct(struct point p){ p.y = 0; } -// TODO: Should be auto-generated /*@ context Perm(p, 1\1); @*/ @@ -133,7 +132,6 @@ int main(){ struct point *pp; pp = &p; - /* //@ assert (pp[0] != NULL ); */ assert (pp != NULL ); p.x = 1; @@ -147,7 +145,7 @@ int main(){ alter_struct(pp); assert(pp->x == 0); assert(p.x == 0); - alter_struct_1(pp); //alter_struct_1(&p) is not supported yet + alter_struct_1(pp); assert(p.x == 1 && p.y == 1); struct point p1, p2, p3; @@ -164,10 +162,6 @@ int main(){ struct polygon pol, *ppols; ppols = &pol; pol.ps = ps; - //@ assert Perm(&ppols->ps[0], write); - //@ assert Perm(&ppols->ps[1], write); - //@ assert Perm(&ppols->ps[2], write); - //@ assert (\forall* int i; 0<=i && i<3; Perm(&ppols->ps[i], write)); int avr_pol = avr_x_pol(ppols, 3); // assert sum_seq(inp_to_seq(ppols->ps, 3)) == 6; assert(avr_pol == 2); diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index a3d7757469..52435b4a68 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -193,7 +193,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { // If structure contains structs, the permission for those fields need to be released as well val permFields = t match { -// case t: TClass[Post] => unwrapStructPerm(access, t, o, makeStruct) + case t: TClass[Post] => unwrapStructPerm(access, t, o, makeStruct) case _ => Seq() } requiresT = diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala index 133ad89ec3..06e1eac550 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -64,6 +64,8 @@ case object PrepareByValueClass extends RewriterBuilder { private case class InAssignmentStatement(assignment: Assign[_]) extends CopyContext + private case class NoCopy() extends CopyContext + case class PointerLocationDerefBlame(blame: Blame[PointerLocationError]) extends Blame[PointerDerefError] { override def blame(error: PointerDerefError): Unit = { @@ -107,9 +109,11 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { } case assign: Assign[Pre] => { val target = inAssignment.having(()) { dispatch(assign.target) } - copyContext.having(InAssignmentStatement(assign)) { - assign.rewrite(target = target) - } + if (assign.target.t.isInstanceOf[TByValueClass[Pre]]) { + copyContext.having(InAssignmentStatement(assign)) { + assign.rewrite(target = target) + } + } else { assign.rewrite(target = target) } } case _ => node.rewriteDefault() } @@ -277,11 +281,20 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { // unwrapClassPerm(DerefPointer(dispatch(e))(PointerLocationDerefBlame(loc.blame))(loc.o), dispatch(p), e.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]]) case assign: PreAssignExpression[Pre] => val target = inAssignment.having(()) { dispatch(assign.target) } - copyContext.having(InAssignmentExpression(assign)) { - assign.rewrite(target = target) + if (assign.target.t.isInstanceOf[TByValueClass[Pre]]) { + copyContext.having(InAssignmentExpression(assign)) { + assign.rewrite(target = target) + } + } else { + // No need for copy semantics in this context + copyContext.having(NoCopy()) { assign.rewrite(target = target) } } case invocation: Invocation[Pre] => { - copyContext.having(InCall(invocation)) { invocation.rewriteDefault() } + invocation.rewrite(args = invocation.args.map { a => + if (a.t.isInstanceOf[TByValueClass[Pre]]) { + copyContext.having(InCall(invocation)) { dispatch(a) } + } else { copyContext.having(NoCopy()) { dispatch(a) } } + }) } // WHOOPSIE WE ALSO MAKE A COPY IF IT WAS A POINTER case dp @ DerefPointer(HeapLocal(Ref(v))) @@ -342,6 +355,7 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { t, f => ClassCopyInAssignmentFailed(dp.blame, assignment, clazz, f), ) + case NoCopy() => dp.rewriteDefault() } } } From be64b272edd2174e3d6ee5431da4e88b2b04c635 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Tue, 18 Jun 2024 17:14:04 +0200 Subject: [PATCH 09/47] Enable use of methods on by-value classes --- src/rewrite/vct/rewrite/ClassToRef.scala | 11 +- .../vct/rewrite/lang/LangCPPToCol.scala | 218 +++++++++++------- 2 files changed, 139 insertions(+), 90 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index bebe3d263e..bdf51b9e52 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -91,7 +91,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { def makeTypeOf: Function[Post] = { implicit val o: Origin = TypeOfOrigin - val obj = new Variable[Post](TRef()) + val obj = new Variable[Post](TAnyValue()) withResult((result: Result[Post]) => function( blame = AbstractApplicable, @@ -175,10 +175,11 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ) typeNumber(cls) + val thisType = dispatch(cls.classType(Nil)) cls.decls.foreach { case function: InstanceFunction[Pre] => implicit val o: Origin = function.o - val thisVar = new Variable[Post](TRef())(This) + val thisVar = new Variable[Post](thisType)(This) diz.having(thisVar.get) { functionSucc(function) = globalDeclarations .declare(labelDecls.scope { @@ -213,7 +214,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { } case method: InstanceMethod[Pre] => implicit val o: Origin = method.o - val thisVar = new Variable[Post](TRef())(This) + val thisVar = new Variable[Post](thisType)(This) diz.having(thisVar.get) { methodSucc(method) = globalDeclarations.declare(labelDecls.scope { new Procedure( @@ -249,7 +250,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { } case cons: Constructor[Pre] => implicit val o: Origin = cons.o - val thisVar = new Variable[Post](TRef())(This) + val thisVar = new Variable[Post](thisType)(This) consSucc(cons) = globalDeclarations.declare(labelDecls.scope { new Procedure( returnType = TVoid(), @@ -293,7 +294,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { )) }) case predicate: InstancePredicate[Pre] => - val thisVar = new Variable[Post](TRef())(This) + val thisVar = new Variable[Post](thisType)(This) diz.having(thisVar.get(predicate.o)) { predicateSucc(predicate) = globalDeclarations.declare( new Predicate( diff --git a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala index debd341db5..d0708ec692 100644 --- a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala @@ -738,7 +738,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) implicit val o: Origin = f.o Star( Perm[Post](FieldLocation[Post](thisObj, f.ref), ReadPerm()), - Deref[Post](thisObj, f.ref)(new SYCLRangeDerefBlame(f)) >= c_const(0), + DerefPointer(Deref[Post](thisObj, f.ref)(new SYCLRangeDerefBlame(f)))( + NonNullPointerNull + ) >= c_const(0), ) }) } @@ -763,9 +765,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) implicit val o: Origin = params(i).o Star( Perm(FieldLocation[Post](result, rangeFields(i).ref), ReadPerm()), - Deref[Post](result, rangeFields(i).ref)(new SYCLRangeDerefBlame( - rangeFields(i) - )) === Local[Post](params(i).ref), + DerefPointer(Deref[Post](result, rangeFields(i).ref)( + new SYCLRangeDerefBlame(rangeFields(i)) + ))(NonNullPointerNull) === Local[Post](params(i).ref), ) }) } @@ -799,21 +801,23 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) implicit val o: Origin = params(i).o foldStar(Seq( Perm(FieldLocation[Post](result, rangeFields(i).ref), ReadPerm()), - Deref[Post](result, rangeFields(i).ref)(new SYCLRangeDerefBlame( - rangeFields(i) - )) === FloorDiv( + DerefPointer(Deref[Post](result, rangeFields(i).ref)( + new SYCLRangeDerefBlame(rangeFields(i)) + ))(NonNullPointerNull) === FloorDiv( Local[Post](params(i).ref), Local[Post](params(i + 1).ref), )(ImpossibleDivByZeroBlame()), Perm(FieldLocation[Post](result, rangeFields(i + 1).ref), ReadPerm()), - Deref[Post](result, rangeFields(i + 1).ref)(new SYCLRangeDerefBlame( - rangeFields(i + 1) - )) === Local[Post](params(i + 1).ref), - Deref[Post](result, rangeFields(i).ref)(new SYCLRangeDerefBlame( - rangeFields(i) - )) * Deref[Post](result, rangeFields(i + 1).ref)( + DerefPointer(Deref[Post](result, rangeFields(i + 1).ref)( new SYCLRangeDerefBlame(rangeFields(i + 1)) - ) === Local[Post](params(i).ref), + ))(NonNullPointerNull) === Local[Post](params(i + 1).ref), + DerefPointer(Deref[Post](result, rangeFields(i).ref)( + new SYCLRangeDerefBlame(rangeFields(i)) + ))(NonNullPointerNull) * DerefPointer( + Deref[Post](result, rangeFields(i + 1).ref)(new SYCLRangeDerefBlame( + rangeFields(i + 1) + )) + )(NonNullPointerNull) === Local[Post](params(i).ref), )) }) } @@ -1094,9 +1098,11 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case _: SYCLTAccessor[Pre] => // Referencing an accessor variable can only be done in kernels, otherwise an error will already have been thrown val accessor = syclAccessorSuccessor(ref) - Deref[Post](currentThis.get, accessor.instanceField.ref)( - SYCLAccessorFieldInsufficientReferencePermissionBlame(local) - ) + DerefPointer( + Deref[Post](currentThis.get, accessor.instanceField.ref)( + SYCLAccessorFieldInsufficientReferencePermissionBlame(local) + ) + )(NonNullPointerNull) case _: SYCLTLocalAccessor[Pre] if currentKernelType.get.isInstanceOf[BasicKernel] => throw SYCLNoLocalAccessorsInBasicKernel(local) @@ -1234,18 +1240,18 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) getGlobalWorkItemLinearId(inv) case "sycl::accessor::get_range" => classInstance match { - case Some(Deref(_, ref)) => + case Some(DerefPointer(Deref(_, ref))) => val accessor = syclAccessorSuccessor.values .find(acc => ref.decl.equals(acc.instanceField)).get LiteralSeq[Post]( TCInt(), accessor.rangeIndexFields.map(f => - Deref[Post](currentThis.get, f.ref)( + DerefPointer(Deref[Post](currentThis.get, f.ref)( SYCLAccessorRangeIndexFieldInsufficientReferencePermissionBlame( inv ) - ) + ))(NonNullPointerNull) ), ) case _ => throw NotApplicable(inv) @@ -1567,14 +1573,14 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val rangeFields: mutable.Buffer[InstanceField[Post]] = mutable.Buffer.empty range.indices.foreach(index => { implicit val o: Origin = range(index).o.where(name = s"range$index") - val instanceField = new InstanceField[Post](TCInt(), Nil) + val instanceField = new InstanceField[Post](TNonNullPointer(TCInt()), Nil) rangeFields.append(instanceField) val iterVar = createRangeIterVar( GlobalScope(), index, - Deref[Post](currentThis.get, instanceField.ref)(new SYCLRangeDerefBlame( - instanceField - )), + DerefPointer(Deref[Post](currentThis.get, instanceField.ref)( + new SYCLRangeDerefBlame(instanceField) + ))(NonNullPointerNull), ) currentDimensionIterVars(GlobalScope()).append(iterVar) }) @@ -1643,28 +1649,30 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) { implicit val o: Origin = kernelDimensions.o .where(name = s"group_range$index") - val groupInstanceField = new InstanceField[Post](TCInt(), Nil) + val groupInstanceField = + new InstanceField[Post](TNonNullPointer(TCInt()), Nil) rangeFields.append(groupInstanceField) val groupIterVar = createRangeIterVar( GroupScope(), index, - Deref[Post](currentThis.get, groupInstanceField.ref)( + DerefPointer(Deref[Post](currentThis.get, groupInstanceField.ref)( new SYCLRangeDerefBlame(groupInstanceField) - ), + ))(NonNullPointerNull), ) currentDimensionIterVars(GroupScope()).append(groupIterVar) } { implicit val o: Origin = localRange(index).o .where(name = s"local_range$index") - val localInstanceField = new InstanceField[Post](TCInt(), Nil) + val localInstanceField = + new InstanceField[Post](TNonNullPointer(TCInt()), Nil) rangeFields.append(localInstanceField) val localIterVar = createRangeIterVar( LocalScope(), index, - Deref[Post](currentThis.get, localInstanceField.ref)( + DerefPointer(Deref[Post](currentThis.get, localInstanceField.ref)( new SYCLRangeDerefBlame(localInstanceField) - ), + ))(NonNullPointerNull), ) currentDimensionIterVars(LocalScope()).append(localIterVar) } @@ -1856,10 +1864,13 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case None => // No accessor for buffer exist in the command group, so make fields and permissions val instanceField = - new InstanceField[Post](buffer.generatedVar.t, Nil)(accO) + new InstanceField[Post]( + TNonNullPointer(buffer.generatedVar.t), + Nil, + )(accO) val rangeIndexFields = Seq .range(0, buffer.range.dimensions.size).map(i => - new InstanceField[Post](TCInt(), Nil)( + new InstanceField[Post](TNonNullPointer(TCInt()), Nil)( dimO.where(name = s"${accName}_r$i") ) ) @@ -1931,7 +1942,11 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } val rangeIndexDerefs: Seq[Expr[Post]] = acc.rangeIndexFields.map(f => - Deref[Post](classObj, f.ref)(new SYCLAccessorDimensionDerefBlame(f))(f.o) + DerefPointer( + Deref[Post](classObj, f.ref)(new SYCLAccessorDimensionDerefBlame(f))( + f.o + ) + )(NonNullPointerNull)(f.o) ) ( @@ -1949,18 +1964,22 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) ReadPerm()(acc.instanceField.o), )(acc.instanceField.o), ValidArray( - Deref[Post](classObj, acc.instanceField.ref)( - new SYCLAccessorDerefBlame(acc.instanceField) - )(acc.instanceField.o), + DerefPointer( + Deref[Post](classObj, acc.instanceField.ref)( + new SYCLAccessorDerefBlame(acc.instanceField) + )(acc.instanceField.o) + )(NonNullPointerNull)(acc.instanceField.o), rangeIndexDerefs.reduce((e1, e2) => (e1 * e2)(acc.buffer.o)), )(acc.instanceField.o), ) )(acc.instanceField.o), Perm( ArrayLocation( - Deref[Post](classObj, acc.instanceField.ref)( - new SYCLAccessorDerefBlame(acc.instanceField) - )(acc.instanceField.o), + DerefPointer( + Deref[Post](classObj, acc.instanceField.ref)( + new SYCLAccessorDerefBlame(acc.instanceField) + )(acc.instanceField.o) + )(NonNullPointerNull)(acc.instanceField.o), Any()(PanicBlame( "The accessor field is not null as that was proven in the previous conditions." ))(acc.instanceField.o), @@ -2012,20 +2031,24 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) constructorPostConditions.append( foldStar[Post]( Eq[Post]( - Deref[Post](result, acc.instanceField.ref)( - new SYCLAccessorDerefBlame(acc.instanceField) - )(acc.instanceField.o), + DerefPointer( + Deref[Post](result, acc.instanceField.ref)( + new SYCLAccessorDerefBlame(acc.instanceField) + )(acc.instanceField.o) + )(NonNullPointerNull)(acc.instanceField.o), Local[Post](newConstructorAccessorArg.ref)( newConstructorAccessorArg.o ), )(newConstructorAccessorArg.o) +: Seq.range(0, acc.rangeIndexFields.size).map(i => Eq[Post]( - Deref[Post](result, acc.rangeIndexFields(i).ref)( - new SYCLAccessorDimensionDerefBlame( - acc.rangeIndexFields(i) - ) - )(acc.rangeIndexFields(i).o), + DerefPointer( + Deref[Post](result, acc.rangeIndexFields(i).ref)( + new SYCLAccessorDimensionDerefBlame( + acc.rangeIndexFields(i) + ) + )(acc.rangeIndexFields(i).o) + )(NonNullPointerNull)(acc.rangeIndexFields(i).o), Local[Post](newConstructorAccessorDimensionArgs(i).ref)( newConstructorAccessorDimensionArgs(i).o ), @@ -2483,12 +2506,14 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) CPP.unwrappedType(base.t) match { case SYCLTAccessor(_, 1, _) => ArraySubscript[Post]( - Deref[Post]( - currentThis.get, - syclAccessorSuccessor(base.ref.get).instanceField.ref, - )(new SYCLAccessorDerefBlame( - syclAccessorSuccessor(base.ref.get).instanceField - ))(sub.o), + DerefPointer( + Deref[Post]( + currentThis.get, + syclAccessorSuccessor(base.ref.get).instanceField.ref, + )(new SYCLAccessorDerefBlame( + syclAccessorSuccessor(base.ref.get).instanceField + ))(sub.o) + )(NonNullPointerNull)(sub.o), rw.dispatch(index), )(SYCLAccessorArraySubscriptErrorBlame(sub))(sub.o) case t: SYCLTAccessor[Pre] => @@ -2504,19 +2529,25 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val linearizeArgs = Seq( rw.dispatch(indexX), rw.dispatch(indexY), - Deref[Post](currentThis.get, accessor.rangeIndexFields(0).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(0)) - ), - Deref[Post](currentThis.get, accessor.rangeIndexFields(1).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(1)) - ), + DerefPointer( + Deref[Post](currentThis.get, accessor.rangeIndexFields(0).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(0)) + ) + )(NonNullPointerNull), + DerefPointer( + Deref[Post](currentThis.get, accessor.rangeIndexFields(1).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(1)) + ) + )(NonNullPointerNull), ) CPP.unwrappedType(base.t) match { case SYCLTAccessor(_, 2, _) => ArraySubscript[Post]( - Deref[Post](currentThis.get, accessor.instanceField.ref)( - new SYCLAccessorDerefBlame(accessor.instanceField) - ), + DerefPointer( + Deref[Post](currentThis.get, accessor.instanceField.ref)( + new SYCLAccessorDerefBlame(accessor.instanceField) + ) + )(NonNullPointerNull), syclHelperFunctions("sycl_:_:linearize_2")( linearizeArgs, SYCLAccessorArraySubscriptLinearizeInvocationBlame( @@ -2544,22 +2575,30 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) rw.dispatch(indexX), rw.dispatch(indexY), rw.dispatch(indexZ), - Deref[Post](currentThis.get, accessor.rangeIndexFields(0).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(0)) - ), - Deref[Post](currentThis.get, accessor.rangeIndexFields(1).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(1)) - ), - Deref[Post](currentThis.get, accessor.rangeIndexFields(2).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(2)) - ), + DerefPointer( + Deref[Post](currentThis.get, accessor.rangeIndexFields(0).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(0)) + ) + )(NonNullPointerNull), + DerefPointer( + Deref[Post](currentThis.get, accessor.rangeIndexFields(1).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(1)) + ) + )(NonNullPointerNull), + DerefPointer( + Deref[Post](currentThis.get, accessor.rangeIndexFields(2).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(2)) + ) + )(NonNullPointerNull), ) CPP.unwrappedType(base.t) match { case SYCLTAccessor(_, 3, _) => ArraySubscript[Post]( - Deref[Post](currentThis.get, accessor.instanceField.ref)( - new SYCLAccessorDerefBlame(accessor.instanceField) - ), + DerefPointer( + Deref[Post](currentThis.get, accessor.instanceField.ref)( + new SYCLAccessorDerefBlame(accessor.instanceField) + ) + )(NonNullPointerNull), syclHelperFunctions("sycl_:_:linearize_3")( linearizeArgs, SYCLAccessorArraySubscriptLinearizeInvocationBlame( @@ -2637,9 +2676,12 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) removeKernelClassInstancePermissions(right), ) - case ActionPerm(Deref(obj, _), _) if obj.equals(this.currentThis.get) => + case ActionPerm(DerefPointer(Deref(obj, _)), _) + if obj.equals(this.currentThis.get) => + tt + case ModelPerm(DerefPointer(Deref(obj, _)), _) + if obj.equals(this.currentThis.get) => tt - case ModelPerm(Deref(obj, _), _) if obj.equals(this.currentThis.get) => tt case Perm(FieldLocation(obj, _), _) if obj.equals(this.currentThis.get) => tt case PointsTo(FieldLocation(obj, _), _, _) @@ -2647,14 +2689,20 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) tt case Value(FieldLocation(obj, _)) if obj.equals(this.currentThis.get) => tt - case Perm(AmbiguousLocation(ArraySubscript(Deref(obj, _), _)), _) - if obj.equals(this.currentThis.get) => + case Perm( + AmbiguousLocation(ArraySubscript(DerefPointer(Deref(obj, _)), _)), + _, + ) if obj.equals(this.currentThis.get) => tt - case PointsTo(AmbiguousLocation(ArraySubscript(Deref(obj, _), _)), _, _) - if obj.equals(this.currentThis.get) => + case PointsTo( + AmbiguousLocation(ArraySubscript(DerefPointer(Deref(obj, _)), _)), + _, + _, + ) if obj.equals(this.currentThis.get) => tt - case Value(AmbiguousLocation(ArraySubscript(Deref(obj, _), _))) - if obj.equals(this.currentThis.get) => + case Value( + AmbiguousLocation(ArraySubscript(DerefPointer(Deref(obj, _)), _)) + ) if obj.equals(this.currentThis.get) => tt case Implies(left, right) => Implies( @@ -2685,9 +2733,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case SYCLAccessor(buffer, SYCLReadWriteAccess(), instanceField, _) => assignLocal( buffer.generatedVar.get, - Deref[Post](variable, instanceField.ref)(new SYCLAccessorDerefBlame( - instanceField - )), + DerefPointer(Deref[Post](variable, instanceField.ref)( + new SYCLAccessorDerefBlame(instanceField) + ))(NonNullPointerNull), ) })) } From aed3ba6c3295021cd0d24b13ad7d4d65a53dd958 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Wed, 19 Jun 2024 11:01:30 +0200 Subject: [PATCH 10/47] Set --useOldAxiomatization to test which test failures are because of me and which aren't --- src/viper/viper/api/backend/silicon/Silicon.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/src/viper/viper/api/backend/silicon/Silicon.scala b/src/viper/viper/api/backend/silicon/Silicon.scala index 0fd0ab0b6e..aa1244309f 100644 --- a/src/viper/viper/api/backend/silicon/Silicon.scala +++ b/src/viper/viper/api/backend/silicon/Silicon.scala @@ -97,6 +97,7 @@ case class Silicon( "--z3ConfigArgs", z3Config, "--ideModeAdvanced", + "--useOldAxiomatization", ) if (proverLogFile.isDefined) { From 53f19a8309d13882c2cd7495f0a95c0152fe898a Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 18 Jul 2024 14:40:29 +0200 Subject: [PATCH 11/47] Improve struct encoding, rewrite tests for new permission syntax --- examples/concepts/c/structs.c | 4 +- res/universal/res/adt/pointer.pvl | 7 - src/col/vct/col/typerules/CoercionUtils.scala | 3 + src/rewrite/vct/rewrite/ClassToRef.scala | 202 +++++++++++++----- .../vct/rewrite/PrepareByValueClass.scala | 103 ++++++--- .../vct/rewrite/adt/ImportOption.scala | 13 ++ .../vct/rewrite/adt/ImportPointer.scala | 60 +++++- .../viper/api/backend/SilverBackend.scala | 2 +- .../viper/api/backend/silicon/Silicon.scala | 1 - .../viper/api/transform/ColToSilver.scala | 14 +- .../vct/test/integration/examples/CSpec.scala | 14 +- 11 files changed, 320 insertions(+), 103 deletions(-) diff --git a/examples/concepts/c/structs.c b/examples/concepts/c/structs.c index 960a6fa1ca..4638fd1ac4 100644 --- a/examples/concepts/c/structs.c +++ b/examples/concepts/c/structs.c @@ -52,6 +52,7 @@ void alter_copy_struct(struct point p){ p.y = 0; } +// TODO: Should be auto-generated /*@ context Perm(p, 1\1); @*/ @@ -132,6 +133,7 @@ int main(){ struct point *pp; pp = &p; + /* //@ assert (pp[0] != NULL ); */ assert (pp != NULL ); p.x = 1; @@ -145,7 +147,7 @@ int main(){ alter_struct(pp); assert(pp->x == 0); assert(p.x == 0); - alter_struct_1(pp); + alter_struct_1(pp); //alter_struct_1(&p) is not supported yet assert(p.x == 1 && p.y == 1); struct point p1, p2, p3; diff --git a/res/universal/res/adt/pointer.pvl b/res/universal/res/adt/pointer.pvl index 6c4da7e5db..9cbbedb16b 100644 --- a/res/universal/res/adt/pointer.pvl +++ b/res/universal/res/adt/pointer.pvl @@ -31,13 +31,6 @@ adt `pointer` { axiom (∀ ref r; ptr_deref({:pointer_inv(r):}) == r); axiom (∀ `pointer` p; pointer_inv({:ptr_deref(p):}) == p); - - axiom (∀ `pointer` p1, `pointer` p2, int offset; - (0 <= offset && offset < `block`.block_length(pointer_block(p1)) && - pointer_block(p1) == pointer_block(p2) && - {:pointer_of(pointer_block(p1), offset):} == - {:pointer_of(pointer_block(p2), offset):}) ==> p1 == p2 - ); } decreases; diff --git a/src/col/vct/col/typerules/CoercionUtils.scala b/src/col/vct/col/typerules/CoercionUtils.scala index a71e109e01..5c16be6dd6 100644 --- a/src/col/vct/col/typerules/CoercionUtils.scala +++ b/src/col/vct/col/typerules/CoercionUtils.scala @@ -143,6 +143,9 @@ case object CoercionUtils { case (TNonNullPointer(innerType), TPointer(element)) if innerType == element => CoerceNonNullPointer(innerType) + case (TNonNullPointer(a), TNonNullPointer(b)) + if getAnyCoercion(a, b).isDefined => + CoerceIdentity(target) case ( TPointer(element), CTPointer(innerType), diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index bdf51b9e52..aa4c140937 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -321,7 +321,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { case cls: ByValueClass[Pre] => implicit val o: Origin = cls.o val axiomType = TAxiomatic[Post](byValClassSucc.ref(cls), Nil) - val (fieldFunctions, fieldTypes) = + val (fieldFunctions, fieldInverses, fieldTypes) = cls.decls.collect { case field: Field[Pre] => val newT = dispatch(field.t) byValFieldSucc(field) = @@ -329,27 +329,73 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { Seq(new Variable(axiomType)(field.o)), newT, )(field.o) - (byValFieldSucc(field), newT) - }.unzip + ( + byValFieldSucc(field), + new ADTFunction[Post]( + Seq(new Variable(newT)(field.o)), + axiomType, + )( + field.o.copy( + field.o.originContents + .filterNot(_.isInstanceOf[SourceName]) + ).where(name = + "inv_" + field.o.find[SourceName].map(_.name) + .getOrElse("unknown") + ) + ), + newT, + ) + }.unzip3 val constructor = - new ADTFunction[Post](fieldTypes.map(new Variable(_)), axiomType)( - cls.o + new ADTFunction[Post]( + fieldTypes.zipWithIndex.map { case (t, i) => + new Variable(t)(Origin(Seq( + PreferredName(Seq("p_" + i)), + LabelContext("classToRef"), + ))) + }, + axiomType, + )( + cls.o.copy( + cls.o.originContents.filterNot(_.isInstanceOf[SourceName]) + ).where(name = + "new_" + cls.o.find[SourceName].map(_.name) + .getOrElse("unknown") + ) + ) + // TAnyValue is a placeholder the pointer adt doesn't have type parameters + val indexFunction = + new ADTFunction[Post]( + Seq(new Variable(TNonNullPointer(TAnyValue()))(Origin( + Seq(PreferredName(Seq("pointer")), LabelContext("classToRef")) + ))), + TInt(), + )( + cls.o.copy( + cls.o.originContents.filterNot(_.isInstanceOf[SourceName]) + ).where(name = + "index_" + cls.o.find[SourceName].map(_.name) + .getOrElse("unknown") + ) ) val destructorAxiom = new ADTAxiom[Post](foralls( fieldTypes, body = variables => { - foldAnd(variables.zip(fieldFunctions).map { case (v, f) => - adtFunctionInvocation[Post]( - f.ref, - args = Seq(adtFunctionInvocation[Post]( - constructor.ref, - None, - args = variables, - )), - ) === v - }) + foldAnd(variables.combinations(2).map { case Seq(v1, v2) => + v1 !== v2 + }.toSeq) ==> + foldAnd(variables.zip(fieldFunctions).map { case (v, f) => + adtFunctionInvocation[Post]( + f.ref, + args = Seq(adtFunctionInvocation[Post]( + constructor.ref, + None, + args = variables, + )), + ) === v + }) }, triggers = variables => { @@ -365,6 +411,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { } }, )) + + // This one generates a matching loop val injectivityAxiom1 = new ADTAxiom[Post](foralls( Seq(axiomType, axiomType), @@ -409,15 +457,56 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { } }, )) + val destructorAxioms = fieldFunctions.zip(fieldInverses).map { + case (f, inv) => + new ADTAxiom[Post](forall( + axiomType, + body = { a => + adtFunctionInvocation[Post]( + inv.ref, + None, + args = Seq( + adtFunctionInvocation[Post](f.ref, None, args = Seq(a)) + ), + ) === a + }, + triggers = { a => + Seq(Seq( + adtFunctionInvocation[Post](f.ref, None, args = Seq(a)) + )) + }, + )) + } + val indexAxioms = fieldFunctions.zipWithIndex.map { case (f, i) => + new ADTAxiom[Post](forall( + axiomType, + body = { a => + adtFunctionInvocation[Post]( + indexFunction.ref, + None, + args = Seq( + adtFunctionInvocation[Post](f.ref, None, args = Seq(a)) + ), + ) === const(i) + }, + triggers = { a => + Seq( + Seq(adtFunctionInvocation[Post](f.ref, None, args = Seq(a))) + ) + }, + )) + } byValConsSucc(cls) = constructor byValClassSucc(cls) = new AxiomaticDataType[Post]( Seq( - constructor, - destructorAxiom, - injectivityAxiom1, +// constructor, +// destructorAxiom, + indexFunction, +// injectivityAxiom1, injectivityAxiom2, - ) ++ fieldFunctions, + ) ++ destructorAxioms ++ indexAxioms ++ fieldFunctions ++ + fieldInverses, Nil, ) globalDeclarations.succeed(cls, byValClassSucc(cls)) @@ -448,41 +537,46 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { )(PanicBlame("typeOf requires nothing.")) === const(typeNumber(cls)) ), )) - case cls: ByValueClass[Pre] => - val (assigns, vars) = - cls.decls.collect { case field: InstanceField[Pre] => - val element = field.t.asPointer.get.element - val newE = dispatch(element) - val v = new Variable[Post](TNonNullPointer(newE)) - ( - InvokeProcedure[Post]( - pointerCreationMethods - .getOrElseUpdate(element, makePointerCreationMethod(newE)) - .ref, - Nil, - Seq(v.get), - Nil, - Nil, - Nil, - )(TrueSatisfiable), - v, - ) - }.unzip - Scope( - vars, - Block( - assigns ++ Seq( - Assign( - Local(target), - adtFunctionInvocation[Post]( - byValConsSucc.ref(cls), - args = vars.map(_.get), - ), - )(AssignLocalOk) - // TODO: Add back typeOf here (but use a separate definition for the adt) - ) - ), - ) + case cls: ByValueClass[Pre] => throw ExtraNode +// val (assigns, vars) = +// cls.decls.collect { case field: InstanceField[Pre] => +// val element = field.t.asPointer.get.element +// val newE = dispatch(element) +// val v = new Variable[Post](TNonNullPointer(newE)) +// ( +// InvokeProcedure[Post]( +// pointerCreationMethods +// .getOrElseUpdate(element, makePointerCreationMethod(newE)) +// .ref, +// Nil, +// Seq(v.get), +// Nil, +// Nil, +// Nil, +// )(TrueSatisfiable), +// v, +// ) +// }.unzip +// val assertions = if (vars.size > 1) { +// Seq(Assert(foldAnd[Post](vars.combinations(2).map { case Seq(a,b) => a.get !== b.get}.toSeq))(PanicBlame("Newly created pointers should be distinct"))) +// } else { +// Nil +// } +// Scope( +// vars, +// Block( +// assigns ++ assertions ++ Seq( +// Assign( +// Local(target), +// adtFunctionInvocation[Post]( +// byValConsSucc.ref(cls), +// args = vars.map(_.get), +// ), +// )(AssignLocalOk) +// // TODO: Add back typeOf here (but use a separate definition for the adt) +// ) +// ), +// ) } } diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala index 06e1eac550..fdfaa81458 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -7,6 +7,7 @@ import vct.col.ref.Ref import vct.col.resolve.ctx.Referrable import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} import vct.col.util.AstBuildHelpers._ +import vct.col.util.SuccessionMap import vct.result.VerificationError.{Unreachable, UserError} // TODO: Think of a better name @@ -86,14 +87,32 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { private val inAssignment: ScopedStack[Unit] = ScopedStack() private val copyContext: ScopedStack[CopyContext] = ScopedStack() + private val classCreationMethods + : SuccessionMap[TByValueClass[Pre], Procedure[Post]] = SuccessionMap() + + def makeClassCreationMethod(t: TByValueClass[Pre]): Procedure[Post] = { + implicit val o: Origin = t.cls.decl.o + + globalDeclarations.declare(withResult((result: Result[Post]) => + procedure[Post]( + blame = AbstractApplicable, + contractBlame = TrueSatisfiable, + returnType = dispatch(t), + ensures = UnitAccountedPredicate( + unwrapClassPerm(result, WritePerm(), t) + ), + decreases = Some(DecreasesClauseNoRecursion[Post]()), + ) + )) + } override def dispatch(node: Statement[Pre]): Statement[Post] = { implicit val o: Origin = node.o node match { case HeapLocalDecl(local) if local.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => { + val t = local.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]] val newLocal = localHeapVariables.dispatch(local) - val t = newLocal.t.asPointer.get.element Block(Seq( HeapLocalDecl(newLocal), // Assign( @@ -103,17 +122,33 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { // TODO: Only do this if the first use does not overwrite it again (do something similar to what I implemented in ImportPointer).... Assign( newLocal.get(DerefAssignTarget), - NewObject(t.asInstanceOf[TByValueClass[Post]].cls), + procedureInvocation[Post]( + TrueSatisfiable, + classCreationMethods + .getOrElseUpdate(t, makeClassCreationMethod(t)).ref, + ), )(AssignLocalOk), )) } - case assign: Assign[Pre] => { + case assign: Assign[Pre] => val target = inAssignment.having(()) { dispatch(assign.target) } if (assign.target.t.isInstanceOf[TByValueClass[Pre]]) { copyContext.having(InAssignmentStatement(assign)) { assign.rewrite(target = target) } } else { assign.rewrite(target = target) } + case Instantiate(Ref(cls), out) + if cls.isInstanceOf[ByValueClass[Pre]] => { + // AssignLocalOk doesn't make too much sense since we don't know if out is a local + val t = TByValueClass[Pre](cls.ref, Seq()) + Assign[Post]( + dispatch(out), + procedureInvocation( + TrueSatisfiable, + classCreationMethods.getOrElseUpdate(t, makeClassCreationMethod(t)) + .ref, + ), + )(AssignLocalOk) } case _ => node.rewriteDefault() } @@ -125,9 +160,8 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { blame: InstanceField[Pre] => Blame[InsufficientPermission], ): Expr[Post] = { implicit val o: Origin = obj.o - val ov = new Variable[Post](obj.t) - val v = - new Variable[Post](dispatch(t))(o.withContent(TypeName("HelloWorld"))) + val ov = new Variable[Post](obj.t)(o.where(name = "original")) + val v = new Variable[Post](dispatch(t))(o.where(name = "copy")) val children = t.cls.decl.decls.collect { case f: InstanceField[Pre] => f.t match { case inner: TByValueClass[Pre] => @@ -154,9 +188,14 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { Then( With( assignLocal(ov.get, obj), - PreAssignExpression(v.get, NewObject[Post](succ(t.cls.decl)))( - AssignLocalOk - ), + PreAssignExpression( + v.get, + procedureInvocation[Post]( + TrueSatisfiable, + classCreationMethods + .getOrElseUpdate(t, makeClassCreationMethod(t)).ref, + ), + )(AssignLocalOk), ), Block(children), ), @@ -233,26 +272,34 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(node: Expr[Pre]): Expr[Post] = { implicit val o: Origin = node.o node match { - case Eq(left, right) - if left.t == right.t && left.t.isInstanceOf[TByValueClass[Pre]] => - val newLeft = dispatch(left) - val newRight = dispatch(right) - return Eq(newLeft, newRight) && unwrapClassComp( - (l, r) => Eq(l, r), - newLeft, - newRight, - left.t.asInstanceOf[TByValueClass[Pre]], - ) - case Neq(left, right) - if left.t == right.t && left.t.isInstanceOf[TByValueClass[Pre]] => - val newLeft = dispatch(left) - val newRight = dispatch(right) - return Neq(newLeft, newRight) && unwrapClassComp( - (l, r) => Neq(l, r), - newLeft, - newRight, - left.t.asInstanceOf[TByValueClass[Pre]], + case NewObject(Ref(cls)) if cls.isInstanceOf[ByValueClass[Pre]] => { + val t = TByValueClass[Pre](cls.ref, Seq()) + procedureInvocation[Post]( + TrueSatisfiable, + classCreationMethods.getOrElseUpdate(t, makeClassCreationMethod(t)) + .ref, ) + } +// case Eq(left, right) +// if left.t == right.t && left.t.isInstanceOf[TByValueClass[Pre]] => +// val newLeft = dispatch(left) +// val newRight = dispatch(right) +// return Eq(newLeft, newRight) && unwrapClassComp( +// (l, r) => Eq(l, r), +// newLeft, +// newRight, +// left.t.asInstanceOf[TByValueClass[Pre]], +// ) +// case Neq(left, right) +// if left.t == right.t && left.t.isInstanceOf[TByValueClass[Pre]] => +// val newLeft = dispatch(left) +// val newRight = dispatch(right) +// return Neq(newLeft, newRight) && unwrapClassComp( +// (l, r) => Neq(l, r), +// newLeft, +// newRight, +// left.t.asInstanceOf[TByValueClass[Pre]], +// ) case _ => {} } if (inAssignment.nonEmpty) diff --git a/src/rewrite/vct/rewrite/adt/ImportOption.scala b/src/rewrite/vct/rewrite/adt/ImportOption.scala index 0ac9e05c37..8dd885426d 100644 --- a/src/rewrite/vct/rewrite/adt/ImportOption.scala +++ b/src/rewrite/vct/rewrite/adt/ImportOption.scala @@ -68,6 +68,19 @@ case class ImportOption[Pre <: Generation](importer: ImportADTImporter) case other => rewriteDefault(other) } + override def preCoerce(e: Expr[Pre]): Expr[Pre] = + e match { + case OptGet(OptSome(inner)) => inner + case OptGet(OptSomeTyped(_, inner)) => inner + case OptGetOrElse(OptSome(inner), _) => inner + case OptGetOrElse(OptSomeTyped(_, inner), _) => inner + case OptSomeTyped(t, OptGet(inner)) + if inner.t.asOption.get.element == t => + inner + case OptSome(OptGet(inner)) => inner + case _ => super.preCoerce(e) + } + override def postCoerce(e: Expr[Pre]): Expr[Post] = e match { case OptEmpty(opt) => diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index dfa103b0f1..9c5d253e85 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -15,7 +15,7 @@ case object ImportPointer extends ImportADTBuilder("pointer") { Origin(Seq(PreferredName(Seq(typeText(t))), LabelContext("pointer field"))) private val PointerCreationOrigin: Origin = Origin( - Seq(LabelContext("classToRef, pointer creation method")) + Seq(LabelContext("adtPointer, pointer creation method")) ) case class PointerNullOptNone(inner: Blame[PointerNull], expr: Expr[_]) @@ -87,8 +87,10 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) private def makePointerCreationMethod(t: Type[Post]): Procedure[Post] = { implicit val o: Origin = PointerCreationOrigin + .where(name = "create_nonnull_pointer_" + t.toString) - val result = new Variable[Post](TAxiomatic(pointerAdt.ref, Nil)) + val result = + new Variable[Post](TAxiomatic(pointerAdt.ref, Nil))(o.where(name = "res")) globalDeclarations.declare(procedure[Post]( blame = AbstractApplicable, contractBlame = TrueSatisfiable, @@ -225,9 +227,63 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) } } + def rewriteTopLevelPointerSubscriptInTrigger(e: Expr[Pre]): Expr[Post] = { + implicit val o: Origin = e.o + e match { + case sub @ PointerSubscript(pointer, index) => + FunctionInvocation[Post]( + ref = pointerDeref.ref, + args = Seq( + FunctionInvocation[Post]( + ref = pointerAdd.ref, + args = Seq(unwrapOption(pointer, sub.blame), dispatch(index)), + typeArgs = Nil, + Nil, + Nil, + )(NoContext(PointerBoundsPreconditionFailed(sub.blame, index))) + ), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame("ptr_deref requires nothing.")) + case deref @ DerefPointer(pointer) => + FunctionInvocation[Post]( + ref = pointerDeref.ref, + args = Seq( + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(unwrapOption(pointer, deref.blame), const(0)), + typeArgs = Nil, + Nil, + Nil, + )(NoContext( + DerefPointerBoundsPreconditionFailed(deref.blame, pointer) + )) + ), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame("ptr_deref requires nothing.")) + case other => rewriteDefault(other) + } + } + override def postCoerce(e: Expr[Pre]): Expr[Post] = { implicit val o: Origin = e.o e match { +// case f @ Forall(_, triggers, _) => +// f.rewrite(triggers = +// triggers.map(_.map(rewriteTopLevelPointerSubscriptInTrigger)) +// ) +// case s @ Starall(_, triggers, _) => +// s.rewrite(triggers = +// triggers.map(_.map(rewriteTopLevelPointerSubscriptInTrigger)) +// ) +// case e @ Exists(_, triggers, _) => +// e.rewrite(triggers = +// triggers.map(_.map(rewriteTopLevelPointerSubscriptInTrigger)) +// ) case sub @ PointerSubscript(pointer, index) => SilverDeref( obj = diff --git a/src/viper/viper/api/backend/SilverBackend.scala b/src/viper/viper/api/backend/SilverBackend.scala index 2e1161e85e..c4a3144068 100644 --- a/src/viper/viper/api/backend/SilverBackend.scala +++ b/src/viper/viper/api/backend/SilverBackend.scala @@ -392,7 +392,7 @@ trait SilverBackend .NegativePermissionValue( info(p).permissionValuePermissionNode.get ) // need to fetch access - case _ => ??? + case r => throw new NotImplementedError("Missing: " + r) } def getDecreasesClause(reason: ErrorReason): col.DecreasesClause[_] = diff --git a/src/viper/viper/api/backend/silicon/Silicon.scala b/src/viper/viper/api/backend/silicon/Silicon.scala index aa1244309f..0fd0ab0b6e 100644 --- a/src/viper/viper/api/backend/silicon/Silicon.scala +++ b/src/viper/viper/api/backend/silicon/Silicon.scala @@ -97,7 +97,6 @@ case class Silicon( "--z3ConfigArgs", z3Config, "--ideModeAdvanced", - "--useOldAxiomatization", ) if (proverLogFile.isDefined) { diff --git a/src/viper/viper/api/transform/ColToSilver.scala b/src/viper/viper/api/transform/ColToSilver.scala index 9ca1a94c77..eb9f55580a 100644 --- a/src/viper/viper/api/transform/ColToSilver.scala +++ b/src/viper/viper/api/transform/ColToSilver.scala @@ -7,7 +7,7 @@ import vct.col.ref.Ref import vct.col.util.AstBuildHelpers.unfoldStar import vct.col.{ast => col} import vct.result.VerificationError.{SystemError, Unreachable} -import viper.silver.ast.TypeVar +import viper.silver.ast.{AnnotationInfo, ConsInfo, TypeVar} import viper.silver.plugin.standard.termination.{ DecreasesClause, DecreasesTuple, @@ -231,7 +231,17 @@ case class ColToSilver(program: col.Program[_]) { function.contract.decreases.toSeq.map(decreases), pred(function.contract.ensures), function.body.map(exp), - )(pos = pos(function), info = NodeInfo(function)) + )( + pos = pos(function), + info = + if (ref(function) == "ptrDerefblahblah") + ConsInfo( + AnnotationInfo(Map("opaque" -> Seq())), + NodeInfo(function), + ) + else + NodeInfo(function), + ) } case procedure: col.Procedure[_] if procedure.returnType == col.TVoid() && !procedure.inline && diff --git a/test/main/vct/test/integration/examples/CSpec.scala b/test/main/vct/test/integration/examples/CSpec.scala index b5af488aab..d8209460ad 100644 --- a/test/main/vct/test/integration/examples/CSpec.scala +++ b/test/main/vct/test/integration/examples/CSpec.scala @@ -18,7 +18,7 @@ class CSpec extends VercorsSpec { int x = 4.0 % 1; } """ - vercors should fail withCode "assignFieldFailed" using silicon in "cannot access field of struct after freeing" c + vercors should fail withCode "ptrPerm" using silicon in "cannot access field of struct after freeing" c """ #include @@ -76,7 +76,7 @@ class CSpec extends VercorsSpec { int main(){ struct d* xs = (struct d*) malloc(sizeof(struct d)*3); struct d* ys = (struct d*) malloc(sizeof(struct d)*3); - //@ exhale Perm(xs[0].x, 1\2); + //@ exhale Perm(&xs[0].x, 1\2); free(xs); } """ @@ -112,7 +112,7 @@ class CSpec extends VercorsSpec { int main(){ struct d s1; struct d* s2 = &s1; - //@ exhale Perm(s2->x, 1\1); + //@ exhale Perm(&s2->x, 1\1); s2->x = 1; } """ @@ -124,7 +124,7 @@ class CSpec extends VercorsSpec { }; int main(){ struct d s; - //@ exhale Perm(s.x, 1\1); + //@ exhale Perm(&s.x, 1\1); s.x = 1; } """ @@ -136,7 +136,7 @@ class CSpec extends VercorsSpec { int main(){ struct d s; s.x = 1; - //@ exhale Perm(s.x, 1\1); + //@ exhale Perm(&s.x, 1\1); int x = s.x; } """ @@ -323,7 +323,7 @@ class CSpec extends VercorsSpec { int main(){ struct d s; - //@ exhale Perm(s.x, 1\1); + //@ exhale Perm(&s.x, 1\1); test(s); } """ @@ -341,7 +341,7 @@ class CSpec extends VercorsSpec { int main(){ struct d s, t; - //@ exhale Perm(s.x, 1\1); + //@ exhale Perm(&s.x, 1\1); t = s; } """ From 7b50ef3a3b3ab531990e094ac4539fde3fc75ecd Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 19 Jul 2024 10:18:02 +0200 Subject: [PATCH 12/47] Make the pointer for struct fields implicit simplifying most locations in the transformations stages --- src/rewrite/vct/rewrite/ClassToRef.scala | 36 ++- .../vct/rewrite/EncodeArrayValues.scala | 9 +- .../vct/rewrite/PrepareByValueClass.scala | 34 ++- src/rewrite/vct/rewrite/TrivialAddrOf.scala | 1 + .../vct/rewrite/VariableToPointer.scala | 23 +- .../vct/rewrite/lang/LangCPPToCol.scala | 218 +++++++----------- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 43 ++-- 7 files changed, 163 insertions(+), 201 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index aa4c140937..ec510b5ec7 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -7,6 +7,7 @@ import vct.col.util.AstBuildHelpers._ import hre.util.ScopedStack import vct.col.rewrite.error.{ExcludedByPassOrder, ExtraNode} import vct.col.ref.Ref +import vct.col.resolve.ctx.Referrable import vct.col.util.SuccessionMap import scala.collection.mutable @@ -34,6 +35,16 @@ case object ClassToRef extends RewriterBuilder { inner.blame(InstanceNull(inv)) } + case class DerefFieldPointerBlame( + inner: Blame[InsufficientPermission], + node: HeapDeref[_], + clazz: ByValueClass[_], + field: String, + ) extends Blame[PointerDerefError] { + override def blame(error: PointerDerefError): Unit = { + inner.blame(InsufficientPermission(node)) + } + } } case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { @@ -323,7 +334,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { val axiomType = TAxiomatic[Post](byValClassSucc.ref(cls), Nil) val (fieldFunctions, fieldInverses, fieldTypes) = cls.decls.collect { case field: Field[Pre] => - val newT = dispatch(field.t) + val newT = TNonNullPointer(dispatch(field.t)) byValFieldSucc(field) = new ADTFunction[Post]( Seq(new Variable(axiomType)(field.o)), @@ -695,17 +706,30 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { .left(PanicBlame("incorrect instance function type?"), inv.blame), ))(inv.o) case ThisObject(_) => diz.top + case ptrOf @ AddrOf(Deref(obj, Ref(field))) + if obj.t.isInstanceOf[TByValueClass[Pre]] => + adtFunctionInvocation[Post]( + byValFieldSucc.ref(field), + args = Seq(dispatch(obj)), + )(ptrOf.o) case deref @ Deref(obj, Ref(field)) => obj.t match { case _: TByReferenceClass[Pre] => SilverDeref[Post](dispatch(obj), byRefFieldSucc.ref(field))( deref.blame )(deref.o) - case _: TByValueClass[Pre] => - adtFunctionInvocation[Post]( - byValFieldSucc.ref(field), - args = Seq(dispatch(obj)), - )(deref.o) + case t: TByValueClass[Pre] => + DerefPointer( + adtFunctionInvocation[Post]( + byValFieldSucc.ref(field), + args = Seq(dispatch(obj)), + )(deref.o) + )(DerefFieldPointerBlame( + deref.blame, + deref, + t.cls.decl.asInstanceOf[ByValueClass[Pre]], + Referrable.originNameOrEmpty(field), + ))(deref.o) } case TypeValue(t) => t match { diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index 52435b4a68..6e963afd04 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -436,10 +436,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { } val newFieldPerms = fields.map(member => { val loc = - (i: Variable[Post]) => - DerefPointer(Deref[Post](struct(i), member.ref)(DerefPerm))( - NonNullPointerNull - ) + (i: Variable[Post]) => Deref[Post](struct(i), member.ref)(DerefPerm) var anns: Seq[(Expr[Post], Expr[Pre] => PointerFreeError)] = Seq(( makeStruct.makePerm( i => FieldLocation[Post](struct(i), member.ref), @@ -452,7 +449,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { ), )) anns = - if (typeIsRef(member.t.asPointer.get.element)) + if (typeIsRef(member.t)) anns :+ ( makeStruct.makeUnique(loc), @@ -460,7 +457,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { ) else anns - member.t.asPointer.get.element match { + member.t match { case newStruct: TClass[Post] => // We recurse, since a field is another struct anns ++ unwrapStructPerm( diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala index fdfaa81458..228e875035 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -166,19 +166,13 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { f.t match { case inner: TByValueClass[Pre] => Assign[Post]( - DerefPointer(Deref[Post](v.get, succ(f))(DerefAssignTarget))( - NonNullPointerNull - ), + Deref[Post](v.get, succ(f))(DerefAssignTarget), copyClassValue(Deref[Post](ov.get, succ(f))(blame(f)), inner, blame), )(AssignLocalOk) case _ => Assign[Post]( - DerefPointer(Deref[Post](v.get, succ(f))(DerefAssignTarget))( - NonNullPointerNull - ), - DerefPointer(Deref[Post](ov.get, succ(f))(blame(f)))( - NonNullPointerNull - ), + Deref[Post](v.get, succ(f))(DerefAssignTarget), + Deref[Post](ov.get, succ(f))(blame(f)), )(AssignLocalOk) } @@ -220,12 +214,10 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { val newFieldPerms = fields.map(member => { val loc = FieldLocation[Post](obj, succ(member)) // TODO: Don't go through regular pointers... - member.t.asPointer.get.element match { + member.t match { case inner: TByValueClass[Pre] => Perm[Post](loc, perm) &* unwrapClassPerm( - DerefPointer(Deref[Post](obj, succ(member))(blame))( - NonNullPointerNull - ), + Deref[Post](obj, succ(member))(blame), perm, inner, structType +: visited, @@ -356,12 +348,16 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { dp, v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], ) - case dp @ DerefPointer(Deref(_, Ref(f))) - if f.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => - rewriteInCopyContext( - dp, - f.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], - ) + case deref @ Deref(_, Ref(f)) if f.t.isInstanceOf[TByValueClass[Pre]] => + if (copyContext.isEmpty) { deref.rewriteDefault() } + else { + // TODO: Improve blame message here + copyClassValue( + deref.rewriteDefault(), + f.t.asInstanceOf[TByValueClass[Pre]], + f => deref.blame, + ) + } case dp @ DerefPointer(Local(Ref(v))) if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => // This can happen if the user specifies a local of type pointer to TByValueClass diff --git a/src/rewrite/vct/rewrite/TrivialAddrOf.scala b/src/rewrite/vct/rewrite/TrivialAddrOf.scala index 63b5396c0d..6cf0e2c060 100644 --- a/src/rewrite/vct/rewrite/TrivialAddrOf.scala +++ b/src/rewrite/vct/rewrite/TrivialAddrOf.scala @@ -37,6 +37,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { case AddrOf(sub @ PointerSubscript(p, i)) => PointerAdd(dispatch(p), dispatch(i))(SubscriptErrorAddError(sub))(e.o) + case AddrOf(Deref(_, _)) => e.rewriteDefault() case AddrOf(other) => throw UnsupportedLocation(other) case assign @ PreAssignExpression(target, AddrOf(value)) if value.t.isInstanceOf[TByReferenceClass[Pre]] => diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala index ab399c63a5..6e8546627c 100644 --- a/src/rewrite/vct/rewrite/VariableToPointer.scala +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -46,8 +46,9 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { case AddrOf(DerefHeapVariable(Ref(v))) if !v.t.isInstanceOf[TByReferenceClass[Pre]] => v - case AddrOf(Deref(_, Ref(f))) - if !f.t.isInstanceOf[TByReferenceClass[Pre]] => + case AddrOf(Deref(o, Ref(f))) + if !f.t.isInstanceOf[TByReferenceClass[Pre]] && + !o.t.isInstanceOf[TByValueClass[Pre]] => f }) super.dispatch(program) @@ -58,15 +59,15 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { // TODO: Use some sort of NonNull pointer type instead case v: HeapVariable[Pre] if addressedSet.contains(v) => heapVariableMap(v) = globalDeclarations - .succeed(v, new HeapVariable(TPointer(dispatch(v.t)))(v.o)) + .succeed(v, new HeapVariable(TNonNullPointer(dispatch(v.t)))(v.o)) case v: Variable[Pre] if addressedSet.contains(v) => variableMap(v) = variables - .succeed(v, new Variable(TPointer(dispatch(v.t)))(v.o)) + .succeed(v, new Variable(TNonNullPointer(dispatch(v.t)))(v.o)) case f: InstanceField[Pre] if addressedSet.contains(f) => fieldMap(f) = classDeclarations.succeed( f, new InstanceField( - TPointer(dispatch(f.t)), + TNonNullPointer(dispatch(f.t)), f.flags.map { it => dispatch(it) }, )(f.o), ) @@ -84,7 +85,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { implicit val o: Origin = local.o Assign( Local[Post](variableMap.ref(local)), - NewPointerArray( + NewNonNullPointerArray( variableMap(local).t.asPointer.get.element, const(1), )(PanicBlame("Size is > 0")), @@ -205,13 +206,11 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { PointerLocation(Deref[Post](dispatch(obj), fieldMap.ref(f))(PanicBlame( "Should always be accessible" )))(PanicBlame("Should always be accessible")) - case PointerLocation( - AddrOf(Deref(obj, Ref(f))) - ) /* if addressedSet.contains(f) always true */ => + case PointerLocation(AddrOf(Deref(obj, Ref(f)))) + if addressedSet.contains(f) => FieldLocation[Post](dispatch(obj), fieldMap.ref(f)) - case PointerLocation( - AddrOf(DerefHeapVariable(Ref(v))) - ) /* if addressedSet.contains(v) always true */ => + case PointerLocation(AddrOf(DerefHeapVariable(Ref(v)))) + if addressedSet.contains(v) => HeapVariableLocation[Post](heapVariableMap.ref(v)) case PointerLocation(AddrOf(local @ Local(_))) => throw UnsupportedAddrOf(local) diff --git a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala index d0708ec692..debd341db5 100644 --- a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala @@ -738,9 +738,7 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) implicit val o: Origin = f.o Star( Perm[Post](FieldLocation[Post](thisObj, f.ref), ReadPerm()), - DerefPointer(Deref[Post](thisObj, f.ref)(new SYCLRangeDerefBlame(f)))( - NonNullPointerNull - ) >= c_const(0), + Deref[Post](thisObj, f.ref)(new SYCLRangeDerefBlame(f)) >= c_const(0), ) }) } @@ -765,9 +763,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) implicit val o: Origin = params(i).o Star( Perm(FieldLocation[Post](result, rangeFields(i).ref), ReadPerm()), - DerefPointer(Deref[Post](result, rangeFields(i).ref)( - new SYCLRangeDerefBlame(rangeFields(i)) - ))(NonNullPointerNull) === Local[Post](params(i).ref), + Deref[Post](result, rangeFields(i).ref)(new SYCLRangeDerefBlame( + rangeFields(i) + )) === Local[Post](params(i).ref), ) }) } @@ -801,23 +799,21 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) implicit val o: Origin = params(i).o foldStar(Seq( Perm(FieldLocation[Post](result, rangeFields(i).ref), ReadPerm()), - DerefPointer(Deref[Post](result, rangeFields(i).ref)( - new SYCLRangeDerefBlame(rangeFields(i)) - ))(NonNullPointerNull) === FloorDiv( + Deref[Post](result, rangeFields(i).ref)(new SYCLRangeDerefBlame( + rangeFields(i) + )) === FloorDiv( Local[Post](params(i).ref), Local[Post](params(i + 1).ref), )(ImpossibleDivByZeroBlame()), Perm(FieldLocation[Post](result, rangeFields(i + 1).ref), ReadPerm()), - DerefPointer(Deref[Post](result, rangeFields(i + 1).ref)( + Deref[Post](result, rangeFields(i + 1).ref)(new SYCLRangeDerefBlame( + rangeFields(i + 1) + )) === Local[Post](params(i + 1).ref), + Deref[Post](result, rangeFields(i).ref)(new SYCLRangeDerefBlame( + rangeFields(i) + )) * Deref[Post](result, rangeFields(i + 1).ref)( new SYCLRangeDerefBlame(rangeFields(i + 1)) - ))(NonNullPointerNull) === Local[Post](params(i + 1).ref), - DerefPointer(Deref[Post](result, rangeFields(i).ref)( - new SYCLRangeDerefBlame(rangeFields(i)) - ))(NonNullPointerNull) * DerefPointer( - Deref[Post](result, rangeFields(i + 1).ref)(new SYCLRangeDerefBlame( - rangeFields(i + 1) - )) - )(NonNullPointerNull) === Local[Post](params(i).ref), + ) === Local[Post](params(i).ref), )) }) } @@ -1098,11 +1094,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case _: SYCLTAccessor[Pre] => // Referencing an accessor variable can only be done in kernels, otherwise an error will already have been thrown val accessor = syclAccessorSuccessor(ref) - DerefPointer( - Deref[Post](currentThis.get, accessor.instanceField.ref)( - SYCLAccessorFieldInsufficientReferencePermissionBlame(local) - ) - )(NonNullPointerNull) + Deref[Post](currentThis.get, accessor.instanceField.ref)( + SYCLAccessorFieldInsufficientReferencePermissionBlame(local) + ) case _: SYCLTLocalAccessor[Pre] if currentKernelType.get.isInstanceOf[BasicKernel] => throw SYCLNoLocalAccessorsInBasicKernel(local) @@ -1240,18 +1234,18 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) getGlobalWorkItemLinearId(inv) case "sycl::accessor::get_range" => classInstance match { - case Some(DerefPointer(Deref(_, ref))) => + case Some(Deref(_, ref)) => val accessor = syclAccessorSuccessor.values .find(acc => ref.decl.equals(acc.instanceField)).get LiteralSeq[Post]( TCInt(), accessor.rangeIndexFields.map(f => - DerefPointer(Deref[Post](currentThis.get, f.ref)( + Deref[Post](currentThis.get, f.ref)( SYCLAccessorRangeIndexFieldInsufficientReferencePermissionBlame( inv ) - ))(NonNullPointerNull) + ) ), ) case _ => throw NotApplicable(inv) @@ -1573,14 +1567,14 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val rangeFields: mutable.Buffer[InstanceField[Post]] = mutable.Buffer.empty range.indices.foreach(index => { implicit val o: Origin = range(index).o.where(name = s"range$index") - val instanceField = new InstanceField[Post](TNonNullPointer(TCInt()), Nil) + val instanceField = new InstanceField[Post](TCInt(), Nil) rangeFields.append(instanceField) val iterVar = createRangeIterVar( GlobalScope(), index, - DerefPointer(Deref[Post](currentThis.get, instanceField.ref)( - new SYCLRangeDerefBlame(instanceField) - ))(NonNullPointerNull), + Deref[Post](currentThis.get, instanceField.ref)(new SYCLRangeDerefBlame( + instanceField + )), ) currentDimensionIterVars(GlobalScope()).append(iterVar) }) @@ -1649,30 +1643,28 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) { implicit val o: Origin = kernelDimensions.o .where(name = s"group_range$index") - val groupInstanceField = - new InstanceField[Post](TNonNullPointer(TCInt()), Nil) + val groupInstanceField = new InstanceField[Post](TCInt(), Nil) rangeFields.append(groupInstanceField) val groupIterVar = createRangeIterVar( GroupScope(), index, - DerefPointer(Deref[Post](currentThis.get, groupInstanceField.ref)( + Deref[Post](currentThis.get, groupInstanceField.ref)( new SYCLRangeDerefBlame(groupInstanceField) - ))(NonNullPointerNull), + ), ) currentDimensionIterVars(GroupScope()).append(groupIterVar) } { implicit val o: Origin = localRange(index).o .where(name = s"local_range$index") - val localInstanceField = - new InstanceField[Post](TNonNullPointer(TCInt()), Nil) + val localInstanceField = new InstanceField[Post](TCInt(), Nil) rangeFields.append(localInstanceField) val localIterVar = createRangeIterVar( LocalScope(), index, - DerefPointer(Deref[Post](currentThis.get, localInstanceField.ref)( + Deref[Post](currentThis.get, localInstanceField.ref)( new SYCLRangeDerefBlame(localInstanceField) - ))(NonNullPointerNull), + ), ) currentDimensionIterVars(LocalScope()).append(localIterVar) } @@ -1864,13 +1856,10 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case None => // No accessor for buffer exist in the command group, so make fields and permissions val instanceField = - new InstanceField[Post]( - TNonNullPointer(buffer.generatedVar.t), - Nil, - )(accO) + new InstanceField[Post](buffer.generatedVar.t, Nil)(accO) val rangeIndexFields = Seq .range(0, buffer.range.dimensions.size).map(i => - new InstanceField[Post](TNonNullPointer(TCInt()), Nil)( + new InstanceField[Post](TCInt(), Nil)( dimO.where(name = s"${accName}_r$i") ) ) @@ -1942,11 +1931,7 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } val rangeIndexDerefs: Seq[Expr[Post]] = acc.rangeIndexFields.map(f => - DerefPointer( - Deref[Post](classObj, f.ref)(new SYCLAccessorDimensionDerefBlame(f))( - f.o - ) - )(NonNullPointerNull)(f.o) + Deref[Post](classObj, f.ref)(new SYCLAccessorDimensionDerefBlame(f))(f.o) ) ( @@ -1964,22 +1949,18 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) ReadPerm()(acc.instanceField.o), )(acc.instanceField.o), ValidArray( - DerefPointer( - Deref[Post](classObj, acc.instanceField.ref)( - new SYCLAccessorDerefBlame(acc.instanceField) - )(acc.instanceField.o) - )(NonNullPointerNull)(acc.instanceField.o), + Deref[Post](classObj, acc.instanceField.ref)( + new SYCLAccessorDerefBlame(acc.instanceField) + )(acc.instanceField.o), rangeIndexDerefs.reduce((e1, e2) => (e1 * e2)(acc.buffer.o)), )(acc.instanceField.o), ) )(acc.instanceField.o), Perm( ArrayLocation( - DerefPointer( - Deref[Post](classObj, acc.instanceField.ref)( - new SYCLAccessorDerefBlame(acc.instanceField) - )(acc.instanceField.o) - )(NonNullPointerNull)(acc.instanceField.o), + Deref[Post](classObj, acc.instanceField.ref)( + new SYCLAccessorDerefBlame(acc.instanceField) + )(acc.instanceField.o), Any()(PanicBlame( "The accessor field is not null as that was proven in the previous conditions." ))(acc.instanceField.o), @@ -2031,24 +2012,20 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) constructorPostConditions.append( foldStar[Post]( Eq[Post]( - DerefPointer( - Deref[Post](result, acc.instanceField.ref)( - new SYCLAccessorDerefBlame(acc.instanceField) - )(acc.instanceField.o) - )(NonNullPointerNull)(acc.instanceField.o), + Deref[Post](result, acc.instanceField.ref)( + new SYCLAccessorDerefBlame(acc.instanceField) + )(acc.instanceField.o), Local[Post](newConstructorAccessorArg.ref)( newConstructorAccessorArg.o ), )(newConstructorAccessorArg.o) +: Seq.range(0, acc.rangeIndexFields.size).map(i => Eq[Post]( - DerefPointer( - Deref[Post](result, acc.rangeIndexFields(i).ref)( - new SYCLAccessorDimensionDerefBlame( - acc.rangeIndexFields(i) - ) - )(acc.rangeIndexFields(i).o) - )(NonNullPointerNull)(acc.rangeIndexFields(i).o), + Deref[Post](result, acc.rangeIndexFields(i).ref)( + new SYCLAccessorDimensionDerefBlame( + acc.rangeIndexFields(i) + ) + )(acc.rangeIndexFields(i).o), Local[Post](newConstructorAccessorDimensionArgs(i).ref)( newConstructorAccessorDimensionArgs(i).o ), @@ -2506,14 +2483,12 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) CPP.unwrappedType(base.t) match { case SYCLTAccessor(_, 1, _) => ArraySubscript[Post]( - DerefPointer( - Deref[Post]( - currentThis.get, - syclAccessorSuccessor(base.ref.get).instanceField.ref, - )(new SYCLAccessorDerefBlame( - syclAccessorSuccessor(base.ref.get).instanceField - ))(sub.o) - )(NonNullPointerNull)(sub.o), + Deref[Post]( + currentThis.get, + syclAccessorSuccessor(base.ref.get).instanceField.ref, + )(new SYCLAccessorDerefBlame( + syclAccessorSuccessor(base.ref.get).instanceField + ))(sub.o), rw.dispatch(index), )(SYCLAccessorArraySubscriptErrorBlame(sub))(sub.o) case t: SYCLTAccessor[Pre] => @@ -2529,25 +2504,19 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val linearizeArgs = Seq( rw.dispatch(indexX), rw.dispatch(indexY), - DerefPointer( - Deref[Post](currentThis.get, accessor.rangeIndexFields(0).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(0)) - ) - )(NonNullPointerNull), - DerefPointer( - Deref[Post](currentThis.get, accessor.rangeIndexFields(1).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(1)) - ) - )(NonNullPointerNull), + Deref[Post](currentThis.get, accessor.rangeIndexFields(0).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(0)) + ), + Deref[Post](currentThis.get, accessor.rangeIndexFields(1).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(1)) + ), ) CPP.unwrappedType(base.t) match { case SYCLTAccessor(_, 2, _) => ArraySubscript[Post]( - DerefPointer( - Deref[Post](currentThis.get, accessor.instanceField.ref)( - new SYCLAccessorDerefBlame(accessor.instanceField) - ) - )(NonNullPointerNull), + Deref[Post](currentThis.get, accessor.instanceField.ref)( + new SYCLAccessorDerefBlame(accessor.instanceField) + ), syclHelperFunctions("sycl_:_:linearize_2")( linearizeArgs, SYCLAccessorArraySubscriptLinearizeInvocationBlame( @@ -2575,30 +2544,22 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) rw.dispatch(indexX), rw.dispatch(indexY), rw.dispatch(indexZ), - DerefPointer( - Deref[Post](currentThis.get, accessor.rangeIndexFields(0).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(0)) - ) - )(NonNullPointerNull), - DerefPointer( - Deref[Post](currentThis.get, accessor.rangeIndexFields(1).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(1)) - ) - )(NonNullPointerNull), - DerefPointer( - Deref[Post](currentThis.get, accessor.rangeIndexFields(2).ref)( - new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(2)) - ) - )(NonNullPointerNull), + Deref[Post](currentThis.get, accessor.rangeIndexFields(0).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(0)) + ), + Deref[Post](currentThis.get, accessor.rangeIndexFields(1).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(1)) + ), + Deref[Post](currentThis.get, accessor.rangeIndexFields(2).ref)( + new SYCLAccessorDimensionDerefBlame(accessor.rangeIndexFields(2)) + ), ) CPP.unwrappedType(base.t) match { case SYCLTAccessor(_, 3, _) => ArraySubscript[Post]( - DerefPointer( - Deref[Post](currentThis.get, accessor.instanceField.ref)( - new SYCLAccessorDerefBlame(accessor.instanceField) - ) - )(NonNullPointerNull), + Deref[Post](currentThis.get, accessor.instanceField.ref)( + new SYCLAccessorDerefBlame(accessor.instanceField) + ), syclHelperFunctions("sycl_:_:linearize_3")( linearizeArgs, SYCLAccessorArraySubscriptLinearizeInvocationBlame( @@ -2676,12 +2637,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) removeKernelClassInstancePermissions(right), ) - case ActionPerm(DerefPointer(Deref(obj, _)), _) - if obj.equals(this.currentThis.get) => - tt - case ModelPerm(DerefPointer(Deref(obj, _)), _) - if obj.equals(this.currentThis.get) => + case ActionPerm(Deref(obj, _), _) if obj.equals(this.currentThis.get) => tt + case ModelPerm(Deref(obj, _), _) if obj.equals(this.currentThis.get) => tt case Perm(FieldLocation(obj, _), _) if obj.equals(this.currentThis.get) => tt case PointsTo(FieldLocation(obj, _), _, _) @@ -2689,20 +2647,14 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) tt case Value(FieldLocation(obj, _)) if obj.equals(this.currentThis.get) => tt - case Perm( - AmbiguousLocation(ArraySubscript(DerefPointer(Deref(obj, _)), _)), - _, - ) if obj.equals(this.currentThis.get) => + case Perm(AmbiguousLocation(ArraySubscript(Deref(obj, _), _)), _) + if obj.equals(this.currentThis.get) => tt - case PointsTo( - AmbiguousLocation(ArraySubscript(DerefPointer(Deref(obj, _)), _)), - _, - _, - ) if obj.equals(this.currentThis.get) => + case PointsTo(AmbiguousLocation(ArraySubscript(Deref(obj, _), _)), _, _) + if obj.equals(this.currentThis.get) => tt - case Value( - AmbiguousLocation(ArraySubscript(DerefPointer(Deref(obj, _)), _)) - ) if obj.equals(this.currentThis.get) => + case Value(AmbiguousLocation(ArraySubscript(Deref(obj, _), _))) + if obj.equals(this.currentThis.get) => tt case Implies(left, right) => Implies( @@ -2733,9 +2685,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case SYCLAccessor(buffer, SYCLReadWriteAccess(), instanceField, _) => assignLocal( buffer.generatedVar.get, - DerefPointer(Deref[Post](variable, instanceField.ref)( - new SYCLAccessorDerefBlame(instanceField) - ))(NonNullPointerNull), + Deref[Post](variable, instanceField.ref)(new SYCLAccessorDerefBlame( + instanceField + )), ) })) } diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index 54abadbe31..1b4fba2294 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -992,9 +992,10 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) Seq(x), ) = fieldDecl fieldDecl.drop() - val t = TNonNullPointer(specs.collectFirst { - case t: CSpecificationType[Pre] => rw.dispatch(t.t) - }.get) + val t = + specs.collectFirst { case t: CSpecificationType[Pre] => + rw.dispatch(t.t) + }.get cStructFieldsSuccessor((decl, fieldDecl)) = new InstanceField(t = t, flags = Nil)(CStructFieldOrigin(x)) rw.classDeclarations @@ -1329,11 +1330,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) t.t }.get if (t.isInstanceOf[CTStruct[Pre]]) { - DerefPointer[Post]( - DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))( - local.blame - ) - )(NonNullPointerNull) + DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))( + local.blame + ) } else { DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))( local.blame @@ -1395,12 +1394,10 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case _: TNotAValue[Pre] => throw TypeUsedAsValue(deref.obj) case _ => ??? } - DerefPointer( - Deref[Post]( - rw.dispatch(deref.obj), - cStructFieldsSuccessor.ref((struct_ref.decl, struct.decls)), - )(deref.blame) - )(NonNullPointerNull) + Deref[Post]( + rw.dispatch(deref.obj), + cStructFieldsSuccessor.ref((struct_ref.decl, struct.decls)), + )(deref.blame) } } @@ -1420,12 +1417,10 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case CTPointer(CTStruct(struct)) => struct case t => throw WrongStructType(t) } - DerefPointer( - Deref[Post]( - DerefPointer(rw.dispatch(deref.struct))(b), - cStructFieldsSuccessor.ref((structRef.decl, struct.decls)), - )(deref.blame)(deref.o) - )(NonNullPointerNull)(deref.o) + Deref[Post]( + DerefPointer(rw.dispatch(deref.struct))(b), + cStructFieldsSuccessor.ref((structRef.decl, struct.decls)), + )(deref.blame)(deref.o) } } @@ -1449,11 +1444,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val newFieldPerms = fields.map(member => { val loc = AmbiguousLocation( - DerefPointer( - Deref[Post]( - newExpr, - cStructFieldsSuccessor.ref((structType.ref.decl, member)), - )(blame) + Deref[Post]( + newExpr, + cStructFieldsSuccessor.ref((structType.ref.decl, member)), )(blame) )(struct.blame) member.specs.collectFirst { From 1a31eddbc9eabf4fb141fd9799b71c0d3fa45dfc Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 19 Jul 2024 13:26:37 +0200 Subject: [PATCH 13/47] Fix the type numbers --- src/rewrite/vct/rewrite/ClassToRef.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 96f640cea0..c32762cb8b 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -102,7 +102,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { def makeTypeOf: Function[Post] = { implicit val o: Origin = TypeOfOrigin - val obj = new Variable[Post](TAnyValue()) + val obj = new Variable[Post](TRef()) withResult((result: Result[Post]) => function( blame = AbstractApplicable, @@ -185,7 +185,11 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { "Class type parameters should be encoded using monomorphization earlier" ) - typeNumber(cls) + cls match { + case clazz: ByReferenceClass[Pre] => typeNumber(cls) + case clazz: ByValueClass[Pre] => {} + } + val thisType = dispatch(cls.classType(Nil)) cls.decls.foreach { case function: InstanceFunction[Pre] => From 3f9f02b33adcd9d37afba1f45cb7b506111b20a1 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 19 Jul 2024 14:29:11 +0200 Subject: [PATCH 14/47] Replaced type numbers with constants for ByValueClass --- src/rewrite/vct/rewrite/ClassToRef.scala | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index c32762cb8b..07126eadc6 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -185,10 +185,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { "Class type parameters should be encoded using monomorphization earlier" ) - cls match { - case clazz: ByReferenceClass[Pre] => typeNumber(cls) - case clazz: ByValueClass[Pre] => {} - } + typeNumber(cls) val thisType = dispatch(cls.classType(Nil)) cls.decls.foreach { @@ -734,13 +731,18 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { case other => ??? } case TypeOf(value) => - FunctionInvocation[Post]( - typeOf.ref(()), - Seq(dispatch(value)), - Nil, - Nil, - Nil, - )(PanicBlame("typeOf requires nothing"))(e.o) + value.t match { + case cls: TByReferenceClass[Pre] => + FunctionInvocation[Post]( + typeOf.ref(()), + Seq(dispatch(value)), + Nil, + Nil, + Nil, + )(PanicBlame("typeOf requires nothing"))(e.o) + case cls: TByValueClass[Pre] => + const[Post](typeNumber(cls.cls.decl))(e.o) + } case InstanceOf(value, TypeValue(TUnion(ts))) => implicit val o: Origin = e.o dispatch(foldOr(ts.map(t => InstanceOf(value, TypeValue(t))))) From bcd96b56f8364dac6b75b9e859f7125374c16b79 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 19 Jul 2024 15:11:02 +0200 Subject: [PATCH 15/47] Temporarily set a fork of silicon in build.sc to test in CI --- build.sc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build.sc b/build.sc index 34fd4c7f28..e0fbcd0ad3 100644 --- a/build.sc +++ b/build.sc @@ -50,8 +50,8 @@ object viper extends ScalaModule { } object siliconGit extends GitModule { - def url = T { "https://github.com/viperproject/silicon.git" } - def commitish = T { "4033dd21614b3bbba9c7615655e41c6cf0b9d80b" } + def url = T { "https://github.com/superaxander/silicon.git" } + def commitish = T { "c63989f64eb759f33bde68c330ce07d6e34134fa" } def filteredRepo = T { val workspace = repo() os.remove.all(workspace / "src" / "test") From 24197b8d42825b511083e50b6c547bfec8b54b8f Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Tue, 23 Jul 2024 09:15:41 +0200 Subject: [PATCH 16/47] Update silver, clean up unused ByValueClass axioms --- build.sc | 2 +- src/rewrite/vct/rewrite/ClassToRef.scala | 68 +------------------ .../vct/rewrite/EncodeArrayValues.scala | 1 - 3 files changed, 4 insertions(+), 67 deletions(-) diff --git a/build.sc b/build.sc index e0fbcd0ad3..00822d675d 100644 --- a/build.sc +++ b/build.sc @@ -41,7 +41,7 @@ object external extends Module { object viper extends ScalaModule { object silverGit extends GitModule { def url = T { "https://github.com/viperproject/silver.git" } - def commitish = T { "4a8065758868eae3414f86f3d96e843a283444fc" } + def commitish = T { "93bc9b7516a710c8f01438e430058c4a54e20512" } def filteredRepo = T { val workspace = repo() os.remove.all(workspace / "src" / "test") diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 07126eadc6..ee01ae69a1 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -390,63 +390,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { .getOrElse("unknown") ) ) - val destructorAxiom = - new ADTAxiom[Post](foralls( - fieldTypes, - body = - variables => { - foldAnd(variables.combinations(2).map { case Seq(v1, v2) => - v1 !== v2 - }.toSeq) ==> - foldAnd(variables.zip(fieldFunctions).map { case (v, f) => - adtFunctionInvocation[Post]( - f.ref, - args = Seq(adtFunctionInvocation[Post]( - constructor.ref, - None, - args = variables, - )), - ) === v - }) - }, - triggers = - variables => { - fieldFunctions.map { f => - Seq(adtFunctionInvocation[Post]( - f.ref, - args = Seq(adtFunctionInvocation[Post]( - constructor.ref, - None, - args = variables, - )), - )) - } - }, - )) - - // This one generates a matching loop - val injectivityAxiom1 = - new ADTAxiom[Post](foralls( - Seq(axiomType, axiomType), - body = { case Seq(a0, a1) => - foldAnd(fieldFunctions.combinations(2).map { - case Seq(f0, f1) => - Neq( - adtFunctionInvocation[Post](f0.ref, args = Seq(a0)), - adtFunctionInvocation[Post](f1.ref, args = Seq(a1)), - ) - }.toSeq) - }, - triggers = { case Seq(a0, a1) => - fieldFunctions.combinations(2).map { case Seq(f0, f1) => - Seq( - adtFunctionInvocation[Post](f0.ref, None, args = Seq(a0)), - adtFunctionInvocation[Post](f1.ref, None, args = Seq(a1)), - ) - }.toSeq - }, - )) - val injectivityAxiom2 = + val injectivityAxiom = new ADTAxiom[Post](foralls( Seq(axiomType, axiomType), body = { case Seq(a0, a1) => @@ -511,14 +455,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { byValConsSucc(cls) = constructor byValClassSucc(cls) = new AxiomaticDataType[Post]( - Seq( -// constructor, -// destructorAxiom, - indexFunction, -// injectivityAxiom1, - injectivityAxiom2, - ) ++ destructorAxioms ++ indexAxioms ++ fieldFunctions ++ - fieldInverses, + Seq(indexFunction, injectivityAxiom) ++ destructorAxioms ++ + indexAxioms ++ fieldFunctions ++ fieldInverses, Nil, ) globalDeclarations.succeed(cls, byValClassSucc(cls)) diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index 6e963afd04..d7865bbda9 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -425,7 +425,6 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { ) // We do not allow this notation for recursive structs implicit val o: Origin = origin - // TODO: Instead of doing complicated stuff here just generate a Perm(struct.field, write) and rely on EncodyByValueClass to deal with it :) val fields = structType match { case t: TClass[Post] => From 653cd5fecd07b2118dd9b3b0e30ba65af9eb1d2f Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Wed, 24 Jul 2024 16:42:42 +0200 Subject: [PATCH 17/47] First working version pointer casts --- src/rewrite/vct/rewrite/ClassToRef.scala | 97 ++++-------- .../vct/rewrite/EncodeArrayValues.scala | 144 ++++++++++++++--- .../vct/rewrite/PrepareByValueClass.scala | 76 +-------- .../rewrite/SimplifyNestedQuantifiers.scala | 12 +- .../vct/rewrite/adt/ImportPointer.scala | 149 +++++++++++++++++- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 12 +- .../vct/test/integration/examples/CSpec.scala | 2 +- 7 files changed, 313 insertions(+), 179 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index ee01ae69a1..0591765789 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -23,9 +23,11 @@ case object ClassToRef extends RewriterBuilder { private def InstanceOfOrigin: Origin = Origin(Seq(PreferredName(Seq("subtype")), LabelContext("classToRef"))) - private val PointerCreationOrigin: Origin = Origin( - Seq(LabelContext("classToRef, pointer creation method")) - ) +// private val AsTypeOrigin: Origin = Origin( +// Seq(LabelContext("classToRef, asType function")) +// ) +// +// private val ValueAdtOrigin: Origin = Origin(Seq(PreferredName(Seq("Value")), LabelContext("classToRef"))) case class InstanceNullPreconditionFailed( inner: Blame[InstanceNull], @@ -76,26 +78,16 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { var typeNumberStore: mutable.Map[Class[Pre], Int] = mutable.Map() val typeOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap() val instanceOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap() - private val pointerCreationMethods - : SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap() - def makePointerCreationMethod(t: Type[Post]): Procedure[Post] = { - implicit val o: Origin = PointerCreationOrigin - - val result = new Variable[Post](TNonNullPointer(t)) - globalDeclarations.declare(procedure[Post]( - blame = AbstractApplicable, - contractBlame = TrueSatisfiable, - returnType = TVoid(), - outArgs = Seq(result), - ensures = UnitAccountedPredicate( - (PointerBlockLength(result.get)(FramedPtrBlockLength) === const(1)) &* - (PointerBlockOffset(result.get)(FramedPtrOffset) === const(0)) &* - Perm(PointerLocation(result.get)(FramedPtrOffset), WritePerm()) - ), - decreases = Some(DecreasesClauseNoRecursion[Post]()), - )) - } +// val valueAdt: SuccessionMap[Unit, AxiomaticDataType[Post]] = SuccessionMap() +// val valueAdtTypeArgument: SuccessionMap[Unit, Variable[Post]] = SuccessionMap() +// val asTypeFunctions: mutable.Map[Type[Pre], ADTFunction[Post]] = mutable.Map() +// +// def makeAsTypeFunction(typeName: String): ADTFunction[Post] = { +// val typeArg = valueAdtTypeArgument.getOrElseUpdate((), new Variable[Post](TType(TAnyValue()))(AsTypeOrigin.where(name="T"))) +// val value = new Variable[Post](TVar(typeArg.ref))(AsTypeOrigin.where(name="value")) +// new ADTFunction[Post](Seq(value), TNonNullPointer(TAnyValue()))(AsTypeOrigin.where(name="as_"+typeName)) +// } def typeNumber(cls: Class[Pre]): Int = typeNumberStore.getOrElseUpdate(cls, typeNumberStore.size + 1) @@ -118,7 +110,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ) } - def transitiveByValuePermissions( + private def transitiveByValuePermissions( obj: Expr[Pre], t: TByValueClass[Pre], amount: Expr[Pre], @@ -174,6 +166,10 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { globalDeclarations.declare(typeOf(())) instanceOf(()) = makeInstanceOf globalDeclarations.declare(instanceOf(())) +// if (asTypeFunctions.nonEmpty) { +// valueAdt(()) = new AxiomaticDataType[Post](asTypeFunctions.values.toSeq, Seq(valueAdtTypeArgument(())))(ValueAdtOrigin) +// globalDeclarations.declare(valueAdt(())) +// } }._1 ) @@ -462,7 +458,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { globalDeclarations.succeed(cls, byValClassSucc(cls)) case _ => cls.drop() } - case decl => rewriteDefault(decl) + case decl => super.dispatch(decl) } def instantiate(cls: Class[Pre], target: Ref[Post, Variable[Post]])( @@ -487,46 +483,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { )(PanicBlame("typeOf requires nothing.")) === const(typeNumber(cls)) ), )) - case cls: ByValueClass[Pre] => throw ExtraNode -// val (assigns, vars) = -// cls.decls.collect { case field: InstanceField[Pre] => -// val element = field.t.asPointer.get.element -// val newE = dispatch(element) -// val v = new Variable[Post](TNonNullPointer(newE)) -// ( -// InvokeProcedure[Post]( -// pointerCreationMethods -// .getOrElseUpdate(element, makePointerCreationMethod(newE)) -// .ref, -// Nil, -// Seq(v.get), -// Nil, -// Nil, -// Nil, -// )(TrueSatisfiable), -// v, -// ) -// }.unzip -// val assertions = if (vars.size > 1) { -// Seq(Assert(foldAnd[Post](vars.combinations(2).map { case Seq(a,b) => a.get !== b.get}.toSeq))(PanicBlame("Newly created pointers should be distinct"))) -// } else { -// Nil -// } -// Scope( -// vars, -// Block( -// assigns ++ assertions ++ Seq( -// Assign( -// Local(target), -// adtFunctionInvocation[Post]( -// byValConsSucc.ref(cls), -// args = vars.map(_.get), -// ), -// )(AssignLocalOk) -// // TODO: Add back typeOf here (but use a separate definition for the adt) -// ) -// ), -// ) + case _: ByValueClass[Pre] => throw ExtraNode } } @@ -577,7 +534,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { }, yields = yields.map { case (e, Ref(v)) => (dispatch(e), succ(v)) }, )(inv.blame)(inv.o) - case other => rewriteDefault(other) + case other => super.dispatch(other) } override def dispatch(node: ApplyAnyPredicate[Pre]): ApplyAnyPredicate[Post] = @@ -666,6 +623,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { t match { case t: TClass[Pre] if t.typeArgs.isEmpty => const(typeNumber(t.cls.decl))(e.o) + // Keep pointer casts intact for the adtPointer stage + case _: TPointer[Pre] | _: TNonNullPointer[Pre] => e.rewriteDefault() case other => ??? } case TypeOf(value) => @@ -701,7 +660,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { Nil, Nil, )(PanicBlame("instanceOf requires nothing"))(e.o) - case Cast(value, typeValue) => + case Cast(value, typeValue) if value.t.asClass.isDefined => dispatch( value ) // Discard for now, should assert instanceOf(value, typeValue) @@ -735,7 +694,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { case v @ Value(PredicateLocation(inv: InstancePredicateApply[Pre])) => implicit val o: Origin = e.o Star[Post](v.rewrite(), dispatch(inv.obj) !== Null()) - case _ => rewriteDefault(e) + case _ => super.dispatch(e) } override def dispatch(t: Type[Pre]): Type[Post] = @@ -747,7 +706,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { Nil, ) case TAnyClass() => TRef() - case t => rewriteDefault(t) + case t => super.dispatch(t) } override def dispatch(loc: Location[Pre]): Location[Post] = @@ -767,6 +726,6 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { )(loc.o) )(NonNullPointerNull)(loc.o) } - case default => rewriteDefault(default) + case default => super.dispatch(default) } } diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index d7865bbda9..a0f967a0b2 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -117,14 +117,15 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { .Map() def makeFree( - t: Type[Post] + t: Type[Pre], + newT: Type[Post], ): (Procedure[Post], FreePointer[Pre] => PointerFreeFailed[Pre]) = { implicit val o: Origin = freeFuncOrigin var errors: Seq[Expr[Pre] => PointerFreeError] = Seq() val proc = globalDeclarations.declare({ val (vars, ptr) = variables.collect { - val a_var = new Variable[Post](TPointer(t))(o.where(name = "p")) + val a_var = new Variable[Post](TPointer(newT))(o.where(name = "p")) variables.declare(a_var) Local[Post](a_var.ref) } @@ -193,7 +194,8 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { // If structure contains structs, the permission for those fields need to be released as well val permFields = t match { - case t: TClass[Post] => unwrapStructPerm(access, t, o, makeStruct) + case t: TClass[Pre] => + unwrapStructPerm(access, None, t, o, makeStruct) case _ => Seq() } requiresT = @@ -412,33 +414,75 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { })) } + def unwrapStructCasts( + struct: Variable[Post] => Expr[Post], + pointer: Variable[Post] => Expr[Post], + structType: TClass[Pre], + origin: Origin, + makeStruct: MakeAnns, + visited: Seq[TClass[Pre]] = Seq(), + ): Seq[(Expr[Post], Expr[Pre] => PointerFreeError)] = { + if (visited.contains(structType)) { + // We do not allow this notation for recursive structs + throw UnsupportedStructPerm(origin) + } + implicit val o: Origin = origin + val field = structType.cls.decl.declarations.collectFirst { + case field: InstanceField[Pre] => field + } + if (field.isDefined) { + // TODO: I kind of want to only do this one level deep (but for every type) + val fieldType = field.get.t + val result = + if (typeIsRef(fieldType)) { + unwrapStructCasts( + (i: Variable[Post]) => + Deref[Post](struct(i), succ(field.get))(DerefPerm), + pointer, + fieldType.asInstanceOf[TClass[Pre]], + origin, + makeStruct, + structType +: visited, + ) + } else { Nil } + result :+ + (( + makeStruct.makeCast( + pointer, + (i: Variable[Post]) => + AddrOf(Deref[Post](struct(i), succ(field.get))(DerefPerm)), + TNonNullPointer(dispatch(fieldType)), + ), + // Error should never occur since this part should not be emitted in the definition of free + (p: Expr[Pre]) => GenericPointerFreeError(p), + )) + } else { Nil } + } + def unwrapStructPerm( struct: Variable[Post] => Expr[Post], - structType: TClass[Post], + // Needs to be provided to define asType assertions using casts + pointer: Option[Variable[Post] => Expr[Post]], + structType: TClass[Pre], origin: Origin, makeStruct: MakeAnns, - visited: Seq[TClass[Post]] = Seq(), + visited: Seq[TClass[Pre]] = Seq(), ): Seq[(Expr[Post], Expr[Pre] => PointerFreeError)] = { - if (visited.contains(structType)) - throw UnsupportedStructPerm( - origin - ) // We do not allow this notation for recursive structs + if (visited.contains(structType)) { + // We do not allow this notation for recursive structs + throw UnsupportedStructPerm(origin) + } implicit val o: Origin = origin - val fields = - structType match { - case t: TClass[Post] => - t.cls.decl.declarations.collect { case field: InstanceField[Post] => - field - } - case _ => Seq() - } + val fields = structType.cls.decl.declarations.collect { + case field: InstanceField[Pre] => field + } val newFieldPerms = fields.map(member => { val loc = - (i: Variable[Post]) => Deref[Post](struct(i), member.ref)(DerefPerm) + (i: Variable[Post]) => Deref[Post](struct(i), succ(member))(DerefPerm) var anns: Seq[(Expr[Post], Expr[Pre] => PointerFreeError)] = Seq(( makeStruct.makePerm( - i => FieldLocation[Post](struct(i), member.ref), + i => FieldLocation[Post](struct(i), succ(member)), IteratedPtrInjective, ), (p: Expr[Pre]) => @@ -447,6 +491,18 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { Referrable.originName(member), ), )) + if (pointer.isDefined) { + anns = + anns :+ + (( + makeStruct.makeSelfCast( + (i: Variable[Post]) => AddrOf(loc(i)), + TNonNullPointer(dispatch(member.t)), + ), + // Error should never occur since this part should not be emitted in the definition of free + (p: Expr[Pre]) => GenericPointerFreeError(p), + )) + } anns = if (typeIsRef(member.t)) anns :+ @@ -457,10 +513,11 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { else anns member.t match { - case newStruct: TClass[Post] => + case newStruct: TClass[Pre] => // We recurse, since a field is another struct anns ++ unwrapStructPerm( loc, + pointer.map { _ => (i: Variable[Post]) => AddrOf(loc(i)) }, newStruct, origin, makeStruct, @@ -470,7 +527,17 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { } }) - newFieldPerms.flatten + if (pointer.isDefined && fields.nonEmpty) { + // TODO: I kind of want to only do this one level deep (but for every type) + newFieldPerms.flatten ++ unwrapStructCasts( + struct, + pointer.get, + structType, + origin, + makeStruct, + visited, + ) + } else { newFieldPerms.flatten } } case class MakeAnns( @@ -499,6 +566,29 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { val body = (pre1 && pre2 && access(i) === access(j)) ==> (i.get === j.get) Forall(Seq(i, j), Seq(triggerUnique), body) } + + def makeCast( + pointer: Variable[Post] => Expr[Post], + fieldPointer: Variable[Post] => Expr[Post], + t: Type[Post], + ): Expr[Post] = { + implicit val o: Origin = arrayCreationOrigin + val zero = const[Post](0) + val body = (zero <= i.get && i.get < size) ==> + (Cast(pointer(i), TypeValue(t)) === Cast(fieldPointer(i), TypeValue(t))) + Forall(Seq(i), Seq(Seq(Cast(pointer(i), TypeValue(t)))), body) + } + + def makeSelfCast( + pointer: Variable[Post] => Expr[Post], + t: Type[Post], + ): Expr[Post] = { + implicit val o: Origin = arrayCreationOrigin + val zero = const[Post](0) + val body = (zero <= i.get && i.get < size) ==> + (Cast(pointer(i), TypeValue(t)) === pointer(i)) + Forall(Seq(i), Seq(Seq(Cast(pointer(i), TypeValue(t)))), body) + } } def typeIsRef(t: Type[_]): Boolean = @@ -526,6 +616,8 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { val j = new Variable[Post](TInt())(o.where(name = "j")) val access = (i: Variable[Post]) => PointerSubscript(result, i.get)(FramedPtrOffset) + val pointerAccess = + (i: Variable[Post]) => PointerAdd(result, i.get)(FramedPtrOffset) val makeStruct = MakeAnns( i, @@ -555,8 +647,9 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { else { ensures &* makeStruct.makeUnique(access) } val permFields = - dispatch(elementType) match { - case t: TClass[Post] => unwrapStructPerm(access, t, o, makeStruct) + elementType match { + case t: TClass[Pre] => + unwrapStructPerm(access, Some(pointerAccess), t, o, makeStruct) case _ => Seq() } @@ -635,11 +728,12 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { case free @ FreePointer(xs) => val newXs = dispatch(xs) val TPointer(t) = newXs.t - val (freeFunc, freeBlame) = freeMethods.getOrElseUpdate(t, makeFree(t)) + val (freeFunc, freeBlame) = freeMethods + .getOrElseUpdate(t, makeFree(xs.t.asPointer.get.element, t)) ProcedureInvocation[Post](freeFunc.ref, Seq(newXs), Nil, Nil, Nil, Nil)( freeBlame(free) )(free.o) - case other => rewriteDefault(other) + case other => super.dispatch(other) } } } diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala index 228e875035..cee61769a6 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -66,19 +66,6 @@ case object PrepareByValueClass extends RewriterBuilder { extends CopyContext private case class NoCopy() extends CopyContext - - case class PointerLocationDerefBlame(blame: Blame[PointerLocationError]) - extends Blame[PointerDerefError] { - override def blame(error: PointerDerefError): Unit = { - error match { - case error: PointerLocationError => blame.blame(error) - case _ => - Unreachable( - "Blame of the respective pointer operation should be used not of DerefPointer" - ) - } - } - } } case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { @@ -213,7 +200,6 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { } val newFieldPerms = fields.map(member => { val loc = FieldLocation[Post](obj, succ(member)) - // TODO: Don't go through regular pointers... member.t match { case inner: TByValueClass[Pre] => Perm[Post](loc, perm) &* unwrapClassPerm( @@ -229,38 +215,6 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { foldStar(newFieldPerms) } - private def unwrapClassComp( - comp: (Expr[Post], Expr[Post]) => Expr[Post], - left: Expr[Post], - right: Expr[Post], - structType: TByValueClass[Pre], - visited: Seq[TByValueClass[Pre]] = Nil, - )(implicit o: Origin): Expr[Post] = { - // TODO: Better error - if (visited.contains(structType)) - throw UnsupportedStructPerm(o) - - val blame = PanicBlame("Struct deref can never fail") - val fields = structType.cls.decl.decls.collect { - case f: InstanceField[Pre] => f - } - foldAnd(fields.map(member => { - val l = - RawDerefPointer(Deref[Post](left, succ(member))(blame))( - NonNullPointerNull - ) - val r = - RawDerefPointer(Deref[Post](right, succ(member))(blame))( - NonNullPointerNull - ) - member.t match { -// case p: TNonNullPointer[Pre] if p.element.isInstanceOf[TByValueClass[Pre]] => -// unwrapClassComp(comp, DerefPointer(l)(NonNullPointerNull), r, p.element.asInstanceOf[TByValueClass[Pre]], structType +: visited) - case _ => comp(l, r) - } - })) - } - override def dispatch(node: Expr[Pre]): Expr[Post] = { implicit val o: Origin = node.o node match { @@ -272,27 +226,7 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { .ref, ) } -// case Eq(left, right) -// if left.t == right.t && left.t.isInstanceOf[TByValueClass[Pre]] => -// val newLeft = dispatch(left) -// val newRight = dispatch(right) -// return Eq(newLeft, newRight) && unwrapClassComp( -// (l, r) => Eq(l, r), -// newLeft, -// newRight, -// left.t.asInstanceOf[TByValueClass[Pre]], -// ) -// case Neq(left, right) -// if left.t == right.t && left.t.isInstanceOf[TByValueClass[Pre]] => -// val newLeft = dispatch(left) -// val newRight = dispatch(right) -// return Neq(newLeft, newRight) && unwrapClassComp( -// (l, r) => Neq(l, r), -// newLeft, -// newRight, -// left.t.asInstanceOf[TByValueClass[Pre]], -// ) - case _ => {} + case _ => } if (inAssignment.nonEmpty) node.rewriteDefault() @@ -315,9 +249,6 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { newP, ) } else { node.rewriteDefault() } - // What if I get rid of this... -// case Perm(loc@PointerLocation(e), p) if e.t.asPointer.exists(t => t.element.isInstanceOf[TByValueClass[Pre]])=> -// unwrapClassPerm(DerefPointer(dispatch(e))(PointerLocationDerefBlame(loc.blame))(loc.o), dispatch(p), e.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]]) case assign: PreAssignExpression[Pre] => val target = inAssignment.having(()) { dispatch(assign.target) } if (assign.target.t.isInstanceOf[TByValueClass[Pre]]) { @@ -328,20 +259,19 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { // No need for copy semantics in this context copyContext.having(NoCopy()) { assign.rewrite(target = target) } } - case invocation: Invocation[Pre] => { + case invocation: Invocation[Pre] => invocation.rewrite(args = invocation.args.map { a => if (a.t.isInstanceOf[TByValueClass[Pre]]) { copyContext.having(InCall(invocation)) { dispatch(a) } } else { copyContext.having(NoCopy()) { dispatch(a) } } }) - } - // WHOOPSIE WE ALSO MAKE A COPY IF IT WAS A POINTER case dp @ DerefPointer(HeapLocal(Ref(v))) if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => rewriteInCopyContext( dp, v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], ) + // TODO: Check for copy semantics in inappropriate places (i.e. when the user has made this a pointer) case dp @ DerefPointer(DerefHeapVariable(Ref(v))) if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => rewriteInCopyContext( diff --git a/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala b/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala index 968cd7d6e4..9ee0d0da11 100644 --- a/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala +++ b/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala @@ -78,11 +78,13 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() case e: Forall[Pre] => topLevel = false equalityChecker = ExpressionEqualityCheck(Some(infoGetter.finalInfo())) - mapUnfoldedStar( - e.body, - (b: Expr[Pre]) => - rewriteBinder(Forall(e.bindings, e.triggers, b)(e.o)), - ) + if (e.bindings.size > 1) { + mapUnfoldedStar( + e.body, + (b: Expr[Pre]) => + rewriteBinder(Forall(e.bindings, e.triggers, b)(e.o)), + ) + } else { e.rewriteDefault() } case e: Starall[Pre] => topLevel = false equalityChecker = ExpressionEqualityCheck(Some(infoGetter.finalInfo())) diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index 9c5d253e85..c67fb2a6c4 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -5,7 +5,7 @@ import ImportADT.typeText import vct.col.origin._ import vct.col.ref.Ref import vct.col.rewrite.Generation -import vct.col.util.AstBuildHelpers._ +import vct.col.util.AstBuildHelpers.{functionInvocation, _} import vct.col.util.SuccessionMap import scala.collection.mutable @@ -18,6 +18,10 @@ case object ImportPointer extends ImportADTBuilder("pointer") { Seq(LabelContext("adtPointer, pointer creation method")) ) + private val AsTypeOrigin: Origin = Origin( + Seq(LabelContext("classToRef, asType function")) + ) + case class PointerNullOptNone(inner: Blame[PointerNull], expr: Expr[_]) extends Blame[OptionNone] { override def blame(error: OptionNone): Unit = inner.blame(PointerNull(expr)) @@ -85,6 +89,23 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) private val pointerCreationMethods : SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap() + val asTypeFunctions: mutable.Map[Type[Pre], Function[Post]] = mutable.Map() + + private def makeAsTypeFunction(typeName: String): Function[Post] = { + val value = + new Variable[Post](TAxiomatic(pointerAdt.ref, Nil))( + AsTypeOrigin.where(name = "value") + ) + globalDeclarations.declare( + function[Post]( + AbstractApplicable, + TrueSatisfiable, + returnType = TAxiomatic(pointerAdt.ref, Nil), + args = Seq(value), + )(AsTypeOrigin.where(name = "as_" + typeName)) + ) + } + private def makePointerCreationMethod(t: Type[Post]): Procedure[Post] = { implicit val o: Origin = PointerCreationOrigin .where(name = "create_nonnull_pointer_" + t.toString) @@ -169,7 +190,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) t match { case TPointer(_) => TOption(TAxiomatic(pointerAdt.ref, Nil)) case TNonNullPointer(_) => TAxiomatic(pointerAdt.ref, Nil) - case other => rewriteDefault(other) + case other => super.postCoerce(other) } override def postCoerce(location: Location[Pre]): Location[Post] = { @@ -380,7 +401,129 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) PointerBlockLength(pointer)(pointerLen.blame) - PointerBlockOffset(pointer)(pointerLen.blame) ) - case other => rewriteDefault(other) + case Cast(value, typeValue) if value.t.asPointer.isDefined => + // TODO: Check if types are compatible + // TODO: Clean up code duplication + val targetType = typeValue.t.asInstanceOf[TType[Pre]].t + val innerType = targetType.asPointer.get.element + val newValue = dispatch(value) + (targetType, value.t) match { + case (TPointer(_), TPointer(_)) => + Select[Post]( + newValue === OptNone(), + OptNoneTyped(TAxiomatic(pointerAdt.ref, Nil)), + OptSome(functionInvocation[Post]( + PanicBlame("as_type requires nothing"), + asTypeFunctions.getOrElseUpdate( + innerType, + makeAsTypeFunction(innerType.toString), + ).ref, + Seq(value match { + case PointerAdd(_, _) => + OptGet(newValue)(PanicBlame( + "OptGet(Some(_)) should always be optimised away" + )) + case _ => + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq( + OptGet(newValue)(PanicBlame( + "Can never be null since this is ensured in the conditional statement" + )), + const(0), + ), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame( + "Pointer out of bounds in pointer cast (no appropriate blame available)" + )) + }), + )), + ) + case (TNonNullPointer(_), TPointer(_)) => + functionInvocation[Post]( + PanicBlame("as_type requires nothing"), + asTypeFunctions.getOrElseUpdate( + innerType, + makeAsTypeFunction(innerType.toString), + ).ref, + Seq(value match { + case PointerAdd(_, _) => + OptGet(newValue)(PanicBlame( + "OptGet(Some(_)) should always be optimised away" + )) + case _ => + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq( + OptGet(newValue)(PanicBlame( + "Casting a pointer to a non-null pointer implies the pointer must be statically known to be non-null" + )), + const(0), + ), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame( + "Pointer out of bounds in pointer cast (no appropriate blame available)" + )) + }), + ) + case (TPointer(_), TNonNullPointer(_)) => + OptSome(functionInvocation[Post]( + PanicBlame("as_type requires nothing"), + asTypeFunctions.getOrElseUpdate( + innerType, + makeAsTypeFunction(innerType.toString), + ).ref, + Seq(value match { + case PointerAdd(_, _) => + OptGet(newValue)(PanicBlame( + "OptGet(Some(_)) should always be optimised away" + )) + case _ => + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(newValue, const(0)), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame( + "Pointer out of bounds in pointer cast (no appropriate blame available)" + )) + }), + )) + case (TNonNullPointer(_), TNonNullPointer(_)) => + functionInvocation[Post]( + PanicBlame("as_type requires nothing"), + asTypeFunctions.getOrElseUpdate( + innerType, + makeAsTypeFunction(innerType.toString), + ).ref, + Seq(value match { + case PointerAdd(_, _) => + OptGet(newValue)(PanicBlame( + "OptGet(Some(_)) should always be optimised away" + )) + case _ => + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(newValue, const(0)), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame( + "Pointer out of bounds in pointer cast (no appropriate blame available)" + )) + }), + ) + } + case other => super.postCoerce(other) } } } diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index a5e11cdf44..517442ae0a 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -412,6 +412,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case CCast(CInvocation(CLocal("__vercors_malloc"), _, _, _), _) => throw UnsupportedMalloc(c) case CCast(n @ Null(), t) if t.asPointer.isDefined => rw.dispatch(n) + // TODO: Check if valid pointer cast + case CCast(e, t) if e.t.asPointer.isDefined && t.asPointer.isDefined => + Cast(rw.dispatch(e), TypeValue(rw.dispatch(t))(t.o))(c.o) case _ => throw UnsupportedCast(c) } @@ -1095,6 +1098,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } } + // TODO: (AS) Fixed-size arrays seem to become pointers but they're actually value types def rewriteArrayDeclaration( decl: CLocalDeclaration[Pre], cta: CTArray[Pre], @@ -1330,9 +1334,11 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) t.t }.get if (t.isInstanceOf[CTStruct[Pre]]) { - DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))( - local.blame - ) + DerefPointer( + DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))( + local.blame + ) + )(local.blame) } else { DerefHeapVariable[Post](cGlobalNameSuccessor.ref(ref))( local.blame diff --git a/test/main/vct/test/integration/examples/CSpec.scala b/test/main/vct/test/integration/examples/CSpec.scala index d8209460ad..9526f29bbe 100644 --- a/test/main/vct/test/integration/examples/CSpec.scala +++ b/test/main/vct/test/integration/examples/CSpec.scala @@ -141,7 +141,7 @@ class CSpec extends VercorsSpec { } """ - vercors should error withCode "unsupportedCast" in "Cast ptr struct to int" c + vercors should verify using silicon in "Cast ptr struct to int" c """ struct d{ int x; From 88305b8c54cc27ef8c38aeab0edea47166f4c1aa Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Wed, 24 Jul 2024 17:20:19 +0200 Subject: [PATCH 18/47] Also get rid of casts from Object to another class --- src/rewrite/vct/rewrite/ClassToRef.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 0591765789..ca771b1f6c 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -660,7 +660,9 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { Nil, Nil, )(PanicBlame("instanceOf requires nothing"))(e.o) - case Cast(value, typeValue) if value.t.asClass.isDefined => + case Cast(value, typeValue) + if value.t.asClass.isDefined || + value.t.isInstanceOf[TAnyClass[Pre]] => dispatch( value ) // Discard for now, should assert instanceOf(value, typeValue) From d736fdd09f9db77e65cbcdcb379c91ea2af56f9e Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 25 Jul 2024 09:53:03 +0200 Subject: [PATCH 19/47] Ignore quantifier in SimplifyNestedQuantifiers if it has a trigger and clean up ClassToRef --- src/rewrite/vct/rewrite/ClassToRef.scala | 20 ------------- .../rewrite/SimplifyNestedQuantifiers.scala | 30 +++++++++++-------- 2 files changed, 18 insertions(+), 32 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index ca771b1f6c..5ca41f2cd4 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -23,12 +23,6 @@ case object ClassToRef extends RewriterBuilder { private def InstanceOfOrigin: Origin = Origin(Seq(PreferredName(Seq("subtype")), LabelContext("classToRef"))) -// private val AsTypeOrigin: Origin = Origin( -// Seq(LabelContext("classToRef, asType function")) -// ) -// -// private val ValueAdtOrigin: Origin = Origin(Seq(PreferredName(Seq("Value")), LabelContext("classToRef"))) - case class InstanceNullPreconditionFailed( inner: Blame[InstanceNull], inv: InvokingNode[_], @@ -79,16 +73,6 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { val typeOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap() val instanceOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap() -// val valueAdt: SuccessionMap[Unit, AxiomaticDataType[Post]] = SuccessionMap() -// val valueAdtTypeArgument: SuccessionMap[Unit, Variable[Post]] = SuccessionMap() -// val asTypeFunctions: mutable.Map[Type[Pre], ADTFunction[Post]] = mutable.Map() -// -// def makeAsTypeFunction(typeName: String): ADTFunction[Post] = { -// val typeArg = valueAdtTypeArgument.getOrElseUpdate((), new Variable[Post](TType(TAnyValue()))(AsTypeOrigin.where(name="T"))) -// val value = new Variable[Post](TVar(typeArg.ref))(AsTypeOrigin.where(name="value")) -// new ADTFunction[Post](Seq(value), TNonNullPointer(TAnyValue()))(AsTypeOrigin.where(name="as_"+typeName)) -// } - def typeNumber(cls: Class[Pre]): Int = typeNumberStore.getOrElseUpdate(cls, typeNumberStore.size + 1) @@ -166,10 +150,6 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { globalDeclarations.declare(typeOf(())) instanceOf(()) = makeInstanceOf globalDeclarations.declare(instanceOf(())) -// if (asTypeFunctions.nonEmpty) { -// valueAdt(()) = new AxiomaticDataType[Post](asTypeFunctions.values.toSeq, Seq(valueAdtTypeArgument(())))(ValueAdtOrigin) -// globalDeclarations.declare(valueAdt(())) -// } }._1 ) diff --git a/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala b/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala index 9ee0d0da11..8d51e2b0e1 100644 --- a/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala +++ b/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala @@ -78,13 +78,11 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() case e: Forall[Pre] => topLevel = false equalityChecker = ExpressionEqualityCheck(Some(infoGetter.finalInfo())) - if (e.bindings.size > 1) { - mapUnfoldedStar( - e.body, - (b: Expr[Pre]) => - rewriteBinder(Forall(e.bindings, e.triggers, b)(e.o)), - ) - } else { e.rewriteDefault() } + mapUnfoldedStar( + e.body, + (b: Expr[Pre]) => + rewriteBinder(Forall(e.bindings, e.triggers, b)(e.o)), + ) case e: Starall[Pre] => topLevel = false equalityChecker = ExpressionEqualityCheck(Some(infoGetter.finalInfo())) @@ -227,6 +225,18 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() )(contract.blame)(contract.o) } + private def hasTriggers(e: Binder[Pre]): Boolean = + e match { + case Forall(body, triggers, _) => + triggers.exists(_.nonEmpty) || body.exists { + case InlinePattern(_, _, _) | InLinePatternLocation(_, _) => true + } + case Starall(body, triggers, _) => + triggers.exists(_.nonEmpty) || body.exists { + case InlinePattern(_, _, _) | InLinePatternLocation(_, _) => true + } + } + def rewriteLinearArray(e: Binder[Pre]): Option[Expr[Post]] = { val originalBody = e match { @@ -239,11 +249,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() return None // PB: do not attempt to reshape quantifiers that already have patterns - if ( - originalBody.exists { - case InlinePattern(_, _, _) | InLinePatternLocation(_, _) => true - } - ) { + if (hasTriggers(e)) { logger.debug(s"Not rewriting $e because it contains patterns") return None } From 261536115eacfaee11a19b9bfeacbd40564ff738 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 25 Jul 2024 10:33:21 +0200 Subject: [PATCH 20/47] Fix compilation error --- src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala b/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala index 8d51e2b0e1..b85541768e 100644 --- a/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala +++ b/src/rewrite/vct/rewrite/SimplifyNestedQuantifiers.scala @@ -227,11 +227,11 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() private def hasTriggers(e: Binder[Pre]): Boolean = e match { - case Forall(body, triggers, _) => + case Forall(_, triggers, body) => triggers.exists(_.nonEmpty) || body.exists { case InlinePattern(_, _, _) | InLinePatternLocation(_, _) => true } - case Starall(body, triggers, _) => + case Starall(_, triggers, body) => triggers.exists(_.nonEmpty) || body.exists { case InlinePattern(_, _, _) | InLinePatternLocation(_, _) => true } From 6c8be0afc8e2747f4c54f19844f60201bca62c15 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 25 Jul 2024 11:05:40 +0200 Subject: [PATCH 21/47] Reduce code duplication in adtPointer, remove all non-pointer casts in classToRef --- src/rewrite/vct/rewrite/ClassToRef.scala | 4 +- .../vct/rewrite/adt/ImportPointer.scala | 148 +++++------------- 2 files changed, 44 insertions(+), 108 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 5ca41f2cd4..a606a9a908 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -640,9 +640,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { Nil, Nil, )(PanicBlame("instanceOf requires nothing"))(e.o) - case Cast(value, typeValue) - if value.t.asClass.isDefined || - value.t.isInstanceOf[TAnyClass[Pre]] => + case Cast(value, typeValue) if value.t.asPointer.isEmpty => dispatch( value ) // Discard for now, should assert instanceOf(value, typeValue) diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index c67fb2a6c4..8419c465a3 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -403,7 +403,6 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) ) case Cast(value, typeValue) if value.t.asPointer.isDefined => // TODO: Check if types are compatible - // TODO: Clean up code duplication val targetType = typeValue.t.asInstanceOf[TType[Pre]].t val innerType = targetType.asPointer.get.element val newValue = dispatch(value) @@ -412,118 +411,57 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) Select[Post]( newValue === OptNone(), OptNoneTyped(TAxiomatic(pointerAdt.ref, Nil)), - OptSome(functionInvocation[Post]( - PanicBlame("as_type requires nothing"), - asTypeFunctions.getOrElseUpdate( - innerType, - makeAsTypeFunction(innerType.toString), - ).ref, - Seq(value match { - case PointerAdd(_, _) => - OptGet(newValue)(PanicBlame( - "OptGet(Some(_)) should always be optimised away" - )) - case _ => - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq( - OptGet(newValue)(PanicBlame( - "Can never be null since this is ensured in the conditional statement" - )), - const(0), - ), - typeArgs = Nil, - Nil, - Nil, - )(PanicBlame( - "Pointer out of bounds in pointer cast (no appropriate blame available)" - )) - }), + OptSome(applyAsTypeFunction( + innerType, + value, + OptGet(newValue)(PanicBlame( + "Can never be null since this is ensured in the conditional expression" + )), )), ) case (TNonNullPointer(_), TPointer(_)) => - functionInvocation[Post]( - PanicBlame("as_type requires nothing"), - asTypeFunctions.getOrElseUpdate( - innerType, - makeAsTypeFunction(innerType.toString), - ).ref, - Seq(value match { - case PointerAdd(_, _) => - OptGet(newValue)(PanicBlame( - "OptGet(Some(_)) should always be optimised away" - )) - case _ => - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq( - OptGet(newValue)(PanicBlame( - "Casting a pointer to a non-null pointer implies the pointer must be statically known to be non-null" - )), - const(0), - ), - typeArgs = Nil, - Nil, - Nil, - )(PanicBlame( - "Pointer out of bounds in pointer cast (no appropriate blame available)" - )) - }), + applyAsTypeFunction( + innerType, + value, + OptGet(newValue)(PanicBlame( + "Casting a pointer to a non-null pointer implies the pointer must be statically known to be non-null" + )), ) case (TPointer(_), TNonNullPointer(_)) => - OptSome(functionInvocation[Post]( - PanicBlame("as_type requires nothing"), - asTypeFunctions.getOrElseUpdate( - innerType, - makeAsTypeFunction(innerType.toString), - ).ref, - Seq(value match { - case PointerAdd(_, _) => - OptGet(newValue)(PanicBlame( - "OptGet(Some(_)) should always be optimised away" - )) - case _ => - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq(newValue, const(0)), - typeArgs = Nil, - Nil, - Nil, - )(PanicBlame( - "Pointer out of bounds in pointer cast (no appropriate blame available)" - )) - }), - )) + OptSome(applyAsTypeFunction(innerType, value, newValue)) case (TNonNullPointer(_), TNonNullPointer(_)) => - functionInvocation[Post]( - PanicBlame("as_type requires nothing"), - asTypeFunctions.getOrElseUpdate( - innerType, - makeAsTypeFunction(innerType.toString), - ).ref, - Seq(value match { - case PointerAdd(_, _) => - OptGet(newValue)(PanicBlame( - "OptGet(Some(_)) should always be optimised away" - )) - case _ => - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq(newValue, const(0)), - typeArgs = Nil, - Nil, - Nil, - )(PanicBlame( - "Pointer out of bounds in pointer cast (no appropriate blame available)" - )) - }), - ) + applyAsTypeFunction(innerType, value, newValue) } case other => super.postCoerce(other) } } + + private def applyAsTypeFunction( + innerType: Type[Pre], + preExpr: Expr[Pre], + postExpr: Expr[Post], + )(implicit o: Origin): Expr[Post] = { + functionInvocation[Post]( + PanicBlame("as_type requires nothing"), + asTypeFunctions + .getOrElseUpdate(innerType, makeAsTypeFunction(innerType.toString)).ref, + Seq(preExpr match { + case PointerAdd(_, _) => + OptGet(postExpr)(PanicBlame( + "OptGet(Some(_)) should always be optimised away" + )) + case _ => + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(postExpr, const(0)), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame( + "Pointer out of bounds in pointer cast (no appropriate blame available)" + )) + }), + ) + } } From e141bcf6d0279c284d5a180f44ff0526ae47899e Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 25 Jul 2024 13:43:40 +0200 Subject: [PATCH 22/47] Fix duplicate OptGet and add asType function to primitive pointer arrays --- src/rewrite/vct/rewrite/EncodeArrayValues.scala | 8 +++++++- src/rewrite/vct/rewrite/adt/ImportPointer.scala | 5 +---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index a0f967a0b2..9f135261c0 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -650,7 +650,13 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { elementType match { case t: TClass[Pre] => unwrapStructPerm(access, Some(pointerAccess), t, o, makeStruct) - case _ => Seq() + case t => + Seq(( + makeStruct + .makeSelfCast(pointerAccess, TNonNullPointer(dispatch(t))), + // Will never be used + (p: Expr[Pre]) => GenericPointerFreeError(p), + )) } ensures = diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index 8419c465a3..89c7c4e6c5 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -446,10 +446,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) asTypeFunctions .getOrElseUpdate(innerType, makeAsTypeFunction(innerType.toString)).ref, Seq(preExpr match { - case PointerAdd(_, _) => - OptGet(postExpr)(PanicBlame( - "OptGet(Some(_)) should always be optimised away" - )) + case PointerAdd(_, _) => postExpr case _ => FunctionInvocation[Post]( ref = pointerAdd.ref, From 401c3d9fc9190f89e5575dd5c5b1cd7ddb92b77d Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 25 Jul 2024 15:06:26 +0200 Subject: [PATCH 23/47] Get rid of more 'unknown' names in the C frontend --- src/rewrite/vct/rewrite/EncodeArrayValues.scala | 11 +---------- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 5 +++-- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index 9f135261c0..4fa2d90e4a 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -171,15 +171,6 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { (p: Expr[Pre]) => PointerInsufficientFreePermission(p), ), ) - var requires = (ptr !== Null()) &* - (PointerBlockOffset(ptr)(FramedPtrBlockOffset) === zero) &* - makeStruct.makePerm( - i => - PointerLocation(PointerAdd(ptr, i.get)(FramedPtrOffset))( - FramedPtrOffset - ), - IteratedPtrInjective, - ) requiresT = if (!typeIsRef(t)) requiresT @@ -216,7 +207,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { body = None, requires = requiresPred, decreases = Some(DecreasesClauseNoRecursion[Post]()), - )(o.where("free_" + t.toString)) + )(o.where(name = "free_" + t.toString)) }) (proc, (node: FreePointer[Pre]) => PointerFreeFailed(node, errors)) } diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index 517442ae0a..eca1812fb8 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -1042,7 +1042,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) contract = rw.dispatch(decl.decl.contract), pure = pure, inline = inline, - )(AbstractApplicable)(init.o) + )(AbstractApplicable)(init.o.sourceName(info.name)) ) ) case None => @@ -1050,7 +1050,8 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) if (t.isInstanceOf[TByValueClass[Post]]) { TNonNullPointer(t) } else { t } cGlobalNameSuccessor(RefCGlobalDeclaration(decl, idx)) = rw - .globalDeclarations.declare(new HeapVariable(newT)(init.o)) + .globalDeclarations + .declare(new HeapVariable(newT)(init.o.sourceName(info.name))) } } } From a1773b771ff07eca3bb6bbfb83d3df40e6e01f9e Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 25 Jul 2024 16:18:42 +0200 Subject: [PATCH 24/47] Remove Viper field access from trigger with top-level PointerSubscript or DerefPointer * Gives a 12% performance improvement on Silicon for examples/concepts/c/structs.c --- .../vct/rewrite/adt/ImportPointer.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index 89c7c4e6c5..a8753a98a4 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -286,25 +286,25 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) Nil, Nil, )(PanicBlame("ptr_deref requires nothing.")) - case other => rewriteDefault(other) + case other => dispatch(other) } } override def postCoerce(e: Expr[Pre]): Expr[Post] = { implicit val o: Origin = e.o e match { -// case f @ Forall(_, triggers, _) => -// f.rewrite(triggers = -// triggers.map(_.map(rewriteTopLevelPointerSubscriptInTrigger)) -// ) -// case s @ Starall(_, triggers, _) => -// s.rewrite(triggers = -// triggers.map(_.map(rewriteTopLevelPointerSubscriptInTrigger)) -// ) -// case e @ Exists(_, triggers, _) => -// e.rewrite(triggers = -// triggers.map(_.map(rewriteTopLevelPointerSubscriptInTrigger)) -// ) + case f @ Forall(_, triggers, _) => + f.rewrite(triggers = + triggers.map(_.map(rewriteTopLevelPointerSubscriptInTrigger)) + ) + case s @ Starall(_, triggers, _) => + s.rewrite(triggers = + triggers.map(_.map(rewriteTopLevelPointerSubscriptInTrigger)) + ) + case e @ Exists(_, triggers, _) => + e.rewrite(triggers = + triggers.map(_.map(rewriteTopLevelPointerSubscriptInTrigger)) + ) case sub @ PointerSubscript(pointer, index) => SilverDeref( obj = From 42aca99b11ea27e815b9b4c52774a7dd406ad6d4 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Wed, 21 Aug 2024 14:48:18 +0200 Subject: [PATCH 25/47] Implement basic pointer casts --- src/rewrite/vct/rewrite/ClassToRef.scala | 301 ++++++++++++++---- .../vct/rewrite/EncodeArrayValues.scala | 109 +------ .../vct/rewrite/adt/ImportPointer.scala | 31 +- .../viper/api/transform/ColToSilver.scala | 12 +- 4 files changed, 275 insertions(+), 178 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index a606a9a908..78e944ea18 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -23,6 +23,11 @@ case object ClassToRef extends RewriterBuilder { private def InstanceOfOrigin: Origin = Origin(Seq(PreferredName(Seq("subtype")), LabelContext("classToRef"))) + private def ValueAdtOrigin: Origin = + Origin(Seq(PreferredName(Seq("Value")), LabelContext("classToRef"))) + + private def CastHelperOrigin: Origin = Origin(Seq(LabelContext("classToRef"))) + case class InstanceNullPreconditionFailed( inner: Blame[InstanceNull], inv: InvokingNode[_], @@ -73,6 +78,15 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { val typeOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap() val instanceOf: SuccessionMap[Unit, Function[Post]] = SuccessionMap() + val valueAdt: SuccessionMap[Unit, AxiomaticDataType[Post]] = SuccessionMap() + val valueAdtTypeArgument: Variable[Post] = + new Variable(TType(TAnyValue()))(ValueAdtOrigin.where(name = "V")) + val valueAsFunctions: mutable.Map[Type[Pre], ADTFunction[Post]] = mutable + .Map() + + val castHelpers: SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap() + val castHelperCalls: ScopedStack[mutable.Set[Statement[Post]]] = ScopedStack() + def typeNumber(cls: Class[Pre]): Int = typeNumberStore.getOrElseUpdate(cls, typeNumberStore.size + 1) @@ -141,15 +155,68 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ) } + private def makeValueAdt: AxiomaticDataType[Post] = { + new AxiomaticDataType[Post]( + valueAsFunctions.values.toSeq, + Seq(valueAdtTypeArgument), + )(ValueAdtOrigin) + } + + // TODO: Also generate value as axioms for arrays once those are properly supported for C/CPP/LLVM + private def makeValueAsFunction( + typeName: String, + t: Type[Post], + ): ADTFunction[Post] = { + new ADTFunction[Post]( + Seq(new Variable(TVar[Post](valueAdtTypeArgument.ref))( + ValueAdtOrigin.where(name = "v") + )), + TNonNullPointer(t), + )(ValueAdtOrigin.where(name = "value_as_" + typeName)) + } + + private def unwrapValueAs( + axiomType: TAxiomatic[Post], + oldT: Type[Pre], + newT: Type[Post], + fieldRef: Ref[Post, ADTFunction[Post]], + )(implicit o: Origin): Seq[ADTAxiom[Post]] = { + (oldT match { + case t: TByValueClass[Pre] => { + // TODO: If there are no fields we should ignore the first field and add the axioms for the second field + t.cls.decl.decls.collectFirst({ case field: InstanceField[Pre] => + unwrapValueAs(axiomType, field.t, dispatch(field.t), fieldRef) + }).getOrElse(Nil) + } + case _ => Nil + }) :+ new ADTAxiom[Post](forall( + axiomType, + body = { a => + InlinePattern(adtFunctionInvocation[Post]( + valueAsFunctions + .getOrElseUpdate(oldT, makeValueAsFunction(oldT.toString, newT)) + .ref, + typeArgs = Some((valueAdt.ref(()), Seq(axiomType))), + args = Seq(a), + )) === Cast( + adtFunctionInvocation(fieldRef, args = Seq(a)), + TypeValue(TNonNullPointer(newT)), + ) + }, + )) + } + override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] = program.rewrite(declarations = globalDeclarations.collect { program.declarations.foreach(dispatch) - implicit val o: Origin = TypeOfOrigin typeOf(()) = makeTypeOf globalDeclarations.declare(typeOf(())) instanceOf(()) = makeInstanceOf globalDeclarations.declare(instanceOf(())) + if (valueAsFunctions.nonEmpty) { + globalDeclarations.declare(valueAdt.getOrElseUpdate((), makeValueAdt)) + } }._1 ) @@ -309,18 +376,49 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { case cls: ByValueClass[Pre] => implicit val o: Origin = cls.o val axiomType = TAxiomatic[Post](byValClassSucc.ref(cls), Nil) + val classType = cls.classType(Nil) + var valueAsAxioms: Seq[ADTAxiom[Post]] = Seq() val (fieldFunctions, fieldInverses, fieldTypes) = cls.decls.collect { case field: Field[Pre] => - val newT = TNonNullPointer(dispatch(field.t)) + val newT = dispatch(field.t) + val nonnullT = TNonNullPointer(newT) byValFieldSucc(field) = new ADTFunction[Post]( Seq(new Variable(axiomType)(field.o)), - newT, + nonnullT, )(field.o) + if (valueAsAxioms.isEmpty) { + // This is the first field + valueAsAxioms = + valueAsAxioms :+ new ADTAxiom[Post](forall( + axiomType, + body = { a => + InlinePattern(adtFunctionInvocation[Post]( + valueAsFunctions.getOrElseUpdate( + field.t, + makeValueAsFunction(field.t.toString, newT), + ).ref, + typeArgs = Some((valueAdt.ref(()), Seq(axiomType))), + args = Seq(a), + )) === adtFunctionInvocation( + byValFieldSucc.ref(field), + args = Seq(a), + ) + }, + )) + + valueAsAxioms = + valueAsAxioms ++ unwrapValueAs( + axiomType, + field.t, + newT, + byValFieldSucc.ref(field), + ) + } ( byValFieldSucc(field), new ADTFunction[Post]( - Seq(new Variable(newT)(field.o)), + Seq(new Variable(nonnullT)(field.o)), axiomType, )( field.o.copy( @@ -331,7 +429,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { .getOrElse("unknown") ) ), - newT, + nonnullT, ) }.unzip3 val constructor = @@ -432,7 +530,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { byValClassSucc(cls) = new AxiomaticDataType[Post]( Seq(indexFunction, injectivityAxiom) ++ destructorAxioms ++ - indexAxioms ++ fieldFunctions ++ fieldInverses, + indexAxioms ++ fieldFunctions ++ fieldInverses ++ + valueAsAxioms, Nil, ) globalDeclarations.succeed(cls, byValClassSucc(cls)) @@ -467,55 +566,63 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { } } - override def dispatch(stat: Statement[Pre]): Statement[Post] = - stat match { - case Instantiate(Ref(cls), Local(Ref(v))) => - instantiate(cls, succ(v))(stat.o) - case inv @ InvokeMethod( - obj, - Ref(method), - args, - outArgs, - typeArgs, - givenMap, - yields, - ) => - InvokeProcedure[Post]( - ref = methodSucc.ref(method), - args = dispatch(obj) +: args.map(dispatch), - outArgs = outArgs.map(dispatch), - typeArgs = typeArgs.map(dispatch), - givenMap = givenMap.map { case (Ref(v), e) => - (succ(v), dispatch(e)) - }, - yields = yields.map { case (e, Ref(v)) => (dispatch(e), succ(v)) }, - )(PreBlameSplit.left( - InstanceNullPreconditionFailed(inv.blame, inv), - PreBlameSplit - .left(PanicBlame("incorrect instance method type?"), inv.blame), - ))(inv.o) - case inv @ InvokeConstructor( - Ref(cons), - _, - out, - args, - outArgs, - typeArgs, - givenMap, - yields, - ) => - InvokeProcedure[Post]( - ref = consSucc.ref(cons), - args = args.map(dispatch), - outArgs = dispatch(out) +: outArgs.map(dispatch), - typeArgs = typeArgs.map(dispatch), - givenMap = givenMap.map { case (Ref(v), e) => - (succ(v), dispatch(e)) - }, - yields = yields.map { case (e, Ref(v)) => (dispatch(e), succ(v)) }, - )(inv.blame)(inv.o) - case other => super.dispatch(other) - } + override def dispatch(stat: Statement[Pre]): Statement[Post] = { + val helpers: mutable.Set[Statement[Post]] = mutable.Set() + val result = + castHelperCalls.having(helpers) { + stat match { + case Instantiate(Ref(cls), Local(Ref(v))) => + instantiate(cls, succ(v))(stat.o) + case inv @ InvokeMethod( + obj, + Ref(method), + args, + outArgs, + typeArgs, + givenMap, + yields, + ) => + InvokeProcedure[Post]( + ref = methodSucc.ref(method), + args = dispatch(obj) +: args.map(dispatch), + outArgs = outArgs.map(dispatch), + typeArgs = typeArgs.map(dispatch), + givenMap = givenMap.map { case (Ref(v), e) => + (succ(v), dispatch(e)) + }, + yields = yields.map { case (e, Ref(v)) => (dispatch(e), succ(v)) }, + )(PreBlameSplit.left( + InstanceNullPreconditionFailed(inv.blame, inv), + PreBlameSplit + .left(PanicBlame("incorrect instance method type?"), inv.blame), + ))(inv.o) + case inv @ InvokeConstructor( + Ref(cons), + _, + out, + args, + outArgs, + typeArgs, + givenMap, + yields, + ) => + InvokeProcedure[Post]( + ref = consSucc.ref(cons), + args = args.map(dispatch), + outArgs = dispatch(out) +: outArgs.map(dispatch), + typeArgs = typeArgs.map(dispatch), + givenMap = givenMap.map { case (Ref(v), e) => + (succ(v), dispatch(e)) + }, + yields = yields.map { case (e, Ref(v)) => (dispatch(e), succ(v)) }, + )(inv.blame)(inv.o) + case other => super.dispatch(other) + } + } + + if (helpers.nonEmpty) { Block(helpers.toSeq :+ result)(stat.o) } + else { result } + } override def dispatch(node: ApplyAnyPredicate[Pre]): ApplyAnyPredicate[Post] = node match { @@ -527,6 +634,77 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { case other => other.rewriteDefault() } + private def unwrapCastConstraints(outerType: Type[Post], t: Type[Pre])( + implicit o: Origin + ): Expr[Post] = { + val newT = dispatch(t) + val constraint = forall[Post]( + TNonNullPointer(outerType), + body = { p => + PolarityDependent( + Greater( + CurPerm(PointerLocation(p)(PanicBlame( + "Referring to a non-null pointer should not cause any verification failures" + ))), + NoPerm(), + ) ==> + (InlinePattern(Cast(p, TypeValue(TNonNullPointer(newT)))) === + adtFunctionInvocation( + valueAsFunctions + .getOrElseUpdate(t, makeValueAsFunction(t.toString, newT)) + .ref, + typeArgs = Some((valueAdt.ref(()), Seq(outerType))), + args = Seq(DerefPointer(p)(PanicBlame( + "Pointer deref is safe since the permission is framed" + ))), + )), + tt, + ) + }, + ) + + if (t.isInstanceOf[TByValueClass[Pre]]) { + constraint &* + t.asInstanceOf[TByValueClass[Pre]].cls.decl.decls.collectFirst { + case field: InstanceField[Pre] => + unwrapCastConstraints(outerType, field.t) + }.getOrElse(tt) + } else { constraint } + } + + private def makeCastHelper(t: Type[Pre]): Procedure[Post] = { + implicit val o: Origin = CastHelperOrigin + .where(name = "constraints_" + t.toString) + globalDeclarations.declare(procedure( + AbstractApplicable, + TrueSatisfiable, + ensures = UnitAccountedPredicate(unwrapCastConstraints(dispatch(t), t)), + )) + } + + private def addCastHelpers(t: Type[Pre], calls: mutable.Set[Statement[Post]])( + implicit o: Origin + ): Unit = { + t match { + case cls: TByValueClass[Pre] => { + calls.add( + InvokeProcedure[Post]( + castHelpers.getOrElseUpdate(t, makeCastHelper(t)).ref, + Nil, + Nil, + Nil, + Nil, + Nil, + )(TrueSatisfiable)(o) + ) + cls.cls.decl.decls.collectFirst { case field: InstanceField[Pre] => + addCastHelpers(field.t, calls) + } + } + case _ => + } + } + override def dispatch(e: Expr[Pre]): Expr[Post] = e match { case inv @ MethodInvocation( @@ -640,7 +818,18 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { Nil, Nil, )(PanicBlame("instanceOf requires nothing"))(e.o) - case Cast(value, typeValue) if value.t.asPointer.isEmpty => + case Cast(value, typeValue) if value.t.asPointer.isDefined => { + // Keep pointer casts and add extra annotations + // TODO: Check if we need to get rid of the pointer add's here since in my testing that broke some of the reasoning + if (castHelperCalls.nonEmpty) { + addCastHelpers(value.t.asPointer.get.element, castHelperCalls.top)( + e.o + ) + } + + e.rewriteDefault() + } + case Cast(value, typeValue) => dispatch( value ) // Discard for now, should assert instanceOf(value, typeValue) diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index 4fa2d90e4a..afc4a6b820 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -185,8 +185,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { // If structure contains structs, the permission for those fields need to be released as well val permFields = t match { - case t: TClass[Pre] => - unwrapStructPerm(access, None, t, o, makeStruct) + case t: TClass[Pre] => unwrapStructPerm(access, t, o, makeStruct) case _ => Seq() } requiresT = @@ -405,55 +404,8 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { })) } - def unwrapStructCasts( - struct: Variable[Post] => Expr[Post], - pointer: Variable[Post] => Expr[Post], - structType: TClass[Pre], - origin: Origin, - makeStruct: MakeAnns, - visited: Seq[TClass[Pre]] = Seq(), - ): Seq[(Expr[Post], Expr[Pre] => PointerFreeError)] = { - if (visited.contains(structType)) { - // We do not allow this notation for recursive structs - throw UnsupportedStructPerm(origin) - } - implicit val o: Origin = origin - val field = structType.cls.decl.declarations.collectFirst { - case field: InstanceField[Pre] => field - } - if (field.isDefined) { - // TODO: I kind of want to only do this one level deep (but for every type) - val fieldType = field.get.t - val result = - if (typeIsRef(fieldType)) { - unwrapStructCasts( - (i: Variable[Post]) => - Deref[Post](struct(i), succ(field.get))(DerefPerm), - pointer, - fieldType.asInstanceOf[TClass[Pre]], - origin, - makeStruct, - structType +: visited, - ) - } else { Nil } - result :+ - (( - makeStruct.makeCast( - pointer, - (i: Variable[Post]) => - AddrOf(Deref[Post](struct(i), succ(field.get))(DerefPerm)), - TNonNullPointer(dispatch(fieldType)), - ), - // Error should never occur since this part should not be emitted in the definition of free - (p: Expr[Pre]) => GenericPointerFreeError(p), - )) - } else { Nil } - } - def unwrapStructPerm( struct: Variable[Post] => Expr[Post], - // Needs to be provided to define asType assertions using casts - pointer: Option[Variable[Post] => Expr[Post]], structType: TClass[Pre], origin: Origin, makeStruct: MakeAnns, @@ -482,18 +434,6 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { Referrable.originName(member), ), )) - if (pointer.isDefined) { - anns = - anns :+ - (( - makeStruct.makeSelfCast( - (i: Variable[Post]) => AddrOf(loc(i)), - TNonNullPointer(dispatch(member.t)), - ), - // Error should never occur since this part should not be emitted in the definition of free - (p: Expr[Pre]) => GenericPointerFreeError(p), - )) - } anns = if (typeIsRef(member.t)) anns :+ @@ -508,7 +448,6 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { // We recurse, since a field is another struct anns ++ unwrapStructPerm( loc, - pointer.map { _ => (i: Variable[Post]) => AddrOf(loc(i)) }, newStruct, origin, makeStruct, @@ -518,17 +457,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { } }) - if (pointer.isDefined && fields.nonEmpty) { - // TODO: I kind of want to only do this one level deep (but for every type) - newFieldPerms.flatten ++ unwrapStructCasts( - struct, - pointer.get, - structType, - origin, - makeStruct, - visited, - ) - } else { newFieldPerms.flatten } + newFieldPerms.flatten } case class MakeAnns( @@ -557,29 +486,6 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { val body = (pre1 && pre2 && access(i) === access(j)) ==> (i.get === j.get) Forall(Seq(i, j), Seq(triggerUnique), body) } - - def makeCast( - pointer: Variable[Post] => Expr[Post], - fieldPointer: Variable[Post] => Expr[Post], - t: Type[Post], - ): Expr[Post] = { - implicit val o: Origin = arrayCreationOrigin - val zero = const[Post](0) - val body = (zero <= i.get && i.get < size) ==> - (Cast(pointer(i), TypeValue(t)) === Cast(fieldPointer(i), TypeValue(t))) - Forall(Seq(i), Seq(Seq(Cast(pointer(i), TypeValue(t)))), body) - } - - def makeSelfCast( - pointer: Variable[Post] => Expr[Post], - t: Type[Post], - ): Expr[Post] = { - implicit val o: Origin = arrayCreationOrigin - val zero = const[Post](0) - val body = (zero <= i.get && i.get < size) ==> - (Cast(pointer(i), TypeValue(t)) === pointer(i)) - Forall(Seq(i), Seq(Seq(Cast(pointer(i), TypeValue(t)))), body) - } } def typeIsRef(t: Type[_]): Boolean = @@ -639,15 +545,8 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { val permFields = elementType match { - case t: TClass[Pre] => - unwrapStructPerm(access, Some(pointerAccess), t, o, makeStruct) - case t => - Seq(( - makeStruct - .makeSelfCast(pointerAccess, TNonNullPointer(dispatch(t))), - // Will never be used - (p: Expr[Pre]) => GenericPointerFreeError(p), - )) + case t: TClass[Pre] => unwrapStructPerm(access, t, o, makeStruct) + case _ => Nil } ensures = diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index a8753a98a4..8c807ccb93 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -2,6 +2,7 @@ package vct.col.rewrite.adt import vct.col.ast._ import ImportADT.typeText +import hre.util.ScopedStack import vct.col.origin._ import vct.col.ref.Ref import vct.col.rewrite.Generation @@ -90,6 +91,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) : SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap() val asTypeFunctions: mutable.Map[Type[Pre], Function[Post]] = mutable.Map() + private val inAxiom: ScopedStack[Unit] = ScopedStack() private def makeAsTypeFunction(typeName: String): Function[Post] = { val value = @@ -186,6 +188,21 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) case other => super.applyCoercion(e, other) } + override def postCoerce(decl: Declaration[Pre]): Unit = { + decl match { + case axiom: ADTAxiom[Pre] => + inAxiom.having(()) { + allScopes.anySucceed(axiom, axiom.rewriteDefault()) + } + // TODO: This is an ugly way to exempt this one bit of generated code from having ptrAdd's added + case proc: Procedure[Pre] + if proc.o.find[LabelContext].exists(_.label == "classToRef") && + proc.o.getPreferredNameOrElse().snake.startsWith("constraints_") => + inAxiom.having(()) { allScopes.anySucceed(proc, proc.rewriteDefault()) } + case _ => super.postCoerce(decl) + } + } + override def postCoerce(t: Type[Pre]): Type[Post] = t match { case TPointer(_) => TOption(TAxiomatic(pointerAdt.ref, Nil)) @@ -270,7 +287,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) case deref @ DerefPointer(pointer) => FunctionInvocation[Post]( ref = pointerDeref.ref, - args = Seq( + args = Seq(if (inAxiom.isEmpty) { FunctionInvocation[Post]( ref = pointerAdd.ref, // Always index with zero, otherwise quantifiers with pointers do not get triggered @@ -281,7 +298,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) )(NoContext( DerefPointerBoundsPreconditionFailed(deref.blame, pointer) )) - ), + } else { unwrapOption(pointer, deref.blame) }), typeArgs = Nil, Nil, Nil, @@ -343,7 +360,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) obj = FunctionInvocation[Post]( ref = pointerDeref.ref, - args = Seq( + args = Seq(if (inAxiom.isEmpty) { FunctionInvocation[Post]( ref = pointerAdd.ref, // Always index with zero, otherwise quantifiers with pointers do not get triggered @@ -354,7 +371,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) )(NoContext( DerefPointerBoundsPreconditionFailed(deref.blame, pointer) )) - ), + } else { unwrapOption(pointer, deref.blame) }), typeArgs = Nil, Nil, Nil, @@ -364,7 +381,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) case deref @ RawDerefPointer(pointer) => FunctionInvocation[Post]( ref = pointerDeref.ref, - args = Seq( + args = Seq(if (inAxiom.isEmpty) { FunctionInvocation[Post]( ref = pointerAdd.ref, // Always index with zero, otherwise quantifiers with pointers do not get triggered @@ -375,7 +392,7 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) )(NoContext( DerefPointerBoundsPreconditionFailed(deref.blame, pointer) )) - ), + } else { unwrapOption(pointer, deref.blame) }), typeArgs = Nil, Nil, Nil, @@ -447,6 +464,8 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) .getOrElseUpdate(innerType, makeAsTypeFunction(innerType.toString)).ref, Seq(preExpr match { case PointerAdd(_, _) => postExpr + // Don't add ptrAdd in an ADT axiom since we cannot use functions with preconditions there + case _ if inAxiom.nonEmpty => postExpr case _ => FunctionInvocation[Post]( ref = pointerAdd.ref, diff --git a/src/viper/viper/api/transform/ColToSilver.scala b/src/viper/viper/api/transform/ColToSilver.scala index a3aaf43270..9e27216415 100644 --- a/src/viper/viper/api/transform/ColToSilver.scala +++ b/src/viper/viper/api/transform/ColToSilver.scala @@ -230,17 +230,7 @@ case class ColToSilver(program: col.Program[_]) { function.contract.decreases.toSeq.map(decreases), accountedPred(function.contract.ensures), function.body.map(exp), - )( - pos = pos(function), - info = - if (ref(function) == "ptrDerefblahblah") - ConsInfo( - AnnotationInfo(Map("opaque" -> Seq())), - NodeInfo(function), - ) - else - NodeInfo(function), - ) + )(pos = pos(function), info = NodeInfo(function)) } case procedure: col.Procedure[_] if procedure.returnType == col.TVoid() && !procedure.inline && From 9b96b334cd5bb59a7e078c505534cb0eb94ee8dd Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 22 Aug 2024 13:59:38 +0200 Subject: [PATCH 26/47] Add pointer cast helpers in loops --- examples/concepts/c/pointer_casts.c | 71 ++++++++++++ src/rewrite/vct/rewrite/ClassToRef.scala | 103 ++++++++++++++---- .../vct/rewrite/adt/ImportPointer.scala | 100 ++++++++++------- .../vct/test/integration/examples/CSpec.scala | 15 +-- 4 files changed, 219 insertions(+), 70 deletions(-) create mode 100644 examples/concepts/c/pointer_casts.c diff --git a/examples/concepts/c/pointer_casts.c b/examples/concepts/c/pointer_casts.c new file mode 100644 index 0000000000..4d5e06ab6e --- /dev/null +++ b/examples/concepts/c/pointer_casts.c @@ -0,0 +1,71 @@ +#include + +struct A { + int integer; + bool boolean; +}; + +struct B { + struct A struct_a; +}; + +void canCastToInteger() { + struct B struct_b; + struct_b.struct_a.integer = 5; + int *pointer_to_integer = (int *)&struct_b; + //@ assert *pointer_to_integer == 5; + //@ assert pointer_to_integer == &struct_b.struct_a.integer; + //@ assert pointer_to_integer == (int *)&struct_b.struct_a; + // The following is not implemented yet + // assert pointer_to_integer == &struct_b + // assert pointer_to_integer == &struct_b.struct_a + *pointer_to_integer = 10; + //@ assert struct_b.struct_a.integer == 10; +} + +void cannotCastToBoolean() { + struct B struct_b; + struct_b.struct_a.boolean = true == true; // We currently don't support boolean literals + // TODO: Do proper type checks for casts + bool *pointer_to_boolean = (bool *)&struct_b; + /*[/expect ptrPerm]*/ + //@ assert *pointer_to_boolean == 5; + /*[/end]*/ + //@ assert pointer_to_boolean == &struct_b.struct_a.boolean; + //@ assert pointer_to_boolean == (bool *)&struct_b.struct_a; +} + +void castRemainsValidInLoop() { + struct B struct_b; + struct_b.struct_a.integer = 10; + + int *pointer_to_integer = (int *)&struct_b; + + //@ loop_invariant 0 <= i && i <= 10; + //@ loop_invariant Perm(&struct_b, write); + //@ loop_invariant Perm(struct_b, write); + //@ loop_invariant pointer_to_integer == (int *)&struct_b; + //@ loop_invariant *pointer_to_integer == 10 - i; + for (int i = 0; i < 10; i++) { + *pointer_to_integer = *pointer_to_integer - 1; + } + + //@ assert struct_b.struct_a.integer == 0; +} + +//@ requires a != NULL; +//@ context Perm(a, write); +//@ ensures *a == \old(*a) + 1; +void increaseByOne(int *a) { + *a += 1; +} + +void callWithCast() { + struct B struct_b; + struct_b.struct_a.integer = 15; + + int *pointer_to_integer = (int *)&struct_b; + increaseByOne(pointer_to_integer); + + //@ assert struct_b.struct_a.integer == 16; +} diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 78e944ea18..f56042400b 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -26,7 +26,8 @@ case object ClassToRef extends RewriterBuilder { private def ValueAdtOrigin: Origin = Origin(Seq(PreferredName(Seq("Value")), LabelContext("classToRef"))) - private def CastHelperOrigin: Origin = Origin(Seq(LabelContext("classToRef"))) + private def CastHelperOrigin: Origin = + Origin(Seq(LabelContext("classToRef cast helpers"))) case class InstanceNullPreconditionFailed( inner: Blame[InstanceNull], @@ -85,7 +86,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { .Map() val castHelpers: SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap() - val castHelperCalls: ScopedStack[mutable.Set[Statement[Post]]] = ScopedStack() + val requiredCastHelpers: ScopedStack[mutable.Set[Type[Pre]]] = ScopedStack() def typeNumber(cls: Class[Pre]): Int = typeNumberStore.getOrElseUpdate(cls, typeNumberStore.size + 1) @@ -566,10 +567,65 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { } } + private def addCastConstraints( + expr: Expr[Pre], + totalHelpers: mutable.Set[Type[Pre]], + ): Expr[Post] = { + val helpers: mutable.Set[Type[Pre]] = mutable.Set() + var result: Seq[Expr[Post]] = Nil + for (clause <- expr.unfoldStar) { + val newClause = requiredCastHelpers.having(helpers) { dispatch(clause) } + if (helpers.nonEmpty) { + result ++= helpers.map { t => + unwrapCastConstraints(dispatch(t), t)(CastHelperOrigin) + }.toSeq + totalHelpers.addAll(helpers) + helpers.clear() + } + result = result :+ newClause + } + foldStar(result)(expr.o) + } + + // For loops add cast helpers before and as an invariant (since otherwise the contract might not be well-formed) + override def dispatch(node: LoopContract[Pre]): LoopContract[Post] = { + implicit val o: Origin = node.o + val helpers: mutable.Set[Type[Pre]] = mutable.Set() + node match { + case LoopInvariant(invariant, decreases) => { + val result = + LoopInvariant( + addCastConstraints(invariant, helpers), + decreases.map(dispatch), + )(node.o) + if (requiredCastHelpers.nonEmpty) { + requiredCastHelpers.top.addAll(helpers) + } + result + } + case contract @ IterationContract( + requires, + ensures, + context_everywhere, + ) => { + val result = + IterationContract( + addCastConstraints(requires, helpers), + addCastConstraints(ensures, helpers), + addCastConstraints(context_everywhere, helpers), + )(contract.blame)(node.o) + if (requiredCastHelpers.nonEmpty) { + requiredCastHelpers.top.addAll(helpers) + } + result + } + } + } + override def dispatch(stat: Statement[Pre]): Statement[Post] = { - val helpers: mutable.Set[Statement[Post]] = mutable.Set() + val helpers: mutable.Set[Type[Pre]] = mutable.Set() val result = - castHelperCalls.having(helpers) { + requiredCastHelpers.having(helpers) { stat match { case Instantiate(Ref(cls), Local(Ref(v))) => instantiate(cls, succ(v))(stat.o) @@ -620,8 +676,18 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { } } - if (helpers.nonEmpty) { Block(helpers.toSeq :+ result)(stat.o) } - else { result } + if (helpers.nonEmpty) { + Block(helpers.map { t => + InvokeProcedure[Post]( + castHelpers.getOrElseUpdate(t, makeCastHelper(t)).ref, + Nil, + Nil, + Nil, + Nil, + Nil, + )(TrueSatisfiable)(CastHelperOrigin) + }.toSeq :+ result)(stat.o) + } else { result } } override def dispatch(node: ApplyAnyPredicate[Pre]): ApplyAnyPredicate[Post] = @@ -682,23 +748,15 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { )) } - private def addCastHelpers(t: Type[Pre], calls: mutable.Set[Statement[Post]])( - implicit o: Origin + private def addCastHelpers( + t: Type[Pre], + helpers: mutable.Set[Type[Pre]], ): Unit = { t match { case cls: TByValueClass[Pre] => { - calls.add( - InvokeProcedure[Post]( - castHelpers.getOrElseUpdate(t, makeCastHelper(t)).ref, - Nil, - Nil, - Nil, - Nil, - Nil, - )(TrueSatisfiable)(o) - ) + helpers.add(t) cls.cls.decl.decls.collectFirst { case field: InstanceField[Pre] => - addCastHelpers(field.t, calls) + addCastHelpers(field.t, helpers) } } case _ => @@ -820,11 +878,8 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { )(PanicBlame("instanceOf requires nothing"))(e.o) case Cast(value, typeValue) if value.t.asPointer.isDefined => { // Keep pointer casts and add extra annotations - // TODO: Check if we need to get rid of the pointer add's here since in my testing that broke some of the reasoning - if (castHelperCalls.nonEmpty) { - addCastHelpers(value.t.asPointer.get.element, castHelperCalls.top)( - e.o - ) + if (requiredCastHelpers.nonEmpty) { + addCastHelpers(value.t.asPointer.get.element, requiredCastHelpers.top) } e.rewriteDefault() diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index 8c807ccb93..72b8432099 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -196,8 +196,8 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) } // TODO: This is an ugly way to exempt this one bit of generated code from having ptrAdd's added case proc: Procedure[Pre] - if proc.o.find[LabelContext].exists(_.label == "classToRef") && - proc.o.getPreferredNameOrElse().snake.startsWith("constraints_") => + if proc.o.find[LabelContext] + .exists(_.label == "classToRef cast helpers") => inAxiom.having(()) { allScopes.anySucceed(proc, proc.rewriteDefault()) } case _ => super.postCoerce(decl) } @@ -287,18 +287,24 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) case deref @ DerefPointer(pointer) => FunctionInvocation[Post]( ref = pointerDeref.ref, - args = Seq(if (inAxiom.isEmpty) { - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq(unwrapOption(pointer, deref.blame), const(0)), - typeArgs = Nil, - Nil, - Nil, - )(NoContext( - DerefPointerBoundsPreconditionFailed(deref.blame, pointer) - )) - } else { unwrapOption(pointer, deref.blame) }), + args = Seq( + if ( + inAxiom.isEmpty && + !deref.o.find[LabelContext] + .exists(_.label == "classToRef cast helpers") + ) { + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(unwrapOption(pointer, deref.blame), const(0)), + typeArgs = Nil, + Nil, + Nil, + )(NoContext( + DerefPointerBoundsPreconditionFailed(deref.blame, pointer) + )) + } else { unwrapOption(pointer, deref.blame) } + ), typeArgs = Nil, Nil, Nil, @@ -360,18 +366,24 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) obj = FunctionInvocation[Post]( ref = pointerDeref.ref, - args = Seq(if (inAxiom.isEmpty) { - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq(unwrapOption(pointer, deref.blame), const(0)), - typeArgs = Nil, - Nil, - Nil, - )(NoContext( - DerefPointerBoundsPreconditionFailed(deref.blame, pointer) - )) - } else { unwrapOption(pointer, deref.blame) }), + args = Seq( + if ( + inAxiom.isEmpty && + !deref.o.find[LabelContext] + .exists(_.label == "classToRef cast helpers") + ) { + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(unwrapOption(pointer, deref.blame), const(0)), + typeArgs = Nil, + Nil, + Nil, + )(NoContext( + DerefPointerBoundsPreconditionFailed(deref.blame, pointer) + )) + } else { unwrapOption(pointer, deref.blame) } + ), typeArgs = Nil, Nil, Nil, @@ -381,18 +393,24 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) case deref @ RawDerefPointer(pointer) => FunctionInvocation[Post]( ref = pointerDeref.ref, - args = Seq(if (inAxiom.isEmpty) { - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq(unwrapOption(pointer, deref.blame), const(0)), - typeArgs = Nil, - Nil, - Nil, - )(NoContext( - DerefPointerBoundsPreconditionFailed(deref.blame, pointer) - )) - } else { unwrapOption(pointer, deref.blame) }), + args = Seq( + if ( + inAxiom.isEmpty && + !deref.o.find[LabelContext] + .exists(_.label == "classToRef cast helpers") + ) { + FunctionInvocation[Post]( + ref = pointerAdd.ref, + // Always index with zero, otherwise quantifiers with pointers do not get triggered + args = Seq(unwrapOption(pointer, deref.blame), const(0)), + typeArgs = Nil, + Nil, + Nil, + )(NoContext( + DerefPointerBoundsPreconditionFailed(deref.blame, pointer) + )) + } else { unwrapOption(pointer, deref.blame) } + ), typeArgs = Nil, Nil, Nil, @@ -465,7 +483,11 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) Seq(preExpr match { case PointerAdd(_, _) => postExpr // Don't add ptrAdd in an ADT axiom since we cannot use functions with preconditions there - case _ if inAxiom.nonEmpty => postExpr + case _ + if inAxiom.nonEmpty || + !preExpr.o.find[LabelContext] + .exists(_.label == "classToRef cast helpers") => + postExpr case _ => FunctionInvocation[Post]( ref = pointerAdd.ref, diff --git a/test/main/vct/test/integration/examples/CSpec.scala b/test/main/vct/test/integration/examples/CSpec.scala index 9526f29bbe..d234f8ef0a 100644 --- a/test/main/vct/test/integration/examples/CSpec.scala +++ b/test/main/vct/test/integration/examples/CSpec.scala @@ -11,6 +11,7 @@ class CSpec extends VercorsSpec { vercors should verify using silicon example "concepts/c/structs.c" vercors should verify using silicon example "concepts/c/vector_add.c" vercors should verify using silicon example "concepts/c/vector_type.c" + vercors should verify using silicon example "concepts/c/pointer_casts.c" vercors should error withCode "resolutionError:type" in "float should not be demoted" c """ @@ -377,17 +378,17 @@ class CSpec extends VercorsSpec { #include struct nested { - struct nested *inner; + struct nested *inner; }; void main() { - int *ip = NULL; - double *dp = NULL; - struct nested *np = NULL; - np = (struct nested*) NULL; + int *ip = NULL; + double *dp = NULL; + struct nested *np = NULL; + np = (struct nested*) NULL; np = (struct nested*) malloc(sizeof(struct nested)); np->inner = NULL; - np->inner = (struct nested*) NULL; + np->inner = (struct nested*) NULL; } """ @@ -562,4 +563,4 @@ class CSpec extends VercorsSpec { return; } """ -} \ No newline at end of file +} From 800f1ddae5dadb193e9ed90f36557d65c7212ab3 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 22 Aug 2024 15:23:18 +0200 Subject: [PATCH 27/47] Add type checking for pointer casts --- examples/concepts/c/pointer_casts.c | 41 ++++++++++++++----- src/col/vct/col/typerules/CoercionUtils.scala | 12 ++++++ src/rewrite/vct/rewrite/ClassToRef.scala | 33 +++++++-------- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 12 +++++- .../vct/test/integration/examples/CSpec.scala | 9 ++++ 5 files changed, 78 insertions(+), 29 deletions(-) diff --git a/examples/concepts/c/pointer_casts.c b/examples/concepts/c/pointer_casts.c index 4d5e06ab6e..d09264f7aa 100644 --- a/examples/concepts/c/pointer_casts.c +++ b/examples/concepts/c/pointer_casts.c @@ -23,17 +23,6 @@ void canCastToInteger() { //@ assert struct_b.struct_a.integer == 10; } -void cannotCastToBoolean() { - struct B struct_b; - struct_b.struct_a.boolean = true == true; // We currently don't support boolean literals - // TODO: Do proper type checks for casts - bool *pointer_to_boolean = (bool *)&struct_b; - /*[/expect ptrPerm]*/ - //@ assert *pointer_to_boolean == 5; - /*[/end]*/ - //@ assert pointer_to_boolean == &struct_b.struct_a.boolean; - //@ assert pointer_to_boolean == (bool *)&struct_b.struct_a; -} void castRemainsValidInLoop() { struct B struct_b; @@ -50,6 +39,36 @@ void castRemainsValidInLoop() { *pointer_to_integer = *pointer_to_integer - 1; } + //@ assert struct_b.struct_a.integer == 0; + struct_b.struct_a.integer = 10; + + // We can also specify the permission through the pointer + //@ loop_invariant 0 <= i && i <= 10; + //@ loop_invariant Perm(pointer_to_integer, write); + //@ loop_invariant *pointer_to_integer == 10 - i; + for (int i = 0; i < 10; i++) { + *pointer_to_integer = *pointer_to_integer - 1; + } + + //@ assert struct_b.struct_a.integer == 0; +} + +void castRemainsValidInParBlock() { + struct B struct_b; + struct_b.struct_a.integer = 10; + + int *pointer_to_integer = (int *)&struct_b; + + //@ context i == 8 ==> Perm(pointer_to_integer, write); + //@ ensures i == 8 ==> *pointer_to_integer == 0; + for (int i = 0; i < 10; i++) { + if (i == 8) { + *pointer_to_integer = *pointer_to_integer - 10; + } + } + + // Unfortunately we don't support a par block where we specify permission to the struct and then access through the cast (the generated cast helper is put too far away) + //@ assert struct_b.struct_a.integer == 0; } diff --git a/src/col/vct/col/typerules/CoercionUtils.scala b/src/col/vct/col/typerules/CoercionUtils.scala index 5c16be6dd6..ca97d5ef3f 100644 --- a/src/col/vct/col/typerules/CoercionUtils.scala +++ b/src/col/vct/col/typerules/CoercionUtils.scala @@ -449,6 +449,18 @@ case object CoercionUtils { case _ => None } + def firstElementIsType[G](aggregate: Type[G], innerType: Type[G]): Boolean = + aggregate match { + case aggregate if getAnyCoercion(aggregate, innerType).isDefined => true + case clazz: TByValueClass[G] => + clazz.cls.decl.decls.collectFirst { case field: InstanceField[G] => + firstElementIsType(field.t, innerType) + }.getOrElse(false) + case TArray(element) => firstElementIsType(element, innerType) + // TODO: Add LLVM types + case _ => false + } + def getAnyCArrayCoercion[G]( source: Type[G] ): Option[(Coercion[G], CTArray[G])] = diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index f56042400b..240edb33be 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -603,22 +603,23 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { } result } - case contract @ IterationContract( - requires, - ensures, - context_everywhere, - ) => { - val result = - IterationContract( - addCastConstraints(requires, helpers), - addCastConstraints(ensures, helpers), - addCastConstraints(context_everywhere, helpers), - )(contract.blame)(node.o) - if (requiredCastHelpers.nonEmpty) { - requiredCastHelpers.top.addAll(helpers) - } - result - } +// case contract @ IterationContract( +// requires, +// ensures, +// context_everywhere, +// ) => { +// val result = +// IterationContract( +// addCastConstraints(requires, helpers), +// addCastConstraints(ensures, helpers), +// addCastConstraints(context_everywhere, helpers), +// )(contract.blame)(node.o) +// if (requiredCastHelpers.nonEmpty) { +// requiredCastHelpers.top.addAll(helpers) +// } +// result +// } + case _: IterationContract[Pre] => throw ExtraNode } } diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index eca1812fb8..3b98839ed6 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -14,6 +14,7 @@ import vct.col.resolve.ctx._ import vct.col.resolve.lang.C.nameFromDeclarator import vct.col.resolve.lang.Java.logger import vct.col.rewrite.{Generation, Rewritten} +import vct.col.typerules.CoercionUtils import vct.col.typerules.CoercionUtils.getCoercion import vct.col.util.SuccessionMap import vct.col.util.AstBuildHelpers._ @@ -412,9 +413,16 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case CCast(CInvocation(CLocal("__vercors_malloc"), _, _, _), _) => throw UnsupportedMalloc(c) case CCast(n @ Null(), t) if t.asPointer.isDefined => rw.dispatch(n) - // TODO: Check if valid pointer cast case CCast(e, t) if e.t.asPointer.isDefined && t.asPointer.isDefined => - Cast(rw.dispatch(e), TypeValue(rw.dispatch(t))(t.o))(c.o) + val newE = rw.dispatch(e) + val newT = rw.dispatch(t) + if ( + CoercionUtils.firstElementIsType( + newE.t.asPointer.get.element, + newT.asPointer.get.element, + ) + ) { Cast(newE, TypeValue(newT)(t.o))(c.o) } + else { throw UnsupportedCast(c) } case _ => throw UnsupportedCast(c) } diff --git a/test/main/vct/test/integration/examples/CSpec.scala b/test/main/vct/test/integration/examples/CSpec.scala index d234f8ef0a..943fcc7085 100644 --- a/test/main/vct/test/integration/examples/CSpec.scala +++ b/test/main/vct/test/integration/examples/CSpec.scala @@ -563,4 +563,13 @@ class CSpec extends VercorsSpec { return; } """ + + vercors should error withCode "unsupportedCast" in "Casting struct pointers only works for the first element" c + """ + void cannotCastToBoolean() { + struct B struct_b; + struct_b.struct_a.boolean = true == true; // We currently don't support boolean literals + bool *pointer_to_boolean = (bool *)&struct_b; + } + """ } From c7a723ebf7e364eb22e161c08a5da5ebe2988605 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 22 Aug 2024 16:27:02 +0200 Subject: [PATCH 28/47] Add back blame erroneously removed by the previous commit --- src/rewrite/vct/rewrite/ClassToRef.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 240edb33be..0b084bacbb 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -589,15 +589,14 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { // For loops add cast helpers before and as an invariant (since otherwise the contract might not be well-formed) override def dispatch(node: LoopContract[Pre]): LoopContract[Post] = { - implicit val o: Origin = node.o val helpers: mutable.Set[Type[Pre]] = mutable.Set() node match { - case LoopInvariant(invariant, decreases) => { + case inv @ LoopInvariant(invariant, decreases) => { val result = LoopInvariant( addCastConstraints(invariant, helpers), decreases.map(dispatch), - )(node.o) + )(inv.blame)(node.o) if (requiredCastHelpers.nonEmpty) { requiredCastHelpers.top.addAll(helpers) } From f0257061f1b00deb377a0210ad7c9b73fe62cbf7 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Mon, 26 Aug 2024 10:45:43 +0200 Subject: [PATCH 29/47] Fix unsupported cast test --- test/main/vct/test/integration/examples/CSpec.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/main/vct/test/integration/examples/CSpec.scala b/test/main/vct/test/integration/examples/CSpec.scala index 943fcc7085..02abb280ad 100644 --- a/test/main/vct/test/integration/examples/CSpec.scala +++ b/test/main/vct/test/integration/examples/CSpec.scala @@ -566,6 +566,15 @@ class CSpec extends VercorsSpec { vercors should error withCode "unsupportedCast" in "Casting struct pointers only works for the first element" c """ + #include + struct A { + int integer; + bool boolean; + }; + + struct B { + struct A struct_a; + }; void cannotCastToBoolean() { struct B struct_b; struct_b.struct_a.boolean = true == true; // We currently don't support boolean literals From ca363da633a9cf33b6d6d95f7b6af6c8f7e50b88 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Mon, 26 Aug 2024 14:34:41 +0200 Subject: [PATCH 30/47] Make the LLVM file verify again --- src/rewrite/vct/rewrite/PrepareByValueClass.scala | 2 +- src/rewrite/vct/rewrite/VariableToPointer.scala | 8 +++++--- src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala | 8 +++++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala index cee61769a6..2250989297 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -220,7 +220,7 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { node match { case NewObject(Ref(cls)) if cls.isInstanceOf[ByValueClass[Pre]] => { val t = TByValueClass[Pre](cls.ref, Seq()) - procedureInvocation[Post]( + return procedureInvocation[Post]( TrueSatisfiable, classCreationMethods.getOrElseUpdate(t, makeClassCreationMethod(t)) .ref, diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala index 6e8546627c..65d9904d94 100644 --- a/src/rewrite/vct/rewrite/VariableToPointer.scala +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -92,7 +92,8 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { )(PanicBlame("Initialisation should always succeed")) } ++ Seq(dispatch(s.body))), ) - case i @ Instantiate(cls, out) => + case i @ Instantiate(cls, out) + if cls.decl.isInstanceOf[ByValueClass[Pre]] => // TODO: Make sure that we recursively build newobject for byvalueclasses // maybe get rid this entirely and only have it in encode by value class Block(Seq(i.rewriteDefault()) ++ cls.decl.declarations.flatMap { @@ -152,8 +153,9 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { DerefPointer(Deref[Post](dispatch(obj), fieldMap.ref(f))(deref.blame))( PanicBlame("Should always be accessible") ) - case newObject @ NewObject(Ref(cls)) => - val obj = new Variable[Post](TByReferenceClass(succ(cls), Seq())) + case newObject @ NewObject(Ref(cls)) + if cls.isInstanceOf[ByValueClass[Pre]] => + val obj = new Variable[Post](TByValueClass(succ(cls), Seq())) ScopedExpr( Seq(obj), With( diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index 0ac4229aa2..75b96d1210 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -249,9 +249,11 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) decl, rw.globalDeclarations.declare( new HeapVariable[Post]( - new TByValueClass[Post]( - new DirectRef[Post, Class[Post]](structMap(struct)), - Seq(), + new TNonNullPointer[Post]( + new TByValueClass[Post]( + new DirectRef[Post, Class[Post]](structMap(struct)), + Seq(), + )(struct.o) )(struct.o) )(decl.o) ), From 845dac407dd14e6b79a3cd24f1bb57df19109f8e Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Wed, 28 Aug 2024 13:29:31 +0200 Subject: [PATCH 31/47] Implement basic pointer type inference for LLVM --- src/col/vct/col/ast/Node.scala | 5 +- .../vct/col/typerules/CoercingRewriter.scala | 8 +- .../Instruction/MemoryOpTransform.cpp | 2 + .../vct/rewrite/lang/LangLLVMToCol.scala | 206 ++++++++++++------ .../vct/rewrite/lang/LangSpecificToCol.scala | 6 + 5 files changed, 157 insertions(+), 70 deletions(-) diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 1cb1e81c58..414800e0b7 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -3597,14 +3597,15 @@ final case class LLVMLoad[G]( loadType: Type[G], pointer: Expr[G], ordering: LLVMMemoryOrdering[G], -)(implicit val o: Origin) +)(val blame: Blame[PointerDerefError])(implicit val o: Origin) extends LLVMExpr[G] with LLVMLoadImpl[G] +// TODO: Figure out how to deal with the blames here (I need a super type of AssignFailed and PointerDerefError) final case class LLVMStore[G]( value: Expr[G], pointer: Expr[G], ordering: LLVMMemoryOrdering[G], -)(implicit val o: Origin) +)(val blame: Blame[VerificationFailure])(implicit val o: Origin) extends LLVMStatement[G] with LLVMStoreImpl[G] final case class LLVMGetElementPointer[G]( diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index 0fe01ca060..7319e60fb9 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -2143,8 +2143,8 @@ abstract class CoercingRewriter[Pre <: Generation]() case Message(_) => e case LLVMLocal(name) => e case LLVMAllocA(allocationType, numElements) => e - case LLVMLoad(loadType, p, ordering) => - LLVMLoad(loadType, llvmPointer(p, loadType)._1, ordering) + case load @ LLVMLoad(loadType, p, ordering) => + LLVMLoad(loadType, llvmPointer(p, loadType)._1, ordering)(load.blame) case LLVMGetElementPointer(structureType, resultType, pointer, indices) => LLVMGetElementPointer( structureType, @@ -2275,8 +2275,8 @@ abstract class CoercingRewriter[Pre <: Generation]() Loop(init, bool(cond), update, contract, body) case LLVMLoop(cond, contract, body) => LLVMLoop(bool(cond), contract, body) - case LLVMStore(value, p, ordering) => - LLVMStore(value, llvmPointer(p, value.t)._1, ordering) + case store @ LLVMStore(value, p, ordering) => + LLVMStore(value, llvmPointer(p, value.t)._1, ordering)(store.blame) case ModelDo(model, perm, after, action, impl) => ModelDo(model, rat(perm), after, action, impl) case n @ Notify(obj) => Notify(cls(obj))(n.blame) diff --git a/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp b/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp index ac8e3511ee..9c154aa250 100644 --- a/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp +++ b/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp @@ -96,6 +96,7 @@ void llvm2col::transformLoad(llvm::LoadInst &loadInstruction, col::LlvmLoad *load = loadExpr->mutable_llvm_load(); load->set_allocated_origin( llvm2col::generateSingleStatementOrigin(loadInstruction)); + load->set_allocated_blame(new col::Blame()); llvm::errs() << "Working on " << loadInstruction << " has type " << *loadInstruction.getType() << "\n"; llvm2col::transformAndSetType(*loadInstruction.getType(), @@ -114,6 +115,7 @@ void llvm2col::transformStore(llvm::StoreInst &storeInstruction, col::LlvmStore *store = colBlock.add_statements()->mutable_llvm_store(); store->set_allocated_origin( llvm2col::generateSingleStatementOrigin(storeInstruction)); + store->set_allocated_blame(new col::Blame()); llvm2col::transformAndSetExpr(funcCursor, storeInstruction, *storeInstruction.getValueOperand(), *store->mutable_value()); diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index 75b96d1210..bd1f79b2dc 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -2,7 +2,15 @@ package vct.rewrite.lang import com.typesafe.scalalogging.LazyLogging import vct.col.ast._ -import vct.col.origin.{Origin, PanicBlame, SourceName, TypeName} +import vct.col.origin.{ + AssignFailed, + Blame, + PointerDerefError, + Origin, + PanicBlame, + SourceName, + TypeName, +} import vct.col.ref.{DirectRef, LazyRef, Ref} import vct.col.resolve.ctx.RefLLVMFunctionDefinition import vct.col.rewrite.{Generation, Rewritten} @@ -53,21 +61,75 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) : SuccessionMap[(LLVMTStruct[Pre], Int), InstanceField[Post]] = SuccessionMap() - private val globalVariableTypeGuesses - : mutable.HashMap[LLVMGlobalVariable[Pre], mutable.HashSet[Type[Pre]]] = - mutable.HashMap() - private val structFieldTypeGuesses - : mutable.HashMap[(LLVMTStruct[Pre], Int), mutable.HashSet[Type[Pre]]] = - mutable.HashMap() - private val localTypeGuesses - : mutable.HashMap[Variable[Pre], mutable.HashSet[Type[Pre]]] = mutable - .HashMap() + private val globalVariableInferredType + : mutable.HashMap[LLVMGlobalVariable[Pre], Type[Pre]] = mutable.HashMap() + private val localVariableInferredType + : mutable.HashMap[Variable[Pre], Type[Pre]] = mutable.HashMap() + + def gatherTypeHints(program: Program[Pre]): Unit = { + val globalVariableTypeGuesses + : mutable.HashMap[LLVMGlobalVariable[Pre], mutable.HashSet[Type[Pre]]] = + mutable.HashMap() +// val structFieldTypeGuesses: mutable.HashMap[(LLVMTStruct[Pre], Int), mutable.HashSet[Type[Pre]]] = mutable.HashMap() + val localTypeGuesses + : mutable.HashMap[Variable[Pre], mutable.HashSet[Type[Pre]]] = mutable + .HashMap() + + def addTypeGuess(pointer: Expr[Pre], inferredType: Type[Pre]): Unit = + pointer match { + case Local(Ref(v)) => + localTypeGuesses.getOrElseUpdate(v, { mutable.HashSet() }) + .add(LLVMTPointer[Pre](Some(inferredType))) + case LLVMPointerValue(Ref(g)) => + globalVariableTypeGuesses.getOrElseUpdate( + g.asInstanceOf[LLVMGlobalVariable[Pre]], + { mutable.HashSet() }, + ).add(inferredType) + case it => ??? + } + + program.collect { + case gep: LLVMGetElementPointer[Pre] => + addTypeGuess(gep.pointer, gep.structureType) + case load: LLVMLoad[Pre] => addTypeGuess(load.pointer, load.loadType) + case store: LLVMStore[Pre] => addTypeGuess(store.pointer, store.value.t) + } + + def findSuperType(types: mutable.HashSet[Type[Pre]]): Option[Type[Pre]] = { + types.map(Some(_)).reduce[Option[Type[Pre]]] { (a, b) => + (a, b) match { + case (None, _) | (_, None) => None + case (Some(a), Some(b)) if a == b || a.superTypeOf(b) => Some(a) + case (Some(a), Some(b)) if b.superTypeOf(a) => Some(b) + case _ => None + } + } + } + + globalVariableTypeGuesses.foreachEntry { case (v, types) => + findSuperType(types).foreach(globalVariableInferredType(v) = _) + } + localTypeGuesses.foreachEntry { case (v, types) => + findSuperType(types).foreach(localVariableInferredType(v) = _) + } + + } def rewriteLocal(local: LLVMLocal[Pre]): Expr[Post] = { implicit val o: Origin = local.o Local(rw.succ(local.ref.get.decl)) } + def rewriteLocalVariable(v: Variable[Pre]): Unit = { + implicit val o: Origin = v.o; + rw.variables.succeed( + v, + new Variable[Post](rw.dispatch( + localVariableInferredType.getOrElse(v, v.t) + )), + ) + } + def rewriteFunctionDef(func: LLVMFunctionDefinition[Pre]): Unit = { implicit val o: Origin = func.o val importedDecl = rw.importedDeclarations.find { @@ -242,7 +304,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) def rewriteGlobalVariable(decl: LLVMGlobalVariable[Pre]): Unit = { // TODO: Handle the initializer // TODO: Include array and vector bounds somehow - decl.variableType match { + globalVariableInferredType.getOrElse(decl, decl.variableType) match { case struct: LLVMTStruct[Pre] => { rewriteStruct(struct) globalVariableMap.update( @@ -319,79 +381,87 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } } - def derefUntil( + private def derefUntil( pointer: Expr[Post], currentType: Type[Pre], untilType: Type[Pre], - ): (Expr[Post], Type[Pre]) = { + ): Option[(Expr[Post], Type[Pre])] = { implicit val o: Origin = pointer.o currentType match { - case _ if currentType == untilType => (AddrOf(pointer), currentType) - case LLVMTPointer(None) => (pointer, LLVMTPointer[Pre](Some(untilType))) + case _ if currentType == untilType => Some((AddrOf(pointer), currentType)) + case LLVMTPointer(None) => + Some((pointer, LLVMTPointer[Pre](Some(untilType)))) case LLVMTPointer(Some(inner)) if inner == untilType => - (pointer, currentType) + Some((pointer, currentType)) case LLVMTPointer(Some(LLVMTArray(numElements, elementType))) => { - val (expr, inner) = derefUntil( + derefUntil( PointerSubscript[Post]( DerefPointer(pointer)(pointer.o), IntegerValue(BigInt(0)), )(pointer.o), elementType, untilType, - ) - (expr, LLVMTPointer[Pre](Some(LLVMTArray(numElements, inner)))) + ).map { case (expr, inner) => + (expr, LLVMTPointer[Pre](Some(LLVMTArray(numElements, inner)))) + } } case LLVMTArray(numElements, elementType) => { - val (expr, inner) = derefUntil( + derefUntil( PointerSubscript[Post](pointer, IntegerValue(BigInt(0)))(pointer.o), elementType, untilType, - ) - (expr, LLVMTArray[Pre](numElements, inner)) + ).map { case (expr, inner) => + (expr, LLVMTArray[Pre](numElements, inner)) + } } case LLVMTPointer(Some(LLVMTVector(numElements, elementType))) => { - val (expr, inner) = derefUntil( + derefUntil( PointerSubscript[Post]( DerefPointer(pointer)(pointer.o), IntegerValue(BigInt(0)), )(pointer.o), elementType, untilType, - ) - (expr, LLVMTPointer[Pre](Some(LLVMTVector(numElements, inner)))) + ).map { case (expr, inner) => + (expr, LLVMTPointer[Pre](Some(LLVMTVector(numElements, inner)))) + } } case LLVMTVector(numElements, elementType) => { - val (expr, inner) = derefUntil( + derefUntil( PointerSubscript[Post](pointer, IntegerValue(BigInt(0)))(pointer.o), elementType, untilType, - ) - (expr, LLVMTVector[Pre](numElements, inner)) + ).map { case (expr, inner) => + (expr, LLVMTVector[Pre](numElements, inner)) + } } case LLVMTPointer(Some(struct @ LLVMTStruct(name, packed, elements))) => { - val (expr, inner) = derefUntil( + derefUntil( Deref[Post]( DerefPointer(pointer)(pointer.o), structFieldMap.ref((struct, 0)), )(pointer.o), elements.head, untilType, - ) - ( - expr, - LLVMTPointer[Pre](Some( - LLVMTStruct(name, packed, inner +: elements.tail) - )), - ) + ).map { case (expr, inner) => + ( + expr, + LLVMTPointer[Pre](Some( + LLVMTStruct(name, packed, inner +: elements.tail) + )), + ) + } } case struct @ LLVMTStruct(name, packed, elements) => { - val (expr, inner) = derefUntil( + derefUntil( Deref[Post](pointer, structFieldMap.ref((struct, 0)))(pointer.o), elements.head, untilType, - ) - (expr, LLVMTStruct[Pre](name, packed, inner +: elements.tail)) + ).map { case (expr, inner) => + (expr, LLVMTStruct[Pre](name, packed, inner +: elements.tail)) + } } + case _ => None } } @@ -425,12 +495,14 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) )(o) AddrOf(rewritePointerChain(structPointer, struct, gep.indices.tail)) case LLVMTPointer(Some(_)) => + val pointerInferredType = getInferredType(gep.pointer) val (pointer, inferredType) = derefUntil( rw.dispatch(gep.pointer), - gep.pointer.t, + pointerInferredType, t, + ).getOrElse( + (Cast(rw.dispatch(gep.pointer), TypeValue(rw.dispatch(t))), t) ) - addTypeGuess(gep.pointer, inferredType) val structPointer = DerefPointer( PointerAdd(pointer, rw.dispatch(gep.indices.head))(o) @@ -447,25 +519,47 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) // Deref might not be the correct thing to use here since technically the pointer is only dereferenced in the load or store instruction } + private def getInferredType(e: Expr[Pre]): Type[Pre] = + e match { + case Local(Ref(v)) => localVariableInferredType.getOrElse(v, e.t) + // Making assumption here that LLVMPointerValue only contains LLVMGlobalVariables whereas LLVMGlobalVariableImpl assumes it can also contain HeapVariables + case LLVMPointerValue(Ref(v)) => + globalVariableInferredType + .getOrElse(v.asInstanceOf[LLVMGlobalVariable[Pre]], e.t) + } + def rewriteStore(store: LLVMStore[Pre]): Statement[Post] = { implicit val o: Origin = store.o + val pointerInferredType = getInferredType(store.pointer) val (pointer, inferredType) = derefUntil( rw.dispatch(store.pointer), - store.pointer.t, + pointerInferredType, store.value.t, + ).getOrElse(( + Cast( + rw.dispatch(store.pointer), + TypeValue(TPointer(rw.dispatch(store.value.t))), + ), + store.value.t, + )) + // TODO: Fix assignfailed blame + Assign(DerefPointer(pointer)(store.blame), rw.dispatch(store.value))( + store.blame ) - addTypeGuess(store.pointer, inferredType) - Assign(DerefPointer(pointer)(store.o), rw.dispatch(store.value))(store.o) } def rewriteLoad(load: LLVMLoad[Pre]): Expr[Post] = { + implicit val o: Origin = load.o + val pointerInferredType = getInferredType(load.pointer) val (pointer, inferredType) = derefUntil( rw.dispatch(load.pointer), - load.pointer.t, + pointerInferredType, load.loadType, - ) - addTypeGuess(load.pointer, inferredType) - DerefPointer(pointer)(load.o)(load.o) + ).getOrElse(( + Cast(rw.dispatch(load.pointer), TypeValue(rw.dispatch(load.loadType))), + load.loadType, + )) + DerefPointer(pointer)(load.blame) } def rewriteAllocA(alloc: LLVMAllocA[Pre]): Expr[Post] = { @@ -498,22 +592,6 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } } - private def addTypeGuess(pointer: Expr[Pre], inferredType: Type[Pre]): Unit = - pointer match { - case Local(Ref(v)) => - localTypeGuesses.getOrElseUpdate(v, { mutable.HashSet() }) - .add(LLVMTPointer[Pre](Some(inferredType))) - case LLVMPointerValue(Ref(g)) => - globalVariableTypeGuesses.getOrElseUpdate( - g.asInstanceOf[LLVMGlobalVariable[Pre]], - { mutable.HashSet() }, - ).add(inferredType) - case it => { - println(it) - ??? - } - } - def rewritePointerValue(pointer: LLVMPointerValue[Pre]): Expr[Post] = { implicit val o: Origin = pointer.o // Will be transformed by VariableToPointer pass diff --git a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala index 9f08d2b457..b28e680536 100644 --- a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala @@ -165,6 +165,11 @@ case class LangSpecificToCol[Pre <: Generation]( } } + override def dispatch(program: Program[Pre]): Program[Post] = { + llvm.gatherTypeHints(program) + super.dispatch(program) + } + override def dispatch(decl: Declaration[Pre]): Unit = decl match { case model: Model[Pre] => @@ -224,6 +229,7 @@ case class LangSpecificToCol[Pre <: Generation]( case glue: JavaBipGlueContainer[Pre] => bip.rewriteGlue(glue) case chor: PVLChoreography[Pre] => veymont.rewriteChoreography(chor) + case v: Variable[Pre] => llvm.rewriteLocalVariable(v) case other => rewriteDefault(other) } From 0e9751a5291114db99b0312cca96e29ed1a948a3 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 29 Aug 2024 10:34:21 +0200 Subject: [PATCH 32/47] Pass-through debug locations from LLVM --- .../vct/col/serialize/SerializeOrigin.scala | 36 ++++- src/hre/hre/io/ChecksumReadableFile.scala | 55 +++++++ src/llvm/include/Origin/OriginProvider.h | 2 + src/llvm/lib/Origin/OriginProvider.cpp | 149 +++++++++++++----- src/serialize/vct/col/ast/Origin.proto | 16 ++ 5 files changed, 214 insertions(+), 44 deletions(-) create mode 100644 src/hre/hre/io/ChecksumReadableFile.scala diff --git a/src/col/vct/col/serialize/SerializeOrigin.scala b/src/col/vct/col/serialize/SerializeOrigin.scala index bcee9d6735..dbe9d00398 100644 --- a/src/col/vct/col/serialize/SerializeOrigin.scala +++ b/src/col/vct/col/serialize/SerializeOrigin.scala @@ -1,11 +1,18 @@ package vct.col.serialize +import hre.io.{ChecksumReadableFile, RWFile} import vct.col.ast.{serialize => ser} import vct.col.origin._ +import java.nio.file.Path import scala.annotation.unused +import scala.collection.mutable +import com.typesafe.scalalogging.LazyLogging; + +object SerializeOrigin extends LazyLogging { + private def fileMap: mutable.HashMap[Path, hre.io.Readable] = + mutable.HashMap() -object SerializeOrigin { def deserialize( @unused origin: ser.Origin @@ -20,6 +27,33 @@ object SerializeOrigin { context.inlineContext, context.shortPosition, ) + case ser.OriginContent.Content.ReadableOrigin(context) => + val path = Path.of(context.directory, context.filename) + ReadableOrigin(fileMap.getOrElseUpdate( + path, { + if (context.checksum.isDefined && context.checksumKind.isDefined) { + val file = ChecksumReadableFile( + path, + doWatch = false, + context.checksumKind.get, + ) + if (file.getChecksum != context.checksum.get) { + logger.warn( + "The checksum of the file " + path + + " does not match the LLVM checksum error locations are likely inaccurate" + ) + } + file + } else { RWFile(path) } + }, + )) + case ser.OriginContent.Content.PositionRange(range) => + // TODO: Preserve the start col idx even if end col idx is missing and improve the origins in LangLLVMToCol? Maybe we could even set the preferred name correctly + PositionRange( + range.startLineIdx, + range.endLineIdx, + range.startColIdx.flatMap { start => range.endColIdx.map((start, _)) }, + ) }) def serialize( diff --git a/src/hre/hre/io/ChecksumReadableFile.scala b/src/hre/hre/io/ChecksumReadableFile.scala new file mode 100644 index 0000000000..f6ea9747ac --- /dev/null +++ b/src/hre/hre/io/ChecksumReadableFile.scala @@ -0,0 +1,55 @@ +package hre.io + +import vct.result.VerificationError.SystemError + +import java.io.Reader +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Path} +import java.security.{MessageDigest, NoSuchAlgorithmException} + +case class UnknownChecksumKind(checksumKind: String) extends SystemError { + override def text: String = + s"Attempted to calculate checksum using unsupported algorithm: $checksumKind" +} + +case class ChecksumReadableFile( + file: Path, + doWatch: Boolean = true, + checksumKind: String, +) extends InMemoryCachedReadable { + override def underlyingPath: Option[Path] = Some(file) + override def fileName: String = file.toString + override def isRereadable: Boolean = true + private var checksumCache: Option[String] = None + + override protected def getReaderImpl: Reader = { + val bytes = Files.readAllBytes(file) + try { + val digest = MessageDigest.getInstance(checksumKind) + checksumCache = Some( + digest.digest(bytes).map(_.asInstanceOf[Int] & 0xff).map { b => + if (b > 0xf) { Integer.toHexString(b) } + else { "0" + Integer.toHexString(b) } + }.mkString + ) + } catch { + case _: NoSuchAlgorithmException => + throw UnknownChecksumKind(checksumKind) + } + Files.newBufferedReader(file, StandardCharsets.UTF_8) + } + + def getChecksum: String = { + if (checksumCache.isEmpty) { + // Calls ensureCache + super.readToCompletion() + } + checksumCache.get + } + + override def enroll(watch: Watch): Unit = { + if (doWatch) + watch.enroll(file) + watch.invalidate(this) + } +} diff --git a/src/llvm/include/Origin/OriginProvider.h b/src/llvm/include/Origin/OriginProvider.h index e11c079662..2fc05ad316 100644 --- a/src/llvm/include/Origin/OriginProvider.h +++ b/src/llvm/include/Origin/OriginProvider.h @@ -2,6 +2,8 @@ #define PALLAS_ORIGINPROVIDER_H #include "vct/col/ast/Origin.pb.h" +#include +#include #include #include diff --git a/src/llvm/lib/Origin/OriginProvider.cpp b/src/llvm/lib/Origin/OriginProvider.cpp index 21bea6dbfb..d16c42de89 100644 --- a/src/llvm/lib/Origin/OriginProvider.cpp +++ b/src/llvm/lib/Origin/OriginProvider.cpp @@ -123,6 +123,53 @@ col::Origin *llvm2col::generateLabelOrigin(llvm::BasicBlock &llvmBlock) { return origin; } +bool generateDebugOrigin(llvm::Instruction &llvmInstruction, + col::Origin *origin) { + const llvm::DebugLoc &loc = llvmInstruction.getDebugLoc(); + if (!loc) + return false; + int line = loc.getLine() - 1; + int col = loc.getCol() - 1; + col::OriginContent *positionRangeContent = origin->add_content(); + col::PositionRange *positionRange = new col::PositionRange(); + positionRange->set_start_line_idx(line); + positionRange->set_end_line_idx(line); + positionRange->set_start_col_idx(col); + positionRangeContent->set_allocated_position_range(positionRange); + auto *scope = llvm::cast(loc.getScope()); + auto *file = scope->getFile(); + llvm::StringRef filename = file->getFilename(); + llvm::StringRef directory = file->getDirectory(); + auto checksumOpt = file->getChecksum(); + col::OriginContent *readableOriginContent = origin->add_content(); + col::ReadableOrigin *readableOrigin = new col::ReadableOrigin(); + readableOrigin->set_allocated_filename(new std::string(filename)); + readableOrigin->set_allocated_directory(new std::string(directory)); + if (checksumOpt != std::nullopt) { + auto checksum = checksumOpt.value(); + readableOrigin->set_allocated_checksum(new std::string(checksum.Value)); + switch (checksum.Kind) { + case llvm::DIFile::ChecksumKind::CSK_MD5: + readableOrigin->set_allocated_checksum_kind(new std::string("MD5")); + break; + case llvm::DIFile::ChecksumKind::CSK_SHA1: + readableOrigin->set_allocated_checksum_kind( + new std::string("SHA-1")); + break; + case llvm::DIFile::ChecksumKind::CSK_SHA256: + readableOrigin->set_allocated_checksum_kind( + new std::string("SHA-256")); + break; + default: + // TODO: Properly add this error to the ErrorReported + llvm::errs() << "Unknown checksum kind " << checksum.Kind << "\n"; + break; + } + } + readableOriginContent->set_allocated_readable_origin(readableOrigin); + return true; +} + col::Origin * llvm2col::generateSingleStatementOrigin(llvm::Instruction &llvmInstruction) { col::Origin *origin = new col::Origin(); @@ -132,13 +179,16 @@ llvm2col::generateSingleStatementOrigin(llvm::Instruction &llvmInstruction) { deriveOperandPreferredName(llvmInstruction)); preferredNameContent->set_allocated_preferred_name(preferredName); - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveSurroundingInstructionContext(llvmInstruction)); - context->set_inline_context(deriveInstructionContext(llvmInstruction)); - context->set_short_position( - deriveInstructionShortPosition(llvmInstruction)); - contextContent->set_allocated_context(context); + if (!generateDebugOrigin(llvmInstruction, origin)) { + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context( + deriveSurroundingInstructionContext(llvmInstruction)); + context->set_inline_context(deriveInstructionContext(llvmInstruction)); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + } return origin; } @@ -151,13 +201,15 @@ llvm2col::generateAssignTargetOrigin(llvm::Instruction &llvmInstruction) { preferredName->add_preferred_name("var"); preferredNameContent->set_allocated_preferred_name(preferredName); - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveInstructionContext(llvmInstruction)); - context->set_inline_context(deriveInstructionLhs(llvmInstruction)); - context->set_short_position( - deriveInstructionShortPosition(llvmInstruction)); - contextContent->set_allocated_context(context); + if (!generateDebugOrigin(llvmInstruction, origin)) { + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveInstructionContext(llvmInstruction)); + context->set_inline_context(deriveInstructionLhs(llvmInstruction)); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + } return origin; } @@ -165,13 +217,16 @@ llvm2col::generateAssignTargetOrigin(llvm::Instruction &llvmInstruction) { col::Origin * llvm2col::generateBinExprOrigin(llvm::Instruction &llvmInstruction) { col::Origin *origin = new col::Origin(); - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveSurroundingInstructionContext(llvmInstruction)); - context->set_inline_context(deriveInstructionContext(llvmInstruction)); - context->set_short_position( - deriveInstructionShortPosition(llvmInstruction)); - contextContent->set_allocated_context(context); + if (!generateDebugOrigin(llvmInstruction, origin)) { + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context( + deriveSurroundingInstructionContext(llvmInstruction)); + context->set_inline_context(deriveInstructionContext(llvmInstruction)); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + } return origin; } @@ -185,13 +240,16 @@ llvm2col::generateFunctionCallOrigin(llvm::CallInst &callInstruction) { callInstruction.getCalledFunction()->getName().str()); preferredNameContent->set_allocated_preferred_name(preferredName); - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveSurroundingInstructionContext(callInstruction)); - context->set_inline_context(deriveInstructionRhs(callInstruction)); - context->set_short_position( - deriveInstructionShortPosition(callInstruction)); - contextContent->set_allocated_context(context); + if (!generateDebugOrigin(callInstruction, origin)) { + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context( + deriveSurroundingInstructionContext(callInstruction)); + context->set_inline_context(deriveInstructionRhs(callInstruction)); + context->set_short_position( + deriveInstructionShortPosition(callInstruction)); + contextContent->set_allocated_context(context); + } return origin; } @@ -204,13 +262,15 @@ col::Origin *llvm2col::generateOperandOrigin(llvm::Instruction &llvmInstruction, preferredName->add_preferred_name(deriveOperandPreferredName(llvmOperand)); preferredNameContent->set_allocated_preferred_name(preferredName); - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveInstructionContext(llvmInstruction)); - context->set_inline_context(deriveOperandContext(llvmOperand)); - context->set_short_position( - deriveInstructionShortPosition(llvmInstruction)); - contextContent->set_allocated_context(context); + if (!generateDebugOrigin(llvmInstruction, origin)) { + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveInstructionContext(llvmInstruction)); + context->set_inline_context(deriveOperandContext(llvmOperand)); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + } return origin; } @@ -260,14 +320,17 @@ llvm2col::generateVoidOperandOrigin(llvm::Instruction &llvmInstruction) { col::PreferredName *preferredName = new col::PreferredName(); preferredName->add_preferred_name("void"); preferredNameContent->set_allocated_preferred_name(preferredName); - - col::OriginContent *contextContent = origin->add_content(); - col::Context *context = new col::Context(); - context->set_context(deriveInstructionContext(llvmInstruction)); - context->set_inline_context("void"); - context->set_short_position( - deriveInstructionShortPosition(llvmInstruction)); - contextContent->set_allocated_context(context); + generateDebugOrigin(llvmInstruction, origin); + + if (!generateDebugOrigin(llvmInstruction, origin)) { + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveInstructionContext(llvmInstruction)); + context->set_inline_context("void"); + context->set_short_position( + deriveInstructionShortPosition(llvmInstruction)); + contextContent->set_allocated_context(context); + } return origin; } diff --git a/src/serialize/vct/col/ast/Origin.proto b/src/serialize/vct/col/ast/Origin.proto index 5d6373c97b..2de2eb1ef0 100644 --- a/src/serialize/vct/col/ast/Origin.proto +++ b/src/serialize/vct/col/ast/Origin.proto @@ -20,6 +20,8 @@ message OriginContent { SourceName source_name = 1; PreferredName preferred_name = 2; Context context = 3; + ReadableOrigin readable_origin = 4; + PositionRange position_range = 5; } } @@ -36,3 +38,17 @@ message Context { required string inline_context = 2; required string short_position = 3; } + +message ReadableOrigin { + required string directory = 1; + required string filename = 2; + optional string checksum = 3; + optional string checksum_kind = 4; +} + +message PositionRange { + required sint32 start_line_idx = 1; + required sint32 end_line_idx = 2; + optional sint32 start_col_idx = 3; + optional sint32 end_col_idx = 4; +} From c8548db39b71dd36fb58d553b7dbe59f3c5dd711 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Mon, 2 Sep 2024 15:52:03 +0200 Subject: [PATCH 33/47] Improved pointer type inference --- src/col/vct/col/ast/Node.scala | 18 +- .../col/ast/lang/llvm/LLVMAllocAImpl.scala | 1 - .../vct/col/ast/lang/llvm/LLVMLoadImpl.scala | 1 - src/col/vct/col/resolve/Resolve.scala | 13 +- .../vct/col/serialize/SerializeOrigin.scala | 17 + .../vct/col/typerules/CoercingRewriter.scala | 9 +- .../vct/col/util/SubstituteReferences.scala | 59 +++ .../Function/FunctionContractDeclarer.cpp | 3 +- .../Instruction/MemoryOpTransform.cpp | 42 +- src/main/vct/main/stages/Resolution.scala | 5 +- .../vct/rewrite/lang/LangLLVMToCol.scala | 369 ++++++++++++------ .../vct/rewrite/lang/LangSpecificToCol.scala | 11 +- src/serialize/vct/col/ast/Origin.proto | 5 + 13 files changed, 401 insertions(+), 152 deletions(-) create mode 100644 src/col/vct/col/util/SubstituteReferences.scala diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 414800e0b7..49d578010d 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -3529,7 +3529,10 @@ final class LLVMFunctionDefinition[G]( )(val blame: Blame[CallableFailure])(implicit val o: Origin) extends LLVMCallable[G] with Applicable[G] - with LLVMFunctionDefinitionImpl[G] + with LLVMFunctionDefinitionImpl[G] { + var importedArguments: Option[Seq[Variable[G]]] = None + var importedReturnType: Option[Type[G]] = None +} @scopes[LabelDecl] final class LLVMSpecFunction[G]( val name: String, @@ -3589,16 +3592,21 @@ final case class LLVMAmbiguousFunctionInvocation[G]( var ref: Option[Ref[G, LLVMCallable[G]]] = None } -final case class LLVMAllocA[G](allocationType: Type[G], numElements: Expr[G])( - implicit val o: Origin -) extends LLVMExpr[G] with LLVMAllocAImpl[G] +// TODO: It would probably be more consistent if LLVMAllocA and LLVMLoad use the Expr type for the variable but it should never be necessary +final case class LLVMAllocA[G]( + variable: Ref[G, Variable[G]], + allocationType: Type[G], + numElements: Expr[G], +)(implicit val o: Origin) + extends LLVMStatement[G] with LLVMAllocAImpl[G] final case class LLVMLoad[G]( + variable: Ref[G, Variable[G]], loadType: Type[G], pointer: Expr[G], ordering: LLVMMemoryOrdering[G], )(val blame: Blame[PointerDerefError])(implicit val o: Origin) - extends LLVMExpr[G] with LLVMLoadImpl[G] + extends LLVMStatement[G] with LLVMLoadImpl[G] // TODO: Figure out how to deal with the blames here (I need a super type of AssignFailed and PointerDerefError) final case class LLVMStore[G]( diff --git a/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala index b14e42e820..8276e1544c 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala @@ -5,5 +5,4 @@ import vct.col.ast.{LLVMAllocA, Type, LLVMTPointer} trait LLVMAllocAImpl[G] extends LLVMAllocAOps[G] { this: LLVMAllocA[G] => - override val t: Type[G] = LLVMTPointer(Some(this.allocationType)) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala index af05a52038..8142e0ea91 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala @@ -5,5 +5,4 @@ import vct.col.ast.ops.LLVMLoadOps trait LLVMLoadImpl[G] extends LLVMLoadOps[G] { this: LLVMLoad[G] => - override val t: Type[G] = this.loadType } diff --git a/src/col/vct/col/resolve/Resolve.scala b/src/col/vct/col/resolve/Resolve.scala index d0ef709925..8b839a2b41 100644 --- a/src/col/vct/col/resolve/Resolve.scala +++ b/src/col/vct/col/resolve/Resolve.scala @@ -1117,8 +1117,18 @@ case object ResolveReferences extends LazyLogging { ) => portName.data = Some((cls, getLit(name))) + case func: LLVMFunctionDefinition[G] => + val importedDecl = ctx.importedDeclarations.find { + case procedure: Procedure[G] => + func.contract.name == procedure.o.get[SourceName].name + } + if (importedDecl.isDefined) { + val importedProcedure = importedDecl.get.asInstanceOf[Procedure[G]] + func.importedArguments = Some(importedProcedure.args) + func.importedReturnType = Some(importedProcedure.returnType) + } case contract: LLVMFunctionContract[G] => - implicit val o: Origin = contract.o +// implicit val o: Origin = contract.o val llvmFunction = ctx.currentResult.get.asInstanceOf[RefLLVMFunctionDefinition[G]].decl val applicableContract = ctx.llvmSpecParser.parse(contract, contract.o) @@ -1129,6 +1139,7 @@ case object ResolveReferences extends LazyLogging { if (importedDecl.isDefined) { val importedProcedure = importedDecl.get.asInstanceOf[Procedure[G]] val importedContract = importedProcedure.contract + implicit val o: Origin = importedContract.o val substitute = Substitute[G]( ((Result[G](importedProcedure.ref) -> AmbiguousResult[G]()) +: importedProcedure.args.zipWithIndex.map { case (l, idx) => diff --git a/src/col/vct/col/serialize/SerializeOrigin.scala b/src/col/vct/col/serialize/SerializeOrigin.scala index dbe9d00398..8e4a01afee 100644 --- a/src/col/vct/col/serialize/SerializeOrigin.scala +++ b/src/col/vct/col/serialize/SerializeOrigin.scala @@ -37,6 +37,7 @@ object SerializeOrigin extends LazyLogging { doWatch = false, context.checksumKind.get, ) + // TODO: Should we do the checksum check later? Potentially this causes a lot of file loading if (file.getChecksum != context.checksum.get) { logger.warn( "The checksum of the file " + path + @@ -54,6 +55,8 @@ object SerializeOrigin extends LazyLogging { range.endLineIdx, range.startColIdx.flatMap { start => range.endColIdx.map((start, _)) }, ) + case ser.OriginContent.Content.LabelContext(label) => + LabelContext(label.label) }) def serialize( @@ -74,6 +77,20 @@ object SerializeOrigin extends LazyLogging { ser.OriginContent.Content .Context(ser.Context(context, inlineContext, shortPosition)) ) + case ReadableOrigin(readable) => + // Not sure how to best deal with directory/filename here + Seq(ser.OriginContent.Content.ReadableOrigin( + ser.ReadableOrigin("", readable.fileName, None, None) + )) + case PositionRange(startLineIdx, endLineIdx, startEndColIdx) => + Seq(ser.OriginContent.Content.PositionRange(ser.PositionRange( + startLineIdx, + endLineIdx, + startEndColIdx.map(_._1), + startEndColIdx.map(_._2), + ))) + case LabelContext(label) => + Seq(ser.OriginContent.Content.LabelContext(ser.LabelContext(label))) case _ => Nil }.map(ser.OriginContent(_)) ) diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index 7319e60fb9..1581fa56eb 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -2142,9 +2142,6 @@ abstract class CoercingRewriter[Pre <: Generation]() case Receiver(_) => e case Message(_) => e case LLVMLocal(name) => e - case LLVMAllocA(allocationType, numElements) => e - case load @ LLVMLoad(loadType, p, ordering) => - LLVMLoad(loadType, llvmPointer(p, loadType)._1, ordering)(load.blame) case LLVMGetElementPointer(structureType, resultType, pointer, indices) => LLVMGetElementPointer( structureType, @@ -2275,6 +2272,12 @@ abstract class CoercingRewriter[Pre <: Generation]() Loop(init, bool(cond), update, contract, body) case LLVMLoop(cond, contract, body) => LLVMLoop(bool(cond), contract, body) + case LLVMAllocA(variable, allocationType, numElements) => + LLVMAllocA(variable, allocationType, int(numElements)) + case load @ LLVMLoad(variable, loadType, p, ordering) => + LLVMLoad(variable, loadType, llvmPointer(p, loadType)._1, ordering)( + load.blame + ) case store @ LLVMStore(value, p, ordering) => LLVMStore(value, llvmPointer(p, value.t)._1, ordering)(store.blame) case ModelDo(model, perm, after, action, impl) => diff --git a/src/col/vct/col/util/SubstituteReferences.scala b/src/col/vct/col/util/SubstituteReferences.scala new file mode 100644 index 0000000000..95dbfed39f --- /dev/null +++ b/src/col/vct/col/util/SubstituteReferences.scala @@ -0,0 +1,59 @@ +package vct.col.util + +import vct.col.ast._ +import vct.col.origin.Origin +import vct.col.ref.{DirectRef, Ref} +import vct.col.rewrite.NonLatchingRewriter + +import scala.reflect.ClassTag + +/** Substitute all references in expressions, resulting AST can be used for + * analysis but not output since it doesn't contain the right declarations + */ +case class SubstituteReferences[G](subs: Map[Object, Object]) + extends NonLatchingRewriter[G, G] { + + case class SuccOrIdentity() extends SuccessorsProviderTrafo[G, G](allScopes) { + override def postTransform[T <: Declaration[G]]( + pre: Declaration[G], + post: Option[T], + ): Option[T] = Some(post.getOrElse(pre.asInstanceOf[T])) + } + + override def succProvider: SuccessorsProvider[G, G] = SuccOrIdentity() + + private def substitute[T <: Declaration[G]](obj: T)( + implicit tag: ClassTag[T] + ): DirectRef[G, T] = new DirectRef(subs.getOrElse(obj, obj).asInstanceOf[T]) + + override def dispatch(e: Expr[G]): Expr[G] = { + implicit val o: Origin = e.o + e match { + // Matching on everything with a reference in it + case Local(Ref(v)) => Local(substitute(v)) + case HeapLocal(Ref(v)) => HeapLocal(substitute(v)) + case EnumUse(Ref(a), Ref(b)) => EnumUse(substitute(a), substitute(b)) + case deref @ DerefHeapVariable(Ref(v)) => + DerefHeapVariable[G](substitute(v))(deref.blame) + case deref @ Deref(obj, Ref(f)) => + Deref[G](dispatch(obj), substitute(f))(deref.blame) + case deref @ ModelDeref(obj, Ref(f)) => + ModelDeref[G](dispatch(obj), substitute(f))(deref.blame) + case FunctionOf(Ref(b), vars) => + FunctionOf[G](substitute(b), vars.map { case Ref(v) => substitute(v) }) + case NewObject(Ref(c)) => NewObject(substitute(c)) + case old @ Old(expr, None) => Old(dispatch(expr), None)(old.blame) + case old @ Old(expr, Some(Ref(l))) => + Old[G](dispatch(expr), Some(substitute(l)))(old.blame) + case ProcessApply(Ref(p), args) => + ProcessApply(substitute(p), args.map(dispatch)) + case EndpointName(Ref(e)) => EndpointName(substitute(e)) + case ChorPerm(Ref(e), loc, perm) => + ChorPerm(substitute(e), dispatch(loc), dispatch(perm)) + case Sender(Ref(s)) => Sender(substitute(s)) + case Receiver(Ref(r)) => Receiver(substitute(r)) + case Message(Ref(m)) => Message(substitute(m)) + case _ => e.rewriteDefault() + } + } +} diff --git a/src/llvm/lib/Passes/Function/FunctionContractDeclarer.cpp b/src/llvm/lib/Passes/Function/FunctionContractDeclarer.cpp index f97ca35cfe..ea6e103214 100644 --- a/src/llvm/lib/Passes/Function/FunctionContractDeclarer.cpp +++ b/src/llvm/lib/Passes/Function/FunctionContractDeclarer.cpp @@ -52,7 +52,8 @@ FunctionContractDeclarerPass::run(Function &F, FunctionAnalysisManager &FAM) { if (!F.hasMetadata(pallas::constants::METADATA_CONTRACT_KEYWORD)) { // set contract to a tautology colContract.set_value("requires true;"); - colContract.set_allocated_origin(new col::Origin()); + colContract.set_allocated_origin( + llvm2col::generateFunctionContractOrigin(F, "requires true;")); return PreservedAnalyses::all(); } // concatenate all contract lines with new lines diff --git a/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp b/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp index 9c154aa250..864245eab8 100644 --- a/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp +++ b/src/llvm/lib/Transform/Instruction/MemoryOpTransform.cpp @@ -4,6 +4,7 @@ #include "Transform/BlockTransform.h" #include "Transform/Transform.h" #include "Util/Exceptions.h" +#include const std::string SOURCE_LOC = "Transform::Instruction::MemoryOp"; @@ -36,15 +37,33 @@ void llvm2col::transformMemoryOp(llvm::Instruction &llvmInstruction, void llvm2col::transformAllocA(llvm::AllocaInst &allocAInstruction, col::Block &colBlock, pallas::FunctionCursor &funcCursor) { - col::Assign &assignment = funcCursor.createAssignmentAndDeclaration( - allocAInstruction, colBlock, - /* pointer type*/ allocAInstruction.getAllocatedType()); - col::Expr *allocAExpr = assignment.mutable_value(); - col::LlvmAllocA *allocA = allocAExpr->mutable_llvm_alloc_a(); + col::LlvmAllocA *allocA = colBlock.add_statements()->mutable_llvm_alloc_a(); allocA->set_allocated_origin( llvm2col::generateSingleStatementOrigin(allocAInstruction)); - llvm2col::transformAndSetType(*allocAInstruction.getAllocatedType(), - *allocA->mutable_allocation_type()); + + if (allocAInstruction.getAllocatedType()->getTypeID() == + llvm::Type::PointerTyID) { + // Pointers are opaque so we'll use the metadata to try and figure out + // what this pointer will point to + for (llvm::DbgDeclareInst *dbg : + llvm::FindDbgDeclareUses(&allocAInstruction)) { + llvm::errs() << "Use of AllocA ptr " << *dbg << "\n"; + llvm::CallInst *dbgCall = llvm::cast(dbg); + llvm::Metadata *metadata = + llvm::cast(dbgCall->getOperand(1)) + ->getMetadata(); + // TODO: Translate this information where possible + } + llvm2col::transformAndSetType(*allocAInstruction.getAllocatedType(), + *allocA->mutable_allocation_type()); + } else { + llvm2col::transformAndSetType(*allocAInstruction.getAllocatedType(), + *allocA->mutable_allocation_type()); + } + col::Variable &varDecl = funcCursor.declareVariable( + allocAInstruction, allocAInstruction.getAllocatedType()); + allocA->mutable_variable()->set_id(varDecl.id()); + llvm2col::transformAndSetExpr(funcCursor, allocAInstruction, *allocAInstruction.getArraySize(), *allocA->mutable_num_elements()); @@ -90,15 +109,12 @@ void llvm2col::transformLoad(llvm::LoadInst &loadInstruction, col::Block &colBlock, pallas::FunctionCursor &funcCursor) { // We are not storing isVolatile and getAlign - col::Assign &assignment = - funcCursor.createAssignmentAndDeclaration(loadInstruction, colBlock); - col::Expr *loadExpr = assignment.mutable_value(); - col::LlvmLoad *load = loadExpr->mutable_llvm_load(); + col::LlvmLoad *load = colBlock.add_statements()->mutable_llvm_load(); load->set_allocated_origin( llvm2col::generateSingleStatementOrigin(loadInstruction)); load->set_allocated_blame(new col::Blame()); - llvm::errs() << "Working on " << loadInstruction << " has type " - << *loadInstruction.getType() << "\n"; + col::Variable &varDecl = funcCursor.declareVariable(loadInstruction); + load->mutable_variable()->set_id(varDecl.id()); llvm2col::transformAndSetType(*loadInstruction.getType(), *load->mutable_load_type()); llvm2col::transformAndSetExpr(funcCursor, loadInstruction, diff --git a/src/main/vct/main/stages/Resolution.scala b/src/main/vct/main/stages/Resolution.scala index 13d437ca70..c7f89e0c99 100644 --- a/src/main/vct/main/stages/Resolution.scala +++ b/src/main/vct/main/stages/Resolution.scala @@ -155,8 +155,8 @@ case class Resolution[G <: Generation]( val ast = LangTypesToCol() .dispatch(Program(importedDeclarations)(blameProvider())) ResolveReferences.resolve(ast, javaParser, llvmParser, Seq()) - LangSpecificToCol(generatePermissions, veymontAllowAssign, Seq()) - .dispatch(ast).asInstanceOf[Program[Rewritten[G]]].declarations + LangSpecificToCol(generatePermissions, veymontAllowAssign).dispatch(ast) + .asInstanceOf[Program[Rewritten[G]]].declarations } ResolveReferences .resolve(typedProgram, javaParser, llvmParser, typedImports) match { @@ -166,7 +166,6 @@ case class Resolution[G <: Generation]( val resolvedProgram = LangSpecificToCol( generatePermissions, veymontAllowAssign, - typedImports, ).dispatch(typedProgram) resolvedProgram.check match { case Nil => // ok diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index bd1f79b2dc..47cc92ea6f 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -2,20 +2,12 @@ package vct.rewrite.lang import com.typesafe.scalalogging.LazyLogging import vct.col.ast._ -import vct.col.origin.{ - AssignFailed, - Blame, - PointerDerefError, - Origin, - PanicBlame, - SourceName, - TypeName, -} +import vct.col.origin.{Origin, PanicBlame, TypeName} import vct.col.ref.{DirectRef, LazyRef, Ref} import vct.col.resolve.ctx.RefLLVMFunctionDefinition import vct.col.rewrite.{Generation, Rewritten} -import vct.col.util.AstBuildHelpers.{VarBuildHelpers, assignLocal, const, tt} -import vct.col.util.{CurrentProgramContext, SuccessionMap} +import vct.col.util.AstBuildHelpers.assignLocal +import vct.col.util.{CurrentProgramContext, SubstituteReferences, SuccessionMap} import vct.result.VerificationError.{SystemError, UserError} import scala.collection.mutable @@ -67,52 +59,187 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) : mutable.HashMap[Variable[Pre], Type[Pre]] = mutable.HashMap() def gatherTypeHints(program: Program[Pre]): Unit = { - val globalVariableTypeGuesses - : mutable.HashMap[LLVMGlobalVariable[Pre], mutable.HashSet[Type[Pre]]] = - mutable.HashMap() -// val structFieldTypeGuesses: mutable.HashMap[(LLVMTStruct[Pre], Int), mutable.HashSet[Type[Pre]]] = mutable.HashMap() - val localTypeGuesses - : mutable.HashMap[Variable[Pre], mutable.HashSet[Type[Pre]]] = mutable - .HashMap() - - def addTypeGuess(pointer: Expr[Pre], inferredType: Type[Pre]): Unit = - pointer match { - case Local(Ref(v)) => - localTypeGuesses.getOrElseUpdate(v, { mutable.HashSet() }) - .add(LLVMTPointer[Pre](Some(inferredType))) - case LLVMPointerValue(Ref(g)) => - globalVariableTypeGuesses.getOrElseUpdate( - g.asInstanceOf[LLVMGlobalVariable[Pre]], - { mutable.HashSet() }, - ).add(inferredType) - case it => ??? + // TODO: We also need to do something where we only keep structurally distinct types + def moreSpecific(self: Type[Pre], other: Type[Pre]): Boolean = { + (self, other) match { + case (a, b) if a == b => false + case (LLVMTPointer(None), _) => false + case (_, LLVMTPointer(None)) => true + case (LLVMTPointer(Some(a)), LLVMTPointer(Some(b))) => + moreSpecific(a, b) + case (LLVMTPointer(Some(a)), TPointer(b)) => moreSpecific(a, b) + case (TPointer(a), LLVMTPointer(Some(b))) => moreSpecific(a, b) + case (TPointer(a), TPointer(b)) => moreSpecific(a, b) + case (LLVMTStruct(_, _, a), LLVMTStruct(_, _, b)) => + a.headOption.exists(ta => b.exists(tb => moreSpecific(ta, tb))) + case (LLVMTStruct(_, _, _), _) => true + case (LLVMTArray(_, a), LLVMTArray(_, b)) => moreSpecific(a, b) + case (LLVMTArray(_, _), _) => true + case _ => false } - - program.collect { - case gep: LLVMGetElementPointer[Pre] => - addTypeGuess(gep.pointer, gep.structureType) - case load: LLVMLoad[Pre] => addTypeGuess(load.pointer, load.loadType) - case store: LLVMStore[Pre] => addTypeGuess(store.pointer, store.value.t) } - def findSuperType(types: mutable.HashSet[Type[Pre]]): Option[Type[Pre]] = { + // TODO: This sorting is non-stable which might cause nondeterministic bugs if there's something wrong with moreSpecific + def findMostSpecific( + types: mutable.ArrayBuffer[Type[Pre]] + ): Option[Type[Pre]] = { types.map(Some(_)).reduce[Option[Type[Pre]]] { (a, b) => (a, b) match { case (None, _) | (_, None) => None - case (Some(a), Some(b)) if a == b || a.superTypeOf(b) => Some(a) - case (Some(a), Some(b)) if b.superTypeOf(a) => Some(b) + case (Some(a), Some(b)) + if a == b || rw.dispatch(a) == rw.dispatch(b) || + moreSpecific(a, b) => + Some(a) + case (Some(a), Some(b)) if moreSpecific(b, a) => Some(b) case _ => None } } } - globalVariableTypeGuesses.foreachEntry { case (v, types) => - findSuperType(types).foreach(globalVariableInferredType(v) = _) + class TypeGuess( + val depends: mutable.Set[Object] = mutable.Set(), + val dependents: mutable.Set[Object] = mutable.Set(), + val getGuesses: mutable.ArrayBuffer[Unit => Type[Pre]] = mutable + .ArrayBuffer(), + var currentType: Type[Pre], + ) { + def add(dependencies: Set[Object], inferType: Unit => Type[Pre]): Unit = { + depends.addAll(dependencies) + getGuesses.addOne(inferType) + } + + def update(): Boolean = { + val superType = findMostSpecific(getGuesses.map(_())) + if (superType.isEmpty) { false } + else { + val updated = currentType == superType.get + currentType = superType.get + updated + } + } } - localTypeGuesses.foreachEntry { case (v, types) => - findSuperType(types).foreach(localVariableInferredType(v) = _) + + val typeGuesses: mutable.HashMap[Object, TypeGuess] = mutable.HashMap() + + def findDependencies(expr: Expr[Pre]): Set[Object] = { + expr.collect { + case Local(Ref(v)) => v + case LLVMPointerValue(Ref(g)) => g + // These two below probably don't do anything + case v: Variable[Pre] => v + case v: LLVMGlobalVariable[Pre] => v + }.toSet } + def replaceWithGuesses( + value: Expr[Pre], + dependencies: Set[Object], + ): Expr[Pre] = { + val subMap = dependencies.filter(typeGuesses.contains).collect { + case v: Variable[Pre] if typeGuesses(v).currentType != v.t => + (v, new Variable[Pre](typeGuesses(v).currentType)(v.o)) + case v: LLVMGlobalVariable[Pre] + if typeGuesses(v).currentType != v.variableType => + ( + v, + new LLVMGlobalVariable[Pre]( + typeGuesses(v).currentType, + v.value, + v.constant, + )(v.o), + ) + } + if (subMap.isEmpty) { value } + else { + // TODO: Support multiple guesses? + SubstituteReferences(subMap.toMap).dispatch(value) + } + } + + def getVariable(expr: Expr[Pre]): Object = { + expr match { + case Local(Ref(v)) => v + case LLVMPointerValue(Ref(g)) => g + case _ => ??? + } + } + + def addTypeGuess( + obj: Object, + dependencies: Set[Object], + inferType: Unit => Type[Pre], + ): Unit = + typeGuesses + .getOrElseUpdate(obj, new TypeGuess(currentType = inferType(()))) + .add(dependencies, inferType) + + // TODO: This could be made more generic and also work with Assign nodes + program.collect { + case func: LLVMFunctionDefinition[Pre] => + func.args.zipWithIndex.foreach { case (a, i) => + addTypeGuess( + a, + Set.empty, + _ => func.importedArguments.map(_(i).t).getOrElse(a.t), + ) + } + case alloc: LLVMAllocA[Pre] => + addTypeGuess( + alloc.variable.decl, + Set.empty, + _ => LLVMTPointer(Some(alloc.allocationType)), + ) + case gep: LLVMGetElementPointer[Pre] => + addTypeGuess( + getVariable(gep.pointer), + Set.empty, + _ => LLVMTPointer(Some(gep.structureType)), + ) + case load: LLVMLoad[Pre] => + addTypeGuess( + getVariable(load.pointer), + Set(load.variable.decl), + _ => + LLVMTPointer(Some( + typeGuesses.get(load.variable.decl).map(_.currentType) + .getOrElse(load.variable.decl.t) + )), + ) + addTypeGuess(load.variable.decl, Set.empty, _ => load.variable.decl.t) + case store: LLVMStore[Pre] => + val dependencies = findDependencies(store.value) + addTypeGuess( + getVariable(store.pointer), + dependencies, + _ => + LLVMTPointer(Some(replaceWithGuesses(store.value, dependencies).t)), + ) + case inv: LLVMFunctionInvocation[Pre] => + inv.ref.decl.importedArguments.getOrElse(inv.ref.decl.args).zipWithIndex + .foreach { case (a, i) => + addTypeGuess(getVariable(inv.args(i)), Set.empty, _ => a.t) + } + } + + typeGuesses.foreachEntry((k, v) => + v.depends.filter(typeGuesses.contains) + .foreach(typeGuesses.get(_).foreach(_.dependents.add(k))) + ) + val updateQueue = mutable.ArrayDeque.from(typeGuesses.keys) + + while (updateQueue.nonEmpty) { + val obj = updateQueue.removeLast() + val guess = typeGuesses(obj) + if (guess.update()) { updateQueue.appendAll(guess.dependents) } + } + + typeGuesses.foreachEntry((e, t) => + e match { + case v: Variable[Pre] => localVariableInferredType(v) = t.currentType + case v: LLVMGlobalVariable[Pre] => + globalVariableInferredType(v) = t.currentType + } + ) } def rewriteLocal(local: LLVMLocal[Pre]): Expr[Post] = { @@ -132,16 +259,14 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) def rewriteFunctionDef(func: LLVMFunctionDefinition[Pre]): Unit = { implicit val o: Origin = func.o - val importedDecl = rw.importedDeclarations.find { - case procedure: Procedure[Pre] => - func.contract.name == procedure.o.get[SourceName].name - } val procedure = rw.labelDecls.scope { - rw.globalDeclarations.declare(if (importedDecl.isDefined) { - val importedProcedure = importedDecl.get.asInstanceOf[Procedure[Pre]] - val newArgs = importedProcedure.args.map { it => it.rewriteDefault() } + val newArgs = func.importedArguments.getOrElse(func.args).map { it => + it.rewriteDefault() + } + rw.globalDeclarations.declare( new Procedure[Post]( - returnType = rw.dispatch(importedProcedure.returnType), + returnType = rw + .dispatch(func.importedReturnType.getOrElse(func.returnType)), args = rw.variables.collect { func.args.zip(newArgs).foreach { case (a, b) => @@ -165,28 +290,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) contract = rw.dispatch(func.contract.data.get), pure = func.pure, )(func.blame) - } else { - new Procedure[Post]( - returnType = rw.dispatch(func.returnType), - args = rw.variables.collect { func.args.foreach(rw.dispatch) }._1, - outArgs = Nil, - typeArgs = Nil, - body = - func.functionBody match { - case None => None - case Some(functionBody) => - if (func.pure) - Some(GotoEliminator(functionBody match { - case scope: Scope[Pre] => scope; - case other => throw UnexpectedLLVMNode(other) - }).eliminate()) - else - Some(rw.dispatch(functionBody)) - }, - contract = rw.dispatch(func.contract.data.get), - pure = func.pure, - )(func.blame) - }) + ) } llvmFunctionMap.update(func, procedure) } @@ -295,7 +399,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } }._1, Seq(), - )(t.o.withContent(new TypeName("struct"))) + )(t.o.withContent(TypeName("struct"))) rw.globalDeclarations.declare(newStruct) structMap(t) = newStruct @@ -461,6 +565,10 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) (expr, LLVMTStruct[Pre](name, packed, inner +: elements.tail)) } } + // Save the expensive check for last. This check is for when we're mixing PVL and LLVM types + case LLVMTPointer(Some(inner)) + if rw.dispatch(inner) == rw.dispatch(untilType) => + Some((pointer, currentType)) case _ => None } } @@ -526,69 +634,100 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case LLVMPointerValue(Ref(v)) => globalVariableInferredType .getOrElse(v.asInstanceOf[LLVMGlobalVariable[Pre]], e.t) + case _ => e.t } def rewriteStore(store: LLVMStore[Pre]): Statement[Post] = { implicit val o: Origin = store.o val pointerInferredType = getInferredType(store.pointer) + val valueInferredType = getInferredType(store.value) val (pointer, inferredType) = derefUntil( rw.dispatch(store.pointer), pointerInferredType, - store.value.t, - ).getOrElse(( - Cast( - rw.dispatch(store.pointer), - TypeValue(TPointer(rw.dispatch(store.value.t))), - ), - store.value.t, - )) + valueInferredType, + ).map { case (pointer, typ) => (DerefPointer(pointer)(store.blame), typ) } + .getOrElse { + if (store.value.t.asPointer.isDefined) { + // TODO: How do we deal with this + ??? + } else { + ( + DerefPointer(Cast( + rw.dispatch(store.pointer), + TypeValue(TPointer(rw.dispatch(valueInferredType))), + ))(store.blame), + pointerInferredType, + ) + } + } // TODO: Fix assignfailed blame - Assign(DerefPointer(pointer)(store.blame), rw.dispatch(store.value))( - store.blame - ) + Assign(pointer, rw.dispatch(store.value))(store.blame) } - def rewriteLoad(load: LLVMLoad[Pre]): Expr[Post] = { + def rewriteLoad(load: LLVMLoad[Pre]): Statement[Post] = { implicit val o: Origin = load.o val pointerInferredType = getInferredType(load.pointer) + val destinationInferredType = localVariableInferredType + .getOrElse(load.variable.decl, load.loadType) val (pointer, inferredType) = derefUntil( rw.dispatch(load.pointer), pointerInferredType, - load.loadType, - ).getOrElse(( - Cast(rw.dispatch(load.pointer), TypeValue(rw.dispatch(load.loadType))), - load.loadType, - )) - DerefPointer(pointer)(load.blame) + destinationInferredType, + ).map { case (pointer, typ) => (DerefPointer(pointer)(load.blame), typ) } + .getOrElse { + if (destinationInferredType.asPointer.isDefined) { + // We need to dereference before casting + ( + Cast( + DerefPointer(rw.dispatch(load.pointer))(load.blame), + TypeValue(rw.dispatch(destinationInferredType)), + ), + pointerInferredType, + ) + } else { + ( + DerefPointer(Cast( + rw.dispatch(load.pointer), + TypeValue(TPointer(rw.dispatch(destinationInferredType))), + ))(load.blame), + pointerInferredType, + ) + } + } + assignLocal(Local(rw.succ(load.variable.decl)), pointer) } - def rewriteAllocA(alloc: LLVMAllocA[Pre]): Expr[Post] = { + def rewriteAllocA(alloc: LLVMAllocA[Pre]): Statement[Post] = { implicit val o: Origin = alloc.o - val t = rw.dispatch(alloc.allocationType) - val v = new Variable[Post](TPointer(t))(alloc.o) - alloc.allocationType match { + val t = + localVariableInferredType.getOrElse( + alloc.variable.decl, + LLVMTPointer(Some(alloc.allocationType)), + ).asPointer.get.element + val newT = rw.dispatch(t) + val v = Local[Post](rw.succ(alloc.variable.decl)) + val elements = rw.dispatch(alloc.numElements) + t match { case structType: LLVMTStruct[Pre] => - With( - Block(Seq( - LocalDecl(v), - assignLocal( - v.get, - NewPointerArray[Post]( - rw.dispatch(alloc.allocationType), - rw.dispatch(alloc.numElements), - )(PanicBlame("allocation should never fail")), - ), - Assign( - DerefPointer(v.get)(alloc.o), - NewObject[Post](structMap.ref(structType)), - )(PanicBlame("assignment should never fail")), + Block(Seq( + assignLocal( + v, + NewPointerArray[Post](newT, elements)(PanicBlame( + "allocation should never fail" + )), + ), + Assign( + DerefPointer(v)(PanicBlame("pointer is framed in allocation")), + NewObject[Post](structMap.ref(structType)), + )(PanicBlame("assignment should never fail")), + )) + case _ => + assignLocal( + v, + NewPointerArray[Post](newT, elements)(PanicBlame( + "allocation should never fail" )), - v.get, ) - case _ => - NewPointerArray[Post](t, rw.dispatch(alloc.numElements))(PanicBlame( - "allocation should never fail" - )) } } diff --git a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala index b28e680536..855b23a150 100644 --- a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala @@ -22,12 +22,6 @@ case object LangSpecificToCol extends RewriterBuilderArg2[Boolean, Boolean] { override def desc: String = "Translate language-specific constructs to a common subset of nodes." - override def apply[Pre <: Generation]( - veymontGeneratePermissions: Boolean, - veymontAllowAssign: Boolean, - ): AbstractRewriter[Pre, _ <: Generation] = - LangSpecificToCol(veymontGeneratePermissions, veymontAllowAssign, Seq()) - def ThisVar(): Origin = Origin(Seq(PreferredName(Seq("this")), LabelContext("constructor this"))) @@ -41,7 +35,6 @@ case object LangSpecificToCol extends RewriterBuilderArg2[Boolean, Boolean] { case class LangSpecificToCol[Pre <: Generation]( generatePermissions: Boolean = false, veymontAllowAssign: Boolean = false, - importedDeclarations: Seq[GlobalDeclaration[Pre]] = Seq(), ) extends Rewriter[Pre] with LazyLogging { val java: LangJavaToCol[Pre] = LangJavaToCol(this) val bip: LangBipToCol[Pre] = LangBipToCol(this) @@ -282,7 +275,9 @@ case class LangSpecificToCol[Pre <: Generation]( cpp.checkPredicateFoldingAllowed(unfold.res) unfold.rewriteDefault() + case load: LLVMLoad[Pre] => llvm.rewriteLoad(load) case store: LLVMStore[Pre] => llvm.rewriteStore(store) + case alloc: LLVMAllocA[Pre] => llvm.rewriteAllocA(alloc) case other => other.rewriteDefault() } @@ -404,8 +399,6 @@ case class LangSpecificToCol[Pre <: Generation]( llvm.rewriteFunctionPointer(pointer) case pointer: LLVMPointerValue[Pre] => llvm.rewritePointerValue(pointer) case gep: LLVMGetElementPointer[Pre] => llvm.rewriteGetElementPointer(gep) - case load: LLVMLoad[Pre] => llvm.rewriteLoad(load) - case alloc: LLVMAllocA[Pre] => llvm.rewriteAllocA(alloc) case int: LLVMIntegerValue[Pre] => IntegerValue(int.value)(int.o) case other => rewriteDefault(other) diff --git a/src/serialize/vct/col/ast/Origin.proto b/src/serialize/vct/col/ast/Origin.proto index 2de2eb1ef0..5d5e0ad1a7 100644 --- a/src/serialize/vct/col/ast/Origin.proto +++ b/src/serialize/vct/col/ast/Origin.proto @@ -22,6 +22,7 @@ message OriginContent { Context context = 3; ReadableOrigin readable_origin = 4; PositionRange position_range = 5; + LabelContext label_context = 6; } } @@ -52,3 +53,7 @@ message PositionRange { optional sint32 start_col_idx = 3; optional sint32 end_col_idx = 4; } + +message LabelContext { + required string label = 1; +} \ No newline at end of file From 05be047354be7fb0eaaeec42d7d6b901aa5215ba Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Mon, 2 Sep 2024 17:14:49 +0200 Subject: [PATCH 34/47] Fix crash when transforming fib.ll --- .../vct/rewrite/lang/LangLLVMToCol.scala | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index 47cc92ea6f..21f273a86e 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -156,11 +156,11 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } } - def getVariable(expr: Expr[Pre]): Object = { + def getVariable(expr: Expr[Pre]): Option[Object] = { expr match { - case Local(Ref(v)) => v - case LLVMPointerValue(Ref(g)) => g - case _ => ??? + case Local(Ref(v)) => Some(v) + case LLVMPointerValue(Ref(g)) => Some(g) + case _ => None } } @@ -190,34 +190,39 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) _ => LLVMTPointer(Some(alloc.allocationType)), ) case gep: LLVMGetElementPointer[Pre] => - addTypeGuess( - getVariable(gep.pointer), - Set.empty, - _ => LLVMTPointer(Some(gep.structureType)), + getVariable(gep.pointer).foreach(v => + addTypeGuess(v, Set.empty, _ => LLVMTPointer(Some(gep.structureType))) ) case load: LLVMLoad[Pre] => - addTypeGuess( - getVariable(load.pointer), - Set(load.variable.decl), - _ => - LLVMTPointer(Some( - typeGuesses.get(load.variable.decl).map(_.currentType) - .getOrElse(load.variable.decl.t) - )), + getVariable(load.pointer).foreach(v => + addTypeGuess( + v, + Set(load.variable.decl), + _ => + LLVMTPointer(Some( + typeGuesses.get(load.variable.decl).map(_.currentType) + .getOrElse(load.variable.decl.t) + )), + ) ) addTypeGuess(load.variable.decl, Set.empty, _ => load.variable.decl.t) case store: LLVMStore[Pre] => val dependencies = findDependencies(store.value) - addTypeGuess( - getVariable(store.pointer), - dependencies, - _ => - LLVMTPointer(Some(replaceWithGuesses(store.value, dependencies).t)), + getVariable(store.pointer).foreach(v => + addTypeGuess( + v, + dependencies, + _ => + LLVMTPointer( + Some(replaceWithGuesses(store.value, dependencies).t) + ), + ) ) case inv: LLVMFunctionInvocation[Pre] => inv.ref.decl.importedArguments.getOrElse(inv.ref.decl.args).zipWithIndex .foreach { case (a, i) => - addTypeGuess(getVariable(inv.args(i)), Set.empty, _ => a.t) + getVariable(inv.args(i)) + .foreach(v => addTypeGuess(v, Set.empty, _ => a.t)) } } From 6ebf84b01a6ce226b05f37a8355575315f65a04d Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Wed, 4 Sep 2024 10:42:48 +0200 Subject: [PATCH 35/47] Convert LLVM loops into COL --- src/col/vct/col/ast/Node.scala | 29 ++++++----- .../ast/lang/llvm/LLVMBasicBlockImpl.scala | 11 ++++ .../ast/lang/llvm/LLVMLoopContractImpl.scala | 9 ---- .../vct/col/ast/lang/llvm/LLVMLoopImpl.scala | 5 +- .../ast/lang/llvm/LLVMLoopInvariantImpl.scala | 9 ---- src/col/vct/col/resolve/Resolve.scala | 17 ++++++- .../ctx/ReferenceResolutionContext.scala | 2 + src/col/vct/col/resolve/lang/LLVM.scala | 4 ++ .../vct/col/typerules/CoercingRewriter.scala | 5 +- src/llvm/include/Origin/OriginProvider.h | 5 ++ .../Passes/Function/FunctionBodyTransformer.h | 2 +- src/llvm/lib/Origin/OriginProvider.cpp | 50 +++++++++++++++---- .../Function/FunctionBodyTransformer.cpp | 15 +++--- src/llvm/lib/Transform/BlockTransform.cpp | 34 +++++++++++-- .../Transform/Instruction/TermOpTransform.cpp | 6 +-- .../ResolveExpressionSideEffects.scala | 23 --------- .../vct/rewrite/lang/LangLLVMToCol.scala | 47 ++++++++++++++++- .../vct/rewrite/lang/LangSpecificToCol.scala | 3 ++ 18 files changed, 187 insertions(+), 89 deletions(-) create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMBasicBlockImpl.scala delete mode 100644 src/col/vct/col/ast/lang/llvm/LLVMLoopContractImpl.scala delete mode 100644 src/col/vct/col/ast/lang/llvm/LLVMLoopInvariantImpl.scala diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 49d578010d..637dc32d48 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -3555,22 +3555,25 @@ final case class LLVMFunctionInvocation[G]( )(val blame: Blame[InvocationFailure])(implicit val o: Origin) extends Apply[G] with LLVMFunctionInvocationImpl[G] -final case class LLVMLoop[G]( - cond: Expr[G], - contract: LLVMLoopContract[G], - body: Statement[G], +final class LLVMBasicBlock[G]( + val label: LabelDecl[G], + val loop: Option[LLVMLoop[G]], + val body: Statement[G], )(implicit val o: Origin) - extends CompositeStatement[G] with LLVMLoopImpl[G] + extends LLVMStatement[G] with LLVMBasicBlockImpl[G] @family -sealed trait LLVMLoopContract[G] - extends NodeFamily[G] with LLVMLoopContractImpl[G] - -final case class LLVMLoopInvariant[G]( - value: String, - references: Seq[(String, Ref[G, Declaration[G]])], -)(val blame: Blame[LoopInvariantFailure])(implicit val o: Origin) - extends LLVMLoopContract[G] with LLVMLoopInvariantImpl[G] +final case class LLVMLoop[G]( + contract: LoopContract[G], + header: Ref[G, LabelDecl[G]], + latch: Ref[G, LabelDecl[G]], + blockLabels: Seq[Ref[G, LabelDecl[G]]], +)(implicit val o: Origin) + extends NodeFamily[G] with LLVMLoopImpl[G] { + var headerBlock: Option[LLVMBasicBlock[G]] = None + var latchBlock: Option[LLVMBasicBlock[G]] = None + var blocks: Option[Seq[LLVMBasicBlock[G]]] = None +} sealed trait LLVMStatement[G] extends Statement[G] with LLVMStatementImpl[G] diff --git a/src/col/vct/col/ast/lang/llvm/LLVMBasicBlockImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMBasicBlockImpl.scala new file mode 100644 index 0000000000..1f27dc0183 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMBasicBlockImpl.scala @@ -0,0 +1,11 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.LLVMBasicBlock +import vct.col.ast.ops.LLVMBasicBlockOps +import vct.col.check.{CheckContext, CheckError} + +trait LLVMBasicBlockImpl[G] extends LLVMBasicBlockOps[G] { + this: LLVMBasicBlock[G] => + + override def check(context: CheckContext[G]): Seq[CheckError] = Nil +} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMLoopContractImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLoopContractImpl.scala deleted file mode 100644 index 17bcc24632..0000000000 --- a/src/col/vct/col/ast/lang/llvm/LLVMLoopContractImpl.scala +++ /dev/null @@ -1,9 +0,0 @@ -package vct.col.ast.lang.llvm - -import vct.col.ast.LLVMLoopContract -import vct.col.ast.ops.LLVMLoopContractFamilyOps - -trait LLVMLoopContractImpl[G] extends LLVMLoopContractFamilyOps[G] { - this: LLVMLoopContract[G] => - -} diff --git a/src/col/vct/col/ast/lang/llvm/LLVMLoopImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLoopImpl.scala index 49c130ba3a..095c2edbec 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMLoopImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMLoopImpl.scala @@ -2,8 +2,11 @@ package vct.col.ast.lang.llvm import vct.col.ast.LLVMLoop import vct.col.ast.ops.LLVMLoopOps +import vct.col.check.{CheckContext, CheckError} +import vct.col.ast.ops.{LLVMLoopOps, LLVMLoopFamilyOps} -trait LLVMLoopImpl[G] extends LLVMLoopOps[G] { +trait LLVMLoopImpl[G] extends LLVMLoopOps[G] with LLVMLoopFamilyOps[G] { this: LLVMLoop[G] => + override def check(context: CheckContext[G]): Seq[CheckError] = Nil } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMLoopInvariantImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLoopInvariantImpl.scala deleted file mode 100644 index ee4eeb2c2e..0000000000 --- a/src/col/vct/col/ast/lang/llvm/LLVMLoopInvariantImpl.scala +++ /dev/null @@ -1,9 +0,0 @@ -package vct.col.ast.lang.llvm - -import vct.col.ast.LLVMLoopInvariant -import vct.col.ast.ops.LLVMLoopInvariantOps - -trait LLVMLoopInvariantImpl[G] extends LLVMLoopInvariantOps[G] { - this: LLVMLoopInvariant[G] => - -} diff --git a/src/col/vct/col/resolve/Resolve.scala b/src/col/vct/col/resolve/Resolve.scala index 8b839a2b41..c39f2e972f 100644 --- a/src/col/vct/col/resolve/Resolve.scala +++ b/src/col/vct/col/resolve/Resolve.scala @@ -557,7 +557,10 @@ case object ResolveReferences extends LazyLogging { func.contract.givenArgs ++ func.contract.yieldsArgs ) case func: LLVMFunctionDefinition[G] => - ctx.copy(currentResult = Some(RefLLVMFunctionDefinition(func))) + ctx.copy( + currentResult = Some(RefLLVMFunctionDefinition(func)), + llvmBlocks = LLVM.scanBlocks(func), + ) case func: LLVMSpecFunction[G] => ctx.copy(currentResult = Some(RefLLVMSpecFunction(func))) .declare(func.args) @@ -1127,6 +1130,18 @@ case object ResolveReferences extends LazyLogging { func.importedArguments = Some(importedProcedure.args) func.importedReturnType = Some(importedProcedure.returnType) } + case loop: LLVMLoop[G] => + loop.blocks = Some(loop.blockLabels.map { label => + ctx.llvmBlocks.get(label.decl) match { + case Some(block) => block + case None => + throw Unreachable( + "LLVM Loop information must only refer to basic blocks contained in the function" + ) + } + }) + loop.headerBlock = Some(ctx.llvmBlocks(loop.header.decl)) + loop.latchBlock = Some(ctx.llvmBlocks(loop.latch.decl)) case contract: LLVMFunctionContract[G] => // implicit val o: Origin = contract.o val llvmFunction = diff --git a/src/col/vct/col/resolve/ctx/ReferenceResolutionContext.scala b/src/col/vct/col/resolve/ctx/ReferenceResolutionContext.scala index 3d88794ea6..9e660a24a4 100644 --- a/src/col/vct/col/resolve/ctx/ReferenceResolutionContext.scala +++ b/src/col/vct/col/resolve/ctx/ReferenceResolutionContext.scala @@ -33,6 +33,8 @@ case class ReferenceResolutionContext[G]( // When true and resolving a local, guard names should also be considered javaBipGuardsEnabled: Boolean = false, typeEnv: Map[Variable[G], Type[G]] = Map.empty[Variable[G], Type[G]], + llvmBlocks: Map[LabelDecl[G], LLVMBasicBlock[G]] = Map + .empty[LabelDecl[G], LLVMBasicBlock[G]], ) { def asTypeResolutionContext: TypeResolutionContext[G] = TypeResolutionContext( diff --git a/src/col/vct/col/resolve/lang/LLVM.scala b/src/col/vct/col/resolve/lang/LLVM.scala index f81a8bbfbf..ca397c788d 100644 --- a/src/col/vct/col/resolve/lang/LLVM.scala +++ b/src/col/vct/col/resolve/lang/LLVM.scala @@ -35,4 +35,8 @@ object LLVM { } } + def scanBlocks[G](node: Node[G]): Map[LabelDecl[G], LLVMBasicBlock[G]] = { + node.collect { case b: LLVMBasicBlock[G] => (b.label, b) }.toMap + } + } diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index 1581fa56eb..3920a11599 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -357,7 +357,6 @@ abstract class CoercingRewriter[Pre <: Generation]() case node: BipGlueDataWire[Pre] => node case node: BipTransitionSignature[Pre] => node case node: LLVMFunctionContract[Pre] => node - case node: LLVMLoopContract[Pre] => node case node: LLVMMemoryOrdering[Pre] => node case node: ProverLanguage[Pre] => node case node: SmtlibFunctionSymbol[Pre] => node @@ -2270,8 +2269,6 @@ abstract class CoercingRewriter[Pre <: Generation]() case l @ Lock(obj) => Lock(cls(obj))(l.blame) case Loop(init, cond, update, contract, body) => Loop(init, bool(cond), update, contract, body) - case LLVMLoop(cond, contract, body) => - LLVMLoop(bool(cond), contract, body) case LLVMAllocA(variable, allocationType, numElements) => LLVMAllocA(variable, allocationType, int(numElements)) case load @ LLVMLoad(variable, loadType, p, ordering) => @@ -2977,7 +2974,7 @@ abstract class CoercingRewriter[Pre <: Generation]() def coerce(node: JavaBipGlueName[Pre]): JavaBipGlueName[Pre] = node def coerce(node: LLVMFunctionContract[Pre]): LLVMFunctionContract[Pre] = node - def coerce(node: LLVMLoopContract[Pre]): LLVMLoopContract[Pre] = node + def coerce(node: LLVMLoop[Pre]): LLVMLoop[Pre] = node def coerce(node: LLVMMemoryOrdering[Pre]): LLVMMemoryOrdering[Pre] = node def coerce(node: ProverLanguage[Pre]): ProverLanguage[Pre] = node diff --git a/src/llvm/include/Origin/OriginProvider.h b/src/llvm/include/Origin/OriginProvider.h index 2fc05ad316..7970ce789e 100644 --- a/src/llvm/include/Origin/OriginProvider.h +++ b/src/llvm/include/Origin/OriginProvider.h @@ -2,6 +2,7 @@ #define PALLAS_ORIGINPROVIDER_H #include "vct/col/ast/Origin.pb.h" +#include #include #include #include @@ -16,6 +17,8 @@ namespace llvm2col { namespace col = vct::col::ast; +col::Origin *generateLabelledOrigin(const std::string label); + col::Origin *generateProgramOrigin(llvm::Module &llvmModule); col::Origin *generateFuncDefOrigin(llvm::Function &llvmFunction); @@ -32,6 +35,8 @@ col::Origin *generateBlockOrigin(llvm::BasicBlock &llvmBlock); col::Origin *generateLabelOrigin(llvm::BasicBlock &llvmBlock); +col::Origin *generateLoopOrigin(llvm::Loop &llvmLoop); + col::Origin *generateSingleStatementOrigin(llvm::Instruction &llvmInstruction); col::Origin *generateAssignTargetOrigin(llvm::Instruction &llvmInstruction); diff --git a/src/llvm/include/Passes/Function/FunctionBodyTransformer.h b/src/llvm/include/Passes/Function/FunctionBodyTransformer.h index ce34ab2688..3b5b6938b3 100644 --- a/src/llvm/include/Passes/Function/FunctionBodyTransformer.h +++ b/src/llvm/include/Passes/Function/FunctionBodyTransformer.h @@ -16,7 +16,7 @@ using namespace llvm; namespace col = vct::col::ast; struct LabeledColBlock { - col::Label &label; + col::LlvmBasicBlock &bb; col::Block █ }; diff --git a/src/llvm/lib/Origin/OriginProvider.cpp b/src/llvm/lib/Origin/OriginProvider.cpp index d16c42de89..00d24d7960 100644 --- a/src/llvm/lib/Origin/OriginProvider.cpp +++ b/src/llvm/lib/Origin/OriginProvider.cpp @@ -8,6 +8,15 @@ namespace col = vct::col::ast; +col::Origin *llvm2col::generateLabelledOrigin(const std::string label) { + col::Origin *origin = new col::Origin(); + col::OriginContent *labelContent = origin->add_content(); + col::LabelContext *labelContext = labelContent->mutable_label_context(); + labelContext->set_label(label); + + return origin; +} + col::Origin *llvm2col::generateProgramOrigin(llvm::Module &llvmModule) { col::Origin *origin = new col::Origin(); col::OriginContent *preferredNameContent = origin->add_content(); @@ -123,9 +132,8 @@ col::Origin *llvm2col::generateLabelOrigin(llvm::BasicBlock &llvmBlock) { return origin; } -bool generateDebugOrigin(llvm::Instruction &llvmInstruction, - col::Origin *origin) { - const llvm::DebugLoc &loc = llvmInstruction.getDebugLoc(); +bool generateDebugOrigin(const llvm::DebugLoc &loc, col::Origin *origin, + const llvm::DebugLoc &endLoc = NULL) { if (!loc) return false; int line = loc.getLine() - 1; @@ -133,8 +141,14 @@ bool generateDebugOrigin(llvm::Instruction &llvmInstruction, col::OriginContent *positionRangeContent = origin->add_content(); col::PositionRange *positionRange = new col::PositionRange(); positionRange->set_start_line_idx(line); - positionRange->set_end_line_idx(line); positionRange->set_start_col_idx(col); + if (endLoc) { + positionRange->set_end_line_idx(endLoc.getLine() - 1); + // Would it be better without setting the end col? + positionRange->set_end_col_idx(endLoc.getCol() - 1); + } else { + positionRange->set_end_line_idx(line); + } positionRangeContent->set_allocated_position_range(positionRange); auto *scope = llvm::cast(loc.getScope()); auto *file = scope->getFile(); @@ -170,6 +184,21 @@ bool generateDebugOrigin(llvm::Instruction &llvmInstruction, return true; } +col::Origin *llvm2col::generateLoopOrigin(llvm::Loop &llvmLoop) { + col::Origin *origin = new col::Origin(); + llvm::Loop::LocRange range = llvmLoop.getLocRange(); + if (!generateDebugOrigin(range.getStart(), origin, range.getEnd())) { + llvm::BasicBlock *llvmBlock = llvmLoop.getHeader(); + col::OriginContent *contextContent = origin->add_content(); + col::Context *context = new col::Context(); + context->set_context(deriveBlockContext(*llvmBlock)); + context->set_inline_context(deriveBlockContext(*llvmBlock)); + context->set_short_position(deriveBlockShortPosition(*llvmBlock)); + contextContent->set_allocated_context(context); + } + return origin; +} + col::Origin * llvm2col::generateSingleStatementOrigin(llvm::Instruction &llvmInstruction) { col::Origin *origin = new col::Origin(); @@ -179,7 +208,7 @@ llvm2col::generateSingleStatementOrigin(llvm::Instruction &llvmInstruction) { deriveOperandPreferredName(llvmInstruction)); preferredNameContent->set_allocated_preferred_name(preferredName); - if (!generateDebugOrigin(llvmInstruction, origin)) { + if (!generateDebugOrigin(llvmInstruction.getDebugLoc(), origin)) { col::OriginContent *contextContent = origin->add_content(); col::Context *context = new col::Context(); context->set_context( @@ -201,7 +230,7 @@ llvm2col::generateAssignTargetOrigin(llvm::Instruction &llvmInstruction) { preferredName->add_preferred_name("var"); preferredNameContent->set_allocated_preferred_name(preferredName); - if (!generateDebugOrigin(llvmInstruction, origin)) { + if (!generateDebugOrigin(llvmInstruction.getDebugLoc(), origin)) { col::OriginContent *contextContent = origin->add_content(); col::Context *context = new col::Context(); context->set_context(deriveInstructionContext(llvmInstruction)); @@ -217,7 +246,7 @@ llvm2col::generateAssignTargetOrigin(llvm::Instruction &llvmInstruction) { col::Origin * llvm2col::generateBinExprOrigin(llvm::Instruction &llvmInstruction) { col::Origin *origin = new col::Origin(); - if (!generateDebugOrigin(llvmInstruction, origin)) { + if (!generateDebugOrigin(llvmInstruction.getDebugLoc(), origin)) { col::OriginContent *contextContent = origin->add_content(); col::Context *context = new col::Context(); context->set_context( @@ -240,7 +269,7 @@ llvm2col::generateFunctionCallOrigin(llvm::CallInst &callInstruction) { callInstruction.getCalledFunction()->getName().str()); preferredNameContent->set_allocated_preferred_name(preferredName); - if (!generateDebugOrigin(callInstruction, origin)) { + if (!generateDebugOrigin(callInstruction.getDebugLoc(), origin)) { col::OriginContent *contextContent = origin->add_content(); col::Context *context = new col::Context(); context->set_context( @@ -262,7 +291,7 @@ col::Origin *llvm2col::generateOperandOrigin(llvm::Instruction &llvmInstruction, preferredName->add_preferred_name(deriveOperandPreferredName(llvmOperand)); preferredNameContent->set_allocated_preferred_name(preferredName); - if (!generateDebugOrigin(llvmInstruction, origin)) { + if (!generateDebugOrigin(llvmInstruction.getDebugLoc(), origin)) { col::OriginContent *contextContent = origin->add_content(); col::Context *context = new col::Context(); context->set_context(deriveInstructionContext(llvmInstruction)); @@ -320,9 +349,8 @@ llvm2col::generateVoidOperandOrigin(llvm::Instruction &llvmInstruction) { col::PreferredName *preferredName = new col::PreferredName(); preferredName->add_preferred_name("void"); preferredNameContent->set_allocated_preferred_name(preferredName); - generateDebugOrigin(llvmInstruction, origin); - if (!generateDebugOrigin(llvmInstruction, origin)) { + if (!generateDebugOrigin(llvmInstruction.getDebugLoc(), origin)) { col::OriginContent *contextContent = origin->add_content(); col::Context *context = new col::Context(); context->set_context(deriveInstructionContext(llvmInstruction)); diff --git a/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp b/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp index 928f1f201a..5905d6ac08 100644 --- a/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp +++ b/src/llvm/lib/Passes/Function/FunctionBodyTransformer.cpp @@ -79,23 +79,24 @@ bool FunctionCursor::isComplete(col::Block &colBlock) { LabeledColBlock & FunctionCursor::getOrSetLLVMBlock2LabeledColBlockEntry(BasicBlock &llvmBlock) { if (!llvmBlock2LabeledColBlock.contains(&llvmBlock)) { - // create label in buffer - col::Label *label = functionBody.add_statements()->mutable_label(); - // set label origin - label->set_allocated_origin(llvm2col::generateLabelOrigin(llvmBlock)); + // create basic block in buffer + col::LlvmBasicBlock *bb = + functionBody.add_statements()->mutable_llvm_basic_block(); + // set basic block origin + bb->set_allocated_origin(llvm2col::generateLabelOrigin(llvmBlock)); // create label declaration in buffer - col::LabelDecl *labelDecl = label->mutable_decl(); + col::LabelDecl *labelDecl = bb->mutable_label(); // set label decl origin labelDecl->set_allocated_origin( llvm2col::generateLabelOrigin(llvmBlock)); // set label decl id llvm2col::setColNodeId(labelDecl); // create block inside label statement - col::Block *block = label->mutable_stat()->mutable_block(); + col::Block *block = bb->mutable_body()->mutable_block(); // set block origin block->set_allocated_origin(llvm2col::generateBlockOrigin(llvmBlock)); // add labeled block to the block2block lut - LabeledColBlock labeledColBlock = {*label, *block}; + LabeledColBlock labeledColBlock = {*bb, *block}; llvmBlock2LabeledColBlock.insert({&llvmBlock, labeledColBlock}); } return llvmBlock2LabeledColBlock.at(&llvmBlock); diff --git a/src/llvm/lib/Transform/BlockTransform.cpp b/src/llvm/lib/Transform/BlockTransform.cpp index 9ce1085009..122d71f957 100644 --- a/src/llvm/lib/Transform/BlockTransform.cpp +++ b/src/llvm/lib/Transform/BlockTransform.cpp @@ -1,5 +1,6 @@ #include "Transform/BlockTransform.h" +#include "Origin/OriginProvider.h" #include "Transform/Instruction/BinaryOpTransform.h" #include "Transform/Instruction/CastOpTransform.h" #include "Transform/Instruction/FuncletPadOpTransform.h" @@ -16,15 +17,38 @@ void llvm2col::transformLLVMBlock(llvm::BasicBlock &llvmBlock, if (functionCursor.isVisited(llvmBlock)) { return; } - col::Block &colBlock = functionCursor.visitLLVMBlock(llvmBlock).block; + pallas::LabeledColBlock &labeled = functionCursor.visitLLVMBlock(llvmBlock); + col::Block &colBlock = labeled.block; /* for (auto *B : llvm::predecessors(&llvmBlock)) { */ /* if (!functionCursor.isVisited(*B)) */ /* return; */ /* } */ - /* if (functionCursor.getLoopInfo().isLoopHeader(&llvmBlock)) { */ - /* transformLoop(llvmBlock, functionCursor); */ - /* return; */ - /* } */ + if (functionCursor.getLoopInfo().isLoopHeader(&llvmBlock)) { + llvm::Loop *llvmLoop = + functionCursor.getLoopInfo().getLoopFor(&llvmBlock); + col::LlvmLoop *loop = labeled.bb.mutable_loop(); + loop->set_allocated_origin(generateLoopOrigin(*llvmLoop)); + col::LoopContract *contract = loop->mutable_contract(); + col::LoopInvariant *invariant = contract->mutable_loop_invariant(); + col::BooleanValue *tt = + invariant->mutable_invariant()->mutable_boolean_value(); + tt->set_value(true); + tt->set_allocated_origin(generateLabelledOrigin("constant true")); + invariant->set_allocated_origin( + generateLabelledOrigin("constant true")); + invariant->mutable_blame(); + + loop->mutable_header()->set_id(labeled.bb.label().id()); + pallas::LabeledColBlock labeled_latch = + functionCursor.getOrSetLLVMBlock2LabeledColBlockEntry( + *llvmLoop->getLoopLatch()); + loop->mutable_latch()->set_id(labeled_latch.bb.label().id()); + for (auto &bb : llvmLoop->blocks()) { + pallas::LabeledColBlock labeled_bb = + functionCursor.getOrSetLLVMBlock2LabeledColBlockEntry(*bb); + loop->add_block_labels()->set_id(labeled_bb.bb.label().id()); + } + } for (auto &I : llvmBlock) { transformInstruction(functionCursor, I, colBlock); } diff --git a/src/llvm/lib/Transform/Instruction/TermOpTransform.cpp b/src/llvm/lib/Transform/Instruction/TermOpTransform.cpp index 66ec7d8565..8bb1acab7b 100644 --- a/src/llvm/lib/Transform/Instruction/TermOpTransform.cpp +++ b/src/llvm/lib/Transform/Instruction/TermOpTransform.cpp @@ -88,7 +88,7 @@ void llvm2col::transformConditionalBranch(llvm::BranchInst &llvmBrInstruction, funcCursor.getOrSetLLVMBlock2LabeledColBlockEntry(*llvmTrueBlock); // goto statement to true block col::Goto *trueGoto = colTrueBranch->mutable_v2()->mutable_goto_(); - trueGoto->mutable_lbl()->set_id(labeledTrueColBlock.label.decl().id()); + trueGoto->mutable_lbl()->set_id(labeledTrueColBlock.bb.label().id()); // set origin for goto to true block trueGoto->set_allocated_origin( generateSingleStatementOrigin(llvmBrInstruction)); @@ -113,7 +113,7 @@ void llvm2col::transformConditionalBranch(llvm::BranchInst &llvmBrInstruction, funcCursor.getOrSetLLVMBlock2LabeledColBlockEntry(*llvmFalseBlock); // goto statement to false block col::Goto *falseGoto = colFalseBranch->mutable_v2()->mutable_goto_(); - falseGoto->mutable_lbl()->set_id(labeledFalseColBlock.label.decl().id()); + falseGoto->mutable_lbl()->set_id(labeledFalseColBlock.bb.label().id()); // set origin for goto to false block falseGoto->set_allocated_origin( llvm2col::generateSingleStatementOrigin(llvmBrInstruction)); @@ -132,7 +132,7 @@ void llvm2col::transformUnConditionalBranch( funcCursor.getOrSetLLVMBlock2LabeledColBlockEntry(*llvmTargetBlock); // create goto to target labeled block col::Goto *colGoto = colBlock.add_statements()->mutable_goto_(); - colGoto->mutable_lbl()->set_id(labeledColBlock.label.decl().id()); + colGoto->mutable_lbl()->set_id(labeledColBlock.bb.label().id()); // set origin of goto statement colGoto->set_allocated_origin( llvm2col::generateSingleStatementOrigin(llvmBrInstruction)); diff --git a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala index fe24bd8f4e..34d69d79e1 100644 --- a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala +++ b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala @@ -434,29 +434,6 @@ case class ResolveExpressionSideEffects[Pre <: Generation]() case proof: FramedProof[Pre] => rewriteDefault(proof) case extract: Extract[Pre] => rewriteDefault(extract) case branch: IndetBranch[Pre] => rewriteDefault(branch) - case LLVMLoop(cond, contract, body) => - evaluateOne(cond) match { - case (Nil, Nil, cond) => - LLVMLoop(cond, dispatch(contract), dispatch(body)) - case (variables, sideEffects, cond) => - val break = new LabelDecl[Post]()(BreakOrigin) - Block(Seq( - LLVMLoop( - tt, - dispatch(contract), - Block(Seq( - Scope( - variables, - Block( - sideEffects :+ Branch(Seq(Not(cond) -> Goto(break.ref))) - ), - ), - dispatch(body), - )), - ), - Label(break, Block(Nil)), - )) - } case rangedFor: RangedFor[Pre] => rewriteDefault(rangedFor) case assign: VeyMontAssignExpression[Pre] => rewriteDefault(assign) case comm: CommunicateX[Pre] => rewriteDefault(comm) diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index 21f273a86e..407e0ae0f6 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -2,11 +2,11 @@ package vct.rewrite.lang import com.typesafe.scalalogging.LazyLogging import vct.col.ast._ -import vct.col.origin.{Origin, PanicBlame, TypeName} +import vct.col.origin.{DiagnosticOrigin, Origin, PanicBlame, TypeName} import vct.col.ref.{DirectRef, LazyRef, Ref} import vct.col.resolve.ctx.RefLLVMFunctionDefinition import vct.col.rewrite.{Generation, Rewritten} -import vct.col.util.AstBuildHelpers.assignLocal +import vct.col.util.AstBuildHelpers._ import vct.col.util.{CurrentProgramContext, SubstituteReferences, SuccessionMap} import vct.result.VerificationError.{SystemError, UserError} @@ -57,6 +57,15 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) : mutable.HashMap[LLVMGlobalVariable[Pre], Type[Pre]] = mutable.HashMap() private val localVariableInferredType : mutable.HashMap[Variable[Pre], Type[Pre]] = mutable.HashMap() + private val loopBlocks: mutable.ArrayBuffer[LLVMBasicBlock[Pre]] = mutable + .ArrayBuffer() + private val elidedBackEdges: mutable.Set[LabelDecl[Pre]] = mutable.Set() + + def gatherBackEdges(program: Program[Pre]): Unit = { + program.collect { case loop: LLVMLoop[Pre] => + elidedBackEdges.add(loop.header.decl) + } + } def gatherTypeHints(program: Program[Pre]): Unit = { // TODO: We also need to do something where we only keep structurally distinct types @@ -750,6 +759,40 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) implicit o: Origin ): Expr[Post] = Result[Post](llvmFunctionMap.ref(ref.decl)) + private def blockToLabel(block: LLVMBasicBlock[Pre]): Statement[Post] = + if (elidedBackEdges.contains(block.label)) { rw.dispatch(block.body) } + else { + Label(rw.labelDecls.dispatch(block.label), rw.dispatch(block.body))( + block.o + ) + } + + def rewriteBasicBlock(block: LLVMBasicBlock[Pre]): Statement[Post] = { + if (loopBlocks.contains(block)) + return Block(Nil)(DiagnosticOrigin) + if (block.loop.isEmpty) { blockToLabel(block) } + else { + val loop = block.loop.get + loopBlocks.addAll(loop.blocks.get) + Loop( + Block(Nil)(block.o), + tt[Post], + Block(Nil)(block.o), + rw.dispatch(loop.contract), + Block(blockToLabel(loop.headerBlock.get) +: loop.blocks.get.filterNot { + b => b == loop.headerBlock.get || b == loop.latchBlock.get + }.map(blockToLabel) :+ blockToLabel(loop.latchBlock.get))(block.o), + )(block.o) + } + } + + def rewriteGoto(goto: Goto[Pre]): Statement[Post] = { + if (elidedBackEdges.contains(goto.lbl.decl)) { + // TODO: Verify that the correct block always follows this one + Block(Nil)(goto.o) + } else { goto.rewriteDefault() } + } + /* Elimination works by replacing every goto with the block its referring too effectively transforming the CFG into a tree. More efficient restructuring algorithms but this works for now. diff --git a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala index 855b23a150..349f712c7d 100644 --- a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala @@ -159,6 +159,7 @@ case class LangSpecificToCol[Pre <: Generation]( } override def dispatch(program: Program[Pre]): Program[Post] = { + llvm.gatherBackEdges(program) llvm.gatherTypeHints(program) super.dispatch(program) } @@ -263,6 +264,7 @@ case class LangSpecificToCol[Pre <: Generation]( case CPPDeclarationStatement(decl) => cpp.rewriteLocalDecl(decl) case scope: CPPLifetimeScope[Pre] => cpp.rewriteLifetimeScope(scope) case goto: CGoto[Pre] => c.rewriteGoto(goto) + case goto: Goto[Pre] => llvm.rewriteGoto(goto) case barrier: GpgpuBarrier[Pre] => c.gpuBarrier(barrier) case eval @ Eval(CPPInvocation(_, _, _, _)) => @@ -278,6 +280,7 @@ case class LangSpecificToCol[Pre <: Generation]( case load: LLVMLoad[Pre] => llvm.rewriteLoad(load) case store: LLVMStore[Pre] => llvm.rewriteStore(store) case alloc: LLVMAllocA[Pre] => llvm.rewriteAllocA(alloc) + case block: LLVMBasicBlock[Pre] => llvm.rewriteBasicBlock(block) case other => other.rewriteDefault() } From 55b69e320c07ca68c2d7fef523c0523a2b60a2cd Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Wed, 4 Sep 2024 13:56:40 +0200 Subject: [PATCH 36/47] Fix broken test with LLVM pure functions --- src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index 407e0ae0f6..ab1a31ba7a 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -802,11 +802,11 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) the program. */ case class GotoEliminator(bodyScope: Scope[Pre]) extends LazyLogging { - val labelDeclMap: Map[LabelDecl[Pre], Label[Pre]] = + val labelDeclMap: Map[LabelDecl[Pre], LLVMBasicBlock[Pre]] = bodyScope.body match { case block: Block[Pre] => block.statements.map { - case label: Label[Pre] => (label.decl, label) + case bb: LLVMBasicBlock[Pre] => (bb.label, bb) case other => throw UnexpectedLLVMNode(other) }.toMap case other => throw UnexpectedLLVMNode(other) @@ -820,7 +820,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) scope.body match { case bodyBlock: Block[Pre] => Block[Post](bodyBlock.statements.head match { - case label: Label[Pre] => Seq(eliminate(label)) + case label: LLVMBasicBlock[Pre] => Seq(eliminate(label)) case other => throw UnexpectedLLVMNode(other) })(scope.body.o) case other => throw UnexpectedLLVMNode(other) @@ -830,9 +830,9 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } } - def eliminate(label: Label[Pre]): Block[Post] = { - implicit val o: Origin = label.o - label.stat match { + def eliminate(bb: LLVMBasicBlock[Pre]): Block[Post] = { + implicit val o: Origin = bb.o + bb.body match { case block: Block[Pre] => block.statements.last match { case goto: Goto[Pre] => From 72cb4219966080b6d9aa1db4a7053e3093d88f00 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 13 Sep 2024 13:41:42 +0200 Subject: [PATCH 37/47] Fix unsoundness in pointer cast encoding --- src/rewrite/vct/rewrite/ClassToRef.scala | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 77c3e6b648..15c2ce2b40 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -377,7 +377,6 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { case cls: ByValueClass[Pre] => implicit val o: Origin = cls.o val axiomType = TAxiomatic[Post](byValClassSucc.ref(cls), Nil) - val classType = cls.classType(Nil) var valueAsAxioms: Seq[ADTAxiom[Post]] = Seq() val (fieldFunctions, fieldInverses, fieldTypes) = cls.decls.collect { case field: Field[Pre] => @@ -409,12 +408,21 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { )) valueAsAxioms = - valueAsAxioms ++ unwrapValueAs( - axiomType, - field.t, - newT, - byValFieldSucc.ref(field), - ) + valueAsAxioms ++ + (field.t match { + case t: TByValueClass[Pre] => + // TODO: If there are no fields we should ignore the first field and add the axioms for the second field + t.cls.decl.decls + .collectFirst({ case innerF: InstanceField[Pre] => + unwrapValueAs( + axiomType, + innerF.t, + dispatch(innerF.t), + byValFieldSucc.ref(field), + ) + }).getOrElse(Nil) + case _ => Nil + }) } ( byValFieldSucc(field), From bede2fc89ee8f143fac6453bed502d1235170583 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Tue, 17 Sep 2024 17:18:10 +0200 Subject: [PATCH 38/47] Allow casting back up to "greater" type --- src/rewrite/vct/rewrite/ClassToRef.scala | 66 ++++++++++++------- .../vct/rewrite/DisambiguateLocation.scala | 37 ++++++++++- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 3 + .../viper/api/transform/SilverToCol.scala | 2 +- 4 files changed, 81 insertions(+), 27 deletions(-) diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index 15c2ce2b40..c7f48ef2b1 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -712,30 +712,48 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { implicit o: Origin ): Expr[Post] = { val newT = dispatch(t) - val constraint = forall[Post]( - TNonNullPointer(outerType), - body = { p => - PolarityDependent( - Greater( - CurPerm(PointerLocation(p)(PanicBlame( - "Referring to a non-null pointer should not cause any verification failures" - ))), - NoPerm(), - ) ==> - (InlinePattern(Cast(p, TypeValue(TNonNullPointer(newT)))) === - adtFunctionInvocation( - valueAsFunctions - .getOrElseUpdate(t, makeValueAsFunction(t.toString, newT)) - .ref, - typeArgs = Some((valueAdt.ref(()), Seq(outerType))), - args = Seq(DerefPointer(p)(PanicBlame( - "Pointer deref is safe since the permission is framed" - ))), - )), - tt, - ) - }, - ) + val constraint = + forall[Post]( + TNonNullPointer(outerType), + body = { p => + PolarityDependent( + Greater( + CurPerm(PointerLocation(p)(PanicBlame( + "Referring to a non-null pointer should not cause any verification failures" + ))), + NoPerm(), + ) ==> + (InlinePattern(Cast(p, TypeValue(TNonNullPointer(newT)))) === + adtFunctionInvocation( + valueAsFunctions + .getOrElseUpdate(t, makeValueAsFunction(t.toString, newT)) + .ref, + typeArgs = Some((valueAdt.ref(()), Seq(outerType))), + args = Seq(DerefPointer(p)(PanicBlame( + "Pointer deref is safe since the permission is framed" + ))), + )), + tt, + ) + }, + ) &* forall[Post]( + TNonNullPointer(outerType), + body = { p => + PolarityDependent( + Greater( + CurPerm(PointerLocation(p)(PanicBlame( + "Referring to a non-null pointer should not cause any verification failures" + ))), + NoPerm(), + ) ==> + (InlinePattern(Cast( + Cast(p, TypeValue(TNonNullPointer(newT))), + TypeValue(TNonNullPointer(outerType)), + )) === p), + tt, + ) + }, + ) if (t.isInstanceOf[TByValueClass[Pre]]) { constraint &* diff --git a/src/rewrite/vct/rewrite/DisambiguateLocation.scala b/src/rewrite/vct/rewrite/DisambiguateLocation.scala index f6e40c4f94..b6f8fe71e6 100644 --- a/src/rewrite/vct/rewrite/DisambiguateLocation.scala +++ b/src/rewrite/vct/rewrite/DisambiguateLocation.scala @@ -2,8 +2,17 @@ package vct.col.rewrite import vct.col.ast._ import vct.col.rewrite.DisambiguateLocation.NotALocation -import vct.col.origin.{Blame, Origin, PointerLocationError} +import vct.col.origin.{ + Blame, + Origin, + PanicBlame, + PointerAddError, + PointerBounds, + PointerLocationError, + PointerNull, +} import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} +import vct.col.util.AstBuildHelpers.const import vct.result.VerificationError.UserError case object DisambiguateLocation extends RewriterBuilder { @@ -25,6 +34,18 @@ case object DisambiguateLocation extends RewriterBuilder { "This expression is not a heap location." + hint.getOrElse("") ) } + + case class PointerAddRedirect(blame: Blame[PointerLocationError]) + extends Blame[PointerAddError] { + override def blame(error: PointerAddError): Unit = + error match { + case nil: PointerNull => blame.blame(nil) + // It should not be possible to acquire an out-of-bounds pointer and pass it around + case bounds: PointerBounds => + PanicBlame("Got location of pointer that was out of bounds") + } + } + override def key: String = "disambiguateLocation" override def desc: String = @@ -32,12 +53,24 @@ case object DisambiguateLocation extends RewriterBuilder { } case class DisambiguateLocation[Pre <: Generation]() extends Rewriter[Pre] { + import DisambiguateLocation._ + def exprToLoc(expr: Expr[Pre], blame: Blame[PointerLocationError])( implicit o: Origin ): Location[Post] = expr match { case expr if expr.t.asPointer.isDefined => - PointerLocation(dispatch(expr))(blame) + expr match { + case e: PointerAdd[Pre] => PointerLocation(dispatch(e))(blame) + // Adding ptr + 0 for triggering purposes (is there a better place to do this transformation?) + case e => + PointerLocation( + PointerAdd[Post](dispatch(e), const[Post](0))(PointerAddRedirect( + blame + )) + )(blame) + } + case expr if expr.t.isInstanceOf[TByValueClass[Pre]] => ByValueClassLocation(dispatch(expr))(blame) case DerefHeapVariable(ref) => HeapVariableLocation(succ(ref.decl)) diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index 3b98839ed6..c56f4f2cb1 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -420,6 +420,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) CoercionUtils.firstElementIsType( newE.t.asPointer.get.element, newT.asPointer.get.element, + ) || CoercionUtils.firstElementIsType( + newT.asPointer.get.element, + newE.t.asPointer.get.element, ) ) { Cast(newE, TypeValue(newT)(t.o))(c.o) } else { throw UnsupportedCast(c) } diff --git a/src/viper/viper/api/transform/SilverToCol.scala b/src/viper/viper/api/transform/SilverToCol.scala index 8bd3090061..28517b466e 100644 --- a/src/viper/viper/api/transform/SilverToCol.scala +++ b/src/viper/viper/api/transform/SilverToCol.scala @@ -559,7 +559,7 @@ case class SilverToCol[G]( case silver.ForPerm(variables, resource, body) => ??(e) case silver.EpsilonPerm() => ??(e) - case silver.InhaleExhaleExp(in, ex) => ??(e) + case silver.InhaleExhaleExp(in, ex) => col.PolarityDependent(f(in), f(ex)) case silver.MagicWand(left, right) => ??(e) case silver.Applying(wand, body) => ??(e) case silver.BackendFuncApp(backendFunc, args) => ??(e) From 6cda75cecf4b1f1568f07f853c668d69f827ea38 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Mon, 30 Sep 2024 16:55:58 +0200 Subject: [PATCH 39/47] Move PointerAdd logic for PointerLocations to ImportPointer --- src/rewrite/vct/rewrite/DisambiguateLocation.scala | 12 +----------- src/rewrite/vct/rewrite/adt/ImportPointer.scala | 14 +++++++++++++- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/rewrite/vct/rewrite/DisambiguateLocation.scala b/src/rewrite/vct/rewrite/DisambiguateLocation.scala index b6f8fe71e6..ff006ee9cb 100644 --- a/src/rewrite/vct/rewrite/DisambiguateLocation.scala +++ b/src/rewrite/vct/rewrite/DisambiguateLocation.scala @@ -60,17 +60,7 @@ case class DisambiguateLocation[Pre <: Generation]() extends Rewriter[Pre] { ): Location[Post] = expr match { case expr if expr.t.asPointer.isDefined => - expr match { - case e: PointerAdd[Pre] => PointerLocation(dispatch(e))(blame) - // Adding ptr + 0 for triggering purposes (is there a better place to do this transformation?) - case e => - PointerLocation( - PointerAdd[Post](dispatch(e), const[Post](0))(PointerAddRedirect( - blame - )) - )(blame) - } - + PointerLocation(dispatch(expr))(blame) case expr if expr.t.isInstanceOf[TByValueClass[Pre]] => ByValueClassLocation(dispatch(expr))(blame) case DerefHeapVariable(ref) => HeapVariableLocation(succ(ref.decl)) diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index 72b8432099..b71432b895 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -214,11 +214,23 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) implicit val o: Origin = location.o location match { case loc @ PointerLocation(pointer) => + val arg = + unwrapOption(pointer, loc.blame) match { + case ptr @ PointerAdd(_, _) => ptr + case ptr => + FunctionInvocation[Post]( + ref = pointerAdd.ref, + args = Seq(ptr, const(0)), + typeArgs = Nil, + Nil, + Nil, + )(PanicBlame("ptrAdd(ptr, 0) should be infallible")) + } SilverFieldLocation( obj = FunctionInvocation[Post]( ref = pointerDeref.ref, - args = Seq(unwrapOption(pointer, loc.blame)), + args = Seq(arg), typeArgs = Nil, Nil, Nil, From 82b53fefee74f0ce45fb6ed7f0dead041fed8a79 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Tue, 1 Oct 2024 13:39:18 +0200 Subject: [PATCH 40/47] Add pointer post condition in attempt to fix injectivity issue --- res/universal/res/adt/pointer.pvl | 1 + 1 file changed, 1 insertion(+) diff --git a/res/universal/res/adt/pointer.pvl b/res/universal/res/adt/pointer.pvl index 9cbbedb16b..0d4e6b4ebf 100644 --- a/res/universal/res/adt/pointer.pvl +++ b/res/universal/res/adt/pointer.pvl @@ -40,6 +40,7 @@ pure ref ptr_deref(`pointer` p) = decreases; requires 0 <= `pointer`.pointer_offset(p) + offset; requires `pointer`.pointer_offset(p) + offset < `block`.block_length(`pointer`.pointer_block(p)); +ensures \polarity_dependent(offset == 0 ==> \result == p, true); pure `pointer` ptr_add(`pointer` p, int offset) = `pointer`.pointer_of( `pointer`.pointer_block(p), From e2ce374aa0a03c0840b5856b3567aea3b93ad33d Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 4 Oct 2024 13:27:25 +0200 Subject: [PATCH 41/47] Make all but one HaliVer example verify again with the new struct encoding --- src/rewrite/vct/rewrite/ParBlockEncoder.scala | 1 + src/rewrite/vct/rewrite/adt/ImportPointer.scala | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/rewrite/vct/rewrite/ParBlockEncoder.scala b/src/rewrite/vct/rewrite/ParBlockEncoder.scala index c872c993a9..d7ee94c478 100644 --- a/src/rewrite/vct/rewrite/ParBlockEncoder.scala +++ b/src/rewrite/vct/rewrite/ParBlockEncoder.scala @@ -229,6 +229,7 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { case l: Local[_] if isConstType(l.t) => true case _: Constant[_] => true case op: BinExpr[Post] => isConstant(op.left) && isConstant(op.right) + case op: UnExpr[Post] => isConstant(op.arg) case _ => false } diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index b71432b895..454873e74e 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -174,7 +174,10 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) ): Expr[Post] = { ptr.t match { case TPointer(_) => - OptGet(dispatch(ptr))(PointerNullOptNone(blame, ptr))(ptr.o) + dispatch(ptr) match { + case OptSome(inner) => inner + case newPtr => OptGet(newPtr)(PointerNullOptNone(blame, ptr))(ptr.o) + } case TNonNullPointer(_) => dispatch(ptr) } } @@ -216,7 +219,9 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) case loc @ PointerLocation(pointer) => val arg = unwrapOption(pointer, loc.blame) match { - case ptr @ PointerAdd(_, _) => ptr + case inv @ FunctionInvocation(ref, _, _, _, _) + if ref.decl == pointerAdd.ref.decl => + inv case ptr => FunctionInvocation[Post]( ref = pointerAdd.ref, From dee46c0801edde2f09318e0e98ea0f5e74a3a8c8 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 4 Oct 2024 15:20:46 +0200 Subject: [PATCH 42/47] Merge PointerArray fallibility and nullability --- src/col/vct/col/ast/Node.scala | 8 ++---- .../vct/col/typerules/CoercingRewriter.scala | 4 +-- .../vct/rewrite/EncodeArrayValues.scala | 28 ++++++++++--------- src/rewrite/vct/rewrite/TrivialAddrOf.scala | 4 ++- .../vct/rewrite/VariableToPointer.scala | 6 ++-- .../vct/rewrite/lang/LangCPPToCol.scala | 11 ++++++-- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 19 +++++++------ .../vct/rewrite/lang/LangLLVMToCol.scala | 4 +-- 8 files changed, 47 insertions(+), 37 deletions(-) diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 8496dfd7c4..637dc32d48 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -1892,11 +1892,9 @@ final case class NewArray[G]( initialize: Boolean, )(val blame: Blame[ArraySizeError])(implicit val o: Origin) extends Expr[G] with NewArrayImpl[G] -final case class NewPointerArray[G]( - element: Type[G], - size: Expr[G], - fallible: Boolean, -)(val blame: Blame[ArraySizeError])(implicit val o: Origin) +final case class NewPointerArray[G](element: Type[G], size: Expr[G])( + val blame: Blame[ArraySizeError] +)(implicit val o: Origin) extends Expr[G] with NewPointerArrayImpl[G] final case class NewNonNullPointerArray[G](element: Type[G], size: Expr[G])( val blame: Blame[ArraySizeError] diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index 15c23c1047..f092c829d0 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -1569,8 +1569,8 @@ abstract class CoercingRewriter[Pre <: Generation]() Neq(coerce(left, sharedType), coerce(right, sharedType)) case na @ NewArray(element, dims, moreDims, initialize) => NewArray(element, dims.map(int), moreDims, initialize)(na.blame) - case na @ NewPointerArray(element, size, fallible) => - NewPointerArray(element, size, fallible)(na.blame) + case na @ NewPointerArray(element, size) => + NewPointerArray(element, size)(na.blame) case na @ NewNonNullPointerArray(element, size) => NewNonNullPointerArray(element, size)(na.blame) case NewObject(cls) => NewObject(cls) diff --git a/src/rewrite/vct/rewrite/EncodeArrayValues.scala b/src/rewrite/vct/rewrite/EncodeArrayValues.scala index 9f774eb3ad..868cf16329 100644 --- a/src/rewrite/vct/rewrite/EncodeArrayValues.scala +++ b/src/rewrite/vct/rewrite/EncodeArrayValues.scala @@ -106,7 +106,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { : mutable.Map[(Type[Pre], Int, Int, Boolean), Procedure[Post]] = mutable .Map() - val pointerArrayCreationMethods: mutable.Map[(Type[Pre], Boolean), Procedure[Post]] = + val pointerArrayCreationMethods: mutable.Map[Type[Pre], Procedure[Post]] = mutable.Map() val nonNullPointerArrayCreationMethods : mutable.Map[Type[Pre], Procedure[Post]] = mutable.Map() @@ -496,10 +496,9 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { def makePointerCreationMethodFor( elementType: Type[Pre], nullable: Boolean, - fallible: Boolean, ) = { implicit val o: Origin = arrayCreationOrigin - // fallible? then 'ar != null ==> ...'; otherwise 'ar != null ** ...' + // !nullable? then 'ar != null ==> ...'; otherwise 'ar != null ** ...' // ar.length == size // forall ar[i] :: Perm(ar[i], write) // (if type ar[i] is pointer or struct): @@ -528,7 +527,6 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { (PointerBlockLength(result)(FramedPtrBlockLength) === sizeArg.get) &* (PointerBlockOffset(result)(FramedPtrBlockOffset) === zero) - if (nullable) { ensures = (result !== Null()) &* ensures } // Pointer location needs pointer add, not pointer subscript ensures = ensures &* makeStruct.makePerm( @@ -556,10 +554,8 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { ensures &* foldStar(permFields.map(_._1)) ensures = - if (!fallible) - (result !== Null()) &* ensures - else - Star(Implies(result !== Null(), ensures), tt) + if (nullable) { Star(Implies(result !== Null(), ensures), tt) } + else { ensures } procedure( blame = AbstractApplicable, @@ -570,7 +566,13 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { args = Seq(sizeArg), requires = UnitAccountedPredicate(requires), ensures = UnitAccountedPredicate(ensures), - )(o.where(name = "make_pointer_array_" + elementType.toString + (if (fallible) "_fallible" else ""))) + )(o.where(name = + "make_pointer_array_" + elementType.toString + + (if (nullable) + "_nullable" + else + "") + )) })) } @@ -601,10 +603,10 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { Nil, Nil, )(ArrayCreationFailed(newArr)) - case newPointerArr @ NewPointerArray(element, size, fallible) => + case newPointerArr @ NewPointerArray(element, size) => val method = pointerArrayCreationMethods.getOrElseUpdate( - (element, fallible), - makePointerCreationMethodFor(element, nullable = true, fallible=fallible), + element, + makePointerCreationMethodFor(element, nullable = true), ) ProcedureInvocation[Post]( method.ref, @@ -617,7 +619,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] { case newPointerArr @ NewNonNullPointerArray(element, size) => val method = nonNullPointerArrayCreationMethods.getOrElseUpdate( element, - makePointerCreationMethodFor(element, nullable = false, fallible=false), + makePointerCreationMethodFor(element, nullable = false), ) ProcedureInvocation[Post]( method.ref, diff --git a/src/rewrite/vct/rewrite/TrivialAddrOf.scala b/src/rewrite/vct/rewrite/TrivialAddrOf.scala index 1282a7b6cd..2ae340aaac 100644 --- a/src/rewrite/vct/rewrite/TrivialAddrOf.scala +++ b/src/rewrite/vct/rewrite/TrivialAddrOf.scala @@ -103,7 +103,9 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { val newPointer = Eval( PreAssignExpression( newTarget, - NewPointerArray(newValue.t, const[Post](1), fallible=false)(PanicBlame("Size is > 0")), + NewNonNullPointerArray(newValue.t, const[Post](1))(PanicBlame( + "Size is > 0" + )), )(blame) ) (newPointer, newTarget, newValue) diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala index 65d9904d94..139f6098c0 100644 --- a/src/rewrite/vct/rewrite/VariableToPointer.scala +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -104,7 +104,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame( "Initialisation should always succeed" )), - NewPointerArray( + NewNonNullPointerArray( fieldMap(f).t.asPointer.get.element, const(1), )(PanicBlame("Size is > 0")), @@ -125,7 +125,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { Deref[Post](dispatch(out), fieldMap.ref(f))(PanicBlame( "Initialisation should always succeed" )), - NewPointerArray( + NewNonNullPointerArray( fieldMap(f).t.asPointer.get.element, const(1), )(PanicBlame("Size is > 0")), @@ -178,7 +178,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { Deref[Post](obj.get, fieldMap.ref(f))(PanicBlame( "Initialisation should always succeed" )), - NewPointerArray( + NewNonNullPointerArray( fieldMap(f).t.asPointer.get.element, const(1), )(PanicBlame("Size is > 0")), diff --git a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala index 4b1357dcf4..612b704211 100644 --- a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala @@ -2740,11 +2740,14 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) (sizeOption, init.init) match { case (None, None) => throw WrongCPPType(decl) case (Some(size), None) => - val newArr = NewPointerArray[Post](t, rw.dispatch(size), fallible=false)(cta.blame) + val newArr = + NewNonNullPointerArray[Post](t, rw.dispatch(size))(cta.blame) Block(Seq(LocalDecl(v), assignLocal(v.get, newArr))) case (None, Some(CPPLiteralArray(exprs))) => val newArr = - NewPointerArray[Post](t, c_const[Post](exprs.size), fallible=false)(cta.blame) + NewNonNullPointerArray[Post](t, c_const[Post](exprs.size))( + cta.blame + ) Block( Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++ assignliteralArray(v, exprs, o) @@ -2755,7 +2758,9 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) if (realSize < exprs.size) logger.warn(s"Excess elements in array initializer: '${decl}'") val newArr = - NewPointerArray[Post](t, c_const[Post](realSize), fallible=false)(cta.blame) + NewNonNullPointerArray[Post](t, c_const[Post](realSize))( + cta.blame + ) Block( Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++ assignliteralArray(v, exprs.take(realSize.intValue), o) diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index 7370e2c9be..61e3d4195a 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -409,7 +409,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) (t1, rw.dispatch(r)) case _ => throw UnsupportedMalloc(c) } - NewPointerArray(rw.dispatch(t1), size, fallible=true)(ArrayMallocFailed(inv))(c.o) + NewPointerArray(rw.dispatch(t1), size)(ArrayMallocFailed(inv))(c.o) case CCast(CInvocation(CLocal("__vercors_malloc"), _, _, _), _) => throw UnsupportedMalloc(c) case CCast(n @ Null(), t) if t.asPointer.isDefined => rw.dispatch(n) @@ -645,10 +645,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) // val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) val assign: Statement[Post] = assignLocal( Local(cNameSuccessor(d).ref), - NewPointerArray[Post]( + NewNonNullPointerArray[Post]( getInnerType(cNameSuccessor(d).t), Local(v.ref), - fallible=false, )(PanicBlame("Shared memory sizes cannot be negative.")), ) declarations ++= Seq(cNameSuccessor(d)) @@ -660,10 +659,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val assign: Statement[Post] = assignLocal( Local(cNameSuccessor(d).ref), // Since we set the size and blame together, we can assume the blame is not None - NewPointerArray[Post]( + NewNonNullPointerArray[Post]( getInnerType(cNameSuccessor(d).t), CIntegerValue(size), - fallible=false )(blame.get), ) declarations ++= Seq(cNameSuccessor(d)) @@ -1131,11 +1129,14 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) (sizeOption, init.init) match { case (None, None) => throw WrongCType(decl) case (Some(size), None) => - val newArr = NewPointerArray[Post](t, rw.dispatch(size), fallible=false)(cta.blame) + val newArr = + NewNonNullPointerArray[Post](t, rw.dispatch(size))(cta.blame) Block(Seq(LocalDecl(v), assignLocal(v.get, newArr))) case (None, Some(CLiteralArray(exprs))) => val newArr = - NewPointerArray[Post](t, c_const[Post](exprs.size), fallible=false)(cta.blame) + NewNonNullPointerArray[Post](t, c_const[Post](exprs.size))( + cta.blame + ) Block( Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++ assignliteralArray(v, exprs, o) @@ -1146,7 +1147,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) if (realSize < exprs.size) logger.warn(s"Excess elements in array initializer: '${decl}'") val newArr = - NewPointerArray[Post](t, c_const[Post](realSize), fallible=false)(cta.blame) + NewNonNullPointerArray[Post](t, c_const[Post](realSize))( + cta.blame + ) Block( Seq(LocalDecl(v), assignLocal(v.get, newArr)) ++ assignliteralArray(v, exprs.take(realSize.intValue), o) diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index ab1a31ba7a..e94c8dba85 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -726,7 +726,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) Block(Seq( assignLocal( v, - NewPointerArray[Post](newT, elements)(PanicBlame( + NewNonNullPointerArray[Post](newT, elements)(PanicBlame( "allocation should never fail" )), ), @@ -738,7 +738,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) case _ => assignLocal( v, - NewPointerArray[Post](newT, elements)(PanicBlame( + NewNonNullPointerArray[Post](newT, elements)(PanicBlame( "allocation should never fail" )), ) From 8566c4bf248648d4b21e1ae77fbd00c525c36d2c Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Mon, 7 Oct 2024 14:08:51 +0200 Subject: [PATCH 43/47] Clean up the PR --- build.sc | 13 +++------- src/col/vct/col/ast/Node.scala | 4 --- .../expr/heap/read/RawDerefPointerImpl.scala | 14 ----------- .../vct/col/typerules/CoercingRewriter.scala | 2 -- src/main/vct/main/stages/Transformation.scala | 1 - .../vct/rewrite/LowerLocalHeapVariables.scala | 15 ----------- .../vct/rewrite/PrepareByValueClass.scala | 1 - .../vct/rewrite/VariableToPointer.scala | 3 --- .../vct/rewrite/adt/ImportPointer.scala | 25 ------------------- .../vct/rewrite/lang/LangLLVMToCol.scala | 3 +-- .../SpecializeEndpointClasses.scala | 2 +- .../backend/silicon/SiliconLogListener.scala | 2 +- .../viper/api/transform/ColToSilver.scala | 2 +- 13 files changed, 7 insertions(+), 80 deletions(-) delete mode 100644 src/col/vct/col/ast/expr/heap/read/RawDerefPointerImpl.scala diff --git a/build.sc b/build.sc index f81b31ff25..1fcf94b416 100644 --- a/build.sc +++ b/build.sc @@ -28,13 +28,6 @@ trait CppSharedModule extends CppModule { } def compile: T[PathRef] = T { -// def temp = //transitiveStaticObjects()).map(_.path)) - print("Definitely a new file") - print(compileOnly() ++ T.traverse(moduleDeps){ - case it: CppModule => it.compileOnly - case it: LinkableModule => it.staticObjects - case _ => T.task { Result.Success(Seq.empty) } - }().flatten) PathRef(toolchain.linkExecutable((compileOnly() ++ T.traverse(moduleDeps){ case it: CppModule => it.compileOnly case it: LinkableModule => it.staticObjects @@ -68,7 +61,7 @@ object external extends Module { object viper extends ScalaModule { object silverGit extends GitModule { def url = T { "https://github.com/viperproject/silver.git" } - def commitish = T { "93bc9b7516a710c8f01438e430058c4a54e20512" } + def commitish = T { "10b1b26a20957e5f000bf1bbcd4017145148afd7" } def filteredRepo = T { val workspace = repo() os.remove.all(workspace / "src" / "test") @@ -78,7 +71,7 @@ object viper extends ScalaModule { object siliconGit extends GitModule { def url = T { "https://github.com/superaxander/silicon.git" } - def commitish = T { "c63989f64eb759f33bde68c330ce07d6e34134fa" } + def commitish = T { "2030e3eb63f4b1c92ddc8885f7c937673effc9bd" } def filteredRepo = T { val workspace = repo() os.remove.all(workspace / "src" / "test") @@ -89,7 +82,7 @@ object viper extends ScalaModule { object carbonGit extends GitModule { def url = T { "https://github.com/viperproject/carbon.git" } - def commitish = T { "758481ef42f42720c36406bb278820ba802c7e68" } + def commitish = T { "d14a703fc6428fbae54e7333d8ede7efbbf850f0" } def filteredRepo = T { val workspace = repo() os.remove.all(workspace / "src" / "test") diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 637dc32d48..73e887de8a 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -1409,10 +1409,6 @@ final case class DerefPointer[G](pointer: Expr[G])( val blame: Blame[PointerDerefError] )(implicit val o: Origin) extends Expr[G] with DerefPointerImpl[G] -final case class RawDerefPointer[G](pointer: Expr[G])( - val blame: Blame[PointerDerefError] -)(implicit val o: Origin) - extends Expr[G] with RawDerefPointerImpl[G] final case class PointerAdd[G](pointer: Expr[G], offset: Expr[G])( val blame: Blame[PointerAddError] )(implicit val o: Origin) diff --git a/src/col/vct/col/ast/expr/heap/read/RawDerefPointerImpl.scala b/src/col/vct/col/ast/expr/heap/read/RawDerefPointerImpl.scala deleted file mode 100644 index d270112e07..0000000000 --- a/src/col/vct/col/ast/expr/heap/read/RawDerefPointerImpl.scala +++ /dev/null @@ -1,14 +0,0 @@ -package vct.col.ast.expr.heap.read - -import vct.col.ast.ops.RawDerefPointerOps -import vct.col.ast.{RawDerefPointer, TRef, Type} -import vct.col.print._ - -trait RawDerefPointerImpl[G] extends RawDerefPointerOps[G] { - this: RawDerefPointer[G] => - override def t: Type[G] = TRef() - - override def precedence: Int = Precedence.POSTFIX - override def layout(implicit ctx: Ctx): Doc = - Group(Text("ptr_deref(") <> pointer <> Text(")")) -} diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index f092c829d0..3fa5927427 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -1261,8 +1261,6 @@ abstract class CoercingRewriter[Pre <: Generation]() case deref @ Deref(obj, ref) => Deref(cls(obj), ref)(deref.blame) case deref @ DerefHeapVariable(ref) => DerefHeapVariable(ref)(deref.blame) case deref @ DerefPointer(p) => DerefPointer(pointer(p)._1)(deref.blame) - case deref @ RawDerefPointer(p) => - RawDerefPointer(pointer(p)._1)(deref.blame) case Drop(xs, count) => Drop(seq(xs)._1, int(count)) case Empty(obj) => Empty(sized(obj)._1) case EmptyProcess() => EmptyProcess() diff --git a/src/main/vct/main/stages/Transformation.scala b/src/main/vct/main/stages/Transformation.scala index 22058a0f81..5907a0af3b 100644 --- a/src/main/vct/main/stages/Transformation.scala +++ b/src/main/vct/main/stages/Transformation.scala @@ -341,7 +341,6 @@ case class SilverTransformation( EncodeString, // Encode spec string as seq EncodeChar, CollectLocalDeclarations, // all decls in Scope -// EncodeByValueClass, VariableToPointer, // should happen before ParBlockEncoder so it can distinguish between variables which can and can't altered in a parallel block DesugarPermissionOperators, // no PointsTo, \pointer, etc. ReadToValue, // resolve wildcard into fractional permission diff --git a/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala b/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala index 9e3fdad2a7..43ee901793 100644 --- a/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala +++ b/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala @@ -13,28 +13,13 @@ case object LowerLocalHeapVariables extends RewriterBuilder { override def desc: String = "Lower LocalHeapVariables to Variables if their address is never taken" - - private val pointerCreationOrigin: Origin = Origin( - Seq(LabelContext("pointer creation method")) - ) } case class LowerLocalHeapVariables[Pre <: Generation]() extends Rewriter[Pre] { - import LowerLocalHeapVariables._ - private val stripped: SuccessionMap[LocalHeapVariable[Pre], Variable[Post]] = SuccessionMap() private val lowered: SuccessionMap[LocalHeapVariable[Pre], Variable[Post]] = SuccessionMap() -// private val pointerCreationMethods: SuccessionMap[Type[Pre], Procedure[Post]] = SuccessionMap() -// -// def makePointerCreationMethod(t: Type[Pre]): Procedure[Post] = { -// implicit val o: Origin = pointerCreationOrigin -// -// val proc = globalDeclarations.declare(withResult((result: Result[Post]) => { -// -// })) -// } override def dispatch(program: Program[Pre]): Program[Post] = { val dereferencedHeapLocals = program.collect { diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala index 2250989297..950ae4f8b0 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -271,7 +271,6 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { dp, v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], ) - // TODO: Check for copy semantics in inappropriate places (i.e. when the user has made this a pointer) case dp @ DerefPointer(DerefHeapVariable(Ref(v))) if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => rewriteInCopyContext( diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala index 139f6098c0..c303f1d21d 100644 --- a/src/rewrite/vct/rewrite/VariableToPointer.scala +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -56,7 +56,6 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(decl: Declaration[Pre]): Unit = decl match { - // TODO: Use some sort of NonNull pointer type instead case v: HeapVariable[Pre] if addressedSet.contains(v) => heapVariableMap(v) = globalDeclarations .succeed(v, new HeapVariable(TNonNullPointer(dispatch(v.t)))(v.o)) @@ -94,8 +93,6 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { ) case i @ Instantiate(cls, out) if cls.decl.isInstanceOf[ByValueClass[Pre]] => - // TODO: Make sure that we recursively build newobject for byvalueclasses - // maybe get rid this entirely and only have it in encode by value class Block(Seq(i.rewriteDefault()) ++ cls.decl.declarations.flatMap { case f: InstanceField[Pre] => if (f.t.asClass.isDefined) { diff --git a/src/rewrite/vct/rewrite/adt/ImportPointer.scala b/src/rewrite/vct/rewrite/adt/ImportPointer.scala index 454873e74e..6556d9c1b5 100644 --- a/src/rewrite/vct/rewrite/adt/ImportPointer.scala +++ b/src/rewrite/vct/rewrite/adt/ImportPointer.scala @@ -407,31 +407,6 @@ case class ImportPointer[Pre <: Generation](importer: ImportADTImporter) )(PanicBlame("ptr_deref requires nothing.")), field = getPointerField(pointer), )(PointerFieldInsufficientPermission(deref.blame, deref)) - case deref @ RawDerefPointer(pointer) => - FunctionInvocation[Post]( - ref = pointerDeref.ref, - args = Seq( - if ( - inAxiom.isEmpty && - !deref.o.find[LabelContext] - .exists(_.label == "classToRef cast helpers") - ) { - FunctionInvocation[Post]( - ref = pointerAdd.ref, - // Always index with zero, otherwise quantifiers with pointers do not get triggered - args = Seq(unwrapOption(pointer, deref.blame), const(0)), - typeArgs = Nil, - Nil, - Nil, - )(NoContext( - DerefPointerBoundsPreconditionFailed(deref.blame, pointer) - )) - } else { unwrapOption(pointer, deref.blame) } - ), - typeArgs = Nil, - Nil, - Nil, - )(PanicBlame("ptr_deref requires nothing.")) case len @ PointerBlockLength(pointer) => ADTFunctionInvocation[Post]( typeArgs = Some((blockAdt.ref, Nil)), diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index e94c8dba85..1480f0b909 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -596,7 +596,6 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) // TODO: Use an actual Blame // Acquire the actual struct through a PointerAdd - // TODO: Can we somehow wrap the rw.dispatch(gep.pointer) to add the known type structureType? gep.pointer.t match { case LLVMTPointer(None) => val structPointer = @@ -798,7 +797,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) effectively transforming the CFG into a tree. More efficient restructuring algorithms but this works for now. This of course only works for acyclic CFGs as otherwise replacement would be infinitely recursive. - Loop restructuring should be handled by pallas as it has much more analytical and contextual information about + Loop restructuring should be handled by Pallas as it has much more analytical and contextual information about the program. */ case class GotoEliminator(bodyScope: Scope[Pre]) extends LazyLogging { diff --git a/src/rewrite/vct/rewrite/veymont/generation/SpecializeEndpointClasses.scala b/src/rewrite/vct/rewrite/veymont/generation/SpecializeEndpointClasses.scala index eef433d8ea..ed34fe037f 100644 --- a/src/rewrite/vct/rewrite/veymont/generation/SpecializeEndpointClasses.scala +++ b/src/rewrite/vct/rewrite/veymont/generation/SpecializeEndpointClasses.scala @@ -1,4 +1,4 @@ -package vct.rewrite.veymont +package vct.rewrite.veymont.generation import com.typesafe.scalalogging.LazyLogging import vct.col.ast.{ diff --git a/src/viper/viper/api/backend/silicon/SiliconLogListener.scala b/src/viper/viper/api/backend/silicon/SiliconLogListener.scala index 47a8690f1c..949907ddd3 100644 --- a/src/viper/viper/api/backend/silicon/SiliconLogListener.scala +++ b/src/viper/viper/api/backend/silicon/SiliconLogListener.scala @@ -281,7 +281,7 @@ class SiliconMemberLogListener( if (log.traceBranchConditions) { val textCond = branchConditions.head match { - case BranchConditionExp(e) => e.toString() + case BranchConditionExp(e) => e.toString case BranchConditionTerm(e) => e.toString case BranchConditionNone(at, count) => s"alternative $at/$count" } diff --git a/src/viper/viper/api/transform/ColToSilver.scala b/src/viper/viper/api/transform/ColToSilver.scala index 9e27216415..badc42b153 100644 --- a/src/viper/viper/api/transform/ColToSilver.scala +++ b/src/viper/viper/api/transform/ColToSilver.scala @@ -6,7 +6,7 @@ import vct.col.ref.Ref import vct.col.util.AstBuildHelpers.unfoldStar import vct.col.{ast => col} import vct.result.VerificationError.{SystemError, Unreachable} -import viper.silver.ast.{AnnotationInfo, ConsInfo, TypeVar, WildcardPerm} +import viper.silver.ast.{TypeVar, WildcardPerm} import viper.silver.plugin.standard.termination.{ DecreasesClause, DecreasesTuple, From 60b419d31ed2af9bd92ce70480dfc8d50080c947 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Wed, 9 Oct 2024 17:00:17 +0200 Subject: [PATCH 44/47] Integrate Bob's feedback --- .../contract/ApplicableContractImpl.scala | 1 - .../col/ast/lang/llvm/LLVMAllocAImpl.scala | 8 +- .../ast/lang/llvm/LLVMArrayValueImpl.scala | 3 +- .../ast/lang/llvm/LLVMBasicBlockImpl.scala | 5 +- .../llvm/LLVMFunctionDefinitionImpl.scala | 12 ++ .../llvm/LLVMFunctionInvocationImpl.scala | 2 +- .../llvm/LLVMFunctionPointerValueImpl.scala | 2 +- .../lang/llvm/LLVMGetElementPointerImpl.scala | 4 +- .../lang/llvm/LLVMGlobalVariableImpl.scala | 9 +- .../ast/lang/llvm/LLVMIntegerValueImpl.scala | 2 +- .../vct/col/ast/lang/llvm/LLVMLoadImpl.scala | 13 +- .../ast/lang/llvm/LLVMMemoryAcquireImpl.scala | 2 +- .../llvm/LLVMMemoryAcquireReleaseImpl.scala | 2 +- .../lang/llvm/LLVMMemoryMonotonicImpl.scala | 2 +- .../lang/llvm/LLVMMemoryNotAtomicImpl.scala | 2 - .../lang/llvm/LLVMMemoryOrderingImpl.scala | 2 - .../ast/lang/llvm/LLVMMemoryReleaseImpl.scala | 2 +- ...LLVMMemorySequentiallyConsistentImpl.scala | 2 +- .../lang/llvm/LLVMMemoryUnorderedImpl.scala | 2 +- .../ast/lang/llvm/LLVMPointerValueImpl.scala | 2 +- .../ast/lang/llvm/LLVMSignExtendImpl.scala | 3 +- .../vct/col/ast/lang/llvm/LLVMStoreImpl.scala | 10 +- .../ast/lang/llvm/LLVMStructValueImpl.scala | 3 +- .../col/ast/lang/llvm/LLVMTArrayImpl.scala | 3 +- .../vct/col/ast/lang/llvm/LLVMTIntImpl.scala | 3 +- .../col/ast/lang/llvm/LLVMTMetadataImpl.scala | 2 + .../col/ast/lang/llvm/LLVMTPointerImpl.scala | 6 + .../col/ast/lang/llvm/LLVMTStructImpl.scala | 12 +- .../col/ast/lang/llvm/LLVMTVectorImpl.scala | 3 +- .../col/ast/lang/llvm/LLVMTruncateImpl.scala | 3 +- .../ast/lang/llvm/LLVMVectorValueImpl.scala | 3 +- .../ast/lang/llvm/LLVMZeroExtendImpl.scala | 3 +- .../llvm/LLVMZeroedAggregateValueImpl.scala | 2 +- src/col/vct/col/resolve/lang/Spec.scala | 32 +-- .../vct/col/typerules/CoercingRewriter.scala | 1 + src/col/vct/col/util/AstBuildHelpers.scala | 17 ++ src/hre/hre/io/ChecksumReadableFile.scala | 7 +- .../vct/rewrite/ConstantifyFinalFields.scala | 21 +- src/rewrite/vct/rewrite/EncodeForkJoin.scala | 8 +- .../vct/rewrite/EncodeResourceValues.scala | 14 +- .../GenerateSingleOwnerPermissions.scala | 16 +- .../vct/rewrite/LowerLocalHeapVariables.scala | 35 +-- .../vct/rewrite/MonomorphizeClass.scala | 45 ++-- .../vct/rewrite/PrepareByValueClass.scala | 202 ++++++++---------- .../ResolveExpressionSideEffects.scala | 26 +-- .../vct/rewrite/VariableToPointer.scala | 3 +- .../vct/rewrite/lang/LangCPPToCol.scala | 10 +- .../vct/rewrite/lang/LangSpecificToCol.scala | 11 - .../vct/rewrite/lang/LangTypesToCol.scala | 5 +- .../vct/rewrite/lang/NoSupportSelfLoop.scala | 16 +- 50 files changed, 271 insertions(+), 333 deletions(-) diff --git a/src/col/vct/col/ast/family/contract/ApplicableContractImpl.scala b/src/col/vct/col/ast/family/contract/ApplicableContractImpl.scala index 2f624e3dd4..f73e13925b 100644 --- a/src/col/vct/col/ast/family/contract/ApplicableContractImpl.scala +++ b/src/col/vct/col/ast/family/contract/ApplicableContractImpl.scala @@ -58,7 +58,6 @@ trait ApplicableContractImpl[G] def nonEmpty: Boolean = !isEmpty - // PB: please keep in sync with CDeclarationImpl def layoutSpec(implicit ctx: Ctx): Doc = Doc.stack(Seq( Doc.stack(givenArgs.map(Text("given") <+> _.show <> ";")), diff --git a/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala index 8276e1544c..cc84f2a7ac 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMAllocAImpl.scala @@ -1,8 +1,14 @@ package vct.col.ast.lang.llvm import vct.col.ast.ops.LLVMAllocAOps -import vct.col.ast.{LLVMAllocA, Type, LLVMTPointer} +import vct.col.ast.LLVMAllocA +import vct.col.print._ trait LLVMAllocAImpl[G] extends LLVMAllocAOps[G] { this: LLVMAllocA[G] => + override def layout(implicit ctx: Ctx): Doc = + Group( + Text(ctx.name(variable)) <+> "=" <+> + Group(Text("alloca") <+> allocationType <> "," <+> numElements) + ) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMArrayValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMArrayValueImpl.scala index d7be0dbc2e..b00f6f9267 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMArrayValueImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMArrayValueImpl.scala @@ -7,5 +7,6 @@ import vct.col.print._ trait LLVMArrayValueImpl[G] extends LLVMArrayValueOps[G] { this: LLVMArrayValue[G] => override def t: Type[G] = arrayType - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Group(Text("[") <> Doc.args(value) <> "]") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMBasicBlockImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMBasicBlockImpl.scala index 1f27dc0183..2ffe3b8d9c 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMBasicBlockImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMBasicBlockImpl.scala @@ -2,10 +2,11 @@ package vct.col.ast.lang.llvm import vct.col.ast.LLVMBasicBlock import vct.col.ast.ops.LLVMBasicBlockOps -import vct.col.check.{CheckContext, CheckError} +import vct.col.print._ trait LLVMBasicBlockImpl[G] extends LLVMBasicBlockOps[G] { this: LLVMBasicBlock[G] => - override def check(context: CheckContext[G]): Seq[CheckError] = Nil + override def layout(implicit ctx: Ctx): Doc = + label.show <> ":" <+> body.layoutAsBlock } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMFunctionDefinitionImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMFunctionDefinitionImpl.scala index 326a86a142..0aefb302bf 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMFunctionDefinitionImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMFunctionDefinitionImpl.scala @@ -4,6 +4,7 @@ import vct.col.ast.declaration.category.ApplicableImpl import vct.col.ast.{Declaration, LLVMFunctionDefinition, Statement} import vct.col.ast.util.Declarator import vct.col.ast.ops.LLVMFunctionDefinitionOps +import vct.col.print._ trait LLVMFunctionDefinitionImpl[G] extends Declarator[G] @@ -13,4 +14,15 @@ trait LLVMFunctionDefinitionImpl[G] override def declarations: Seq[Declaration[G]] = args override def body: Option[Statement[G]] = functionBody + + override def layout(implicit ctx: Ctx): Doc = + Doc.stack(Seq( + contract, + Group( + (if (pure) + Text("pure") <+> returnType + else + returnType.show) <+> ctx.name(this) <> "(" <> Doc.args(args) <> ")" + ) <+> body.map(_.layoutAsBlock).getOrElse(Text("")), + )) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMFunctionInvocationImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMFunctionInvocationImpl.scala index e92e32fb16..525765f515 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMFunctionInvocationImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMFunctionInvocationImpl.scala @@ -10,7 +10,7 @@ trait LLVMFunctionInvocationImpl[G] extends LLVMFunctionInvocationOps[G] { override def layout(implicit ctx: Ctx): Doc = Group( - Group(Text(ctx.name(ref)) <> "(") <> Doc.args(args) <> ")" <> + Group(Text("@") <> ctx.name(ref) <> "(") <> Doc.args(args) <> ")" <> DocUtil.givenYields(givenMap, yields) ) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMFunctionPointerValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMFunctionPointerValueImpl.scala index 4513befdb2..be8c8e2749 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMFunctionPointerValueImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMFunctionPointerValueImpl.scala @@ -8,5 +8,5 @@ trait LLVMFunctionPointerValueImpl[G] extends LLVMFunctionPointerValueOps[G] { this: LLVMFunctionPointerValue[G] => // TODO: Do we want a separate type for function pointers? For now we don't support function pointers anyway override def t: Type[G] = LLVMTPointer(None) - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text("@") <> ctx.name(value) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMGetElementPointerImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMGetElementPointerImpl.scala index 795dc90cc6..d7c7f4f86f 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMGetElementPointerImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMGetElementPointerImpl.scala @@ -7,5 +7,7 @@ import vct.col.print._ trait LLVMGetElementPointerImpl[G] extends LLVMGetElementPointerOps[G] { this: LLVMGetElementPointer[G] => override def t: Type[G] = LLVMTPointer(Some(resultType)) - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Text("getelementptr") <+> structureType <> "," <+> pointer <> "," <+> + Doc.args(indices) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMGlobalVariableImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMGlobalVariableImpl.scala index 80a7d93b43..e385749350 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMGlobalVariableImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMGlobalVariableImpl.scala @@ -6,5 +6,12 @@ import vct.col.print._ trait LLVMGlobalVariableImpl[G] extends LLVMGlobalVariableOps[G] { this: LLVMGlobalVariable[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Text("@") <> ctx.name(this) <+> "=" <+> + (if (constant) + "constant" + else + "global") <+> + (if (value.isDefined) { variableType.show <+> value.get } + else { variableType }) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMIntegerValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMIntegerValueImpl.scala index 6edc11d6b0..f9b11b97fb 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMIntegerValueImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMIntegerValueImpl.scala @@ -7,5 +7,5 @@ import vct.col.print._ trait LLVMIntegerValueImpl[G] extends LLVMIntegerValueOps[G] { this: LLVMIntegerValue[G] => override def t: Type[G] = integerType - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text(value.toString) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala index 8142e0ea91..b0e84a7d78 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMLoadImpl.scala @@ -1,8 +1,19 @@ package vct.col.ast.lang.llvm -import vct.col.ast.{LLVMLoad, Type} +import vct.col.ast.{LLVMLoad, LLVMMemoryNotAtomic} import vct.col.ast.ops.LLVMLoadOps +import vct.col.print._ trait LLVMLoadImpl[G] extends LLVMLoadOps[G] { this: LLVMLoad[G] => + + private def layoutOrdering(inner: Doc)(implicit ctx: Ctx): Doc = + if (ordering.isInstanceOf[LLVMMemoryNotAtomic[_]]) { Group(inner) } + else { Group(inner <+> ordering) } + + override def layout(implicit ctx: Ctx): Doc = + Group( + Text(ctx.name(variable)) <+> "=" <+> + layoutOrdering(Text("store") <+> pointer) + ) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireImpl.scala index 4aaafdb0b5..7feab889d3 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireImpl.scala @@ -6,5 +6,5 @@ import vct.col.print._ trait LLVMMemoryAcquireImpl[G] extends LLVMMemoryAcquireOps[G] { this: LLVMMemoryAcquire[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text("acquire") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireReleaseImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireReleaseImpl.scala index 87fd27b4ea..a38904eadc 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireReleaseImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryAcquireReleaseImpl.scala @@ -6,5 +6,5 @@ import vct.col.print._ trait LLVMMemoryAcquireReleaseImpl[G] extends LLVMMemoryAcquireReleaseOps[G] { this: LLVMMemoryAcquireRelease[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text("acq_rel") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryMonotonicImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryMonotonicImpl.scala index b3d700bc42..3d1b9d8cf5 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMMemoryMonotonicImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryMonotonicImpl.scala @@ -6,5 +6,5 @@ import vct.col.print._ trait LLVMMemoryMonotonicImpl[G] extends LLVMMemoryMonotonicOps[G] { this: LLVMMemoryMonotonic[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text("monotonic") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryNotAtomicImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryNotAtomicImpl.scala index 796ceb4c7e..db9ca29b96 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMMemoryNotAtomicImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryNotAtomicImpl.scala @@ -2,9 +2,7 @@ package vct.col.ast.lang.llvm import vct.col.ast.LLVMMemoryNotAtomic import vct.col.ast.ops.LLVMMemoryNotAtomicOps -import vct.col.print._ trait LLVMMemoryNotAtomicImpl[G] extends LLVMMemoryNotAtomicOps[G] { this: LLVMMemoryNotAtomic[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryOrderingImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryOrderingImpl.scala index b4ece7c46b..bd022e138e 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMMemoryOrderingImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryOrderingImpl.scala @@ -2,9 +2,7 @@ package vct.col.ast.lang.llvm import vct.col.ast.LLVMMemoryOrdering import vct.col.ast.ops.LLVMMemoryOrderingFamilyOps -import vct.col.print._ trait LLVMMemoryOrderingImpl[G] extends LLVMMemoryOrderingFamilyOps[G] { this: LLVMMemoryOrdering[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryReleaseImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryReleaseImpl.scala index e856bc3b35..954dfe88b3 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMMemoryReleaseImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryReleaseImpl.scala @@ -6,5 +6,5 @@ import vct.col.print._ trait LLVMMemoryReleaseImpl[G] extends LLVMMemoryReleaseOps[G] { this: LLVMMemoryRelease[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text("release") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemorySequentiallyConsistentImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemorySequentiallyConsistentImpl.scala index 9576bcb1e6..3e89491722 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMMemorySequentiallyConsistentImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemorySequentiallyConsistentImpl.scala @@ -7,5 +7,5 @@ import vct.col.print._ trait LLVMMemorySequentiallyConsistentImpl[G] extends LLVMMemorySequentiallyConsistentOps[G] { this: LLVMMemorySequentiallyConsistent[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text("seq_cst") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMMemoryUnorderedImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMMemoryUnorderedImpl.scala index 7f929bf5fb..6df5798a1d 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMMemoryUnorderedImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMMemoryUnorderedImpl.scala @@ -6,5 +6,5 @@ import vct.col.print._ trait LLVMMemoryUnorderedImpl[G] extends LLVMMemoryUnorderedOps[G] { this: LLVMMemoryUnordered[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text("unordered") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMPointerValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMPointerValueImpl.scala index 890bab6ca6..32b61c7384 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMPointerValueImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMPointerValueImpl.scala @@ -19,5 +19,5 @@ trait LLVMPointerValueImpl[G] extends LLVMPointerValueOps[G] { case v: HeapVariable[G] => LLVMTPointer(Some(v.t)) } } - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text("@") <> ctx.name(value) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMSignExtendImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMSignExtendImpl.scala index 5fb7c7b1c3..8d7c697947 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMSignExtendImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMSignExtendImpl.scala @@ -7,5 +7,6 @@ import vct.col.print._ trait LLVMSignExtendImpl[G] extends LLVMSignExtendOps[G] { this: LLVMSignExtend[G] => override def t: Type[G] = outputType - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Group(Text("sext") <+> inputType <+> value <+> "to" <+> outputType) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMStoreImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMStoreImpl.scala index 9b3b82337d..6c764856e7 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMStoreImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMStoreImpl.scala @@ -1,8 +1,16 @@ package vct.col.ast.lang.llvm -import vct.col.ast.LLVMStore +import vct.col.ast.{LLVMMemoryNotAtomic, LLVMStore} import vct.col.ast.ops.LLVMStoreOps +import vct.col.print._ trait LLVMStoreImpl[G] extends LLVMStoreOps[G] { this: LLVMStore[G] => + + private def layoutOrdering(inner: Doc)(implicit ctx: Ctx): Doc = + if (ordering.isInstanceOf[LLVMMemoryNotAtomic[_]]) { Group(inner) } + else { Group(inner <+> ordering) } + + override def layout(implicit ctx: Ctx): Doc = + layoutOrdering(Text("store") <+> value <> "," <+> pointer) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMStructValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMStructValueImpl.scala index a37d6ab04e..4aa5b9f6d8 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMStructValueImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMStructValueImpl.scala @@ -7,5 +7,6 @@ import vct.col.print._ trait LLVMStructValueImpl[G] extends LLVMStructValueOps[G] { this: LLVMStructValue[G] => override def t: Type[G] = structType - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Group(Text("{") <> Doc.args(value) <> "}") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTArrayImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTArrayImpl.scala index fbceb6ba2b..d20fc91563 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMTArrayImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMTArrayImpl.scala @@ -6,5 +6,6 @@ import vct.col.print._ trait LLVMTArrayImpl[G] extends LLVMTArrayOps[G] { this: LLVMTArray[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Group(Text("[") <> numElements.toString <+> "x" <+> elementType <> "]") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTIntImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTIntImpl.scala index 697c024544..0fd2017ab7 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMTIntImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMTIntImpl.scala @@ -6,5 +6,6 @@ import vct.col.print._ trait LLVMTIntImpl[G] extends LLVMTIntOps[G] { this: LLVMTInt[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + + override def layout(implicit ctx: Ctx): Doc = Text("i") <> bitWidth.toString } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTMetadataImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTMetadataImpl.scala index bbf522bc87..4e6bc7770e 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMTMetadataImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMTMetadataImpl.scala @@ -2,8 +2,10 @@ package vct.col.ast.lang.llvm import vct.col.ast.LLVMTMetadata import vct.col.ast.ops.LLVMTMetadataOps +import vct.col.print.{Ctx, Doc, Group, Text} trait LLVMTMetadataImpl[G] extends LLVMTMetadataOps[G] { this: LLVMTMetadata[G] => + override def layout(implicit ctx: Ctx): Doc = Text("metadata") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTPointerImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTPointerImpl.scala index 770fa9295c..8bcb560c1b 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMTPointerImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMTPointerImpl.scala @@ -2,8 +2,14 @@ package vct.col.ast.lang.llvm import vct.col.ast.LLVMTPointer import vct.col.ast.ops.LLVMTPointerOps +import vct.col.print._ trait LLVMTPointerImpl[G] extends LLVMTPointerOps[G] { this: LLVMTPointer[G] => + override def layout(implicit ctx: Ctx): Doc = + if (innerType.isDefined) + Group(Text("ptr") <+> innerType.get) + else + Text("ptr") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTStructImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTStructImpl.scala index cb839186b9..02ef6b9444 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMTStructImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMTStructImpl.scala @@ -6,5 +6,15 @@ import vct.col.print._ trait LLVMTStructImpl[G] extends LLVMTStructOps[G] { this: LLVMTStruct[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + + private def layoutPacked(inner: Doc)(implicit ctx: Ctx): Doc = + if (packed) { Text("<") <> inner <> ">" } + else { inner } + + override def layout(implicit ctx: Ctx): Doc = { + if (name.isDefined) + Text(name.get) + else + (layoutPacked(Text("{") <> Doc.args(elements) <> "}")) + } } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTVectorImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTVectorImpl.scala index 6bc89d8b85..ea055aa147 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMTVectorImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMTVectorImpl.scala @@ -6,5 +6,6 @@ import vct.col.print._ trait LLVMTVectorImpl[G] extends LLVMTVectorOps[G] { this: LLVMTVector[G] => - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Group(Text("<") <> numElements.toString <+> "x" <+> elementType <> ">") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMTruncateImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMTruncateImpl.scala index 25344cf63a..e6f4b51fe8 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMTruncateImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMTruncateImpl.scala @@ -7,5 +7,6 @@ import vct.col.print._ trait LLVMTruncateImpl[G] extends LLVMTruncateOps[G] { this: LLVMTruncate[G] => override def t: Type[G] = outputType - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Group(Text("trunc") <+> inputType <+> value <+> "to" <+> outputType) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMVectorValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMVectorValueImpl.scala index 3661272542..d0d20c6735 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMVectorValueImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMVectorValueImpl.scala @@ -7,5 +7,6 @@ import vct.col.ast.ops.LLVMVectorValueOps trait LLVMVectorValueImpl[G] extends LLVMVectorValueOps[G] { this: LLVMVectorValue[G] => override def t: Type[G] = vectorType - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Group(Text("<") <> Doc.args(value) <> ">") } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMZeroExtendImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMZeroExtendImpl.scala index eca0bfddc0..119b63be77 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMZeroExtendImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMZeroExtendImpl.scala @@ -7,5 +7,6 @@ import vct.col.print._ trait LLVMZeroExtendImpl[G] extends LLVMZeroExtendOps[G] { this: LLVMZeroExtend[G] => override def t: Type[G] = outputType - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = + Group(Text("zext") <+> inputType <+> value <+> "to" <+> outputType) } diff --git a/src/col/vct/col/ast/lang/llvm/LLVMZeroedAggregateValueImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMZeroedAggregateValueImpl.scala index e8217e0c40..ae524a1520 100644 --- a/src/col/vct/col/ast/lang/llvm/LLVMZeroedAggregateValueImpl.scala +++ b/src/col/vct/col/ast/lang/llvm/LLVMZeroedAggregateValueImpl.scala @@ -8,5 +8,5 @@ trait LLVMZeroedAggregateValueImpl[G] extends LLVMZeroedAggregateValueOps[G] { this: LLVMZeroedAggregateValue[G] => override def value: Unit = () override def t: Type[G] = aggregateType - // override def layout(implicit ctx: Ctx): Doc = ??? + override def layout(implicit ctx: Ctx): Doc = Text("zeroinitializer") } diff --git a/src/col/vct/col/resolve/lang/Spec.scala b/src/col/vct/col/resolve/lang/Spec.scala index 52b6cd9a32..f2e00f57bf 100644 --- a/src/col/vct/col/resolve/lang/Spec.scala +++ b/src/col/vct/col/resolve/lang/Spec.scala @@ -348,12 +348,8 @@ case object Spec { def findMethod[G](obj: Expr[G], name: String): Option[InstanceMethod[G]] = obj.t match { - case TByReferenceClass(Ref(cls), _) => - cls.decls.flatMap(Referrable.from).collectFirst { - case ref @ RefInstanceMethod(decl) if ref.name == name => decl - } - case TByValueClass(Ref(cls), _) => - cls.decls.flatMap(Referrable.from).collectFirst { + case cls: TClass[G] => + cls.cls.decl.decls.flatMap(Referrable.from).collectFirst { case ref @ RefInstanceMethod(decl) if ref.name == name => decl } case _ => None @@ -364,12 +360,8 @@ case object Spec { name: String, ): Option[InstanceFunction[G]] = obj.t match { - case TByReferenceClass(Ref(cls), _) => - cls.decls.flatMap(Referrable.from).collectFirst { - case ref @ RefInstanceFunction(decl) if ref.name == name => decl - } - case TByValueClass(Ref(cls), _) => - cls.decls.flatMap(Referrable.from).collectFirst { + case cls: TClass[G] => + cls.cls.decl.decls.flatMap(Referrable.from).collectFirst { case ref @ RefInstanceFunction(decl) if ref.name == name => decl } case _ => None @@ -380,12 +372,8 @@ case object Spec { name: String, ): Option[InstancePredicate[G]] = obj.t match { - case TByReferenceClass(Ref(cls), _) => - cls.decls.flatMap(Referrable.from).collectFirst { - case ref @ RefInstancePredicate(decl) if ref.name == name => decl - } - case TByValueClass(Ref(cls), _) => - cls.decls.flatMap(Referrable.from).collectFirst { + case cls: TClass[G] => + cls.cls.decl.decls.flatMap(Referrable.from).collectFirst { case ref @ RefInstancePredicate(decl) if ref.name == name => decl } case JavaTClass(Ref(cls), _) => @@ -397,12 +385,8 @@ case object Spec { def findField[G](obj: Expr[G], name: String): Option[InstanceField[G]] = obj.t match { - case TByReferenceClass(Ref(cls), _) => - cls.decls.flatMap(Referrable.from).collectFirst { - case ref @ RefField(decl) if ref.name == name => decl - } - case TByValueClass(Ref(cls), _) => - cls.decls.flatMap(Referrable.from).collectFirst { + case cls: TClass[G] => + cls.cls.decl.decls.flatMap(Referrable.from).collectFirst { case ref @ RefField(decl) if ref.name == name => decl } case _ => None diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index 3fa5927427..f2fa435ea0 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -2267,6 +2267,7 @@ abstract class CoercingRewriter[Pre <: Generation]() case l @ Lock(obj) => Lock(cls(obj))(l.blame) case Loop(init, cond, update, contract, body) => Loop(init, bool(cond), update, contract, body) + case block: LLVMBasicBlock[Pre] => block case LLVMAllocA(variable, allocationType, numElements) => LLVMAllocA(variable, allocationType, int(numElements)) case load @ LLVMLoad(variable, loadType, p, ordering) => diff --git a/src/col/vct/col/util/AstBuildHelpers.scala b/src/col/vct/col/util/AstBuildHelpers.scala index 1ae4392aeb..1953c351a9 100644 --- a/src/col/vct/col/util/AstBuildHelpers.scala +++ b/src/col/vct/col/util/AstBuildHelpers.scala @@ -221,6 +221,23 @@ object AstBuildHelpers { } } + implicit class ClassBuildHelpers[Pre, Post](cls: Class[Pre])( + implicit rewriter: AbstractRewriter[Pre, Post] + ) { + def rewrite( + typeArgs: Seq[Variable[Post]] = rewriter.variables + .dispatch(cls.typeArgs), + decls: Seq[ClassDeclaration[Post]] = rewriter.classDeclarations + .dispatch(cls.decls), + supports: Seq[Type[Post]] = cls.supports.map(rewriter.dispatch), + ): Class[Post] = + cls match { + case cls: ByReferenceClass[Pre] => + cls.rewrite(typeArgs, decls, supports) + case cls: ByValueClass[Pre] => cls.rewrite(typeArgs, decls, supports) + } + } + implicit class MethodBuildHelpers[Pre, Post](method: AbstractMethod[Pre])( implicit rewriter: AbstractRewriter[Pre, Post] ) { diff --git a/src/hre/hre/io/ChecksumReadableFile.scala b/src/hre/hre/io/ChecksumReadableFile.scala index f6ea9747ac..fac5d4c7d8 100644 --- a/src/hre/hre/io/ChecksumReadableFile.scala +++ b/src/hre/hre/io/ChecksumReadableFile.scala @@ -2,7 +2,7 @@ package hre.io import vct.result.VerificationError.SystemError -import java.io.Reader +import java.io.{BufferedReader, ByteArrayInputStream, InputStreamReader, Reader} import java.nio.charset.StandardCharsets import java.nio.file.{Files, Path} import java.security.{MessageDigest, NoSuchAlgorithmException} @@ -36,7 +36,10 @@ case class ChecksumReadableFile( case _: NoSuchAlgorithmException => throw UnknownChecksumKind(checksumKind) } - Files.newBufferedReader(file, StandardCharsets.UTF_8) + new BufferedReader(new InputStreamReader( + new ByteArrayInputStream(bytes), + StandardCharsets.UTF_8, + )) } def getChecksum: String = { diff --git a/src/rewrite/vct/rewrite/ConstantifyFinalFields.scala b/src/rewrite/vct/rewrite/ConstantifyFinalFields.scala index 3405775b89..18e8270385 100644 --- a/src/rewrite/vct/rewrite/ConstantifyFinalFields.scala +++ b/src/rewrite/vct/rewrite/ConstantifyFinalFields.scala @@ -80,22 +80,11 @@ case class ConstantifyFinalFields[Pre <: Generation]() extends Rewriter[Pre] { implicit val o: Origin = field.o if (isFinal(field)) { val `this` = - currentClass.top match { - case _: ByReferenceClass[Pre] => - new Variable[Post](TByReferenceClass( - succ(currentClass.top), - currentClass.top.typeArgs.map { v: Variable[Pre] => - TVar(succ(v)) - }, - )) - case _: ByValueClass[Pre] => - new Variable[Post](TByValueClass( - succ(currentClass.top), - currentClass.top.typeArgs.map { v: Variable[Pre] => - TVar(succ(v)) - }, - )) - } + new Variable(dispatch( + currentClass.top.classType(currentClass.top.typeArgs.map { + v: Variable[Pre] => TVar(v.ref) + }) + )) fieldFunction(field) = globalDeclarations .declare(withResult((result: Result[Post]) => function[Post]( diff --git a/src/rewrite/vct/rewrite/EncodeForkJoin.scala b/src/rewrite/vct/rewrite/EncodeForkJoin.scala index d41770483d..fd9d2d4c38 100644 --- a/src/rewrite/vct/rewrite/EncodeForkJoin.scala +++ b/src/rewrite/vct/rewrite/EncodeForkJoin.scala @@ -131,13 +131,7 @@ case class EncodeForkJoin[Pre <: Generation]() extends Rewriter[Pre] { implicit val o: Origin = e.o cls.decls.collectFirst { case run: RunMethod[Pre] => run } match { case Some(_) => - val obj = - cls match { - case _: ByReferenceClass[Pre] => - new Variable[Post](TByReferenceClass(succ(cls), Seq())) - case _: ByValueClass[Pre] => - new Variable[Post](TByValueClass(succ(cls), Seq())) - } + val obj = new Variable(dispatch(cls.classType(Nil))) ScopedExpr( Seq(obj), With( diff --git a/src/rewrite/vct/rewrite/EncodeResourceValues.scala b/src/rewrite/vct/rewrite/EncodeResourceValues.scala index 0c307bb3a2..34f4552f91 100644 --- a/src/rewrite/vct/rewrite/EncodeResourceValues.scala +++ b/src/rewrite/vct/rewrite/EncodeResourceValues.scala @@ -184,11 +184,7 @@ case class EncodeResourceValues[Pre <: Generation]() case ResourcePattern.HeapVariableLocation(_) => Nil case ResourcePattern.FieldLocation(f) => nonGeneric(fieldOwner(f)) - Seq(fieldOwner(f) match { - case cls: ByReferenceClass[Pre] => - TByReferenceClass(succ(cls), Seq()) - case cls: ByValueClass[Pre] => TByValueClass(succ(cls), Seq()) - }) + Seq(dispatch(fieldOwner(f).classType(Nil))) case ResourcePattern.ModelLocation(f) => Seq(TModel(succ(modelFieldOwner(f)))) case ResourcePattern.SilverFieldLocation(_) => Seq(TRef()) @@ -200,12 +196,8 @@ case class EncodeResourceValues[Pre <: Generation]() ref.args.map(_.t).map(dispatch) case ResourcePattern.InstancePredicateLocation(ref) => nonGeneric(predicateOwner(ref)) - (predicateOwner(ref) match { - case cls: ByReferenceClass[Pre] => - TByReferenceClass(succ[Class[Post]](cls), Seq()) - case cls: ByValueClass[Pre] => - TByValueClass(succ[Class[Post]](cls), Seq()) - }) +: ref.args.map(_.t).map(dispatch) + dispatch(predicateOwner(ref).classType(Nil)) +: ref.args.map(_.t) + .map(dispatch) } def freeTypes(pattern: ResourcePattern): Seq[Type[Post]] = diff --git a/src/rewrite/vct/rewrite/GenerateSingleOwnerPermissions.scala b/src/rewrite/vct/rewrite/GenerateSingleOwnerPermissions.scala index 24ba02e8ff..7d2cc07949 100644 --- a/src/rewrite/vct/rewrite/GenerateSingleOwnerPermissions.scala +++ b/src/rewrite/vct/rewrite/GenerateSingleOwnerPermissions.scala @@ -287,22 +287,20 @@ case class GenerateSingleOwnerPermissions[Pre <: Generation]( u, )), ) - case t: TByReferenceClass[Pre] - if !generatingClasses.contains(t.cls.decl) => - generatingClasses.having(t.cls.decl) { - foldStar(t.cls.decl.collect { case f: InstanceField[Pre] => + case TByReferenceClass(Ref(cls), _) if !generatingClasses.contains(cls) => + generatingClasses.having(cls) { + foldStar(cls.collect { case f: InstanceField[Pre] => fieldTransitivePerm(e, f)(f.o) }) } - case t: TByReferenceClass[Pre] => + case TByReferenceClass(Ref(cls), _) => // The class we are generating permission for has already been encountered when going through the chain // of fields. So we cut off the computation - if (!warnedClasses.contains(t.cls.decl)) { + if (!warnedClasses.contains(cls)) { logger.warn( - s"Not generating permissions for recursive occurrence of ${t.cls - .decl.o.getPreferredNameOrElse().ucamel}. Circular datastructures are not supported by permission generation" + s"Not generating permissions for recursive occurrence of ${cls.o.getPreferredNameOrElse().ucamel}. Circular datastructures are not supported by permission generation" ) - warnedClasses.addOne(t.cls.decl) + warnedClasses.addOne(cls) } tt case _ => tt diff --git a/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala b/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala index 43ee901793..c19c6b0fce 100644 --- a/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala +++ b/src/rewrite/vct/rewrite/LowerLocalHeapVariables.scala @@ -31,32 +31,15 @@ case class LowerLocalHeapVariables[Pre <: Generation]() extends Rewriter[Pre] { v } VerificationError.withContext(CurrentRewriteProgramContext(program)) { - localHeapVariables.scope { - variables.scope { - enumConstants.scope { - modelDeclarations.scope { - aDTDeclarations.scope { - classDeclarations.scope { - globalDeclarations.scope { - program.collect { - case HeapLocal(Ref(v)) if !nakedHeapLocals.contains(v) => - v - }.foreach(v => - stripped(v) = - new Variable[Post](dispatch(v.t.asPointer.get.element))( - v.o - ) - ) - Program(globalDeclarations.dispatch(program.declarations))( - dispatch(program.blame) - )(program.o) - } - } - } - } - } - } - } + program.rewrite(declarations = { + program.collect { + case HeapLocal(Ref(v)) if !nakedHeapLocals.contains(v) => v + }.foreach(v => + stripped(v) = + new Variable[Post](dispatch(v.t.asPointer.get.element))(v.o) + ) + globalDeclarations.dispatch(program.declarations) + }) } } diff --git a/src/rewrite/vct/rewrite/MonomorphizeClass.scala b/src/rewrite/vct/rewrite/MonomorphizeClass.scala index dd91b27be4..07ed5a9f87 100644 --- a/src/rewrite/vct/rewrite/MonomorphizeClass.scala +++ b/src/rewrite/vct/rewrite/MonomorphizeClass.scala @@ -3,9 +3,9 @@ package vct.rewrite import com.typesafe.scalalogging.LazyLogging import hre.util.ScopedStack import vct.col.ast._ -import vct.col.ref.{LazyRef, Ref} -import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder, Rewritten} -import vct.col.util.AstBuildHelpers.ContractApplicableBuildHelpers +import vct.col.ref.Ref +import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} +import vct.col.util.AstBuildHelpers.ClassBuildHelpers import vct.col.util.{Substitute, SuccessionMap} import scala.collection.mutable @@ -48,7 +48,7 @@ case class MonomorphizeClass[Pre <: Generation]() cls: Class[Pre], typeValues: Seq[Type[Pre]], keepBodies: Boolean, - ): Unit = { + ): Class[Post] = { /* Known limitation: the knownInstantations set does not take into account how a class was instantiated. A class can be instantiated both abstractly (without method bodies) and concretely (with method bodies) for the same sequence of type arguments, maybe. If that's the case, the knownInstantiations should take the @@ -60,7 +60,7 @@ case class MonomorphizeClass[Pre <: Generation]() logger.debug( s"Class ${cls.o.getPreferredNameOrElse().ucamel} with type args $typeValues is already instantiated, so skipping instantiation" ) - return + return genericSucc((key, cls)).asInstanceOf[Class[Post]] } val mode = if (keepBodies) { "concretely" } @@ -83,19 +83,15 @@ case class MonomorphizeClass[Pre <: Generation]() classDeclarations.scope { variables.scope { localHeapVariables.scope { - allScopes.anyDeclare(allScopes.anySucceedOnly( - cls, - cls match { - case cls: ByReferenceClass[Pre] => - cls.rewrite(typeArgs = Seq()) - case cls: ByValueClass[Pre] => cls.rewrite(typeArgs = Seq()) - }, - )) + allScopes.anyDeclare( + allScopes.anySucceedOnly(cls, cls.rewrite(typeArgs = Seq())) + ) } } } } } + genericSucc((key, cls)).asInstanceOf[Class[Post]] } override def dispatch(decl: Declaration[Pre]): Unit = @@ -137,28 +133,13 @@ case class MonomorphizeClass[Pre <: Generation]() override def dispatch(t: Type[Pre]): Type[Post] = (t, ctx.topOption) match { - case (TByReferenceClass(Ref(cls), typeArgs), ctx) if typeArgs.nonEmpty => + case (cls: TClass[Pre], ctx) if cls.typeArgs.nonEmpty => val typeValues = ctx match { - case Some(ctx) => typeArgs.map(ctx.substitute.dispatch) - case None => typeArgs + case Some(ctx) => cls.typeArgs.map(ctx.substitute.dispatch) + case None => cls.typeArgs } - instantiate(cls, typeValues, false) - TByReferenceClass[Post]( - genericSucc.ref[Post, Class[Post]](((cls, typeValues), cls)), - Seq(), - ) - case (TByValueClass(Ref(cls), typeArgs), ctx) if typeArgs.nonEmpty => - val typeValues = - ctx match { - case Some(ctx) => typeArgs.map(ctx.substitute.dispatch) - case None => typeArgs - } - instantiate(cls, typeValues, false) - TByValueClass[Post]( - genericSucc.ref[Post, Class[Post]](((cls, typeValues), cls)), - Seq(), - ) + instantiate(cls.cls.decl, typeValues, false).classType(Seq()) case (tvar @ TVar(_), Some(ctx)) => dispatch(ctx.substitutions(tvar)) case _ => t.rewriteDefault() } diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala index 950ae4f8b0..330bd3c588 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -74,7 +74,8 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { private val inAssignment: ScopedStack[Unit] = ScopedStack() private val copyContext: ScopedStack[CopyContext] = ScopedStack() - private val classCreationMethods + copyContext.push(NoCopy()) + private val classCreationMethodsSucc : SuccessionMap[TByValueClass[Pre], Procedure[Post]] = SuccessionMap() def makeClassCreationMethod(t: TByValueClass[Pre]): Procedure[Post] = { @@ -111,7 +112,7 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { newLocal.get(DerefAssignTarget), procedureInvocation[Post]( TrueSatisfiable, - classCreationMethods + classCreationMethodsSucc .getOrElseUpdate(t, makeClassCreationMethod(t)).ref, ), )(AssignLocalOk), @@ -119,24 +120,13 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { } case assign: Assign[Pre] => val target = inAssignment.having(()) { dispatch(assign.target) } - if (assign.target.t.isInstanceOf[TByValueClass[Pre]]) { - copyContext.having(InAssignmentStatement(assign)) { - assign.rewrite(target = target) - } - } else { assign.rewrite(target = target) } - case Instantiate(Ref(cls), out) - if cls.isInstanceOf[ByValueClass[Pre]] => { - // AssignLocalOk doesn't make too much sense since we don't know if out is a local - val t = TByValueClass[Pre](cls.ref, Seq()) - Assign[Post]( - dispatch(out), - procedureInvocation( - TrueSatisfiable, - classCreationMethods.getOrElseUpdate(t, makeClassCreationMethod(t)) - .ref, - ), - )(AssignLocalOk) - } + assign.target.t match { + case _: TByValueClass[Pre] => + copyContext.having(InAssignmentStatement(assign)) { + assign.rewrite(target = target) + } + case _ => assign.rewrite(target = target) + } case _ => node.rewriteDefault() } } @@ -150,19 +140,14 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { val ov = new Variable[Post](obj.t)(o.where(name = "original")) val v = new Variable[Post](dispatch(t))(o.where(name = "copy")) val children = t.cls.decl.decls.collect { case f: InstanceField[Pre] => - f.t match { - case inner: TByValueClass[Pre] => - Assign[Post]( - Deref[Post](v.get, succ(f))(DerefAssignTarget), - copyClassValue(Deref[Post](ov.get, succ(f))(blame(f)), inner, blame), - )(AssignLocalOk) - case _ => - Assign[Post]( - Deref[Post](v.get, succ(f))(DerefAssignTarget), - Deref[Post](ov.get, succ(f))(blame(f)), - )(AssignLocalOk) - - } + Assign[Post]( + Deref[Post](v.get, succ(f))(DerefAssignTarget), + f.t match { + case inner: TByValueClass[Pre] => + copyClassValue(Deref[Post](ov.get, succ(f))(blame(f)), inner, blame) + case _ => Deref[Post](ov.get, succ(f))(blame(f)) + }, + )(AssignLocalOk) } ScopedExpr( Seq(ov, v), @@ -173,7 +158,7 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { v.get, procedureInvocation[Post]( TrueSatisfiable, - classCreationMethods + classCreationMethodsSucc .getOrElseUpdate(t, makeClassCreationMethod(t)).ref, ), )(AssignLocalOk), @@ -218,114 +203,101 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(node: Expr[Pre]): Expr[Post] = { implicit val o: Origin = node.o node match { - case NewObject(Ref(cls)) if cls.isInstanceOf[ByValueClass[Pre]] => { + case NewObject(Ref(cls: ByValueClass[Pre])) => val t = TByValueClass[Pre](cls.ref, Seq()) - return procedureInvocation[Post]( + procedureInvocation[Post]( TrueSatisfiable, - classCreationMethods.getOrElseUpdate(t, makeClassCreationMethod(t)) - .ref, + classCreationMethodsSucc + .getOrElseUpdate(t, makeClassCreationMethod(t)).ref, ) - } - case _ => - } - if (inAssignment.nonEmpty) - node.rewriteDefault() - else - node match { - case Perm(ByValueClassLocation(e), p) => - unwrapClassPerm( - dispatch(e), - dispatch(p), - e.t.asInstanceOf[TByValueClass[Pre]], - ) - case Perm(pl @ PointerLocation(dhv @ DerefHeapVariable(Ref(v))), p) - if v.t.isInstanceOf[TNonNullPointer[Pre]] => - val t = v.t.asInstanceOf[TNonNullPointer[Pre]] - if (t.element.isInstanceOf[TByValueClass[Pre]]) { - val newV: Ref[Post, HeapVariable[Post]] = succ(v) - val newP = dispatch(p) - Perm(HeapVariableLocation(newV), newP) &* Perm( - PointerLocation(DerefHeapVariable(newV)(dhv.blame))(pl.blame), - newP, - ) - } else { node.rewriteDefault() } - case assign: PreAssignExpression[Pre] => - val target = inAssignment.having(()) { dispatch(assign.target) } - if (assign.target.t.isInstanceOf[TByValueClass[Pre]]) { - copyContext.having(InAssignmentExpression(assign)) { - assign.rewrite(target = target) - } - } else { - // No need for copy semantics in this context - copyContext.having(NoCopy()) { assign.rewrite(target = target) } - } - case invocation: Invocation[Pre] => - invocation.rewrite(args = invocation.args.map { a => - if (a.t.isInstanceOf[TByValueClass[Pre]]) { - copyContext.having(InCall(invocation)) { dispatch(a) } - } else { copyContext.having(NoCopy()) { dispatch(a) } } - }) - case dp @ DerefPointer(HeapLocal(Ref(v))) - if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => - rewriteInCopyContext( - dp, - v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], - ) - case dp @ DerefPointer(DerefHeapVariable(Ref(v))) - if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => - rewriteInCopyContext( - dp, - v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], + case _ if inAssignment.nonEmpty => node.rewriteDefault() + case Perm(ByValueClassLocation(e), p) => + unwrapClassPerm( + dispatch(e), + dispatch(p), + e.t.asInstanceOf[TByValueClass[Pre]], + ) + case Perm(pl @ PointerLocation(dhv @ DerefHeapVariable(Ref(v))), p) + if v.t.isInstanceOf[TNonNullPointer[Pre]] => + val t = v.t.asInstanceOf[TNonNullPointer[Pre]] + if (t.element.isInstanceOf[TByValueClass[Pre]]) { + val newV: Ref[Post, HeapVariable[Post]] = succ(v) + val newP = dispatch(p) + Perm(HeapVariableLocation(newV), newP) &* Perm( + PointerLocation(DerefHeapVariable(newV)(dhv.blame))(pl.blame), + newP, ) - case deref @ Deref(_, Ref(f)) if f.t.isInstanceOf[TByValueClass[Pre]] => - if (copyContext.isEmpty) { deref.rewriteDefault() } - else { - // TODO: Improve blame message here - copyClassValue( - deref.rewriteDefault(), - f.t.asInstanceOf[TByValueClass[Pre]], - f => deref.blame, - ) + } else { node.rewriteDefault() } + case assign: PreAssignExpression[Pre] => + val target = inAssignment.having(()) { dispatch(assign.target) } + if (assign.target.t.isInstanceOf[TByValueClass[Pre]]) { + copyContext.having(InAssignmentExpression(assign)) { + assign.rewrite(target = target) } - case dp @ DerefPointer(Local(Ref(v))) - if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => - // This can happen if the user specifies a local of type pointer to TByValueClass - rewriteInCopyContext( - dp, - v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], - ) - case _ => node.rewriteDefault() - } + } else { + // No need for copy semantics in this context + copyContext.having(NoCopy()) { assign.rewrite(target = target) } + } + case invocation: Invocation[Pre] => + invocation.rewrite(args = invocation.args.map { a => + if (a.t.isInstanceOf[TByValueClass[Pre]]) { + copyContext.having(InCall(invocation)) { dispatch(a) } + } else { copyContext.having(NoCopy()) { dispatch(a) } } + }) + case dp @ DerefPointer(HeapLocal(Ref(v))) + if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => + rewriteInCopyContext( + dp, + v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], + ) + case dp @ DerefPointer(DerefHeapVariable(Ref(v))) + if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => + rewriteInCopyContext( + dp, + v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], + ) + case deref @ Deref(_, Ref(f)) if f.t.isInstanceOf[TByValueClass[Pre]] => + // TODO: Improve blame message here + copyClassValue( + deref.rewriteDefault(), + f.t.asInstanceOf[TByValueClass[Pre]], + f => deref.blame, + ) + case dp @ DerefPointer(Local(Ref(v))) + if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => + // This can happen if the user specifies a local of type pointer to TByValueClass + rewriteInCopyContext( + dp, + v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], + ) + case _ => node.rewriteDefault() + } } private def rewriteInCopyContext( dp: DerefPointer[Pre], t: TByValueClass[Pre], ): Expr[Post] = { - if (copyContext.isEmpty) { - // If we are in other kinds of expressions like if statements - return dp.rewriteDefault() - } - val clazz = t.cls.decl.asInstanceOf[ByValueClass[Pre]] + val cls = t.cls.decl.asInstanceOf[ByValueClass[Pre]] copyContext.top match { case InCall(invocation) => copyClassValue( dp.rewriteDefault(), t, - f => ClassCopyInCallFailed(dp.blame, invocation, clazz, f), + f => ClassCopyInCallFailed(dp.blame, invocation, cls, f), ) case InAssignmentExpression(assignment) => copyClassValue( dp.rewriteDefault(), t, - f => ClassCopyInAssignmentFailed(dp.blame, assignment, clazz, f), + f => ClassCopyInAssignmentFailed(dp.blame, assignment, cls, f), ) case InAssignmentStatement(assignment) => copyClassValue( dp.rewriteDefault(), t, - f => ClassCopyInAssignmentFailed(dp.blame, assignment, clazz, f), + f => ClassCopyInAssignmentFailed(dp.blame, assignment, cls, f), ) case NoCopy() => dp.rewriteDefault() } diff --git a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala index 34d69d79e1..54d56e0ba7 100644 --- a/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala +++ b/src/rewrite/vct/rewrite/ResolveExpressionSideEffects.scala @@ -659,19 +659,7 @@ case class ResolveExpressionSideEffects[Pre <: Generation]() givenMap, yields, ) => - val typ = - cons.cls.decl match { - case cls: ByReferenceClass[Pre] => - TByReferenceClass[Post]( - succ[Class[Post]](cls), - classTypeArgs.map(dispatch), - ) - case cls: ByValueClass[Pre] => - TByValueClass[Post]( - succ[Class[Post]](cls), - classTypeArgs.map(dispatch), - ) - } + val typ = dispatch(cons.cls.decl.classType(classTypeArgs)) val res = new Variable[Post](typ)(ResultVar) variables.succeed(res.asInstanceOf[Variable[Pre]], res) effect( @@ -691,17 +679,7 @@ case class ResolveExpressionSideEffects[Pre <: Generation]() cons.cls.decl.classType(classTypeArgs), ) case NewObject(Ref(cls)) => - val res = - cls match { - case cls: ByReferenceClass[Pre] => - new Variable[Post]( - TByReferenceClass(succ[Class[Post]](cls), Seq()) - )(ResultVar) - case cls: ByValueClass[Pre] => - new Variable[Post](TByValueClass(succ[Class[Post]](cls), Seq()))( - ResultVar - ) - } + val res = new Variable[Post](dispatch(cls.classType(Seq())))(ResultVar) variables.succeed(res.asInstanceOf[Variable[Pre]], res) effect(Instantiate[Post](succ(cls), res.get(ResultVar))(e.o)) stored(res.get(SideEffectOrigin), cls.ref.decl.classType(Seq())) diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala index c303f1d21d..48e02ac91f 100644 --- a/src/rewrite/vct/rewrite/VariableToPointer.scala +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -150,8 +150,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { DerefPointer(Deref[Post](dispatch(obj), fieldMap.ref(f))(deref.blame))( PanicBlame("Should always be accessible") ) - case newObject @ NewObject(Ref(cls)) - if cls.isInstanceOf[ByValueClass[Pre]] => + case newObject @ NewObject(Ref(cls: ByValueClass[Pre])) => val obj = new Variable[Post](TByValueClass(succ(cls), Seq())) ScopedExpr( Seq(obj), diff --git a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala index 612b704211..8f4fd738ee 100644 --- a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala @@ -1351,7 +1351,7 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) // Create a class that can be used to create a 'this' object // It will be linked to the class made near the end of this method. val preEventClass: Class[Pre] = - new ByValueClass(Nil, Nil, Nil)(commandGroup.o) + new ByReferenceClass(Nil, Nil, Nil, tt)(commandGroup.o) this.currentThis = Some( rw.dispatch(ThisObject[Pre](preEventClass.ref)(preEventClass.o)) ) @@ -1478,9 +1478,8 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) )(KernelLambdaRunMethodBlame(kernelDeclaration))(commandGroup.o) // Create the surrounding class - // cl::sycl::event has a default copy constructor hence a ByValueClass val postEventClass = - new ByValueClass[Post]( + new ByReferenceClass[Post]( typeArgs = Seq(), decls = currentKernelType.get.getRangeFields ++ @@ -1488,12 +1487,13 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) .flatMap(acc => acc.instanceField +: acc.rangeIndexFields) ++ Seq(kernelRunner), supports = Seq(), + intrinsicLockInvariant = tt, )(commandGroup.o.where(name = "SYCL_EVENT_CLASS")) rw.globalDeclarations.succeed(preEventClass, postEventClass) // Create a variable to refer to the class instance val eventClassRef = - new Variable[Post](TByValueClass(postEventClass.ref, Seq()))( + new Variable[Post](TByReferenceClass(postEventClass.ref, Seq()))( commandGroup.o.where(name = "sycl_event_ref") ) // Store the class ref and read-write accessors to be used when the kernel is done running @@ -1979,7 +1979,7 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) preClass: Class[Pre], commandGroupO: Origin, ): Procedure[Post] = { - val t = rw.dispatch(TByValueClass[Pre](preClass.ref, Seq())) + val t = rw.dispatch(TByReferenceClass[Pre](preClass.ref, Seq())) rw.globalDeclarations.declare( withResult((result: Result[Post]) => { val constructorPostConditions: mutable.Buffer[Expr[Post]] = diff --git a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala index 6db3b7852c..8bf5bf9044 100644 --- a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala @@ -336,17 +336,6 @@ case class LangSpecificToCol[Pre <: Generation]( case cast: CCast[Pre] => c.cast(cast) case sizeof: SizeOf[Pre] => throw LangCToCol.UnsupportedSizeof(sizeof) -// case Perm(a @ AmbiguousLocation(expr), perm) -// if c.getBaseType(expr.t).isInstanceOf[CTStruct[Pre]] => -// c.getBaseType(expr.t) match { -// case structType: CTStruct[Pre] => -// c.unwrapStructPerm( -// dispatch(a).asInstanceOf[AmbiguousLocation[Post]], -// perm, -// structType, -// e.o, -// ) -// } case local: CPPLocal[Pre] => cpp.local(local) case deref: CPPClassMethodOrFieldAccess[Pre] => cpp.deref(deref) case inv: CPPInvocation[Pre] => cpp.invocation(inv) diff --git a/src/rewrite/vct/rewrite/lang/LangTypesToCol.scala b/src/rewrite/vct/rewrite/lang/LangTypesToCol.scala index adfbeb4be1..a19f731c8a 100644 --- a/src/rewrite/vct/rewrite/lang/LangTypesToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangTypesToCol.scala @@ -88,10 +88,7 @@ case class LangTypesToCol[Pre <: Generation]() extends Rewriter[Pre] { case t @ PVLNamedType(_, typeArgs) => t.ref.get match { case spec: SpecTypeNameTarget[Pre] => specType(spec, typeArgs) - case RefClass(decl: ByReferenceClass[Pre]) => - TByReferenceClass(succ[Class[Post]](decl), typeArgs.map(dispatch)) - case RefClass(decl: ByValueClass[Pre]) => - TByValueClass(succ[Class[Post]](decl), typeArgs.map(dispatch)) + case RefClass(decl: Class[Pre]) => dispatch(decl.classType(typeArgs)) } case t @ CPrimitiveType(specs) => dispatch(C.getPrimitiveType(specs, context = Some(t))) diff --git a/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala b/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala index f59b38fcb1..97d39d7b0a 100644 --- a/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala +++ b/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala @@ -2,7 +2,7 @@ package vct.rewrite.lang import vct.col.ast._ import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} -import RewriteHelpers._ +import vct.col.util.AstBuildHelpers.ClassBuildHelpers case object NoSupportSelfLoop extends RewriterBuilder { override def key: String = "removeSupportSelfLoop" @@ -13,22 +13,14 @@ case object NoSupportSelfLoop extends RewriterBuilder { case class NoSupportSelfLoop[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(decl: Declaration[Pre]): Unit = decl match { - case cls: ByReferenceClass[Pre] => + case cls: Class[Pre] => globalDeclarations.succeed( cls, cls.rewrite(supports = cls.supports.filter(_.asClass.get.cls.decl != cls) - .map(_.rewriteDefault()) + .map(_.rewriteDefault) ), ) - case cls: ByValueClass[Pre] => - globalDeclarations.succeed( - cls, - cls.rewrite(supports = - cls.supports.filter(_.asClass.get.cls.decl != cls) - .map(_.rewriteDefault()) - ), - ) - case other => rewriteDefault(other) + case other => super.dispatch(other) } } From ea45b988fec1a4d286721fa85d79906de8934be7 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Thu, 10 Oct 2024 11:42:13 +0200 Subject: [PATCH 45/47] Revert some changes which caused tests to fail --- src/col/vct/col/util/AstBuildHelpers.scala | 11 ++++---- .../vct/rewrite/MonomorphizeClass.scala | 28 ++++++++++++++----- .../vct/rewrite/PrepareByValueClass.scala | 4 ++- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/col/vct/col/util/AstBuildHelpers.scala b/src/col/vct/col/util/AstBuildHelpers.scala index 1953c351a9..0522190f81 100644 --- a/src/col/vct/col/util/AstBuildHelpers.scala +++ b/src/col/vct/col/util/AstBuildHelpers.scala @@ -225,16 +225,17 @@ object AstBuildHelpers { implicit rewriter: AbstractRewriter[Pre, Post] ) { def rewrite( - typeArgs: Seq[Variable[Post]] = rewriter.variables + typeArgs: => Seq[Variable[Post]] = rewriter.variables .dispatch(cls.typeArgs), - decls: Seq[ClassDeclaration[Post]] = rewriter.classDeclarations + decls: => Seq[ClassDeclaration[Post]] = rewriter.classDeclarations .dispatch(cls.decls), - supports: Seq[Type[Post]] = cls.supports.map(rewriter.dispatch), + supports: => Seq[Type[Post]] = cls.supports.map(rewriter.dispatch), ): Class[Post] = cls match { case cls: ByReferenceClass[Pre] => - cls.rewrite(typeArgs, decls, supports) - case cls: ByValueClass[Pre] => cls.rewrite(typeArgs, decls, supports) + cls.rewrite(typeArgs = typeArgs, decls = decls, supports = supports) + case cls: ByValueClass[Pre] => + cls.rewrite(typeArgs = typeArgs, decls = decls, supports = supports) } } diff --git a/src/rewrite/vct/rewrite/MonomorphizeClass.scala b/src/rewrite/vct/rewrite/MonomorphizeClass.scala index 07ed5a9f87..b581e1f32b 100644 --- a/src/rewrite/vct/rewrite/MonomorphizeClass.scala +++ b/src/rewrite/vct/rewrite/MonomorphizeClass.scala @@ -48,7 +48,7 @@ case class MonomorphizeClass[Pre <: Generation]() cls: Class[Pre], typeValues: Seq[Type[Pre]], keepBodies: Boolean, - ): Class[Post] = { + ): Unit = { /* Known limitation: the knownInstantations set does not take into account how a class was instantiated. A class can be instantiated both abstractly (without method bodies) and concretely (with method bodies) for the same sequence of type arguments, maybe. If that's the case, the knownInstantiations should take the @@ -60,7 +60,7 @@ case class MonomorphizeClass[Pre <: Generation]() logger.debug( s"Class ${cls.o.getPreferredNameOrElse().ucamel} with type args $typeValues is already instantiated, so skipping instantiation" ) - return genericSucc((key, cls)).asInstanceOf[Class[Post]] + return } val mode = if (keepBodies) { "concretely" } @@ -91,7 +91,6 @@ case class MonomorphizeClass[Pre <: Generation]() } } } - genericSucc((key, cls)).asInstanceOf[Class[Post]] } override def dispatch(decl: Declaration[Pre]): Unit = @@ -133,13 +132,28 @@ case class MonomorphizeClass[Pre <: Generation]() override def dispatch(t: Type[Pre]): Type[Post] = (t, ctx.topOption) match { - case (cls: TClass[Pre], ctx) if cls.typeArgs.nonEmpty => + case (TByReferenceClass(Ref(cls), typeArgs), ctx) if typeArgs.nonEmpty => val typeValues = ctx match { - case Some(ctx) => cls.typeArgs.map(ctx.substitute.dispatch) - case None => cls.typeArgs + case Some(ctx) => typeArgs.map(ctx.substitute.dispatch) + case None => typeArgs } - instantiate(cls.cls.decl, typeValues, false).classType(Seq()) + instantiate(cls, typeValues, false) + TByReferenceClass[Post]( + genericSucc.ref[Post, Class[Post]](((cls, typeValues), cls)), + Seq(), + ) + case (TByValueClass(Ref(cls), typeArgs), ctx) if typeArgs.nonEmpty => + val typeValues = + ctx match { + case Some(ctx) => typeArgs.map(ctx.substitute.dispatch) + case None => typeArgs + } + instantiate(cls, typeValues, false) + TByValueClass[Post]( + genericSucc.ref[Post, Class[Post]](((cls, typeValues), cls)), + Seq(), + ) case (tvar @ TVar(_), Some(ctx)) => dispatch(ctx.substitutions(tvar)) case _ => t.rewriteDefault() } diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/PrepareByValueClass.scala index 330bd3c588..42a128696c 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/PrepareByValueClass.scala @@ -256,7 +256,9 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { dp, v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], ) - case deref @ Deref(_, Ref(f)) if f.t.isInstanceOf[TByValueClass[Pre]] => + case deref @ Deref(_, Ref(f)) + if f.t.isInstanceOf[TByValueClass[Pre]] && + copyContext.top != NoCopy() => // TODO: Improve blame message here copyClassValue( deref.rewriteDefault(), From 033b7ad47316f2445045bb72ffb7afe99cab6384 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 11 Oct 2024 12:02:59 +0200 Subject: [PATCH 46/47] Remove supports from ByValueClass and remove useless LLVM coercions --- src/col/vct/col/ast/Node.scala | 7 --- .../declaration/global/ByValueClassImpl.scala | 1 + .../family/coercion/CoerceLLVMArrayImpl.scala | 8 --- .../coercion/CoerceLLVMPointerImpl.scala | 9 ---- .../ast/family/coercion/CoercionImpl.scala | 1 - .../vct/col/typerules/CoercingRewriter.scala | 24 ++------- src/col/vct/col/typerules/CoercionUtils.scala | 52 ------------------- src/col/vct/col/util/AstBuildHelpers.scala | 5 +- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 1 - .../vct/rewrite/lang/LangLLVMToCol.scala | 1 - .../vct/rewrite/lang/NoSupportSelfLoop.scala | 5 +- 11 files changed, 8 insertions(+), 106 deletions(-) delete mode 100644 src/col/vct/col/ast/family/coercion/CoerceLLVMArrayImpl.scala delete mode 100644 src/col/vct/col/ast/family/coercion/CoerceLLVMPointerImpl.scala diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 73e887de8a..9f46327b76 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -682,7 +682,6 @@ final class ByReferenceClass[G]( final class ByValueClass[G]( val typeArgs: Seq[Variable[G]], val decls: Seq[ClassDeclaration[G]], - val supports: Seq[Type[G]], )(implicit val o: Origin) extends Class[G] with ByValueClassImpl[G] final class Model[G](val declarations: Seq[ModelDeclaration[G]])( @@ -1172,12 +1171,6 @@ final case class CoerceZFracFrac[G]()(implicit val o: Origin) final case class CoerceLLVMIntInt[G]()(implicit val o: Origin) extends Coercion[G] with CoerceLLVMIntIntImpl[G] -final case class CoerceLLVMPointer[G](from: Option[Type[G]], to: Type[G])( - implicit val o: Origin -) extends Coercion[G] with CoerceLLVMPointerImpl[G] -final case class CoerceLLVMArray[G](source: Type[G], target: Type[G])( - implicit val o: Origin -) extends Coercion[G] with CoerceLLVMArrayImpl[G] @family sealed trait Expr[G] extends NodeFamily[G] with ExprImpl[G] diff --git a/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala b/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala index e776781309..452501211e 100644 --- a/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala +++ b/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala @@ -7,6 +7,7 @@ import vct.col.util.AstBuildHelpers._ trait ByValueClassImpl[G] extends ByValueClassOps[G] { this: ByValueClass[G] => override def intrinsicLockInvariant: Expr[G] = tt + override def supports: Seq[Type[G]] = Nil override def classType(typeArgs: Seq[Type[G]]): TByValueClass[G] = TByValueClass[G](this.ref, typeArgs) } diff --git a/src/col/vct/col/ast/family/coercion/CoerceLLVMArrayImpl.scala b/src/col/vct/col/ast/family/coercion/CoerceLLVMArrayImpl.scala deleted file mode 100644 index a77db8453b..0000000000 --- a/src/col/vct/col/ast/family/coercion/CoerceLLVMArrayImpl.scala +++ /dev/null @@ -1,8 +0,0 @@ -package vct.col.ast.family.coercion - -import vct.col.ast.ops.CoerceLLVMArrayOps -import vct.col.ast.CoerceLLVMArray - -trait CoerceLLVMArrayImpl[G] extends CoerceLLVMArrayOps[G] { - this: CoerceLLVMArray[G] => -} diff --git a/src/col/vct/col/ast/family/coercion/CoerceLLVMPointerImpl.scala b/src/col/vct/col/ast/family/coercion/CoerceLLVMPointerImpl.scala deleted file mode 100644 index e686d1c593..0000000000 --- a/src/col/vct/col/ast/family/coercion/CoerceLLVMPointerImpl.scala +++ /dev/null @@ -1,9 +0,0 @@ -package vct.col.ast.family.coercion - -import vct.col.ast.{CoerceLLVMPointer, TPointer} -import vct.col.ast.ops.CoerceLLVMPointerOps - -trait CoerceLLVMPointerImpl[G] extends CoerceLLVMPointerOps[G] { - this: CoerceLLVMPointer[G] => - override def target: TPointer[G] = TPointer(to) -} diff --git a/src/col/vct/col/ast/family/coercion/CoercionImpl.scala b/src/col/vct/col/ast/family/coercion/CoercionImpl.scala index 998074fa51..a8958f885c 100644 --- a/src/col/vct/col/ast/family/coercion/CoercionImpl.scala +++ b/src/col/vct/col/ast/family/coercion/CoercionImpl.scala @@ -89,6 +89,5 @@ trait CoercionImpl[G] extends CoercionFamilyOps[G] { case CoerceCFloatCInt(_) => false case CoerceLLVMIntInt() => true - case CoerceLLVMPointer(_, _) => true } } diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index f2fa435ea0..c6a2e54d0f 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -306,8 +306,6 @@ abstract class CoercingRewriter[Pre <: Generation]() case CoerceCFloatFloat(_, _) => e case CoerceLLVMIntInt() => e - case CoerceLLVMPointer(_, _) => e - case CoerceLLVMArray(_, _) => e } } @@ -551,15 +549,6 @@ abstract class CoercingRewriter[Pre <: Generation]() (ApplyCoercion(e, coercion)(coercionOrigin(e)), t) case None => throw IncoercibleText(e, s"pointer") } - def llvmPointer( - e: Expr[Pre], - innerType: Type[Pre], - ): (Expr[Pre], TPointer[Pre]) = - CoercionUtils.getAnyLLVMPointerCoercion(e.t, innerType) match { - case Some((coercion, t)) => - (ApplyCoercion(e, coercion)(coercionOrigin(e)), t) - case None => throw IncoercibleText(e, s"llvm pointer of $innerType") - } def matrix(e: Expr[Pre]): (Expr[Pre], TMatrix[Pre]) = CoercionUtils.getAnyMatrixCoercion(e.t) match { case Some((coercion, t)) => @@ -2140,12 +2129,7 @@ abstract class CoercingRewriter[Pre <: Generation]() case Message(_) => e case LLVMLocal(name) => e case LLVMGetElementPointer(structureType, resultType, pointer, indices) => - LLVMGetElementPointer( - structureType, - resultType, - llvmPointer(pointer, structureType)._1, - indices, - ) + LLVMGetElementPointer(structureType, resultType, pointer, indices) case LLVMSignExtend(inputType, outputType, value) => e case LLVMZeroExtend(inputType, outputType, value) => e case LLVMTruncate(inputType, outputType, value) => e @@ -2271,11 +2255,9 @@ abstract class CoercingRewriter[Pre <: Generation]() case LLVMAllocA(variable, allocationType, numElements) => LLVMAllocA(variable, allocationType, int(numElements)) case load @ LLVMLoad(variable, loadType, p, ordering) => - LLVMLoad(variable, loadType, llvmPointer(p, loadType)._1, ordering)( - load.blame - ) + LLVMLoad(variable, loadType, p, ordering)(load.blame) case store @ LLVMStore(value, p, ordering) => - LLVMStore(value, llvmPointer(p, value.t)._1, ordering)(store.blame) + LLVMStore(value, p, ordering)(store.blame) case ModelDo(model, perm, after, action, impl) => ModelDo(model, rat(perm), after, action, impl) case n @ Notify(obj) => Notify(cls(obj))(n.blame) diff --git a/src/col/vct/col/typerules/CoercionUtils.scala b/src/col/vct/col/typerules/CoercionUtils.scala index a7b16b8e81..bee93e6e2e 100644 --- a/src/col/vct/col/typerules/CoercionUtils.scala +++ b/src/col/vct/col/typerules/CoercionUtils.scala @@ -123,8 +123,6 @@ case object CoercionUtils { case (TNull(), TAnyClass()) => CoerceNullAnyClass() case (TNull(), TPointer(target)) => CoerceNullPointer(target) case (TNull(), CTPointer(target)) => CoerceNullPointer(target) - case (TNull(), LLVMTPointer(Some(target))) => CoerceNullPointer(target) - case (TNull(), LLVMTPointer(None)) => CoerceNullPointer(TAny()) case (TNull(), TEnum(target)) => CoerceNullEnum(target) case (CTArray(_, innerType), TArray(element)) if element == innerType => @@ -309,23 +307,6 @@ case object CoercionUtils { case None => return None } - // TODO: Back and forth should not be needed... - case (LLVMTPointer(Some(_)), LLVMTPointer(None)) => - CoerceIdentity(LLVMTPointer(None)) - case (LLVMTPointer(None), LLVMTPointer(Some(innerType))) => - CoerceIdentity(LLVMTPointer(Some(innerType))) - case (TPointer(_), LLVMTPointer(None)) => - CoerceIdentity(LLVMTPointer(None)) - case (LLVMTPointer(None), TPointer(innerType)) => - CoerceLLVMPointer(None, innerType) - case ( - LLVMTPointer(Some(LLVMTArray(numElements, elementType))), - TPointer(innerType), - ) if numElements > 0 => - getAnyCoercion(elementType, innerType).getOrElse(return None) - case (LLVMTPointer(Some(leftInner)), TPointer(rightInner)) => - getAnyCoercion(leftInner, rightInner).getOrElse(return None) - case (TPointer(TAny()), TPointer(any)) => CoerceIdentity(TPointer(any)) case (TPointer(any), TPointer(TAny())) => CoerceIdentity(TPointer(any)) @@ -495,24 +476,6 @@ case object CoercionUtils { case _ => false } - def getAnyLLVMPointerCoercion[G]( - source: Type[G], - innerType: Type[G], - ): Option[(Coercion[G], TPointer[G])] = - source match { - case LLVMTPointer(None) => - Some((CoerceLLVMPointer(None, innerType), TPointer[G](innerType))) - case LLVMTPointer(Some(t)) if firstElementIsType(t, innerType) => - Some(CoerceLLVMPointer(Some(t), innerType), TPointer[G](innerType)) - case TPointer(TAny()) => - Some((CoerceLLVMPointer(None, innerType), TPointer[G](innerType))) - case TPointer(t) if firstElementIsType(t, innerType) => - Some(CoerceLLVMPointer(Some(t), innerType), TPointer[G](innerType)) - case _: TNull[G] => - Some((CoerceLLVMPointer(None, innerType), TPointer[G](innerType))) - case _ => None - } - def getAnyCArrayCoercion[G]( source: Type[G] ): Option[(Coercion[G], CTArray[G])] = @@ -550,21 +513,6 @@ case object CoercionUtils { .asInstanceOf[TArray[G]], )) case t: TArray[G] => Some((CoerceIdentity(source), t)) - case t: LLVMTArray[G] => { - val t2 = TArray[G](t.elementType) - Some(CoerceLLVMArray(t, t2), t2) - } - case LLVMTPointer(None) => - Some(CoerceIdentity(source), TArray[G](TAnyValue())) - case LLVMTPointer(Some(t)) => - getAnyArrayCoercion(t) match { - case Some(inner) => - Some( - CoercionSequence(Seq(inner._1, CoerceIdentity(source))), - inner._2, - ) - case None => None - } case _: TNull[G] => val t = TArray[G](TAnyValue()) Some((CoerceNullArray(t), t)) diff --git a/src/col/vct/col/util/AstBuildHelpers.scala b/src/col/vct/col/util/AstBuildHelpers.scala index 0522190f81..68d6bf8b8d 100644 --- a/src/col/vct/col/util/AstBuildHelpers.scala +++ b/src/col/vct/col/util/AstBuildHelpers.scala @@ -229,13 +229,12 @@ object AstBuildHelpers { .dispatch(cls.typeArgs), decls: => Seq[ClassDeclaration[Post]] = rewriter.classDeclarations .dispatch(cls.decls), - supports: => Seq[Type[Post]] = cls.supports.map(rewriter.dispatch), ): Class[Post] = cls match { case cls: ByReferenceClass[Pre] => - cls.rewrite(typeArgs = typeArgs, decls = decls, supports = supports) + cls.rewrite(typeArgs = typeArgs, decls = decls) case cls: ByValueClass[Pre] => - cls.rewrite(typeArgs = typeArgs, decls = decls, supports = supports) + cls.rewrite(typeArgs = typeArgs, decls = decls) } } diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index 61e3d4195a..2e766c9c97 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -1016,7 +1016,6 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) .declare(cStructFieldsSuccessor((decl, fieldDecl))) } }._1, - Seq(), )(CStructOrigin(sdecl)) rw.globalDeclarations.declare(newStruct) diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index 1480f0b909..4574c1e8df 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -412,7 +412,6 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) rw.classDeclarations.declare(structFieldMap((t, idx))) } }._1, - Seq(), )(t.o.withContent(TypeName("struct"))) rw.globalDeclarations.declare(newStruct) diff --git a/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala b/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala index 97d39d7b0a..89f3494fa7 100644 --- a/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala +++ b/src/rewrite/vct/rewrite/lang/NoSupportSelfLoop.scala @@ -2,7 +2,6 @@ package vct.rewrite.lang import vct.col.ast._ import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} -import vct.col.util.AstBuildHelpers.ClassBuildHelpers case object NoSupportSelfLoop extends RewriterBuilder { override def key: String = "removeSupportSelfLoop" @@ -13,12 +12,12 @@ case object NoSupportSelfLoop extends RewriterBuilder { case class NoSupportSelfLoop[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(decl: Declaration[Pre]): Unit = decl match { - case cls: Class[Pre] => + case cls: ByReferenceClass[Pre] => globalDeclarations.succeed( cls, cls.rewrite(supports = cls.supports.filter(_.asClass.get.cls.decl != cls) - .map(_.rewriteDefault) + .map(_.rewriteDefault()) ), ) case other => super.dispatch(other) From bbe95cc1cd88633115f62f9502368edf45bb6bb4 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Tue, 15 Oct 2024 14:13:32 +0200 Subject: [PATCH 47/47] Incorporate last remaining feedback from Bob --- .../global/ByReferenceClassImpl.scala | 12 + .../declaration/global/ByValueClassImpl.scala | 6 +- .../ast/declaration/global/ClassImpl.scala | 16 +- .../declaration/singular/EndpointImpl.scala | 2 +- .../vct/col/ast/type/typeclass/TypeImpl.scala | 11 + src/col/vct/col/resolve/Resolve.scala | 1 + .../vct/col/util/SubstituteReferences.scala | 6 +- src/hre/hre/util/ScopedStack.scala | 1 - src/main/vct/main/stages/Transformation.scala | 4 +- src/rewrite/vct/rewrite/ClassToRef.scala | 364 ++++++++---------- .../vct/rewrite/DisambiguateLocation.scala | 2 +- src/rewrite/vct/rewrite/EncodeAutoValue.scala | 161 ++++---- ...ss.scala => EncodeByValueClassUsage.scala} | 58 ++- .../vct/rewrite/EncodeIntrinsicLock.scala | 7 +- .../rewrite/PropagateContextEverywhere.scala | 14 +- src/rewrite/vct/rewrite/TrivialAddrOf.scala | 4 +- .../vct/rewrite/VariableToPointer.scala | 10 +- .../vct/rewrite/lang/LangCPPToCol.scala | 8 +- src/rewrite/vct/rewrite/lang/LangCToCol.scala | 2 +- .../vct/rewrite/lang/LangLLVMToCol.scala | 7 +- 20 files changed, 326 insertions(+), 370 deletions(-) rename src/rewrite/vct/rewrite/{PrepareByValueClass.scala => EncodeByValueClassUsage.scala} (85%) diff --git a/src/col/vct/col/ast/declaration/global/ByReferenceClassImpl.scala b/src/col/vct/col/ast/declaration/global/ByReferenceClassImpl.scala index ec95adbe7c..a19eb5672d 100644 --- a/src/col/vct/col/ast/declaration/global/ByReferenceClassImpl.scala +++ b/src/col/vct/col/ast/declaration/global/ByReferenceClassImpl.scala @@ -2,9 +2,21 @@ package vct.col.ast.declaration.global import vct.col.ast.{ByReferenceClass, TByReferenceClass, Type} import vct.col.ast.ops.ByReferenceClassOps +import vct.col.print._ +import vct.col.util.AstBuildHelpers.tt trait ByReferenceClassImpl[G] extends ByReferenceClassOps[G] { this: ByReferenceClass[G] => override def classType(typeArgs: Seq[Type[G]]): TByReferenceClass[G] = TByReferenceClass[G](this.ref, typeArgs) + + override def layoutLockInvariant(implicit ctx: Ctx): Doc = + if (intrinsicLockInvariant == tt) { Empty } + else { + Doc.spec(Show.lazily { c: Ctx => + implicit val ctx: Ctx = c + Text("lock_invariant") <+> + Nest(intrinsicLockInvariant.show <> ";" <+/> Empty) + }) + } } diff --git a/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala b/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala index 452501211e..416f595106 100644 --- a/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala +++ b/src/col/vct/col/ast/declaration/global/ByValueClassImpl.scala @@ -1,13 +1,13 @@ package vct.col.ast.declaration.global -import vct.col.ast.{ByValueClass, Expr, InstanceField, TByValueClass, Type} +import vct.col.ast.{ByValueClass, TByValueClass, Type} import vct.col.ast.ops.ByValueClassOps -import vct.col.util.AstBuildHelpers._ +import vct.col.print._ trait ByValueClassImpl[G] extends ByValueClassOps[G] { this: ByValueClass[G] => - override def intrinsicLockInvariant: Expr[G] = tt override def supports: Seq[Type[G]] = Nil override def classType(typeArgs: Seq[Type[G]]): TByValueClass[G] = TByValueClass[G](this.ref, typeArgs) + override def layoutLockInvariant(implicit ctx: Ctx): Doc = Empty } diff --git a/src/col/vct/col/ast/declaration/global/ClassImpl.scala b/src/col/vct/col/ast/declaration/global/ClassImpl.scala index 251557ee8d..ef58b5528c 100644 --- a/src/col/vct/col/ast/declaration/global/ClassImpl.scala +++ b/src/col/vct/col/ast/declaration/global/ClassImpl.scala @@ -21,14 +21,12 @@ trait ClassImpl[G] extends Declarator[G] { def typeArgs: Seq[Variable[G]] def decls: Seq[ClassDeclaration[G]] def supports: Seq[Type[G]] - def intrinsicLockInvariant: Expr[G] def classType(typeArgs: Seq[Type[G]]): TClass[G] def transSupportArrowsHelper( seen: Set[TClass[G]] ): Seq[(TClass[G], TClass[G])] = { - // TODO: Does this break things if we have a ByValueClass with supers? val t: TClass[G] = classType( typeArgs.map((v: Variable[G]) => TVar(v.ref[Variable[G]])) ) @@ -49,9 +47,7 @@ trait ClassImpl[G] extends Declarator[G] { override def declarations: Seq[Declaration[G]] = decls ++ typeArgs - def layoutLockInvariant(implicit ctx: Ctx): Doc = - Text("lock_invariant") <+> Nest(intrinsicLockInvariant.show) <> ";" <+/> - Empty + def layoutLockInvariant(implicit ctx: Ctx): Doc def layoutLock(implicit ctx: Ctx): Doc = Text("Lock") <+> "intrinsicLock$" <+> "=" <+> "new" <+> @@ -59,10 +55,7 @@ trait ClassImpl[G] extends Declarator[G] { "intrinsicLock$" <> "." <> "newCondition()" <> ";" def layoutJava(implicit ctx: Ctx): Doc = - (if (intrinsicLockInvariant == tt[G]) - Empty - else - Doc.spec(Show.lazily(layoutLockInvariant(_)))) <+/> Group( + layoutLockInvariant <+/> Group( Text("class") <+> ctx.name(this) <> (if (typeArgs.nonEmpty) Text("<") <> Doc.args(typeArgs) <> ">" @@ -78,10 +71,7 @@ trait ClassImpl[G] extends Declarator[G] { ) <>> Doc.stack2(layoutLock +: decls) <+/> "}" def layoutPvl(implicit ctx: Ctx): Doc = - (if (intrinsicLockInvariant == tt[G]) - Empty - else - Doc.spec(Show.lazily(layoutLockInvariant(_)))) <+/> Group( + layoutLockInvariant <+/> Group( Text("class") <+> ctx.name(this) <> (if (typeArgs.nonEmpty) Text("<") <> Doc.args(typeArgs) <> ">" diff --git a/src/col/vct/col/ast/declaration/singular/EndpointImpl.scala b/src/col/vct/col/ast/declaration/singular/EndpointImpl.scala index 8701db45e9..1c243f0473 100644 --- a/src/col/vct/col/ast/declaration/singular/EndpointImpl.scala +++ b/src/col/vct/col/ast/declaration/singular/EndpointImpl.scala @@ -12,7 +12,7 @@ trait EndpointImpl[G] override def layout(implicit ctx: Ctx): Doc = Group(Text("endpoint") <+> ctx.name(this) <+> "=" <+> init) - def t: TClass[G] = TByReferenceClass(cls, typeArgs) + def t: TClass[G] = cls.decl.classType(typeArgs) override def check(ctx: CheckContext[G]): Seq[CheckError] = super.check(ctx) diff --git a/src/col/vct/col/ast/type/typeclass/TypeImpl.scala b/src/col/vct/col/ast/type/typeclass/TypeImpl.scala index efdaf2191f..3d9fd87c0c 100644 --- a/src/col/vct/col/ast/type/typeclass/TypeImpl.scala +++ b/src/col/vct/col/ast/type/typeclass/TypeImpl.scala @@ -50,6 +50,17 @@ trait TypeImpl[G] extends TypeFamilyOps[G] { CoercionUtils.getAnySmtlibArrayCoercion(this).map(_._2) def asSmtlibSeq: Option[TSmtlibSeq[G]] = CoercionUtils.getAnySmtlibSeqCoercion(this).map(_._2) + def asByReferenceClass: Option[TByReferenceClass[G]] = + this match { + // TODO: Check uses should this be a coercion to also catch null and the like? + case cls: TByReferenceClass[G] => Some(cls) + case _ => None + } + def asByValueClass: Option[TByValueClass[G]] = + this match { + case cls: TByValueClass[G] => Some(cls) + case _ => None + } /*def asVector: Option[TVector] = optMatch(this) { case vec: TVector => vec }*/ def particularize(substitutions: Map[Variable[G], Type[G]]): Type[G] = { diff --git a/src/col/vct/col/resolve/Resolve.scala b/src/col/vct/col/resolve/Resolve.scala index c39f2e972f..4612e9d097 100644 --- a/src/col/vct/col/resolve/Resolve.scala +++ b/src/col/vct/col/resolve/Resolve.scala @@ -1143,6 +1143,7 @@ case object ResolveReferences extends LazyLogging { loop.headerBlock = Some(ctx.llvmBlocks(loop.header.decl)) loop.latchBlock = Some(ctx.llvmBlocks(loop.latch.decl)) case contract: LLVMFunctionContract[G] => + // WONTFIX: // implicit val o: Origin = contract.o val llvmFunction = ctx.currentResult.get.asInstanceOf[RefLLVMFunctionDefinition[G]].decl diff --git a/src/col/vct/col/util/SubstituteReferences.scala b/src/col/vct/col/util/SubstituteReferences.scala index 95dbfed39f..c778f76aaa 100644 --- a/src/col/vct/col/util/SubstituteReferences.scala +++ b/src/col/vct/col/util/SubstituteReferences.scala @@ -8,7 +8,11 @@ import vct.col.rewrite.NonLatchingRewriter import scala.reflect.ClassTag /** Substitute all references in expressions, resulting AST can be used for - * analysis but not output since it doesn't contain the right declarations + * analysis but not output since it doesn't contain the right declarations. + * + * This is unsafe for rewriting if the substituted declarations are not + * declared. In general, I would advise against using this class if you are not + * 100% certain this is what you need. */ case class SubstituteReferences[G](subs: Map[Object, Object]) extends NonLatchingRewriter[G, G] { diff --git a/src/hre/hre/util/ScopedStack.scala b/src/hre/hre/util/ScopedStack.scala index 1270cd14a1..de4e2f0117 100644 --- a/src/hre/hre/util/ScopedStack.scala +++ b/src/hre/hre/util/ScopedStack.scala @@ -22,7 +22,6 @@ case class ScopedStack[T]() { def isEmpty: Boolean = stack.isEmpty def nonEmpty: Boolean = stack.nonEmpty def push(t: T): Unit = stack.push(t) - def pop(): T = stack.pop() def top: T = stack.top def topOption: Option[T] = stack.headOption def find(f: T => Boolean): Option[T] = stack.find(f) diff --git a/src/main/vct/main/stages/Transformation.scala b/src/main/vct/main/stages/Transformation.scala index 5907a0af3b..d6dc2ed33b 100644 --- a/src/main/vct/main/stages/Transformation.scala +++ b/src/main/vct/main/stages/Transformation.scala @@ -29,7 +29,7 @@ import vct.rewrite.adt.ImportSetCompat import vct.rewrite.{ DisambiguatePredicateExpression, EncodeAutoValue, - PrepareByValueClass, + EncodeByValueClassUsage, EncodeRange, EncodeResourceValues, ExplicitResourceValues, @@ -410,7 +410,7 @@ case class SilverTransformation( // flatten out functions in the rhs of assignments, making it harder to detect final field assignments where the // value is pure and therefore be put in the contract of the constant function. ConstantifyFinalFields, - PrepareByValueClass, + EncodeByValueClassUsage, // Resolve side effects including method invocations, for encodetrythrowsignals. ResolveExpressionSideChecks, ResolveExpressionSideEffects, diff --git a/src/rewrite/vct/rewrite/ClassToRef.scala b/src/rewrite/vct/rewrite/ClassToRef.scala index c7f48ef2b1..6fa42781e9 100644 --- a/src/rewrite/vct/rewrite/ClassToRef.scala +++ b/src/rewrite/vct/rewrite/ClassToRef.scala @@ -109,27 +109,6 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ) } - private def transitiveByValuePermissions( - obj: Expr[Pre], - t: TByValueClass[Pre], - amount: Expr[Pre], - )(implicit o: Origin): Expr[Pre] = { - t.cls.decl.decls.collect[Expr[Pre]] { case field: InstanceField[Pre] => - field.t match { - case field_t: TByValueClass[Pre] => - fieldPerm[Pre](obj, field.ref, amount) &* - transitiveByValuePermissions( - Deref[Pre](obj, field.ref)(PanicBlame( - "Permission should already be ensured" - )), - field_t, - amount, - ) - case _ => fieldPerm(obj, field.ref, amount) - } - }.reduce[Expr[Pre]] { (a, b) => a &* b } - } - def makeInstanceOf: Function[Post] = { implicit val o: Origin = InstanceOfOrigin val sub = new Variable[Post](TInt()) @@ -374,181 +353,170 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { case _ => throw ExtraNode } cls match { - case cls: ByValueClass[Pre] => - implicit val o: Origin = cls.o - val axiomType = TAxiomatic[Post](byValClassSucc.ref(cls), Nil) - var valueAsAxioms: Seq[ADTAxiom[Post]] = Seq() - val (fieldFunctions, fieldInverses, fieldTypes) = - cls.decls.collect { case field: Field[Pre] => - val newT = dispatch(field.t) - val nonnullT = TNonNullPointer(newT) - byValFieldSucc(field) = - new ADTFunction[Post]( - Seq(new Variable(axiomType)(field.o)), - nonnullT, - )(field.o) - if (valueAsAxioms.isEmpty) { - // This is the first field - valueAsAxioms = - valueAsAxioms :+ new ADTAxiom[Post](forall( - axiomType, - body = { a => - InlinePattern(adtFunctionInvocation[Post]( - valueAsFunctions.getOrElseUpdate( - field.t, - makeValueAsFunction(field.t.toString, newT), - ).ref, - typeArgs = Some((valueAdt.ref(()), Seq(axiomType))), - args = Seq(a), - )) === adtFunctionInvocation( - byValFieldSucc.ref(field), - args = Seq(a), - ) - }, - )) - - valueAsAxioms = - valueAsAxioms ++ - (field.t match { - case t: TByValueClass[Pre] => - // TODO: If there are no fields we should ignore the first field and add the axioms for the second field - t.cls.decl.decls - .collectFirst({ case innerF: InstanceField[Pre] => - unwrapValueAs( - axiomType, - innerF.t, - dispatch(innerF.t), - byValFieldSucc.ref(field), - ) - }).getOrElse(Nil) - case _ => Nil - }) - } - ( - byValFieldSucc(field), - new ADTFunction[Post]( - Seq(new Variable(nonnullT)(field.o)), - axiomType, - )( - field.o.copy( - field.o.originContents - .filterNot(_.isInstanceOf[SourceName]) - ).where(name = - "inv_" + field.o.find[SourceName].map(_.name) - .getOrElse("unknown") - ) - ), - nonnullT, - ) - }.unzip3 - val constructor = - new ADTFunction[Post]( - fieldTypes.zipWithIndex.map { case (t, i) => - new Variable(t)(Origin(Seq( - PreferredName(Seq("p_" + i)), - LabelContext("classToRef"), - ))) - }, - axiomType, - )( - cls.o.copy( - cls.o.originContents.filterNot(_.isInstanceOf[SourceName]) - ).where(name = - "new_" + cls.o.find[SourceName].map(_.name) - .getOrElse("unknown") - ) - ) - // TAnyValue is a placeholder the pointer adt doesn't have type parameters - val indexFunction = - new ADTFunction[Post]( - Seq(new Variable(TNonNullPointer(TAnyValue()))(Origin( - Seq(PreferredName(Seq("pointer")), LabelContext("classToRef")) - ))), - TInt(), - )( - cls.o.copy( - cls.o.originContents.filterNot(_.isInstanceOf[SourceName]) - ).where(name = - "index_" + cls.o.find[SourceName].map(_.name) - .getOrElse("unknown") - ) - ) - val injectivityAxiom = - new ADTAxiom[Post](foralls( - Seq(axiomType, axiomType), - body = { case Seq(a0, a1) => - foldAnd(fieldFunctions.map { f => - Implies( - Eq( - adtFunctionInvocation[Post](f.ref, args = Seq(a0)), - adtFunctionInvocation[Post](f.ref, args = Seq(a1)), - ), - a0 === a1, - ) - }) - }, - triggers = { case Seq(a0, a1) => - fieldFunctions.map { f => - Seq( - adtFunctionInvocation[Post](f.ref, None, args = Seq(a0)), - adtFunctionInvocation[Post](f.ref, None, args = Seq(a1)), - ) - } - }, - )) - val destructorAxioms = fieldFunctions.zip(fieldInverses).map { - case (f, inv) => - new ADTAxiom[Post](forall( - axiomType, - body = { a => - adtFunctionInvocation[Post]( - inv.ref, - None, - args = Seq( - adtFunctionInvocation[Post](f.ref, None, args = Seq(a)) - ), - ) === a - }, - triggers = { a => - Seq(Seq( - adtFunctionInvocation[Post](f.ref, None, args = Seq(a)) - )) - }, - )) - } - val indexAxioms = fieldFunctions.zipWithIndex.map { case (f, i) => - new ADTAxiom[Post](forall( - axiomType, - body = { a => - adtFunctionInvocation[Post]( - indexFunction.ref, - None, - args = Seq( - adtFunctionInvocation[Post](f.ref, None, args = Seq(a)) - ), - ) === const(i) - }, - triggers = { a => - Seq( - Seq(adtFunctionInvocation[Post](f.ref, None, args = Seq(a))) - ) - }, - )) - } - byValConsSucc(cls) = constructor - byValClassSucc(cls) = - new AxiomaticDataType[Post]( - Seq(indexFunction, injectivityAxiom) ++ destructorAxioms ++ - indexAxioms ++ fieldFunctions ++ fieldInverses ++ - valueAsAxioms, - Nil, - ) - globalDeclarations.succeed(cls, byValClassSucc(cls)) + case cls: ByValueClass[Pre] => encodeByValueClass(cls) case _ => cls.drop() } case decl => super.dispatch(decl) } + private def encodeByValueClass(cls: ByValueClass[Pre]) = { + implicit val o: Origin = cls.o + val axiomType = TAxiomatic[Post](byValClassSucc.ref(cls), Nil) + var valueAsAxioms: Seq[ADTAxiom[Post]] = Seq() + val (fieldFunctions, fieldInverses, fieldTypes) = + cls.decls.collect { case field: Field[Pre] => + val newT = dispatch(field.t) + val nonnullT = TNonNullPointer(newT) + byValFieldSucc(field) = + new ADTFunction[Post]( + Seq(new Variable(axiomType)(field.o)), + nonnullT, + )(field.o) + if (valueAsAxioms.isEmpty) { + // This is the first field + valueAsAxioms = + valueAsAxioms :+ new ADTAxiom[Post](forall( + axiomType, + body = { a => + InlinePattern(adtFunctionInvocation[Post]( + valueAsFunctions.getOrElseUpdate( + field.t, + makeValueAsFunction(field.t.toString, newT), + ).ref, + typeArgs = Some((valueAdt.ref(()), Seq(axiomType))), + args = Seq(a), + )) === adtFunctionInvocation( + byValFieldSucc.ref(field), + args = Seq(a), + ) + }, + )) + + valueAsAxioms = + valueAsAxioms ++ + (field.t match { + case t: TByValueClass[Pre] => + // TODO: If there are no fields we should ignore the first field and add the axioms for the second field + t.cls.decl.decls + .collectFirst({ case innerF: InstanceField[Pre] => + unwrapValueAs( + axiomType, + innerF.t, + dispatch(innerF.t), + byValFieldSucc.ref(field), + ) + }).getOrElse(Nil) + case _ => Nil + }) + } + ( + byValFieldSucc(field), + new ADTFunction[Post]( + Seq(new Variable(nonnullT)(field.o)), + axiomType, + )( + field.o.copy( + field.o.originContents.filterNot(_.isInstanceOf[SourceName]) + ).where(name = + "inv_" + field.o.find[SourceName].map(_.name).getOrElse("unknown") + ) + ), + nonnullT, + ) + }.unzip3 + val constructor = + new ADTFunction[Post]( + fieldTypes.zipWithIndex.map { case (t, i) => + new Variable(t)(Origin( + Seq(PreferredName(Seq("p_" + i)), LabelContext("classToRef")) + )) + }, + axiomType, + )( + cls.o.copy(cls.o.originContents.filterNot(_.isInstanceOf[SourceName])) + .where(name = + "new_" + cls.o.find[SourceName].map(_.name).getOrElse("unknown") + ) + ) + // TAnyValue is a placeholder the pointer adt doesn't have type parameters + val indexFunction = + new ADTFunction[Post]( + Seq(new Variable(TNonNullPointer(TAnyValue()))(Origin( + Seq(PreferredName(Seq("pointer")), LabelContext("classToRef")) + ))), + TInt(), + )( + cls.o.copy(cls.o.originContents.filterNot(_.isInstanceOf[SourceName])) + .where(name = + "index_" + cls.o.find[SourceName].map(_.name).getOrElse("unknown") + ) + ) + val injectivityAxiom = + new ADTAxiom[Post](foralls( + Seq(axiomType, axiomType), + body = { case Seq(a0, a1) => + foldAnd(fieldFunctions.map { f => + Implies( + Eq( + adtFunctionInvocation[Post](f.ref, args = Seq(a0)), + adtFunctionInvocation[Post](f.ref, args = Seq(a1)), + ), + a0 === a1, + ) + }) + }, + triggers = { case Seq(a0, a1) => + fieldFunctions.map { f => + Seq( + adtFunctionInvocation[Post](f.ref, None, args = Seq(a0)), + adtFunctionInvocation[Post](f.ref, None, args = Seq(a1)), + ) + } + }, + )) + val destructorAxioms = fieldFunctions.zip(fieldInverses).map { + case (f, inv) => + new ADTAxiom[Post](forall( + axiomType, + body = { a => + adtFunctionInvocation[Post]( + inv.ref, + None, + args = Seq( + adtFunctionInvocation[Post](f.ref, None, args = Seq(a)) + ), + ) === a + }, + triggers = { a => + Seq(Seq(adtFunctionInvocation[Post](f.ref, None, args = Seq(a)))) + }, + )) + } + val indexAxioms = fieldFunctions.zipWithIndex.map { case (f, i) => + new ADTAxiom[Post](forall( + axiomType, + body = { a => + adtFunctionInvocation[Post]( + indexFunction.ref, + None, + args = Seq(adtFunctionInvocation[Post](f.ref, None, args = Seq(a))), + ) === const(i) + }, + triggers = { a => + Seq(Seq(adtFunctionInvocation[Post](f.ref, None, args = Seq(a)))) + }, + )) + } + byValConsSucc(cls) = constructor + byValClassSucc(cls) = + new AxiomaticDataType[Post]( + Seq(indexFunction, injectivityAxiom) ++ destructorAxioms ++ + indexAxioms ++ fieldFunctions ++ fieldInverses ++ valueAsAxioms, + Nil, + ) + globalDeclarations.succeed(cls, byValClassSucc(cls)) + } + def instantiate(cls: Class[Pre], target: Ref[Post, Variable[Post]])( implicit o: Origin ): Statement[Post] = { @@ -755,13 +723,13 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { }, ) - if (t.isInstanceOf[TByValueClass[Pre]]) { - constraint &* - t.asInstanceOf[TByValueClass[Pre]].cls.decl.decls.collectFirst { - case field: InstanceField[Pre] => - unwrapCastConstraints(outerType, field.t) + t match { + case TByValueClass(Ref(cls), _) => + constraint &* cls.decls.collectFirst { case field: InstanceField[Pre] => + unwrapCastConstraints(outerType, field.t) }.getOrElse(tt) - } else { constraint } + case _ => constraint + } } private def makeCastHelper(t: Type[Pre]): Procedure[Post] = { @@ -837,7 +805,7 @@ case class ClassToRef[Pre <: Generation]() extends Rewriter[Pre] { ))(inv.o) case ThisObject(_) => diz.top case ptrOf @ AddrOf(Deref(obj, Ref(field))) - if obj.t.isInstanceOf[TByValueClass[Pre]] => + if obj.t.asByValueClass.isDefined => adtFunctionInvocation[Post]( byValFieldSucc.ref(field), args = Seq(dispatch(obj)), diff --git a/src/rewrite/vct/rewrite/DisambiguateLocation.scala b/src/rewrite/vct/rewrite/DisambiguateLocation.scala index ff006ee9cb..5d7bc6db64 100644 --- a/src/rewrite/vct/rewrite/DisambiguateLocation.scala +++ b/src/rewrite/vct/rewrite/DisambiguateLocation.scala @@ -61,7 +61,7 @@ case class DisambiguateLocation[Pre <: Generation]() extends Rewriter[Pre] { expr match { case expr if expr.t.asPointer.isDefined => PointerLocation(dispatch(expr))(blame) - case expr if expr.t.isInstanceOf[TByValueClass[Pre]] => + case expr if expr.t.asByValueClass.isDefined => ByValueClassLocation(dispatch(expr))(blame) case DerefHeapVariable(ref) => HeapVariableLocation(succ(ref.decl)) case Deref(obj, ref) => FieldLocation(dispatch(obj), succ(ref.decl)) diff --git a/src/rewrite/vct/rewrite/EncodeAutoValue.scala b/src/rewrite/vct/rewrite/EncodeAutoValue.scala index 76589823ac..1334f2c940 100644 --- a/src/rewrite/vct/rewrite/EncodeAutoValue.scala +++ b/src/rewrite/vct/rewrite/EncodeAutoValue.scala @@ -63,20 +63,25 @@ case object EncodeAutoValue extends RewriterBuilder { override def blame(error: PostconditionFailed): Unit = app.blame.blame(AutoValueLeakCheckFailed(error.failure, app)) } + + private case class ConditionContext[Pre <: Generation]( + context: PreOrPost, + prePostMap: mutable.ArrayBuffer[(Expr[Pre], Expr[Rewritten[Pre]])] = + mutable.ArrayBuffer[(Expr[Pre], Expr[Rewritten[Pre]])](), + autoValueLocations: mutable.HashSet[Location[Rewritten[Pre]]] = mutable + .HashSet[Location[Rewritten[Pre]]](), + normalLocations: mutable.HashSet[Location[Rewritten[Pre]]] = mutable + .HashSet[Location[Rewritten[Pre]]](), + ) } case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { import EncodeAutoValue._ - private val conditionContext: ScopedStack[ - ( - PreOrPost, - mutable.ArrayBuffer[(Expr[Pre], Expr[Post])], - mutable.HashSet[Location[Post]], - mutable.HashSet[Location[Post]], - ) - ] = ScopedStack() + private val conditionContext: ScopedStack[Option[ConditionContext[Pre]]] = + ScopedStack() + conditionContext.push(None) private val inFunction: ScopedStack[Unit] = ScopedStack() private lazy val anyRead: Function[Post] = { @@ -122,14 +127,14 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { val postMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() node.rewrite( requires = - conditionContext.having( - (InPrecondition(), preMap, mutable.HashSet(), mutable.HashSet()) - ) { dispatch(node.requires) }, + conditionContext.having(Some( + ConditionContext(InPrecondition(), prePostMap = preMap) + )) { dispatch(node.requires) }, ensures = { val predicate = - conditionContext.having( - ((InPostcondition(), postMap, mutable.HashSet(), mutable.HashSet())) - ) { dispatch(node.ensures) } + conditionContext.having(Some( + ConditionContext(InPostcondition(), prePostMap = postMap) + )) { dispatch(node.ensures) } val filtered = preMap.filterNot(pre => postMap.exists(post => { Compare.isIsomorphic(pre._1, post._1, matchFreeVariables = true) @@ -145,12 +150,13 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(e: Expr[Pre]): Expr[Post] = { implicit val o: Origin = e.o - if (conditionContext.isEmpty) { + if (conditionContext.top.isEmpty) { e match { case AutoValue(_) => throw InvalidAutoValue(e) case _ => e.rewriteDefault() } } else { + val context = conditionContext.top.get e match { case AutoValue(loc) if inFunction.nonEmpty => Value(dispatch(loc)) case AutoValue(loc) => { @@ -161,8 +167,13 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { case SilverFieldLocation(obj, field) => (SilverFieldLocation(x.get, field), obj) } - conditionContext.top match { - case (InPrecondition(), preMap, avLocations, locations) => + context match { + case ConditionContext( + InPrecondition(), + preMap, + avLocations, + locations, + ) => preMap += (( e, @@ -187,7 +198,12 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { !ForPerm(Seq(x), genericLocation, ff) &* !ForPerm(Seq(x), genericLocation, x.get !== obj), ) - case (InPostcondition(), postMap, avLocations, locations) => + case ConditionContext( + InPostcondition(), + postMap, + avLocations, + locations, + ) => postMap += ((e, tt)) avLocations += postLoc if (locations.contains(postLoc)) { @@ -206,8 +222,8 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { } case Perm(loc, perm) => { val postLoc = dispatch(loc) - conditionContext.top match { - case (_, _, avLocations, locations) => { + context match { + case ConditionContext(_, _, avLocations, locations) => { locations += postLoc; if (avLocations.contains(postLoc)) { throw CombinedAutoValue(avLocations.find(_ == postLoc).get, e) @@ -219,25 +235,23 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { case Let(binding, value, main) => variables.scope { localHeapVariables.scope { - val top = conditionContext.pop() val (b, v) = - try { (variables.dispatch(binding), dispatch(value)) } - finally { conditionContext.push(top) } + conditionContext.having(None) { + (variables.dispatch(binding), dispatch(value)) + } val mMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() val m = - conditionContext.having(( - conditionContext.top._1, - mMap, - conditionContext.top._3, - conditionContext.top._4, - )) { dispatch(main) } + conditionContext.having(Some(context.copy(prePostMap = mMap))) { + dispatch(main) + } if (mMap.isEmpty) { Let(b, v, m) } else { mMap.foreach(postM => - conditionContext.top._2 - .append((Let(binding, value, postM._1), Let(b, v, postM._2))) + context.prePostMap.append( + (Let(binding, value, postM._1), Let(b, v, postM._2)) + ) ) - conditionContext.top._1 match { + context.context match { case InPrecondition() => Let(b, v, m) case InPostcondition() => Let( @@ -252,31 +266,22 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { } } case Select(condition, left, right) => - val top = conditionContext.pop() - val c = - try { dispatch(condition) } - finally { conditionContext.push(top) } + val c = conditionContext.having(None) { dispatch(condition) } val lMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() val rMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() val l = - conditionContext.having(( - conditionContext.top._1, - lMap, - conditionContext.top._3, - conditionContext.top._4, - )) { dispatch(left) } + conditionContext.having(Some(context.copy(prePostMap = lMap))) { + dispatch(left) + } val r = - conditionContext.having(( - conditionContext.top._1, - rMap, - conditionContext.top._3, - conditionContext.top._4, - )) { dispatch(right) } + conditionContext.having(Some(context.copy(prePostMap = rMap))) { + dispatch(right) + } if (lMap.isEmpty && rMap.isEmpty) Select(c, l, r) else { lMap.foreach(postL => - conditionContext.top._2.append(( + context.prePostMap.append(( Select(condition, postL._1, tt), Select( Old(c, None)(PanicBlame( @@ -288,7 +293,7 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { )) ) rMap.foreach(postR => - conditionContext.top._2.append(( + context.prePostMap.append(( Select(condition, tt, postR._1), Select( Old(c, None)(PanicBlame( @@ -299,7 +304,7 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { ), )) ) - conditionContext.top._1 match { + context.context match { case InPrecondition() => Select(c, l, r) case InPostcondition() => Select( @@ -313,22 +318,16 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { } case Implies(left, right) => val rMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() - val top = conditionContext.pop() - val l = - try { dispatch(left) } - finally { conditionContext.push(top); } + val l = conditionContext.having(None) { dispatch(left) } val r = - conditionContext.having(( - conditionContext.top._1, - rMap, - conditionContext.top._3, - conditionContext.top._4, - )) { dispatch(right) } + conditionContext.having(Some(context.copy(prePostMap = rMap))) { + dispatch(right) + } if (rMap.nonEmpty) { - conditionContext.top._1 match { + context.context match { case InPrecondition() => rMap.foreach(postR => - conditionContext.top._2 + context.prePostMap .append((Implies(left, postR._1), Implies(l, postR._2))) ) Implies(l, r) @@ -337,7 +336,7 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { // duplicates therefore we don't include the old here since // otherwise it wouldn't match the precondition case rMap.foreach(postR => - conditionContext.top._2.append((Implies(left, postR._1), tt)) + context.prePostMap.append((Implies(left, postR._1), tt)) ) Implies( Old(l, None)(PanicBlame( @@ -348,26 +347,20 @@ case class EncodeAutoValue[Pre <: Generation]() extends Rewriter[Pre] { } } else { Implies(l, r) } case Star(left, right) => - val top = conditionContext.pop() - try { - val lMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() - val rMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() - val l = - conditionContext.having((top._1, lMap, top._3, top._4)) { - dispatch(left) - } - val r = - conditionContext.having((top._1, rMap, top._3, top._4)) { - dispatch(right) - } - top._2.addAll(lMap) - top._2.addAll(rMap) - Star(l, r) - } finally { conditionContext.push(top); } - case _ => - val top = conditionContext.pop() - try { e.rewriteDefault() } - finally { conditionContext.push(top); } + val lMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() + val rMap = mutable.ArrayBuffer[(Expr[Pre], Expr[Post])]() + val l = + conditionContext.having(Some(context.copy(prePostMap = lMap))) { + dispatch(left) + } + val r = + conditionContext.having(Some(context.copy(prePostMap = rMap))) { + dispatch(right) + } + context.prePostMap.addAll(lMap) + context.prePostMap.addAll(rMap) + Star(l, r) + case _ => conditionContext.having(None) { e.rewriteDefault() } } } } diff --git a/src/rewrite/vct/rewrite/PrepareByValueClass.scala b/src/rewrite/vct/rewrite/EncodeByValueClassUsage.scala similarity index 85% rename from src/rewrite/vct/rewrite/PrepareByValueClass.scala rename to src/rewrite/vct/rewrite/EncodeByValueClassUsage.scala index 42a128696c..0a8b9c3546 100644 --- a/src/rewrite/vct/rewrite/PrepareByValueClass.scala +++ b/src/rewrite/vct/rewrite/EncodeByValueClassUsage.scala @@ -8,11 +8,10 @@ import vct.col.resolve.ctx.Referrable import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} import vct.col.util.AstBuildHelpers._ import vct.col.util.SuccessionMap -import vct.result.VerificationError.{Unreachable, UserError} +import vct.result.VerificationError.UserError -// TODO: Think of a better name -case object PrepareByValueClass extends RewriterBuilder { - override def key: String = "prepareByValueClass" +case object EncodeByValueClassUsage extends RewriterBuilder { + override def key: String = "encodeByValueClassUsage" override def desc: String = "Initialise ByValueClasses when they are declared and copy them whenever they're read" @@ -68,9 +67,9 @@ case object PrepareByValueClass extends RewriterBuilder { private case class NoCopy() extends CopyContext } -case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { +case class EncodeByValueClassUsage[Pre <: Generation]() extends Rewriter[Pre] { - import PrepareByValueClass._ + import EncodeByValueClassUsage._ private val inAssignment: ScopedStack[Unit] = ScopedStack() private val copyContext: ScopedStack[CopyContext] = ScopedStack() @@ -98,8 +97,8 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { implicit val o: Origin = node.o node match { case HeapLocalDecl(local) - if local.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => { - val t = local.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]] + if local.t.asPointer.get.element.asByValueClass.isDefined => { + val t = local.t.asPointer.get.element.asByValueClass.get val newLocal = localHeapVariables.dispatch(local) Block(Seq( HeapLocalDecl(newLocal), @@ -212,15 +211,14 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { ) case _ if inAssignment.nonEmpty => node.rewriteDefault() case Perm(ByValueClassLocation(e), p) => - unwrapClassPerm( - dispatch(e), - dispatch(p), - e.t.asInstanceOf[TByValueClass[Pre]], + unwrapClassPerm(dispatch(e), dispatch(p), e.t.asByValueClass.get) + case Perm(pl @ PointerLocation(dhv @ DerefHeapVariable(Ref(v))), p) => + assert( + v.t.isInstanceOf[TNonNullPointer[Pre]], + "Frontends should ensure that HeapVariables are non-null pointers", ) - case Perm(pl @ PointerLocation(dhv @ DerefHeapVariable(Ref(v))), p) - if v.t.isInstanceOf[TNonNullPointer[Pre]] => val t = v.t.asInstanceOf[TNonNullPointer[Pre]] - if (t.element.isInstanceOf[TByValueClass[Pre]]) { + if (t.element.asByValueClass.isDefined) { val newV: Ref[Post, HeapVariable[Post]] = succ(v) val newP = dispatch(p) Perm(HeapVariableLocation(newV), newP) &* Perm( @@ -230,7 +228,7 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { } else { node.rewriteDefault() } case assign: PreAssignExpression[Pre] => val target = inAssignment.having(()) { dispatch(assign.target) } - if (assign.target.t.isInstanceOf[TByValueClass[Pre]]) { + if (assign.target.t.asByValueClass.isDefined) { copyContext.having(InAssignmentExpression(assign)) { assign.rewrite(target = target) } @@ -240,38 +238,28 @@ case class PrepareByValueClass[Pre <: Generation]() extends Rewriter[Pre] { } case invocation: Invocation[Pre] => invocation.rewrite(args = invocation.args.map { a => - if (a.t.isInstanceOf[TByValueClass[Pre]]) { + if (a.t.asByValueClass.isDefined) { copyContext.having(InCall(invocation)) { dispatch(a) } } else { copyContext.having(NoCopy()) { dispatch(a) } } }) case dp @ DerefPointer(HeapLocal(Ref(v))) - if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => - rewriteInCopyContext( - dp, - v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], - ) + if v.t.asPointer.get.element.asByValueClass.isDefined => + rewriteInCopyContext(dp, v.t.asPointer.get.element.asByValueClass.get) case dp @ DerefPointer(DerefHeapVariable(Ref(v))) - if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => - rewriteInCopyContext( - dp, - v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], - ) + if v.t.asPointer.get.element.asByValueClass.isDefined => + rewriteInCopyContext(dp, v.t.asPointer.get.element.asByValueClass.get) case deref @ Deref(_, Ref(f)) - if f.t.isInstanceOf[TByValueClass[Pre]] && - copyContext.top != NoCopy() => + if f.t.asByValueClass.isDefined && copyContext.top != NoCopy() => // TODO: Improve blame message here copyClassValue( deref.rewriteDefault(), - f.t.asInstanceOf[TByValueClass[Pre]], + f.t.asByValueClass.get, f => deref.blame, ) case dp @ DerefPointer(Local(Ref(v))) - if v.t.asPointer.get.element.isInstanceOf[TByValueClass[Pre]] => + if v.t.asPointer.get.element.asByValueClass.isDefined => // This can happen if the user specifies a local of type pointer to TByValueClass - rewriteInCopyContext( - dp, - v.t.asPointer.get.element.asInstanceOf[TByValueClass[Pre]], - ) + rewriteInCopyContext(dp, v.t.asPointer.get.element.asByValueClass.get) case _ => node.rewriteDefault() } } diff --git a/src/rewrite/vct/rewrite/EncodeIntrinsicLock.scala b/src/rewrite/vct/rewrite/EncodeIntrinsicLock.scala index 1cc1b3a7ab..30bef4fd7f 100644 --- a/src/rewrite/vct/rewrite/EncodeIntrinsicLock.scala +++ b/src/rewrite/vct/rewrite/EncodeIntrinsicLock.scala @@ -83,9 +83,10 @@ case class EncodeIntrinsicLock[Pre <: Generation]() extends Rewriter[Pre] { val needsHeld: mutable.Set[Class[Pre]] = mutable.Set() val needsCommitted: mutable.Set[Class[Pre]] = mutable.Set() - def getClass(obj: Expr[Pre]): Class[Pre] = + def getClass(obj: Expr[Pre]): ByReferenceClass[Pre] = obj.t match { - case t: TClass[Pre] => t.cls.decl + case t: TByReferenceClass[Pre] => + t.cls.decl.asInstanceOf[ByReferenceClass[Pre]] case _ => throw UnreachableAfterTypeCheck( "This argument is not a class type.", @@ -114,7 +115,7 @@ case class EncodeIntrinsicLock[Pre <: Generation]() extends Rewriter[Pre] { rewriteDefault(program) } - def needsInvariant(cls: Class[Pre]): Boolean = + def needsInvariant(cls: ByReferenceClass[Pre]): Boolean = cls.intrinsicLockInvariant != tt[Pre] def needsInvariant(e: Expr[Pre]): Boolean = needsInvariant(getClass(e)) diff --git a/src/rewrite/vct/rewrite/PropagateContextEverywhere.scala b/src/rewrite/vct/rewrite/PropagateContextEverywhere.scala index 6cdb955cca..814326d682 100644 --- a/src/rewrite/vct/rewrite/PropagateContextEverywhere.scala +++ b/src/rewrite/vct/rewrite/PropagateContextEverywhere.scala @@ -33,16 +33,6 @@ case class PropagateContextEverywhere[Pre <: Generation]() val invariants: ScopedStack[Seq[Expr[Pre]]] = ScopedStack() invariants.push(Nil) - def withInvariant[T](inv: Expr[Pre])(f: => T): T = { - val old = invariants.top - invariants.pop() - invariants.push(old ++ unfoldStar(inv)) - val result = f - invariants.pop() - invariants.push(old) - result - } - def freshInvariants()(implicit o: Origin): Expr[Post] = foldStar(invariants.top.map(dispatch)) @@ -51,7 +41,9 @@ case class PropagateContextEverywhere[Pre <: Generation]() case app: ContractApplicable[Pre] => allScopes.anyDeclare(allScopes.anySucceedOnly( app, - withInvariant(app.contract.contextEverywhere) { + invariants.having( + invariants.top ++ unfoldStar(app.contract.contextEverywhere) + ) { app match { case func: AbstractFunction[Pre] => func.rewrite(blame = diff --git a/src/rewrite/vct/rewrite/TrivialAddrOf.scala b/src/rewrite/vct/rewrite/TrivialAddrOf.scala index 2ae340aaac..2e51257eea 100644 --- a/src/rewrite/vct/rewrite/TrivialAddrOf.scala +++ b/src/rewrite/vct/rewrite/TrivialAddrOf.scala @@ -44,7 +44,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { case AddrOf(Deref(_, _)) => e.rewriteDefault() case AddrOf(other) => throw UnsupportedLocation(other) case assign @ PreAssignExpression(target, AddrOf(value)) - if value.t.isInstanceOf[TByReferenceClass[Pre]] => + if value.t.asByReferenceClass.isDefined => implicit val o: Origin = assign.o val (newPointer, newTarget, newValue) = rewriteAssign( target, @@ -66,7 +66,7 @@ case class TrivialAddrOf[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(s: Statement[Pre]): Statement[Post] = s match { case assign @ Assign(target, AddrOf(value)) - if value.t.isInstanceOf[TByReferenceClass[Pre]] => + if value.t.asByReferenceClass.isDefined => implicit val o: Origin = assign.o val (newPointer, newTarget, newValue) = rewriteAssign( target, diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala index 48e02ac91f..7add9ce003 100644 --- a/src/rewrite/vct/rewrite/VariableToPointer.scala +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -39,16 +39,14 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { SuccessionMap() override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] = { - // TODO: Replace the isInstanceOf[TByReferenceClass] checks with something that more clearly communicates that we want to exclude all reference types + // TODO: Replace the asByReferenceClass checks with something that more clearly communicates that we want to exclude all reference types addressedSet.addAll(program.collect { - case AddrOf(Local(Ref(v))) if !v.t.isInstanceOf[TByReferenceClass[Pre]] => - v + case AddrOf(Local(Ref(v))) if v.t.asByReferenceClass.isEmpty => v case AddrOf(DerefHeapVariable(Ref(v))) - if !v.t.isInstanceOf[TByReferenceClass[Pre]] => + if v.t.asByReferenceClass.isEmpty => v case AddrOf(Deref(o, Ref(f))) - if !f.t.isInstanceOf[TByReferenceClass[Pre]] && - !o.t.isInstanceOf[TByValueClass[Pre]] => + if f.t.asByReferenceClass.isEmpty && o.t.asByValueClass.isEmpty => f }) super.dispatch(program) diff --git a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala index 8f4fd738ee..3a0ccfa6e3 100644 --- a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala @@ -2584,14 +2584,14 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) def rewriteLifetimeScope(scope: CPPLifetimeScope[Pre]): Statement[Post] = { implicit val o: Origin = scope.o - syclBufferSuccessor.push(mutable.Map.empty) + val successorMap = mutable.Map.empty[Variable[Post], SYCLBuffer[Post]] - val rewrittenBody = rw.dispatch(scope.body) + val rewrittenBody = + syclBufferSuccessor.having(successorMap) { rw.dispatch(scope.body) } // Destroy all buffers and copy their data back to host val bufferDestructions: Seq[Statement[Post]] = - syclBufferSuccessor.pop().map(tuple => destroySYCLBuffer(tuple._2, scope)) - .toSeq + successorMap.map(tuple => destroySYCLBuffer(tuple._2, scope)).toSeq Block[Post](rewrittenBody +: bufferDestructions) } diff --git a/src/rewrite/vct/rewrite/lang/LangCToCol.scala b/src/rewrite/vct/rewrite/lang/LangCToCol.scala index 2e766c9c97..91fe1a2d9d 100644 --- a/src/rewrite/vct/rewrite/lang/LangCToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCToCol.scala @@ -1057,7 +1057,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) ) case None => val newT = - if (t.isInstanceOf[TByValueClass[Post]]) { TNonNullPointer(t) } + if (t.asByValueClass.isDefined) { TNonNullPointer(t) } else { t } cGlobalNameSuccessor(RefCGlobalDeclaration(decl, idx)) = rw .globalDeclarations diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index 4574c1e8df..728125b10e 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -96,6 +96,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) (a, b) match { case (None, _) | (_, None) => None case (Some(a), Some(b)) + // TODO: This should be removed as soon as we have proper contracts we load from LLVM instead of mixing PVL and LLVM. Comparing in Post is really bad if a == b || rw.dispatch(a) == rw.dispatch(b) || moreSpecific(a, b) => Some(a) @@ -159,10 +160,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) ) } if (subMap.isEmpty) { value } - else { - // TODO: Support multiple guesses? - SubstituteReferences(subMap.toMap).dispatch(value) - } + else { SubstituteReferences(subMap.toMap).dispatch(value) } } def getVariable(expr: Expr[Pre]): Option[Object] = { @@ -579,6 +577,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) } } // Save the expensive check for last. This check is for when we're mixing PVL and LLVM types + // TODO: This check should be removed ASAP when we get real LLVM contracts since comparing types in Post is bad case LLVMTPointer(Some(inner)) if rw.dispatch(inner) == rw.dispatch(untilType) => Some((pointer, currentType))