From 091c6936a43efe659834e081c8cd3f6c4cddf311 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Fri, 23 Aug 2024 15:38:35 -0700 Subject: [PATCH 1/5] Extract type requirements from old code --- internal/migrate/staging_validator.go | 100 +++++++++++++++------ internal/migrate/staging_validator_test.go | 56 +++++++++++- 2 files changed, 127 insertions(+), 29 deletions(-) diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go index 5f9c95458..327993f0b 100644 --- a/internal/migrate/staging_validator.go +++ b/internal/migrate/staging_validator.go @@ -25,6 +25,8 @@ import ( "fmt" "strings" + "github.com/rs/zerolog" + "golang.org/x/exp/slices" "github.com/onflow/cadence" @@ -40,6 +42,7 @@ import ( "github.com/onflow/contract-updater/lib/go/templates" flowsdk "github.com/onflow/flow-go-sdk" "github.com/onflow/flow-go/cmd/util/ledger/migrations" + "github.com/onflow/flow-go/cmd/util/ledger/reporters" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flowkit/v2" @@ -187,10 +190,12 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) v.stagedContracts = make(map[common.AddressLocation]stagedContractUpdate) for _, stagedContract := range stagedContracts { - v.stagedContracts[stagedContract.DeployLocation] = stagedContract + stagedContractLocation := stagedContract.DeployLocation + + v.stagedContracts[stagedContractLocation] = stagedContract // Add the contract code to the contracts map for pretty printing - v.contracts[stagedContract.SourceLocation] = stagedContract.Code + v.contracts[stagedContractLocation] = stagedContract.Code } // Load system contracts @@ -199,24 +204,64 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) // Parse and check all staged contracts errs := v.checkAllStaged() + typeRequirements := &migrations.LegacyTypeRequirements{} + + // Extract type requirements from the old codes for all staged contracts. + for _, contract := range v.stagedContracts { + location := contract.DeployLocation + + // Don't validate contracts with existing errors + if errs[location] != nil { + continue + } + + // Get the account for the contract + address := flowsdk.Address(location.Address) + account, err := v.flow.GetAccount(context.Background(), address) + if err != nil { + return fmt.Errorf("failed to get account: %w", err) + } + + // Get the target contract old code + contractName := location.Name + oldCode, ok := account.Contracts[contractName] + if !ok { + return fmt.Errorf("old contract code not found for contract: %s", contractName) + } + + migrations.ExtractTypeRequirements( + migrations.AddressContract{ + Location: location, + Code: oldCode, + }, + zerolog.Nop(), + reporters.ReportNilWriter{}, + typeRequirements, + ) + } + // Validate all contract updates for _, contract := range v.stagedContracts { + location := contract.DeployLocation + // Don't validate contracts with existing errors - if errs[contract.SourceLocation] != nil { + if errs[location] != nil { continue } // Validate the contract update - checker := v.checkingCache[contract.SourceLocation].checker - err := v.validateContractUpdate(contract, checker) + checker := v.checkingCache[location].checker + err := v.validateContractUpdate(contract, checker, typeRequirements) if err != nil { - errs[contract.SourceLocation] = err + errs[location] = err } } // Check for any upstream contract update failures for _, contract := range v.stagedContracts { - err := errs[contract.SourceLocation] + location := contract.DeployLocation + + err := errs[location] // We will override any errors other than those related // to missing dependencies, since they are more specific @@ -234,19 +279,14 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) badDeps := make([]common.Location, 0) v.forEachDependency(contract, func(dependency common.Location) { - strLocation, ok := dependency.(common.StringLocation) - if !ok { - return - } - - if errs[strLocation] != nil { + if errs[dependency] != nil { badDeps = append(badDeps, dependency) } }) if len(badDeps) > 0 { - errs[contract.SourceLocation] = &upstreamValidationError{ - Location: contract.SourceLocation, + errs[location] = &upstreamValidationError{ + Location: location, BadDependencies: badDeps, } } @@ -257,7 +297,7 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) // Map errors to address locations errsByAddress := make(map[common.AddressLocation]error) for _, contract := range v.stagedContracts { - err := errs[contract.SourceLocation] + err := errs[contract.DeployLocation] if err != nil { errsByAddress[contract.DeployLocation] = err } @@ -267,12 +307,13 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) return nil } -func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error { - errors := make(map[common.StringLocation]error) +func (v *stagingValidatorImpl) checkAllStaged() map[common.Location]error { + errors := make(map[common.Location]error) for _, contract := range v.stagedContracts { - _, err := v.checkContract(contract.SourceLocation) + location := contract.DeployLocation + _, err := v.checkContract(location) if err != nil { - errors[contract.SourceLocation] = err + errors[location] = err } } @@ -281,6 +322,8 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error // Note: nodes are not visited more than once so cyclic imports are not an issue // They will be reported, however, by the checker, if they do exist for _, contract := range v.stagedContracts { + location := contract.DeployLocation + // Create a set of all dependencies missingDependencies := make([]common.AddressLocation, 0) v.forEachDependency(contract, func(dependency common.Location) { @@ -292,7 +335,7 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error }) if len(missingDependencies) > 0 { - errors[contract.SourceLocation] = &missingDependenciesError{ + errors[location] = &missingDependenciesError{ MissingContracts: missingDependencies, } } @@ -300,7 +343,11 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error return errors } -func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpdate, checker *sema.Checker) (err error) { +func (v *stagingValidatorImpl) validateContractUpdate( + contract stagedContractUpdate, + checker *sema.Checker, + typeRequirements *migrations.LegacyTypeRequirements, +) (err error) { // Gracefully recover from panics defer func() { if r := recover(); r != nil { @@ -333,7 +380,7 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd // Check if contract code is valid according to Cadence V1 Update Checker validator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( - contract.SourceLocation, + contract.DeployLocation, contractName, &accountContractNamesProviderImpl{ resolverFunc: v.resolveAddressContractNames, @@ -349,9 +396,6 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd return fmt.Errorf("unsupported network: %s", v.flow.Network().Name) } - // TODO: extract type requirements from the old contracts - typeRequirements := &migrations.LegacyTypeRequirements{} - validator.WithUserDefinedTypeChangeChecker( migrations.NewUserDefinedTypeChangeCheckerFunc(chainId, typeRequirements), ) @@ -560,7 +604,7 @@ func (v *stagingValidatorImpl) resolveLocation( // If the contract one of our staged contract updates, use the source location if stagedUpdate, ok := v.stagedContracts[resovledAddrLocation]; ok { - resolvedLocation = stagedUpdate.SourceLocation + resolvedLocation = stagedUpdate.DeployLocation } else { resolvedLocation = resovledAddrLocation } @@ -725,7 +769,7 @@ func (v *stagingValidatorImpl) forEachDependency( } } } - traverse(contract.SourceLocation) + traverse(contract.DeployLocation) } // Helper for pretty printing errors diff --git a/internal/migrate/staging_validator_test.go b/internal/migrate/staging_validator_test.go index b6b01b9ce..4242010c8 100644 --- a/internal/migrate/staging_validator_test.go +++ b/internal/migrate/staging_validator_test.go @@ -496,7 +496,7 @@ func Test_StagingValidator(t *testing.T) { // check that error exists & ensure that the local contract names are used (not the deploy locations) fooErr := validatorErr.errors[simpleAddressLocation("0x01.Foo")] require.ErrorContains(t, fooErr, "mismatched types") - require.ErrorContains(t, fooErr, "Foo.cdc") + require.ErrorContains(t, fooErr, "0000000000000001.Foo") // Bar should have an error related to var upstreamErr *upstreamValidationError @@ -822,4 +822,58 @@ func Test_StagingValidator(t *testing.T) { err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) require.NoError(t, err) }) + + t.Run("with type requirements", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + pub contract interface Foo { + pub let bar: @Bar? + pub resource Bar {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Bar": []byte(` + import Foo from 0x01 + pub contract FooImpl: Foo { + pub let bar: @Foo.Bar? + pub resource BarImpl {} + init() { + self.bar <- nil + } + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + access(all) contract interface Foo { + access(all) let bar: @{Bar}? + access(all) resource interface Bar {} + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Bar"), + SourceLocation: common.StringLocation("./Bar.cdc"), + Code: []byte(` + import Foo from 0x01 + access(all) contract FooImpl: Foo { + access(all) let bar: @{Foo.Bar}? + access(all) resource BarImpl: Foo.Bar {} + init() { + self.bar <- nil + } + }`), + }, + }) + + require.NoError(t, err) + }) } From c98add39df6110ec6831f8b16f85980a52044d40 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Fri, 23 Aug 2024 16:09:56 -0700 Subject: [PATCH 2/5] Re-use fetched old contract codes --- internal/migrate/staging_validator.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go index 327993f0b..67042a901 100644 --- a/internal/migrate/staging_validator.go +++ b/internal/migrate/staging_validator.go @@ -62,9 +62,13 @@ type stagingValidatorImpl struct { // Cache for account contract names so we don't have to fetch them multiple times accountContractNames map[common.Address][]string + // All resolved contract code contracts map[common.Location][]byte + // Contract codes that are not updated/staged + oldCodes map[common.Location][]byte + // Dependency graph for staged contracts // This root level map holds all nodes graph map[common.Location]node @@ -176,6 +180,7 @@ func newStagingValidator(flow flowkit.Services) *stagingValidatorImpl { checkingCache: make(map[common.Location]*cachedCheckingResult), accountContractNames: make(map[common.Address][]string), graph: make(map[common.Location]node), + oldCodes: make(map[common.Location][]byte), } } @@ -228,6 +233,7 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) if !ok { return fmt.Errorf("old contract code not found for contract: %s", contractName) } + v.oldCodes[location] = oldCode migrations.ExtractTypeRequirements( migrations.AddressContract{ @@ -355,16 +361,11 @@ func (v *stagingValidatorImpl) validateContractUpdate( } }() - // Get the account for the contract - address := flowsdk.Address(contract.DeployLocation.Address) - account, err := v.flow.GetAccount(context.Background(), address) - if err != nil { - return fmt.Errorf("failed to get account: %w", err) - } + location := contract.DeployLocation + contractName := location.Name // Get the target contract old code - contractName := contract.DeployLocation.Name - contractCode, ok := account.Contracts[contractName] + contractCode, ok := v.oldCodes[location] if !ok { return fmt.Errorf("old contract code not found for contract: %s", contractName) } @@ -380,7 +381,7 @@ func (v *stagingValidatorImpl) validateContractUpdate( // Check if contract code is valid according to Cadence V1 Update Checker validator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( - contract.DeployLocation, + location, contractName, &accountContractNamesProviderImpl{ resolverFunc: v.resolveAddressContractNames, From c1db89e2ec064061908782c2b9659c2db0226f76 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Fri, 23 Aug 2024 16:13:11 -0700 Subject: [PATCH 3/5] Refactor and cleanup --- internal/migrate/staging_validator.go | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go index 67042a901..90cfa2d35 100644 --- a/internal/migrate/staging_validator.go +++ b/internal/migrate/staging_validator.go @@ -597,19 +597,11 @@ func (v *stagingValidatorImpl) resolveLocation( for i := range resolvedLocations { identifier := identifiers[i] - var resolvedLocation common.Location - resovledAddrLocation := common.AddressLocation{ + resolvedLocation := common.AddressLocation{ Address: addressLocation.Address, Name: identifier.Identifier, } - // If the contract one of our staged contract updates, use the source location - if stagedUpdate, ok := v.stagedContracts[resovledAddrLocation]; ok { - resolvedLocation = stagedUpdate.DeployLocation - } else { - resolvedLocation = resovledAddrLocation - } - resolvedLocations[i] = runtime.ResolvedLocation{ Location: resolvedLocation, Identifiers: []runtime.Identifier{identifier}, From 1c6c140af2a5ed29e854e5c227bca2707c76e4ef Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Tue, 27 Aug 2024 10:14:46 -0700 Subject: [PATCH 4/5] Lint --- internal/migrate/staging_validator.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go index 7c5122996..8e90110ef 100644 --- a/internal/migrate/staging_validator.go +++ b/internal/migrate/staging_validator.go @@ -44,9 +44,10 @@ import ( "github.com/onflow/flow-go/model/flow" "github.com/onflow/contract-updater/lib/go/templates" - "github.com/onflow/flow-cli/internal/util" flowsdk "github.com/onflow/flow-go-sdk" "github.com/onflow/flowkit/v2" + + "github.com/onflow/flow-cli/internal/util" ) //go:generate mockery --name stagingValidator --inpackage --testonly --case underscore From bd87b9fa478ee61027ea811f87bc8300d0136c1b Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Tue, 27 Aug 2024 12:24:09 -0700 Subject: [PATCH 5/5] Add test for staging with entitlements --- internal/migrate/staging_validator_test.go | 57 ++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/internal/migrate/staging_validator_test.go b/internal/migrate/staging_validator_test.go index c01bbb646..a1e15dc09 100644 --- a/internal/migrate/staging_validator_test.go +++ b/internal/migrate/staging_validator_test.go @@ -876,4 +876,61 @@ func Test_StagingValidator(t *testing.T) { require.NoError(t, err) }) + + t.Run("contract update with entitlements", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + pub contract Foo { + pub resource Bar {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Test": []byte(` + import Foo from 0x01 + pub contract Test { + pub resource R { + pub var bar: auth &Foo.Bar? + init() { + self.bar = nil + } + } + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + access(all) contract Foo { + access(all) resource Bar { + access(E) fun foo(){} + } + access(all) entitlement E + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Test"), + SourceLocation: common.StringLocation("./Test.cdc"), + Code: []byte(` + import Foo from 0x01 + access(all) contract Test { + access(all) resource R { + access(all) var bar: auth(Foo.E) &Foo.Bar? + init() { + self.bar = nil + } + } + }`), + }, + }) + + require.NoError(t, err) + }) }