diff --git a/.github/workflows/commit-lint.yml b/.github/workflows/commit-lint.yml index 7df13ebcac..4ace940171 100644 --- a/.github/workflows/commit-lint.yml +++ b/.github/workflows/commit-lint.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_mvp0.1_dev pull_request: branches: - master - develop + - state_expiry_mvp0.1_dev jobs: commitlint: diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 234cbfda5b..393930ed92 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_mvp0.1_dev pull_request: branches: - master - develop + - state_expiry_mvp0.1_dev jobs: truffle-test: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 87bfb710dd..a8b0a49dc4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_mvp0.1_dev pull_request: branches: - master - develop + - state_expiry_mvp0.1_dev jobs: golang-lint: diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index f4d9e053f2..fd860cfe6e 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_mvp0.1_dev pull_request: branches: - master - develop + - state_expiry_mvp0.1_dev jobs: unit-test: diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 3633ac5c4a..b54a766dcf 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -188,7 +188,7 @@ func (b *SimulatedBackend) stateByBlockNumber(ctx context.Context, blockNumber * if err != nil { return nil, err } - return b.blockchain.StateAt(block.Root()) + return b.blockchain.StateAt(block.Root(), block.Hash(), block.Number()) } // CodeAt returns the code associated with a certain account in the blockchain. diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index 2d3c7d0d0d..816df6f2e0 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -56,7 +56,7 @@ var ( Flags: flags.Merge([]cli.Flag{ utils.CachePreimagesFlag, utils.StateSchemeFlag, - }, utils.DatabasePathFlags), + }, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` The init command initializes a new genesis block and definition for the network. This is a destructive action and changes the network in which you will be @@ -86,7 +86,7 @@ It expects the genesis file as argument.`, Name: "dumpgenesis", Usage: "Dumps genesis block JSON configuration to stdout", ArgsUsage: "", - Flags: append([]cli.Flag{utils.DataDirFlag}, utils.NetworkFlags...), + Flags: flags.Merge([]cli.Flag{utils.DataDirFlag}, utils.NetworkFlags, utils.StateExpiryBaseFlags), Description: ` The dumpgenesis command prints the genesis configuration of the network preset if one is set. Otherwise it prints the genesis from the datadir.`, @@ -121,7 +121,7 @@ if one is set. Otherwise it prints the genesis from the datadir.`, utils.TransactionHistoryFlag, utils.StateSchemeFlag, utils.StateHistoryFlag, - }, utils.DatabasePathFlags), + }, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` The import command imports blocks from an RLP-encoded form. The form can be one file with several RLP-encoded blocks, or several files can be used. @@ -138,7 +138,7 @@ processing will proceed even if an individual RLP-file import failure occurs.`, utils.CacheFlag, utils.SyncModeFlag, utils.StateSchemeFlag, - }, utils.DatabasePathFlags), + }, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` Requires a first argument of the file to write to. Optional second and third arguments control the first and @@ -154,7 +154,7 @@ be gzipped.`, Flags: flags.Merge([]cli.Flag{ utils.CacheFlag, utils.SyncModeFlag, - }, utils.DatabasePathFlags), + }, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` The import-preimages command imports hash preimages from an RLP encoded stream. It's deprecated, please use "geth db import" instead. @@ -168,7 +168,7 @@ It's deprecated, please use "geth db import" instead. Flags: flags.Merge([]cli.Flag{ utils.CacheFlag, utils.SyncModeFlag, - }, utils.DatabasePathFlags), + }, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` The export-preimages command exports hash preimages to an RLP encoded stream. It's deprecated, please use "geth db export" instead. @@ -188,7 +188,7 @@ It's deprecated, please use "geth db export" instead. utils.StartKeyFlag, utils.DumpLimitFlag, utils.StateSchemeFlag, - }, utils.DatabasePathFlags), + }, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` This command dumps out the state for a given block (or latest, if none provided). `, diff --git a/cmd/geth/config.go b/cmd/geth/config.go index b1744c8040..1418ade4dc 100644 --- a/cmd/geth/config.go +++ b/cmd/geth/config.go @@ -49,7 +49,7 @@ var ( Name: "dumpconfig", Usage: "Export configuration values in a TOML format", ArgsUsage: "", - Flags: flags.Merge(nodeFlags, rpcFlags), + Flags: flags.Merge(nodeFlags, rpcFlags, utils.StateExpiryBaseFlags), Description: `Export configuration values in TOML format (to stdout by default).`, } diff --git a/cmd/geth/consolecmd.go b/cmd/geth/consolecmd.go index 526ede9619..c6575d247e 100644 --- a/cmd/geth/consolecmd.go +++ b/cmd/geth/consolecmd.go @@ -33,7 +33,7 @@ var ( Action: localConsole, Name: "console", Usage: "Start an interactive JavaScript environment", - Flags: flags.Merge(nodeFlags, rpcFlags, consoleFlags), + Flags: flags.Merge(nodeFlags, rpcFlags, consoleFlags, utils.StateExpiryBaseFlags), Description: ` The Geth console is an interactive shell for the JavaScript runtime environment which exposes a node admin interface as well as the Ðapp JavaScript API. diff --git a/cmd/geth/dbcmd.go b/cmd/geth/dbcmd.go index eb6185fc2f..cb54c7694d 100644 --- a/cmd/geth/dbcmd.go +++ b/cmd/geth/dbcmd.go @@ -18,6 +18,7 @@ package main import ( "bytes" + "errors" "fmt" "math" "os" @@ -49,7 +50,7 @@ var ( Name: "removedb", Usage: "Remove blockchain and state databases", ArgsUsage: "", - Flags: utils.DatabasePathFlags, + Flags: flags.Merge(utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` Remove blockchain and state databases`, } @@ -59,6 +60,7 @@ Remove blockchain and state databases`, ArgsUsage: "", Subcommands: []*cli.Command{ dbInspectCmd, + dbInspectTrieCmd, dbStatCmd, dbCompactCmd, dbGetCmd, @@ -84,15 +86,25 @@ Remove blockchain and state databases`, ArgsUsage: " ", Flags: flags.Merge([]cli.Flag{ utils.SyncModeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Usage: "Inspect the storage size for each type of data in the database", Description: `This commands iterates the entire database. If the optional 'prefix' and 'start' arguments are provided, then the iteration is limited to the given subset of data.`, } + dbInspectTrieCmd = &cli.Command{ + Action: inspectTrie, + Name: "inspect-trie", + ArgsUsage: " ", + Flags: flags.Merge([]cli.Flag{ + utils.SyncModeFlag, + }, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), + Usage: "Inspect the MPT tree of the account and contract.", + Description: `This commands iterates the entrie WorldState.`, + } dbCheckStateContentCmd = &cli.Command{ Action: checkStateContent, Name: "check-state-content", ArgsUsage: "", - Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags), + Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Usage: "Verify that state data is cryptographically correct", Description: `This command iterates the entire database for 32-byte keys, looking for rlp-encoded trie nodes. For each trie node encountered, it checks that the key corresponds to the keccak256(value). If this is not true, this indicates @@ -143,7 +155,7 @@ a data corruption.`, Usage: "Print leveldb statistics", Flags: flags.Merge([]cli.Flag{ utils.SyncModeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), } dbCompactCmd = &cli.Command{ Action: dbCompact, @@ -153,7 +165,7 @@ a data corruption.`, utils.SyncModeFlag, utils.CacheFlag, utils.CacheDatabaseFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: `This command performs a database compaction. WARNING: This operation may take a very long time to finish, and may cause database corruption if it is aborted during execution'!`, @@ -165,7 +177,7 @@ corruption if it is aborted during execution'!`, ArgsUsage: "", Flags: flags.Merge([]cli.Flag{ utils.SyncModeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: "This command looks up the specified database key from the database.", } dbDeleteCmd = &cli.Command{ @@ -175,7 +187,7 @@ corruption if it is aborted during execution'!`, ArgsUsage: "", Flags: flags.Merge([]cli.Flag{ utils.SyncModeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: `This command deletes the specified database key from the database. WARNING: This is a low-level operation which may cause database corruption!`, } @@ -186,7 +198,7 @@ WARNING: This is a low-level operation which may cause database corruption!`, ArgsUsage: " ", Flags: flags.Merge([]cli.Flag{ utils.SyncModeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: `This command sets a given database key to the given value. WARNING: This is a low-level operation which may cause database corruption!`, } @@ -198,7 +210,7 @@ WARNING: This is a low-level operation which may cause database corruption!`, Flags: flags.Merge([]cli.Flag{ utils.SyncModeFlag, utils.StateSchemeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: "This command looks up the specified database key from the database.", } dbDumpFreezerIndex = &cli.Command{ @@ -208,7 +220,7 @@ WARNING: This is a low-level operation which may cause database corruption!`, ArgsUsage: " ", Flags: flags.Merge([]cli.Flag{ utils.SyncModeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: "This command displays information about the freezer index.", } dbImportCmd = &cli.Command{ @@ -218,7 +230,7 @@ WARNING: This is a low-level operation which may cause database corruption!`, ArgsUsage: " ", Flags: flags.Merge([]cli.Flag{ utils.SyncModeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: "Exports the specified chain data to an RLP encoded stream, optionally gzip-compressed.", } dbMetadataCmd = &cli.Command{ @@ -237,15 +249,16 @@ WARNING: This is a low-level operation which may cause database corruption!`, Usage: "Shows metadata about the chain status.", Flags: flags.Merge([]cli.Flag{ utils.SyncModeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: "Shows metadata about the chain status.", } ancientInspectCmd = &cli.Command{ Action: ancientInspect, Name: "inspect-reserved-oldest-blocks", - Flags: []cli.Flag{ - utils.DataDirFlag, - }, + Flags: flags.Merge( + []cli.Flag{utils.DataDirFlag}, + utils.StateExpiryBaseFlags, + ), Usage: "Inspect the ancientStore information", Description: `This commands will read current offset from kvdb, which is the current offset and starting BlockNumber of ancientStore, will also displays the reserved number of blocks in ancientStore `, @@ -352,6 +365,79 @@ func ancientInspect(ctx *cli.Context) error { return rawdb.AncientInspect(db) } +func inspectTrie(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("required arguments: %v", ctx.Command.ArgsUsage) + } + + if ctx.NArg() > 3 { + return fmt.Errorf("Max 3 arguments: %v", ctx.Command.ArgsUsage) + } + + var ( + blockNumber uint64 + blockRoot common.Hash + jobnum uint64 + ) + + stack, _ := makeConfigNode(ctx) + defer stack.Close() + + db := utils.MakeChainDatabase(ctx, stack, true, true) + defer db.Close() + + if ctx.NArg() >= 1 { + if ctx.Args().Get(0) == "latest" { + headBlock := rawdb.ReadHeadBlock(db) + if headBlock == nil { + return errors.New("failed to load head block") + } + blockNumber = headBlock.NumberU64() + blockRoot = headBlock.Root() + } else if ctx.Args().Get(0) == "snapshot" { + blockRoot = rawdb.ReadSnapshotRoot(db) + blockNumber = math.MaxUint64 + } else { + var err error + blockNumber, err = strconv.ParseUint(ctx.Args().Get(0), 10, 64) + if err != nil { + return fmt.Errorf("failed to Parse blocknum, Args[0]: %v, err: %v", ctx.Args().Get(0), err) + } + blockHash := rawdb.ReadCanonicalHash(db, blockNumber) + block := rawdb.ReadBlock(db, blockHash, blockNumber) + blockRoot = block.Root() + } + + if ctx.NArg() == 1 { + jobnum = 1000 + } else { + var err error + jobnum, err = strconv.ParseUint(ctx.Args().Get(1), 10, 64) + if err != nil { + return fmt.Errorf("failed to Parse jobnum, Args[1]: %v, err: %v", ctx.Args().Get(1), err) + } + } + + if (blockRoot == common.Hash{}) { + log.Error("Empty root hash") + } + fmt.Printf("ReadBlockHeader, root: %v, blocknum: %v\n", blockRoot, blockNumber) + trieDB := trie.NewDatabase(db, nil) + theTrie, err := trie.New(trie.TrieID(blockRoot), trieDB) + if err != nil { + fmt.Printf("fail to new trie tree, err: %v, rootHash: %v\n", err, blockRoot.String()) + return err + } + theInspect, err := trie.NewInspector(trieDB, theTrie, blockNumber, jobnum) + if err != nil { + return err + } + theInspect.Run() + theInspect.DisplayResult() + } + return nil +} + func checkStateContent(ctx *cli.Context) error { var ( prefix []byte diff --git a/cmd/geth/main.go b/cmd/geth/main.go index b1b08588e2..d35016e2a8 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -215,6 +215,16 @@ var ( utils.MetricsInfluxDBBucketFlag, utils.MetricsInfluxDBOrganizationFlag, } + + stateExpiryFlags = []cli.Flag{ + utils.StateExpiryEnableFlag, + utils.StateExpiryFullStateEndpointFlag, + utils.StateExpiryStateEpoch1BlockFlag, + utils.StateExpiryStateEpoch2BlockFlag, + utils.StateExpiryStateEpochPeriodFlag, + utils.StateExpiryEnableLocalReviveFlag, + utils.StateExpiryEnableRemoteModeFlag, + } ) var app = flags.NewApp("the go-ethereum command line interface") @@ -265,6 +275,7 @@ func init() { consoleFlags, debug.Flags, metricsFlags, + stateExpiryFlags, ) app.Before = func(ctx *cli.Context) error { diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 0d9e583e79..99b9bc5161 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -25,11 +25,12 @@ import ( "path/filepath" "time" + "github.com/ethereum/go-ethereum/core" + "github.com/prometheus/tsdb/fileutil" "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state/pruner" @@ -61,7 +62,9 @@ var ( Flags: flags.Merge([]cli.Flag{ utils.BloomFilterSizeFlag, utils.TriesInMemoryFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + utils.StateExpiryMaxThreadFlag, + configFileFlag, + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` geth snapshot prune-state will prune historical state data with the help of the state snapshot. @@ -79,13 +82,13 @@ WARNING: it's only supported in hash mode(--state.scheme=hash)". Usage: "Prune block data offline", Action: pruneBlock, Category: "MISCELLANEOUS COMMANDS", - Flags: []cli.Flag{ + Flags: flags.Merge([]cli.Flag{ utils.DataDirFlag, utils.AncientFlag, utils.BlockAmountReserved, utils.TriesInMemoryFlag, utils.CheckSnapshotWithMPT, - }, + }, utils.StateExpiryBaseFlags), Description: ` geth offline prune-block for block data in ancientdb. The amount of blocks expected for remaining after prune can be specified via block-amount-reserved in this command, @@ -105,7 +108,7 @@ so it's very necessary to do block data prune, this feature will handle it. Action: verifyState, Flags: flags.Merge([]cli.Flag{ utils.StateSchemeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` geth snapshot verify-state will traverse the whole accounts and storages set based on the specified @@ -120,10 +123,10 @@ In other words, this command does the snapshot to trie conversion. ArgsUsage: "", Action: pruneAllState, Category: "MISCELLANEOUS COMMANDS", - Flags: []cli.Flag{ + Flags: flags.Merge([]cli.Flag{ utils.DataDirFlag, utils.AncientFlag, - }, + }, utils.StateExpiryBaseFlags), Description: ` will prune all historical trie state data except genesis block. All trie nodes will be deleted from the database. @@ -141,7 +144,7 @@ the trie clean cache with default directory will be deleted. Usage: "Check that there is no 'dangling' snap storage", ArgsUsage: "", Action: checkDanglingStorage, - Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags), + Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` geth snapshot check-dangling-storage traverses the snap storage data, and verifies that all snapshot storage data has a corresponding account. @@ -152,7 +155,7 @@ data, and verifies that all snapshot storage data has a corresponding account. Usage: "Check all snapshot layers for the a specific account", ArgsUsage: "
", Action: checkAccount, - Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags), + Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` geth snapshot inspect-account
checks all snapshot layers and prints out information about the specified address. @@ -165,7 +168,7 @@ information about the specified address. Action: traverseState, Flags: flags.Merge([]cli.Flag{ utils.StateSchemeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` geth snapshot traverse-state will traverse the whole state from the given state root and will abort if any @@ -182,7 +185,7 @@ It's also usable without snapshot enabled. Action: traverseRawState, Flags: flags.Merge([]cli.Flag{ utils.StateSchemeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` geth snapshot traverse-rawstate will traverse the whole state from the given root and will abort if any referenced @@ -205,7 +208,7 @@ It's also usable without snapshot enabled. utils.DumpLimitFlag, utils.TriesInMemoryFlag, utils.StateSchemeFlag, - }, utils.NetworkFlags, utils.DatabasePathFlags), + }, utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` This command is semantically equivalent to 'geth dump', but uses the snapshots as the backend data source, making this command a lot faster. @@ -411,24 +414,50 @@ func pruneBlock(ctx *cli.Context) error { // Deprecation: this command should be deprecated once the hash-based // scheme is deprecated. func pruneState(ctx *cli.Context) error { - stack, _ := makeConfigNode(ctx) + stack, cfg := makeConfigNode(ctx) defer stack.Close() chaindb := utils.MakeChainDatabase(ctx, stack, false, false) defer chaindb.Close() - if rawdb.ReadStateScheme(chaindb) != rawdb.HashScheme { - log.Crit("Offline pruning is not required for path scheme") + chainConfig, _, err := core.LoadChainConfig(chaindb, cfg.Eth.Genesis) + if err != nil { + return err + } + + cacheConfig := &core.CacheConfig{ + TrieCleanLimit: cfg.Eth.TrieCleanCache, + TrieCleanNoPrefetch: cfg.Eth.NoPrefetch, + TrieDirtyLimit: cfg.Eth.TrieDirtyCache, + TrieDirtyDisabled: cfg.Eth.NoPruning, + TrieTimeLimit: cfg.Eth.TrieTimeout, + NoTries: cfg.Eth.TriesVerifyMode != core.LocalVerify, + SnapshotLimit: cfg.Eth.SnapshotCache, + TriesInMemory: cfg.Eth.TriesInMemory, + Preimages: cfg.Eth.Preimages, + StateHistory: cfg.Eth.StateHistory, + StateScheme: cfg.Eth.StateScheme, + StateExpiryCfg: cfg.Eth.StateExpiryCfg, } prunerconfig := pruner.Config{ - Datadir: stack.ResolvePath(""), - BloomSize: ctx.Uint64(utils.BloomFilterSizeFlag.Name), + Datadir: stack.ResolvePath(""), + BloomSize: ctx.Uint64(utils.BloomFilterSizeFlag.Name), + EnableStateExpiry: cfg.Eth.StateExpiryCfg.EnableExpiry(), + ChainConfig: chainConfig, + CacheConfig: cacheConfig, + MaxExpireThreads: ctx.Uint64(utils.StateExpiryMaxThreadFlag.Name), } pruner, err := pruner.NewPruner(chaindb, prunerconfig, ctx.Uint64(utils.TriesInMemoryFlag.Name)) if err != nil { log.Error("Failed to open snapshot tree", "err", err) return err } + + if cfg.Eth.StateScheme == rawdb.PathScheme { + // when using PathScheme, only prune expired state + return pruner.ExpiredPrune(common.Big0, common.Hash{}) + } + if ctx.NArg() > 1 { log.Error("Too many arguments given") return errors.New("too many arguments") diff --git a/cmd/geth/verkle.go b/cmd/geth/verkle.go index e3953ed5b7..aebee0c1cd 100644 --- a/cmd/geth/verkle.go +++ b/cmd/geth/verkle.go @@ -45,7 +45,7 @@ var ( Usage: "verify the conversion of a MPT into a verkle tree", ArgsUsage: "", Action: verifyVerkle, - Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags), + Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` geth verkle verify This command takes a root commitment and attempts to rebuild the tree. @@ -56,7 +56,7 @@ This command takes a root commitment and attempts to rebuild the tree. Usage: "Dump a verkle tree to a DOT file", ArgsUsage: " [ ...]", Action: expandVerkle, - Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags), + Flags: flags.Merge(utils.NetworkFlags, utils.DatabasePathFlags, utils.StateExpiryBaseFlags), Description: ` geth verkle dump [ ...] This command will produce a dot file representing the tree, rooted at . diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 2f23414ee8..9b13e1012a 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -35,6 +35,9 @@ import ( "strings" "time" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rlp" + "github.com/fatih/structs" pcsclite "github.com/gballet/go-libpcsclite" gopsutil "github.com/shirou/gopsutil/mem" @@ -1116,6 +1119,56 @@ var ( RemoteDBFlag, HttpHeaderFlag, } + StateExpiryBaseFlags = []cli.Flag{ + StateExpiryEnableFlag, + StateExpiryEnableRemoteModeFlag, + } +) + +var ( + // State Expiry Flags + StateExpiryEnableFlag = &cli.BoolFlag{ + Name: "state-expiry", + Usage: "Enable state expiry, it will mark state's epoch meta and prune un-accessed states later", + Category: flags.StateExpiryCategory, + } + StateExpiryFullStateEndpointFlag = &cli.StringFlag{ + Name: "state-expiry.remote", + Usage: "set state expiry remote full state rpc endpoint, every expired state will fetch from remote", + Category: flags.StateExpiryCategory, + } + StateExpiryStateEpoch1BlockFlag = &cli.Uint64Flag{ + Name: "state-expiry.epoch1", + Usage: "set state expiry epoch1 block number", + Category: flags.StateExpiryCategory, + } + StateExpiryStateEpoch2BlockFlag = &cli.Uint64Flag{ + Name: "state-expiry.epoch2", + Usage: "set state expiry epoch2 block number", + Category: flags.StateExpiryCategory, + } + StateExpiryStateEpochPeriodFlag = &cli.Uint64Flag{ + Name: "state-expiry.period", + Usage: "set state expiry epoch period after epoch2", + Category: flags.StateExpiryCategory, + } + StateExpiryEnableLocalReviveFlag = &cli.BoolFlag{ + Name: "state-expiry.localrevive", + Usage: "if enable local revive", + Category: flags.StateExpiryCategory, + } + StateExpiryMaxThreadFlag = &cli.Uint64Flag{ + Name: "state-expiry.maxthread", + Usage: "set state expiry maxthread in prune", + Value: 10000, + Category: flags.StateExpiryCategory, + } + StateExpiryEnableRemoteModeFlag = &cli.BoolFlag{ + Name: "state-expiry.remotemode", + Usage: "set state expiry in remotemode", + Value: false, + Category: flags.StateExpiryCategory, + } ) func init() { @@ -1931,13 +1984,25 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) { cfg.StateHistory = ctx.Uint64(StateHistoryFlag.Name) } // Parse state scheme, abort the process if it's not compatible. - chaindb := tryMakeReadOnlyDatabase(ctx, stack) + chaindb := MakeChainDatabase(ctx, stack, false, false) scheme, err := ParseStateScheme(ctx, chaindb) - chaindb.Close() if err != nil { Fatalf("%v", err) } cfg.StateScheme = scheme + seCfg, err := ParseStateExpiryConfig(ctx, chaindb, scheme) + if err != nil { + Fatalf("%v", err) + } + seCfg.AllowedPeerList = make([]string, 0) + for _, seNode := range stack.Config().P2P.StateExpiryAllowedNodes { + seCfg.AllowedPeerList = append(seCfg.AllowedPeerList, seNode.ID().String()) + } + if len(seCfg.AllowedPeerList) == 0 && seCfg.EnableRemoteMode { + log.Warn("State expiry remote mode is enabled but no allowed peers are specified.") + } + cfg.StateExpiryCfg = seCfg + chaindb.Close() // Parse transaction history flag, if user is still using legacy config // file with 'TxLookupLimit' configured, copy the value to 'TransactionHistory'. @@ -2512,6 +2577,82 @@ func ParseStateScheme(ctx *cli.Context, disk ethdb.Database) (string, error) { return "", fmt.Errorf("incompatible state scheme, stored: %s, provided: %s", stored, scheme) } +func ParseStateExpiryConfig(ctx *cli.Context, disk ethdb.Database, scheme string) (*types.StateExpiryConfig, error) { + enc := rawdb.ReadStateExpiryCfg(disk) + var stored *types.StateExpiryConfig + if len(enc) > 0 { + var cfg types.StateExpiryConfig + if err := rlp.DecodeBytes(enc, &cfg); err != nil { + return nil, err + } + stored = &cfg + } + newCfg := &types.StateExpiryConfig{StateScheme: scheme} + if ctx.IsSet(StateExpiryEnableFlag.Name) { + newCfg.Enable = ctx.Bool(StateExpiryEnableFlag.Name) + } + if ctx.IsSet(StateExpiryEnableRemoteModeFlag.Name) { + newCfg.EnableRemoteMode = ctx.Bool(StateExpiryEnableRemoteModeFlag.Name) + } + if ctx.IsSet(StateExpiryFullStateEndpointFlag.Name) { + newCfg.FullStateEndpoint = ctx.String(StateExpiryFullStateEndpointFlag.Name) + } + + // some config will use stored default + if ctx.IsSet(StateExpiryStateEpoch1BlockFlag.Name) { + newCfg.StateEpoch1Block = ctx.Uint64(StateExpiryStateEpoch1BlockFlag.Name) + } else if stored != nil { + newCfg.StateEpoch1Block = stored.StateEpoch1Block + } + if ctx.IsSet(StateExpiryStateEpoch2BlockFlag.Name) { + newCfg.StateEpoch2Block = ctx.Uint64(StateExpiryStateEpoch2BlockFlag.Name) + } else if stored != nil { + newCfg.StateEpoch2Block = stored.StateEpoch2Block + } + if ctx.IsSet(StateExpiryStateEpochPeriodFlag.Name) { + newCfg.StateEpochPeriod = ctx.Uint64(StateExpiryStateEpochPeriodFlag.Name) + } else if stored != nil { + newCfg.StateEpochPeriod = stored.StateEpochPeriod + } + if ctx.IsSet(StateExpiryEnableLocalReviveFlag.Name) { + newCfg.EnableLocalRevive = ctx.Bool(StateExpiryEnableLocalReviveFlag.Name) + } + + // override prune level + newCfg.PruneLevel = types.StateExpiryPruneLevel1 + switch newCfg.StateScheme { + case rawdb.HashScheme: + // TODO(0xbundler): will stop support HBSS later. + newCfg.PruneLevel = types.StateExpiryPruneLevel0 + case rawdb.PathScheme: + newCfg.PruneLevel = types.StateExpiryPruneLevel1 + default: + return nil, fmt.Errorf("not support the state scheme: %v", newCfg.StateScheme) + } + + if err := newCfg.Validation(); err != nil { + return nil, err + } + if err := stored.CheckCompatible(newCfg); err != nil { + return nil, err + } + + log.Info("Apply State Expiry", "cfg", newCfg) + if !newCfg.Enable { + return newCfg, nil + } + + // save it into db + enc, err := rlp.EncodeToBytes(newCfg) + if err != nil { + return nil, err + } + if err = rawdb.WriteStateExpiryCfg(disk, enc); err != nil { + return nil, err + } + return newCfg, nil +} + // MakeTrieDatabase constructs a trie database based on the configured scheme. func MakeTrieDatabase(ctx *cli.Context, disk ethdb.Database, preimage bool, readOnly bool) *trie.Database { config := &trie.Config{ diff --git a/common/big.go b/common/big.go index 65d4377bf7..4d32622092 100644 --- a/common/big.go +++ b/common/big.go @@ -28,3 +28,7 @@ var ( Big256 = big.NewInt(256) Big257 = big.NewInt(257) ) + +func AddBig1(tmp *big.Int) *big.Int { + return new(big.Int).Add(tmp, Big1) +} diff --git a/common/bytes.go b/common/bytes.go index d1f5c6c995..eaaa29c6cf 100644 --- a/common/bytes.go +++ b/common/bytes.go @@ -36,6 +36,13 @@ func FromHex(s string) []byte { return Hex2Bytes(s) } +func No0xPrefix(s string) string { + if has0xPrefix(s) { + return s[2:] + } + return s +} + // CopyBytes returns an exact copy of the provided bytes. func CopyBytes(b []byte) (copiedBytes []byte) { if b == nil { diff --git a/consensus/parlia/parlia.go b/consensus/parlia/parlia.go index 8f8fd18cc1..bc8362bdb3 100644 --- a/consensus/parlia/parlia.go +++ b/consensus/parlia/parlia.go @@ -1122,7 +1122,7 @@ func (p *Parlia) Finalize(chain consensus.ChainHeaderReader, header *types.Heade err = p.slash(spoiledVal, state, header, cx, txs, receipts, systemTxs, usedGas, false) if err != nil { // it is possible that slash validator failed because of the slash channel is disabled. - log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal) + log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal, "err", err) } } } @@ -1183,7 +1183,7 @@ func (p *Parlia) FinalizeAndAssemble(chain consensus.ChainHeaderReader, header * err = p.slash(spoiledVal, state, header, cx, &txs, &receipts, nil, &header.GasUsed, true) if err != nil { // it is possible that slash validator failed because of the slash channel is disabled. - log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal) + log.Error("slash validator failed", "block hash", header.Hash(), "address", spoiledVal, "err", err) } } } @@ -1692,13 +1692,14 @@ func (p *Parlia) applyTransaction( } actualTx := (*receivedTxs)[0] if !bytes.Equal(p.signer.Hash(actualTx).Bytes(), expectedHash.Bytes()) { - return fmt.Errorf("expected tx hash %v, get %v, nonce %d, to %s, value %s, gas %d, gasPrice %s, data %s", expectedHash.String(), actualTx.Hash().String(), - expectedTx.Nonce(), - expectedTx.To().String(), - expectedTx.Value().String(), - expectedTx.Gas(), - expectedTx.GasPrice().String(), - hex.EncodeToString(expectedTx.Data()), + return fmt.Errorf("expected tx hash %v, get %v, nonce %d:%d, to %s:%s, value %s:%s, gas %d:%d, gasPrice %s:%s, data %s:%s, dbErr: %v", expectedHash.String(), actualTx.Hash().String(), + expectedTx.Nonce(), actualTx.Nonce(), + expectedTx.To().String(), actualTx.To().String(), + expectedTx.Value().String(), actualTx.Value().String(), + expectedTx.Gas(), actualTx.Gas(), + expectedTx.GasPrice().String(), actualTx.GasPrice().String(), + hex.EncodeToString(expectedTx.Data()), hex.EncodeToString(actualTx.Data()), + state.Error(), ) } expectedTx = actualTx @@ -1959,7 +1960,7 @@ func applyMessage( msg.Value(), ) if err != nil { - log.Error("apply message failed", "msg", string(ret), "err", err) + log.Error("apply message failed", "contract", msg.To(), "caller", msg.From(), "data", msg.Data(), "msg", string(ret), "err", err, "dberror", state.Error()) } return msg.Gas() - returnGas, err } diff --git a/core/blockchain.go b/core/blockchain.go index 94c54fc996..219065930e 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -158,14 +158,18 @@ type CacheConfig struct { SnapshotNoBuild bool // Whether the background generation is allowed SnapshotWait bool // Wait for snapshot construction on startup. TODO(karalabe): This is a dirty hack for testing, nuke it + + // state expiry feature + StateExpiryCfg *types.StateExpiryConfig } -// triedbConfig derives the configures for trie database. -func (c *CacheConfig) triedbConfig() *trie.Config { +// TriedbConfig derives the configures for trie database. +func (c *CacheConfig) TriedbConfig() *trie.Config { config := &trie.Config{ - Cache: c.TrieCleanLimit, - Preimages: c.Preimages, - NoTries: c.NoTries, + Cache: c.TrieCleanLimit, + Preimages: c.Preimages, + NoTries: c.NoTries, + EnableStateExpiry: c.StateExpiryCfg.EnableExpiry(), } if c.StateScheme == rawdb.HashScheme { config.HashDB = &hashdb.Config{ @@ -178,6 +182,10 @@ func (c *CacheConfig) triedbConfig() *trie.Config { CleanCacheSize: c.TrieCleanLimit * 1024 * 1024, DirtyCacheSize: c.TrieDirtyLimit * 1024 * 1024, } + if config.EnableStateExpiry { + // state expiry need more cache for save epoch meta, but not exceed maxBuffer + config.PathDB.CleanCacheSize = 2 * config.PathDB.CleanCacheSize + } } return config } @@ -293,6 +301,10 @@ type BlockChain struct { // monitor doubleSignMonitor *monitor.DoubleSignMonitor + + // state expiry feature + stateExpiryCfg *types.StateExpiryConfig + fullStateDB ethdb.FullStateDB } // NewBlockChain returns a fully initialised block chain using information @@ -313,7 +325,8 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis diffLayerChanCache, _ := exlru.New(diffLayerCacheLimit) // Open trie database with provided config - triedb := trie.NewDatabase(db, cacheConfig.triedbConfig()) + triedb := trie.NewDatabase(db, cacheConfig.TriedbConfig()) + // Setup the genesis block, commit the provided genesis specification // to database if the genesis block is not present yet, or load the // stored one from database. @@ -364,6 +377,14 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, genesis *Genesis bc.processor = NewStateProcessor(chainConfig, bc, engine) var err error + if cacheConfig.StateExpiryCfg.EnableExpiry() { + log.Info("enable state expiry feature", "RemoteEndPoint", cacheConfig.StateExpiryCfg.FullStateEndpoint) + bc.stateExpiryCfg = cacheConfig.StateExpiryCfg + bc.fullStateDB, err = ethdb.NewFullStateRPCServer(cacheConfig.StateExpiryCfg.FullStateEndpoint) + if err != nil { + return nil, err + } + } bc.hc, err = NewHeaderChain(db, chainConfig, engine, bc.insertStopped) if err != nil { return nil, err @@ -599,6 +620,22 @@ func (bc *BlockChain) cacheBlock(hash common.Hash, block *types.Block) { bc.blockCache.Add(hash, block) } +func (bc *BlockChain) EnableStateExpiry() bool { + return bc.stateExpiryCfg.EnableExpiry() +} + +func (bc *BlockChain) EnableStateExpiryLocalRevive() bool { + if bc.EnableStateExpiry() { + return bc.stateExpiryCfg.EnableLocalRevive + } + + return false +} + +func (bc *BlockChain) FullStateDB() ethdb.FullStateDB { + return bc.fullStateDB +} + // empty returns an indicator whether the blockchain is empty. // Note, it's a special case that we connect a non-empty ancient // database with an empty node, so that we can plugin the ancient @@ -1015,8 +1052,15 @@ func (bc *BlockChain) SnapSyncCommitHead(hash common.Hash) error { } // StateAtWithSharedPool returns a new mutable state based on a particular point in time with sharedStorage -func (bc *BlockChain) StateAtWithSharedPool(root common.Hash) (*state.StateDB, error) { - return state.NewWithSharedPool(root, bc.stateCache, bc.snaps) +func (bc *BlockChain) StateAtWithSharedPool(root, startAtBlockHash common.Hash, height *big.Int) (*state.StateDB, error) { + stateDB, err := state.NewWithSharedPool(root, bc.stateCache, bc.snaps) + if err != nil { + return nil, err + } + if bc.EnableStateExpiry() { + stateDB.InitStateExpiryFeature(bc.stateExpiryCfg, bc.fullStateDB, startAtBlockHash, height) + } + return stateDB, err } // Reset purges the entire blockchain, restoring it to its genesis state. @@ -1167,6 +1211,7 @@ func (bc *BlockChain) Stop() { log.Error("Failed to journal state snapshot", "err", err) } } + if bc.triedb.Scheme() == rawdb.PathScheme { // Ensure that the in-memory trie nodes are journaled to disk properly. if err := bc.triedb.Journal(bc.CurrentBlock().Root); err != nil { @@ -1180,7 +1225,7 @@ func (bc *BlockChain) Stop() { // - HEAD-127: So we have a hard limit on the number of blocks reexecuted if !bc.cacheConfig.TrieDirtyDisabled { triedb := bc.triedb - var once sync.Once + var once sync.Once for _, offset := range []uint64{0, 1, TriesInMemory - 1} { if number := bc.CurrentBlock().Number.Uint64(); number > offset { @@ -1191,9 +1236,9 @@ func (bc *BlockChain) Stop() { } else { rawdb.WriteSafePointBlockNumber(bc.db, recent.NumberU64()) once.Do(func() { - rawdb.WriteHeadBlockHash(bc.db, recent.Hash()) + rawdb.WriteHeadBlockHash(bc.db, recent.Hash()) }) - } + } } } if snapBase != (common.Hash{}) { @@ -1203,14 +1248,8 @@ func (bc *BlockChain) Stop() { } else { rawdb.WriteSafePointBlockNumber(bc.db, bc.CurrentBlock().Number.Uint64()) } - } - - if snapBase != (common.Hash{}) { - log.Info("Writing snapshot state to disk", "root", snapBase) - if err := bc.triedb.Commit(snapBase, true); err != nil { - log.Error("Failed to commit recent state trie", "err", err) - } else { - rawdb.WriteSafePointBlockNumber(bc.db, bc.CurrentBlock().Number.Uint64()) + if err := triedb.CommitEpochMeta(snapBase); err != nil { + log.Error("Failed to commit recent epoch meta", "err", err) } } for !bc.triegc.Empty() { @@ -1221,6 +1260,13 @@ func (bc *BlockChain) Stop() { } } } + epochMetaSnap := bc.triedb.EpochMetaSnapTree() + if epochMetaSnap != nil { + if err := epochMetaSnap.Journal(); err != nil { + log.Error("Failed to journal epochMetaSnapTree", "err", err) + } + } + // Close the trie database, release all the held resources as the last step. if err := bc.triedb.Close(); err != nil { log.Error("Failed to close trie database", "err", err) @@ -1602,13 +1648,17 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types. // If node is running in path mode, skip explicit gc operation // which is unnecessary in this mode. if bc.triedb.Scheme() == rawdb.PathScheme { + err := bc.triedb.CommitEpochMeta(block.Root()) + if err != nil { + return err + } return nil } triedb := bc.stateCache.TrieDB() // If we're running an archive node, always flush if bc.cacheConfig.TrieDirtyDisabled { - err := triedb.Commit(block.Root(), false) + err := triedb.CommitAll(block.Root(), false) if err != nil { return err } @@ -1616,6 +1666,12 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types. // Full but not archive node, do proper garbage collection triedb.Reference(block.Root(), common.Hash{}) // metadata reference to keep trie alive bc.triegc.Push(block.Root(), -int64(block.NumberU64())) + // TODO(0xbundler): when opt commit later, remove it. + go triedb.CommitEpochMeta(block.Root()) + //err := triedb.CommitEpochMeta(block.Root()) + //if err != nil { + // return err + //} if current := block.NumberU64(); current > bc.triesInMemory { // If we exceeded our memory allowance, flush matured singleton nodes to disk @@ -1652,7 +1708,7 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types. log.Info("State in memory for too long, committing", "time", bc.gcproc, "allowance", flushInterval, "optimum", float64(chosen-bc.lastWrite)/float64(bc.triesInMemory)) } // Flush an entire trie and restart the counters - triedb.Commit(header.Root, true) + triedb.CommitAll(header.Root, true) rawdb.WriteSafePointBlockNumber(bc.db, chosen) bc.lastWrite = chosen bc.gcproc = 0 @@ -1833,7 +1889,6 @@ func (bc *BlockChain) insertChain(chain types.Blocks, setHead bool) (int, error) if bc.insertStopped() { return 0, nil } - // Start a parallel signature recovery (signer will fluke on fork transition, minimal perf loss) signer := types.MakeSigner(bc.chainConfig, chain[0].Number(), chain[0].Time()) go SenderCacher.RecoverFromBlocks(signer, chain) @@ -2015,6 +2070,9 @@ func (bc *BlockChain) insertChain(chain types.Blocks, setHead bool) (int, error) return it.index, err } bc.updateHighestVerifiedHeader(block.Header()) + if bc.EnableStateExpiry() { + statedb.InitStateExpiryFeature(bc.stateExpiryCfg, bc.fullStateDB, parent.Hash(), block.Number()) + } // Enable prefetching to pull in trie node paths while processing transactions statedb.StartPrefetcher("chain") @@ -2022,9 +2080,11 @@ func (bc *BlockChain) insertChain(chain types.Blocks, setHead bool) (int, error) // For diff sync, it may fallback to full sync, so we still do prefetch if len(block.Transactions()) >= prefetchTxNumber { // do Prefetch in a separate goroutine to avoid blocking the critical path - // 1.do state prefetch for snapshot cache throwaway := statedb.CopyDoPrefetch() + if throwaway != nil && bc.EnableStateExpiry() { + throwaway.InitStateExpiryFeature(bc.stateExpiryCfg, bc.fullStateDB, parent.Hash(), block.Number()) + } go bc.prefetcher.Prefetch(block, throwaway, &bc.vmConfig, interruptCh) // 2.do trie prefetch for MPT trie node cache @@ -2090,6 +2150,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, setHead bool) (int, error) status, err = bc.writeBlockAndSetHead(block, receipts, logs, statedb, false) } if err != nil { + log.Error("insert chain commit err", "err", err) return it.index, err } // Update the metrics touched during block commit @@ -2106,7 +2167,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, setHead bool) (int, error) stats.usedGas += usedGas dirty, _ := bc.triedb.Size() - stats.report(chain, it.index, dirty, setHead) + stats.report(chain, it.index, dirty, setHead, bc.stateExpiryCfg) if !setHead { // After merge we expect few side chains. Simply count @@ -3105,3 +3166,8 @@ func (bc *BlockChain) SetTrieFlushInterval(interval time.Duration) { func (bc *BlockChain) GetTrieFlushInterval() time.Duration { return time.Duration(bc.flushInterval.Load()) } + +// StorageTrie just get Storage trie from db +func (bc *BlockChain) StorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash) (state.Trie, error) { + return bc.stateCache.OpenStorageTrie(stateRoot, addr, root) +} diff --git a/core/blockchain_insert.go b/core/blockchain_insert.go index ffe2d6501c..9a225afad3 100644 --- a/core/blockchain_insert.go +++ b/core/blockchain_insert.go @@ -39,7 +39,7 @@ const statsReportLimit = 8 * time.Second // report prints statistics if some number of blocks have been processed // or more than a few seconds have passed since the last message. -func (st *insertStats) report(chain []*types.Block, index int, dirty common.StorageSize, setHead bool) { +func (st *insertStats) report(chain []*types.Block, index int, dirty common.StorageSize, setHead bool, config *types.StateExpiryConfig) { // Fetch the timings for the batch var ( now = mclock.Now() @@ -60,6 +60,9 @@ func (st *insertStats) report(chain []*types.Block, index int, dirty common.Stor "blocks", st.processed, "txs", txs, "mgas", float64(st.usedGas) / 1000000, "elapsed", common.PrettyDuration(elapsed), "mgasps", float64(st.usedGas) * 1000 / float64(elapsed), } + if config.EnableExpiry() { + context = append(context, []interface{}{"stateEpoch", types.GetStateEpoch(config, end.Number())}...) + } if timestamp := time.Unix(int64(end.Time()), 0); time.Since(timestamp) > time.Minute { context = append(context, []interface{}{"age", common.PrettyAge(timestamp)}...) } diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index 802b979a10..2369d8a7da 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -347,17 +347,26 @@ func (bc *BlockChain) ContractCodeWithPrefix(hash common.Hash) ([]byte, error) { // State returns a new mutable state based on the current HEAD block. func (bc *BlockChain) State() (*state.StateDB, error) { - return bc.StateAt(bc.CurrentBlock().Root) + return bc.StateAt(bc.CurrentBlock().Root, bc.CurrentBlock().Hash(), bc.CurrentBlock().Number) } // StateAt returns a new mutable state based on a particular point in time. -func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { - return state.New(root, bc.stateCache, bc.snaps) +func (bc *BlockChain) StateAt(startAtRoot common.Hash, startAtBlockHash common.Hash, expectHeight *big.Int) (*state.StateDB, error) { + sdb, err := state.New(startAtRoot, bc.stateCache, bc.snaps) + if err != nil { + return nil, err + } + if bc.EnableStateExpiry() { + sdb.InitStateExpiryFeature(bc.stateExpiryCfg, bc.fullStateDB, startAtBlockHash, expectHeight) + } + return sdb, err } // Config retrieves the chain's fork configuration. func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig } +func (bc *BlockChain) StateExpiryConfig() *types.StateExpiryConfig { return bc.stateExpiryCfg } + // Engine retrieves the blockchain's consensus engine. func (bc *BlockChain) Engine() consensus.Engine { return bc.engine } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 56a5c7763e..750e6518c9 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -4346,7 +4346,7 @@ func TestTransientStorageReset(t *testing.T) { t.Fatalf("failed to insert into chain: %v", err) } // Check the storage - state, err := chain.StateAt(chain.CurrentHeader().Root) + state, err := chain.StateAt(chain.CurrentHeader().Root, chain.CurrentHeader().Hash(), chain.CurrentHeader().Number) if err != nil { t.Fatalf("Failed to load state %v", err) } diff --git a/core/chain_makers.go b/core/chain_makers.go index f0026089ac..58593bc9bf 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -324,7 +324,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse if err != nil { panic(fmt.Sprintf("state write error: %v", err)) } - if err = triedb.Commit(root, false); err != nil { + if err := statedb.Database().TrieDB().CommitAll(root, false); err != nil { panic(fmt.Sprintf("trie write error: %v", err)) } return block, b.receipts diff --git a/core/genesis.go b/core/genesis.go index 8fa107ac66..d3ed27df2e 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -172,7 +172,7 @@ func (ga *GenesisAlloc) flush(db ethdb.Database, triedb *trie.Database, blockhas } // Commit newly generated states into disk if it's not empty. if root != types.EmptyRootHash { - if err := triedb.Commit(root, true); err != nil { + if err := triedb.CommitAll(root, true); err != nil { return err } } diff --git a/core/rawdb/accessors_epoch_meta.go b/core/rawdb/accessors_epoch_meta.go new file mode 100644 index 0000000000..9c72f36dc9 --- /dev/null +++ b/core/rawdb/accessors_epoch_meta.go @@ -0,0 +1,63 @@ +package rawdb + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" +) + +func DeleteEpochMetaSnapshotJournal(db ethdb.KeyValueWriter) { + if err := db.Delete(epochMetaSnapshotJournalKey); err != nil { + log.Crit("Failed to remove snapshot journal", "err", err) + } +} + +func ReadEpochMetaSnapshotJournal(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(epochMetaSnapshotJournalKey) + return data +} + +func WriteEpochMetaSnapshotJournal(db ethdb.KeyValueWriter, journal []byte) { + if err := db.Put(epochMetaSnapshotJournalKey, journal); err != nil { + log.Crit("Failed to store snapshot journal", "err", err) + } +} + +func ReadEpochMetaPlainStateMeta(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(epochMetaPlainStateMeta) + return data +} + +func WriteEpochMetaPlainStateMeta(db ethdb.KeyValueWriter, val []byte) error { + return db.Put(epochMetaPlainStateMeta, val) +} + +func ReadEpochMetaPlainState(db ethdb.KeyValueReader, addr common.Hash, path string) []byte { + val, _ := db.Get(epochMetaPlainStateKey(addr, path)) + return val +} + +func WriteEpochMetaPlainState(db ethdb.KeyValueWriter, addr common.Hash, path string, val []byte) error { + return db.Put(epochMetaPlainStateKey(addr, path), val) +} + +func DeleteEpochMetaPlainState(db ethdb.KeyValueWriter, addr common.Hash, path string) error { + return db.Delete(epochMetaPlainStateKey(addr, path)) +} + +func ReadStateExpiryCfg(db ethdb.Reader) []byte { + val, _ := db.Get(stateExpiryCfgKey) + return val +} + +func WriteStateExpiryCfg(db ethdb.KeyValueWriter, val []byte) error { + return db.Put(stateExpiryCfgKey, val) +} + +func epochMetaPlainStateKey(addr common.Hash, path string) []byte { + key := make([]byte, len(EpochMetaPlainStatePrefix)+len(addr)+len(path)) + copy(key[:], EpochMetaPlainStatePrefix) + copy(key[len(EpochMetaPlainStatePrefix):], addr.Bytes()) + copy(key[len(EpochMetaPlainStatePrefix)+len(addr):], path) + return key +} diff --git a/core/rawdb/accessors_trie.go b/core/rawdb/accessors_trie.go index f5c2f8899a..2911901ab4 100644 --- a/core/rawdb/accessors_trie.go +++ b/core/rawdb/accessors_trie.go @@ -20,6 +20,8 @@ import ( "fmt" "sync" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" @@ -110,9 +112,14 @@ func ReadStorageTrieNode(db ethdb.KeyValueReader, accountHash common.Hash, path if err != nil { return nil, common.Hash{} } + + raw, err := types.DecodeTypedTrieNodeRaw(data) + if err != nil { + panic(fmt.Errorf("ReadStorageTrieNode err, %v", err)) + } h := newHasher() defer h.release() - return data, h.hash(data) + return data, h.hash(raw) } // HasStorageTrieNode checks the storage trie node presence with the provided diff --git a/core/rawdb/ancient_utils.go b/core/rawdb/ancient_utils.go index 392ac79631..39632fb3be 100644 --- a/core/rawdb/ancient_utils.go +++ b/core/rawdb/ancient_utils.go @@ -109,9 +109,9 @@ func inspectFreezers(db ethdb.Database) ([]freezerInfo, error) { return nil, err } infos = append(infos, info) - - default: - return nil, fmt.Errorf("unknown freezer, supported ones: %v", freezers) + //TODO(0xbundler): bug? + //default: + // return nil, fmt.Errorf("unknown freezer, supported ones: %v", freezers) } } return infos, nil diff --git a/core/rawdb/database.go b/core/rawdb/database.go index 40a034923d..11e9580ad1 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -629,6 +629,8 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { txLookups stat accountSnaps stat storageSnaps stat + snapJournal stat + trieJournal stat preimages stat bloomBits stat cliqueSnaps stat @@ -642,6 +644,11 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { metadata stat unaccounted stat + // state expiry statistics + epochMetaMetaSize stat + epochMetaSnapJournalSize stat + epochMetaPlainStateSize stat + // Totals total common.StorageSize ) @@ -703,14 +710,24 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { bytes.HasPrefix(key, BloomTrieIndexPrefix) || bytes.HasPrefix(key, BloomTriePrefix): // Bloomtrie sub bloomTrieNodes.Add(size) + case bytes.Equal(key, epochMetaPlainStateMeta): + epochMetaMetaSize.Add(size) + case bytes.Equal(key, snapshotJournalKey): + snapJournal.Add(size) + case bytes.Equal(key, trieJournalKey): + trieJournal.Add(size) + case bytes.Equal(key, epochMetaSnapshotJournalKey): + epochMetaSnapJournalSize.Add(size) + case bytes.HasPrefix(key, EpochMetaPlainStatePrefix) && len(key) >= (len(EpochMetaPlainStatePrefix)+common.HashLength): + epochMetaPlainStateSize.Add(size) default: var accounted bool for _, meta := range [][]byte{ databaseVersionKey, headHeaderKey, headBlockKey, headFastBlockKey, - lastPivotKey, fastTrieProgressKey, snapshotDisabledKey, SnapshotRootKey, snapshotJournalKey, + lastPivotKey, fastTrieProgressKey, snapshotDisabledKey, SnapshotRootKey, snapshotGeneratorKey, snapshotRecoveryKey, txIndexTailKey, fastTxLookupLimitKey, uncleanShutdownKey, badBlockKey, transitionStatusKey, skeletonSyncStatusKey, - persistentStateIDKey, trieJournalKey, snapshotSyncStatusKey, + persistentStateIDKey, snapshotSyncStatusKey, } { if bytes.Equal(key, meta) { metadata.Add(size) @@ -743,14 +760,19 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { {"Key-Value store", "Path trie state lookups", stateLookups.Size(), stateLookups.Count()}, {"Key-Value store", "Path trie account nodes", accountTries.Size(), accountTries.Count()}, {"Key-Value store", "Path trie storage nodes", storageTries.Size(), storageTries.Count()}, + {"Key-Value store", "Path trie snap journal", trieJournal.Size(), trieJournal.Count()}, {"Key-Value store", "Trie preimages", preimages.Size(), preimages.Count()}, {"Key-Value store", "Account snapshot", accountSnaps.Size(), accountSnaps.Count()}, {"Key-Value store", "Storage snapshot", storageSnaps.Size(), storageSnaps.Count()}, + {"Key-Value store", "Snapshot Journal", snapJournal.Size(), snapJournal.Count()}, {"Key-Value store", "Clique snapshots", cliqueSnaps.Size(), cliqueSnaps.Count()}, {"Key-Value store", "Parlia snapshots", parliaSnaps.Size(), parliaSnaps.Count()}, {"Key-Value store", "Singleton metadata", metadata.Size(), metadata.Count()}, {"Light client", "CHT trie nodes", chtTrieNodes.Size(), chtTrieNodes.Count()}, {"Light client", "Bloom trie nodes", bloomTrieNodes.Size(), bloomTrieNodes.Count()}, + {"State Expiry", "Epoch Metadata", epochMetaMetaSize.Size(), epochMetaMetaSize.Count()}, + {"State Expiry", "EpochMeta KV", epochMetaPlainStateSize.Size(), epochMetaPlainStateSize.Count()}, + {"State Expiry", "EpochMeta Snap Journal", epochMetaSnapJournalSize.Size(), epochMetaSnapJournalSize.Count()}, } // Inspect all registered append-only file store then. ancients, err := inspectFreezers(db) diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index 13a00d795a..99498174d4 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -103,6 +103,16 @@ var ( // transitionStatusKey tracks the eth2 transition status. transitionStatusKey = []byte("eth2-transition") + // state expiry feature + // epochMetaSnapshotJournalKey tracks the in-memory diff layers across restarts. + epochMetaSnapshotJournalKey = []byte("epochMetaSnapshotJournalKey") + + // epochMetaPlainStateMeta save disk layer meta data + epochMetaPlainStateMeta = []byte("epochMetaPlainStateMeta") + + // stateExpiryCfgKey save state expiry persistence config + stateExpiryCfgKey = []byte("stateExpiryCfgKey") + // Data item prefixes (use single byte to avoid mixing data types, avoid `i`, used for indexes). headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td @@ -144,6 +154,9 @@ var ( CliqueSnapshotPrefix = []byte("clique-") ParliaSnapshotPrefix = []byte("parlia-") + // state expiry feature + EpochMetaPlainStatePrefix = []byte("em") // EpochMetaPlainStatePrefix + addr hash + path -> val + preimageCounter = metrics.NewRegisteredCounter("db/preimage/total", nil) preimageHitCounter = metrics.NewRegisteredCounter("db/preimage/hits", nil) ) diff --git a/core/state/database.go b/core/state/database.go index cc5dc73c77..ab17685f84 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -21,6 +21,8 @@ import ( "fmt" "time" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/lru" "github.com/ethereum/go-ethereum/core/rawdb" @@ -99,6 +101,8 @@ type Trie interface { // a trie.MissingNodeError is returned. GetStorage(addr common.Address, key []byte) ([]byte, error) + GetStorageAndUpdateEpoch(addr common.Address, key []byte) ([]byte, error) + // GetAccount abstracts an account read from the trie. It retrieves the // account blob from the trie with provided account address and decodes it // with associated decoding algorithm. If the specified account is not in @@ -154,6 +158,21 @@ type Trie interface { // nodes of the longest existing prefix of the key (at least the root), ending // with the node that proves the absence of the key. Prove(key []byte, proofDb ethdb.KeyValueWriter) error + + // ProveByPath generate proof state in trie. + ProveByPath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error + + // TryRevive revive expired state from proof. + TryRevive(key []byte, proof []*trie.MPTProofNub) ([]*trie.MPTProofNub, error) + + // SetEpoch set current epoch in trie, it must set in initial period, or it will get error behavior. + SetEpoch(types.StateEpoch) + + // Epoch get current epoch in trie + Epoch() types.StateEpoch + + // TryLocalRevive it revive using local non-pruned states + TryLocalRevive(addr common.Address, key []byte) ([]byte, error) } // NewDatabase creates a backing store for state. The returned database is safe for @@ -191,6 +210,7 @@ func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database { } } +// TODO(0xbundler): may TrieCacheSize not support in PBSS func NewDatabaseWithConfigAndCache(db ethdb.Database, config *trie.Config) Database { atc, _ := exlru.New(accountTrieCacheSize) stc, _ := exlru.New(storageTrieCacheSize) @@ -267,6 +287,7 @@ func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, address common.Addre triesPairs := tries.([3]*triePair) for _, triePair := range triesPairs { if triePair != nil && triePair.root == root { + log.Info("OpenStorageTrie hit storageTrieCache", "addr", address, "root", root) return triePair.trie.(*trie.SecureTrie).Copy(), nil } } diff --git a/core/state/iterator.go b/core/state/iterator.go index bb9af08206..24fe4a77d5 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -127,6 +127,9 @@ func (it *nodeIterator) step() error { if err != nil { return err } + if it.state.EnableExpire() { + dataTrie.SetEpoch(it.state.Epoch()) + } it.dataIt, err = dataTrie.NodeIterator(nil) if err != nil { return err diff --git a/core/state/pruner/pruner.go b/core/state/pruner/pruner.go index 1555852d03..7550e4743b 100644 --- a/core/state/pruner/pruner.go +++ b/core/state/pruner/pruner.go @@ -19,14 +19,20 @@ package pruner import ( "bytes" "encoding/binary" + "encoding/hex" "errors" "fmt" "math" + "math/big" "os" "path/filepath" "strings" + "sync" "time" + "github.com/ethereum/go-ethereum/params" + bloomfilter "github.com/holiman/bloomfilter/v2" + "github.com/prometheus/tsdb/fileutil" "github.com/ethereum/go-ethereum/common" @@ -57,13 +63,23 @@ const ( // rangeCompactionThreshold is the minimal deleted entry number for // triggering range compaction. It's a quite arbitrary number but just // to avoid triggering range compaction because of small deletion. - rangeCompactionThreshold = 100000 + rangeCompactionThreshold = 1000000 + + FixedPrefixAndAddrSize = 33 + + defaultReportDuration = 60 * time.Second + + defaultChannelSize = 200000 ) // Config includes all the configurations for pruning. type Config struct { - Datadir string // The directory of the state database - BloomSize uint64 // The Megabytes of memory allocated to bloom-filter + Datadir string // The directory of the state database + BloomSize uint64 // The Megabytes of memory allocated to bloom-filter + EnableStateExpiry bool + ChainConfig *params.ChainConfig + CacheConfig *core.CacheConfig + MaxExpireThreads uint64 } // Pruner is an offline tool to prune the stale state with the @@ -84,6 +100,7 @@ type Pruner struct { stateBloom *stateBloom snaptree *snapshot.Tree triesInMemory uint64 + flattenBlock *types.Header } type BlockPruner struct { @@ -102,6 +119,7 @@ func NewPruner(db ethdb.Database, config Config, triesInMemory uint64) (*Pruner, } // Offline pruning is only supported in legacy hash based scheme. triedb := trie.NewDatabase(db, trie.HashDefaults) + log.Info("ChainConfig", "headBlock", headBlock.NumberU64(), "config", config) snapconfig := snapshot.Config{ CacheSize: 256, @@ -122,6 +140,12 @@ func NewPruner(db ethdb.Database, config Config, triesInMemory uint64) (*Pruner, if err != nil { return nil, err } + + flattenBlockHash := rawdb.ReadCanonicalHash(db, headBlock.NumberU64()-triesInMemory) + flattenBlock := rawdb.ReadHeader(db, flattenBlockHash, headBlock.NumberU64()-triesInMemory) + if flattenBlock == nil { + return nil, fmt.Errorf("cannot find %v depth block, it cannot prune", triesInMemory) + } return &Pruner{ config: config, chainHeader: headBlock.Header(), @@ -129,6 +153,7 @@ func NewPruner(db ethdb.Database, config Config, triesInMemory uint64) (*Pruner, stateBloom: stateBloom, snaptree: snaptree, triesInMemory: triesInMemory, + flattenBlock: flattenBlock, }, nil } @@ -181,7 +206,7 @@ func pruneAll(maindb ethdb.Database, g *core.Genesis) error { ) eta = time.Duration(left/speed) * time.Millisecond } - if time.Since(logged) > 8*time.Second { + if time.Since(logged) > defaultReportDuration { log.Info("Pruning state data", "nodes", count, "size", size, "elapsed", common.PrettyDuration(time.Since(pstart)), "eta", common.PrettyDuration(eta)) logged = time.Now() @@ -241,6 +266,7 @@ func pruneAll(maindb ethdb.Database, g *core.Genesis) error { } func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, stateBloom *stateBloom, bloomPath string, middleStateRoots map[common.Hash]struct{}, start time.Time) error { + log.Info("Start Prune state data", "root", root) // Delete all stale trie nodes in the disk. With the help of state bloom // the trie nodes(and codes) belong to the active state will be filtered // out. A very small part of stale tries will also be filtered because of @@ -288,7 +314,7 @@ func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, sta ) eta = time.Duration(left/speed) * time.Millisecond } - if time.Since(logged) > 8*time.Second { + if time.Since(logged) > defaultReportDuration { log.Info("Pruning state data", "nodes", count, "size", size, "elapsed", common.PrettyDuration(time.Since(pstart)), "eta", common.PrettyDuration(eta)) logged = time.Now() @@ -363,6 +389,7 @@ func (p *BlockPruner) backUpOldDb(name string, cache, handles int, namespace str log.Error("Failed to open ancient database", "err=", err) return err } + defer chainDb.Close() log.Info("chainDB opened successfully") @@ -651,7 +678,319 @@ func (p *Pruner) Prune(root common.Hash) error { return err } log.Info("State bloom filter committed", "name", filterName) - return prune(p.snaptree, root, p.db, p.stateBloom, filterName, middleRoots, start) + if err = prune(p.snaptree, root, p.db, p.stateBloom, filterName, middleRoots, start); err != nil { + return err + } + + // find target header + header := p.chainHeader + for header != nil && header.Root != root { + header = rawdb.ReadHeader(p.db, header.ParentHash, header.Number.Uint64()-1) + } + if header == nil || header.Root != root { + return fmt.Errorf("cannot find target block root, chainHeader: %v:%v:%v, targetRoot: %v", + p.chainHeader.Number, p.chainHeader.Hash(), p.chainHeader.Root, root) + } + + if err = p.ExpiredPrune(header.Number, root); err != nil { + return err + } + + return nil +} + +// ExpiredPrune it must run later to prune, using bloom filter in HBSS to prevent pruning in use trie node, cannot prune concurrently. +// but in PBSS, it need not bloom filter +func (p *Pruner) ExpiredPrune(height *big.Int, root common.Hash) error { + if !p.config.EnableStateExpiry { + log.Info("stop prune expired state, disable state expiry", "height", height, "root", root, "scheme", p.config.CacheConfig.StateScheme) + return nil + } + + // if root is empty, using the deepest snap block to prune expired state + if root == (common.Hash{}) { + height = p.flattenBlock.Number + root = p.flattenBlock.Root + } + + var ( + bloom *bloomfilter.Filter + epoch = types.GetStateEpoch(p.config.CacheConfig.StateExpiryCfg, height) + trieDB = trie.NewDatabase(p.db, p.config.CacheConfig.TriedbConfig()) + err error + ) + log.Info("start prune expired state", "height", height, "root", root, "scheme", p.config.CacheConfig.StateScheme, "epoch", epoch) + + // if using HBSS, must tag all unexpired trie prevent shared trie delete + if rawdb.HashScheme == p.config.CacheConfig.StateScheme { + bloom, err = p.unExpiredBloomTag(trieDB, epoch, root) + if err != nil { + return err + } + } + + var ( + scanExpiredTrieCh = make(chan *snapshot.ContractItem, defaultChannelSize) + pruneExpiredInDiskCh = make(chan *trie.NodeInfo, defaultChannelSize) + rets = make([]error, 3) + tasksWG sync.WaitGroup + ) + tasksWG.Add(2) + go func() { + defer tasksWG.Done() + rets[0] = asyncScanExpiredInTrie(trieDB, root, epoch, scanExpiredTrieCh, pruneExpiredInDiskCh, p.config.MaxExpireThreads) + }() + go func() { + defer tasksWG.Done() + rets[1] = asyncPruneExpiredStorageInDisk(p.db, pruneExpiredInDiskCh, bloom, p.config.CacheConfig.StateScheme) + }() + rets[2] = snapshot.TraverseContractTrie(p.snaptree, root, scanExpiredTrieCh) + + // wait task done + tasksWG.Wait() + for i, item := range rets { + if item != nil { + log.Error("prune expired state got error", "index", i, "err", item) + } + } + + // recap epoch meta snap, save journal + snap := trieDB.EpochMetaSnapTree() + if snap != nil { + log.Info("epoch meta snap handle", "root", root) + if err := snap.Cap(root); err != nil { + log.Error("asyncPruneExpired, SnapTree Cap err", "err", err) + return err + } + if err := snap.Journal(); err != nil { + log.Error("asyncPruneExpired, SnapTree Journal err", "err", err) + return err + } + } + log.Info("Expired State pruning successful") + + return nil +} + +func (p *Pruner) unExpiredBloomTag(trieDB *trie.Database, epoch types.StateEpoch, root common.Hash) (*bloomfilter.Filter, error) { + var ( + scanUnExpiredTrieCh = make(chan *snapshot.ContractItem, defaultChannelSize) + tagUnExpiredInBloomCh = make(chan *trie.NodeInfo, defaultChannelSize) + rets = make([]error, 3) + tasksWG sync.WaitGroup + ) + + bloom, err := bloomfilter.New(p.config.BloomSize*1024*1024*8, 4) + if err != nil { + return nil, err + } + + tasksWG.Add(2) + go func() { + defer tasksWG.Done() + rets[0] = asyncScanUnExpiredInTrie(trieDB, root, epoch, scanUnExpiredTrieCh, tagUnExpiredInBloomCh, p.config.MaxExpireThreads) + }() + go func() { + defer tasksWG.Done() + rets[1] = asyncTagUnExpiredInBloom(tagUnExpiredInBloomCh, bloom) + }() + rets[2] = snapshot.TraverseContractTrie(p.snaptree, root, scanUnExpiredTrieCh) + tasksWG.Wait() + + return bloom, nil +} + +func asyncTagUnExpiredInBloom(tagUnExpiredInBloomCh chan *trie.NodeInfo, bloom *bloomfilter.Filter) error { + var ( + trieCount = 0 + start = time.Now() + logged = time.Now() + ) + for info := range tagUnExpiredInBloomCh { + trieCount++ + bloom.Add(stateBloomHasher(info.Hash[:])) + if time.Since(logged) > defaultReportDuration { + log.Info("Tag unexpired states in bloom", "trieNodes", trieCount) + logged = time.Now() + } + } + log.Info("Tag unexpired states in bloom", "trieNodes", trieCount, "elapsed", common.PrettyDuration(time.Since(start))) + return nil +} + +func asyncScanUnExpiredInTrie(db *trie.Database, stateRoot common.Hash, epoch types.StateEpoch, scanUnExpiredTrieCh chan *snapshot.ContractItem, tagUnExpiredInBloomCh chan *trie.NodeInfo, maxThreads uint64) error { + defer func() { + close(tagUnExpiredInBloomCh) + }() + st := trie.NewScanTask(tagUnExpiredInBloomCh, maxThreads, false) + go st.Report(defaultReportDuration) + for item := range scanUnExpiredTrieCh { + log.Info("start scan trie unexpired state", "addrHash", item.Addr, "root", item.Root) + tr, err := trie.New(&trie.ID{ + StateRoot: stateRoot, + Owner: item.Addr, + Root: item.Root, + }, db) + if err != nil { + log.Error("asyncScanUnExpiredInTrie, trie.New err", "id", item, "err", err) + return err + } + tr.SetEpoch(epoch) + if st.MoreThread() { + st.Schedule(func() { + if err = tr.ScanForPrune(st); err != nil { + log.Error("asyncScanExpiredInTrie, ScanForPrune err", "id", item, "err", err) + } + }) + continue + } + if err = tr.ScanForPrune(st); err != nil { + log.Error("asyncScanExpiredInTrie, ScanForPrune err", "id", item, "err", err) + return err + } + } + st.WaitThreads() + return nil +} + +// asyncScanExpiredInTrie prune trie expired state +// here are some issues when just delete it from hash-based storage, because it's shared kv in hbss +// but it's ok for pbss. +func asyncScanExpiredInTrie(db *trie.Database, stateRoot common.Hash, epoch types.StateEpoch, expireContractCh chan *snapshot.ContractItem, pruneExpiredInDisk chan *trie.NodeInfo, maxThreads uint64) error { + defer func() { + close(pruneExpiredInDisk) + }() + st := trie.NewScanTask(pruneExpiredInDisk, maxThreads, true) + go st.Report(defaultReportDuration) + for item := range expireContractCh { + log.Debug("start scan trie expired state", "addrHash", item.Addr, "root", item.Root) + tr, err := trie.New(&trie.ID{ + StateRoot: stateRoot, + Owner: item.Addr, + Root: item.Root, + }, db) + if err != nil { + log.Error("asyncScanExpiredInTrie, trie.New err", "id", item, "err", err) + return err + } + tr.SetEpoch(epoch) + if err = tr.ScanForPrune(st); err != nil { + log.Error("asyncScanExpiredInTrie, ScanForPrune err", "id", item, "err", err) + return err + } + } + st.WaitThreads() + return nil +} + +func asyncPruneExpiredStorageInDisk(diskdb ethdb.Database, pruneExpiredInDisk chan *trie.NodeInfo, bloom *bloomfilter.Filter, scheme string) error { + var ( + itemCount = 0 + trieCount = 0 + epochMetaCount = 0 + snapCount = 0 + trieSize common.StorageSize + snapSize common.StorageSize + epochMetaSize common.StorageSize + start = time.Now() + logged = time.Now() + ) + batch := diskdb.NewBatch() + for info := range pruneExpiredInDisk { + log.Debug("found expired state", "addr", info.Addr, "path", + hex.EncodeToString(info.Path), "epoch", info.Epoch, "isBranch", + info.IsBranch, "isLeaf", info.IsLeaf) + itemCount++ + addr := info.Addr + switch scheme { + case rawdb.PathScheme: + val := rawdb.ReadTrieNode(diskdb, addr, info.Path, info.Hash, rawdb.PathScheme) + if len(val) == 0 { + log.Debug("cannot find source trie?", "addr", addr, "path", info.Path, "hash", info.Hash, "epoch", info.Epoch) + } else { + trieCount++ + trieSize += common.StorageSize(len(val) + FixedPrefixAndAddrSize + len(info.Path)) + rawdb.DeleteTrieNode(batch, addr, info.Path, info.Hash, rawdb.PathScheme) + } + case rawdb.HashScheme: + // hbss has shared kv, so using bloom to filter them out. + if bloom == nil || !bloom.Contains(stateBloomHasher(info.Hash.Bytes())) { + val := rawdb.ReadTrieNode(diskdb, addr, info.Path, info.Hash, rawdb.HashScheme) + if len(val) == 0 { + log.Debug("cannot find source trie?", "addr", addr, "path", info.Path, "hash", info.Hash, "epoch", info.Epoch) + } else { + trieCount++ + trieSize += common.StorageSize(len(val) + FixedPrefixAndAddrSize) + rawdb.DeleteTrieNode(batch, addr, info.Path, info.Hash, rawdb.HashScheme) + } + } + } + // delete epoch meta in HBSS + if info.IsBranch && rawdb.HashScheme == scheme { + val := rawdb.ReadEpochMetaPlainState(diskdb, addr, string(info.Path)) + if len(val) == 0 && info.Epoch > types.StateEpoch0 { + log.Debug("cannot find source epochmeta?", "addr", addr, "path", info.Path, "hash", info.Hash, "epoch", info.Epoch) + } + if len(val) > 0 { + epochMetaCount++ + epochMetaSize += common.StorageSize(FixedPrefixAndAddrSize + len(info.Path) + len(val)) + rawdb.DeleteEpochMetaPlainState(batch, addr, string(info.Path)) + } + } + // replace snapshot kv only epoch + if info.IsLeaf { + size, err := snapshot.ShrinkExpiredLeaf(batch, diskdb, addr, info.Key, info.Epoch, scheme) + if err != nil { + log.Error("ShrinkExpiredLeaf err", "addr", addr, "key", info.Key, "err", err) + } + if size > 0 { + snapCount++ + snapSize += common.StorageSize(size) + } + } + if batch.ValueSize() >= ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + log.Error("asyncPruneExpiredStorageInDisk, batch write err", "err", err) + } + batch.Reset() + } + if time.Since(logged) > defaultReportDuration { + log.Info("Pruning expired states", "items", itemCount, "trieNodes", trieCount, "trieSize", trieSize, + "SnapKV", snapCount, "SnapKVSize", snapSize, "EpochMeta", epochMetaCount, + "EpochMetaSize", epochMetaSize) + logged = time.Now() + } + } + if batch.ValueSize() > 0 { + if err := batch.Write(); err != nil { + log.Error("asyncPruneExpiredStorageInDisk, batch write err", "err", err) + } + batch.Reset() + } + log.Info("Pruned expired states", "items", itemCount, "trieNodes", trieCount, "trieSize", trieSize, + "SnapKV", snapCount, "SnapKVSize", snapSize, "EpochMeta", epochMetaCount, + "EpochMetaSize", epochMetaSize, "elapsed", common.PrettyDuration(time.Since(start))) + // Start compactions, will remove the deleted data from the disk immediately. + // Note for small pruning, the compaction is skipped. + if trieCount+snapCount+epochMetaCount >= rangeCompactionThreshold { + cstart := time.Now() + for b := 0x00; b <= 0xf0; b += 0x10 { + var ( + start = []byte{byte(b)} + end = []byte{byte(b + 0x10)} + ) + if b == 0xf0 { + end = nil + } + log.Info("Compacting database", "range", fmt.Sprintf("%#x-%#x", start, end), "elapsed", common.PrettyDuration(time.Since(cstart))) + if err := diskdb.Compact(start, end); err != nil { + log.Error("Database compaction failed", "error", err) + return err + } + } + log.Info("Database compaction finished", "elapsed", common.PrettyDuration(time.Since(cstart))) + } + return nil } // RecoverPruning will resume the pruning procedure during the system restart. diff --git a/core/state/snapshot/conversion.go b/core/state/snapshot/conversion.go index 5c5d6a8b48..0c61e168d6 100644 --- a/core/state/snapshot/conversion.go +++ b/core/state/snapshot/conversion.go @@ -42,6 +42,11 @@ type trieKV struct { value []byte } +type ContractItem struct { + Addr common.Hash + Root common.Hash +} + type ( // trieGeneratorFn is the interface of trie generation which can // be implemented by different trie algorithm. @@ -106,6 +111,71 @@ func GenerateTrie(snaptree *Tree, root common.Hash, src ethdb.Database, dst ethd return nil } +// TraverseContractTrie traverse contract from snap iterator +func TraverseContractTrie(snaptree *Tree, root common.Hash, pruneExpiredTrieCh chan *ContractItem) error { + stats := newGenerateStats() + // Traverse all state by snapshot, re-generate the whole state trie + acctIt, err := snaptree.AccountIterator(root, common.Hash{}) + if err != nil { + return err // The required snapshot might not exist. + } + defer acctIt.Release() + + var ( + stoplog = make(chan bool, 1) // 1-size buffer, works when logging is not enabled + wg sync.WaitGroup + ) + // Spin up a go-routine for progress logging + if stats != nil { + wg.Add(1) + go func() { + defer wg.Done() + runReport(stats, stoplog) + }() + } + + var ( + logged = time.Now() + processed = uint64(0) + account *types.StateAccount + ) + // Start to feed leaves + for acctIt.Next() { + // Fetch the next account and process it concurrently + account, err = types.FullAccount(acctIt.Account()) + if err != nil { + break + } + + hash := acctIt.Hash() + // async prune trie expired states + if pruneExpiredTrieCh != nil && (account.Root != common.Hash{} && account.Root != types.EmptyRootHash) { + pruneExpiredTrieCh <- &ContractItem{ + Addr: hash, + Root: account.Root, + } + } + + // Accumulate the generation statistic if it's required. + processed++ + if time.Since(logged) > 3*time.Second && stats != nil { + stats.progressAccounts(hash, processed) + logged, processed = time.Now(), 0 + } + } + close(pruneExpiredTrieCh) + stoplog <- true + + // Commit the last part statistic. + if processed > 0 && stats != nil { + stats.finishAccounts(processed) + } + + // wait tasks down + wg.Wait() + return err +} + // generateStats is a collection of statistics gathered by the trie generator // for logging purposes. type generateStats struct { @@ -338,7 +408,15 @@ func generateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, accou } leaf = trieKV{it.Hash(), fullData} } else { - leaf = trieKV{it.Hash(), common.CopyBytes(it.(StorageIterator).Slot())} + enc := common.CopyBytes(it.(StorageIterator).Slot()) + if len(enc) > 0 { + val, err := DecodeValueFromRLPBytes(enc) + if err != nil { + return stop(err) + } + enc, _ = rlp.EncodeToBytes(val.GetVal()) + } + leaf = trieKV{it.Hash(), enc} } in <- leaf diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go index fbd83e3b52..7b5be0f2e7 100644 --- a/core/state/snapshot/generate.go +++ b/core/state/snapshot/generate.go @@ -194,7 +194,11 @@ func (dl *diskLayer) proveRange(ctx *generatorContext, trieId *trie.ID, prefix [ keys = append(keys, common.CopyBytes(key[len(prefix):])) if valueConvertFn == nil { - vals = append(vals, common.CopyBytes(iter.Value())) + rlpVal, err := convertSnapValToRLPVal(iter.Value()) + if err != nil { + return nil, err + } + vals = append(vals, rlpVal) } else { val, err := valueConvertFn(iter.Value()) if err != nil { @@ -204,10 +208,18 @@ func (dl *diskLayer) proveRange(ctx *generatorContext, trieId *trie.ID, prefix [ // // Here append the original value to ensure that the number of key and // value are aligned. - vals = append(vals, common.CopyBytes(iter.Value())) + rlpVal, err := convertSnapValToRLPVal(val) + if err != nil { + return nil, err + } + vals = append(vals, rlpVal) log.Error("Failed to convert account state data", "err", err) } else { - vals = append(vals, val) + rlpVal, err := convertSnapValToRLPVal(val) + if err != nil { + return nil, err + } + vals = append(vals, rlpVal) } } } @@ -367,8 +379,10 @@ func (dl *diskLayer) generateRange(ctx *generatorContext, trieId *trie.ID, prefi return false, nil, err } if nodes != nil { - tdb.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) - tdb.Commit(root, false) + // TODO(Nathan): why block is zero? + block := uint64(0) + tdb.Update(root, types.EmptyRootHash, block, trienode.NewWithNodeSet(nodes), nil) + tdb.CommitAll(root, false) } resolver = func(owner common.Hash, path []byte, hash common.Hash) []byte { return rawdb.ReadTrieNode(mdb, owner, path, hash, tdb.Scheme()) @@ -733,6 +747,14 @@ func increaseKey(key []byte) []byte { return nil } +func convertSnapValToRLPVal(val []byte) ([]byte, error) { + snapVal, err := DecodeValueFromRLPBytes(val) + if err != nil { + return nil, err + } + return rlp.EncodeToBytes(snapVal.GetVal()) +} + // abortErr wraps an interruption signal received to represent the // generation is aborted by external processes. type abortErr struct { diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go index 07016b675c..40a73f5bd8 100644 --- a/core/state/snapshot/generate_test.go +++ b/core/state/snapshot/generate_test.go @@ -66,7 +66,7 @@ func testGeneration(t *testing.T, scheme string) { helper.makeStorageTrie(hashData([]byte("acc-3")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true) root, snap := helper.CommitAndGenerate() - if have, want := root, common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"); have != want { + if have, want := root, common.HexToHash("0x1b4b0ae3b50e6ce40184d08fc5857c5b6909e2b1d8017d9e3f69170e323b1f6c"); have != want { t.Fatalf("have %#x want %#x", have, want) } select { @@ -196,7 +196,8 @@ func (t *testHelper) addAccount(acckey string, acc *types.StateAccount) { func (t *testHelper) addSnapStorage(accKey string, keys []string, vals []string) { accHash := hashData([]byte(accKey)) for i, key := range keys { - rawdb.WriteStorageSnapshot(t.diskdb, accHash, hashData([]byte(key)), []byte(vals[i])) + val, _ := rlp.EncodeToBytes(vals[i]) + rawdb.WriteStorageSnapshot(t.diskdb, accHash, hashData([]byte(key)), val) } } @@ -204,7 +205,8 @@ func (t *testHelper) makeStorageTrie(owner common.Hash, keys []string, vals []st id := trie.StorageTrieID(types.EmptyRootHash, owner, types.EmptyRootHash) stTrie, _ := trie.NewStateTrie(id, t.triedb) for i, k := range keys { - stTrie.MustUpdate([]byte(k), []byte(vals[i])) + rlpVal, _ := rlp.EncodeToBytes(vals[i]) + stTrie.MustUpdate([]byte(k), rlpVal) // [133,118,97,108,45,49] } if !commit { return stTrie.Hash() @@ -512,7 +514,7 @@ func testGenerateCorruptStorageTrie(t *testing.T, scheme string) { // Delete a node in the storage trie. targetPath := []byte{0x4} - targetHash := common.HexToHash("0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371") + targetHash := common.HexToHash("0x1b4b0ae3b50e6ce40184d08fc5857c5b6909e2b1d8017d9e3f69170e323b1f6c") rawdb.DeleteTrieNode(helper.diskdb, hashData([]byte("acc-1")), targetPath, targetHash, scheme) rawdb.DeleteTrieNode(helper.diskdb, hashData([]byte("acc-3")), targetPath, targetHash, scheme) @@ -552,12 +554,16 @@ func testGenerateWithExtraAccounts(t *testing.T, scheme string) { // Identical in the snap key := hashData([]byte("acc-1")) - rawdb.WriteAccountSnapshot(helper.diskdb, key, val) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-1")), []byte("val-1")) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-2")), []byte("val-2")) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-3")), []byte("val-3")) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-4")), []byte("val-4")) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-5")), []byte("val-5")) + val, _ = rlp.EncodeToBytes([]byte("val-1")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-1")), val) + val, _ = rlp.EncodeToBytes([]byte("val-2")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-2")), val) + val, _ = rlp.EncodeToBytes([]byte("val-3")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-3")), val) + val, _ = rlp.EncodeToBytes([]byte("val-4")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-4")), val) + val, _ = rlp.EncodeToBytes([]byte("val-5")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-5")), val) } { // Account two exists only in the snapshot @@ -570,9 +576,12 @@ func testGenerateWithExtraAccounts(t *testing.T, scheme string) { val, _ := rlp.EncodeToBytes(acc) key := hashData([]byte("acc-2")) rawdb.WriteAccountSnapshot(helper.diskdb, key, val) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-1")), []byte("b-val-1")) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-2")), []byte("b-val-2")) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-3")), []byte("b-val-3")) + val, _ = rlp.EncodeToBytes([]byte("b-val-1")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-1")), val) + val, _ = rlp.EncodeToBytes([]byte("b-val-2")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-2")), val) + val, _ = rlp.EncodeToBytes([]byte("b-val-3")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-3")), val) } root := helper.Commit() @@ -629,9 +638,12 @@ func testGenerateWithManyExtraAccounts(t *testing.T, scheme string) { // Identical in the snap key := hashData([]byte("acc-1")) rawdb.WriteAccountSnapshot(helper.diskdb, key, val) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-1")), []byte("val-1")) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-2")), []byte("val-2")) - rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-3")), []byte("val-3")) + val, _ = rlp.EncodeToBytes([]byte("val-1")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-1")), val) + val, _ = rlp.EncodeToBytes([]byte("val-2")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-2")), val) + val, _ = rlp.EncodeToBytes([]byte("val-3")) + rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-3")), val) } { // 100 accounts exist only in snapshot diff --git a/core/state/snapshot/snapshot_expire.go b/core/state/snapshot/snapshot_expire.go new file mode 100644 index 0000000000..769e0e3c60 --- /dev/null +++ b/core/state/snapshot/snapshot_expire.go @@ -0,0 +1,35 @@ +package snapshot + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" +) + +// ShrinkExpiredLeaf tool function for snapshot kv prune +func ShrinkExpiredLeaf(writer ethdb.KeyValueWriter, reader ethdb.KeyValueReader, accountHash common.Hash, storageHash common.Hash, epoch types.StateEpoch, scheme string) (int64, error) { + switch scheme { + case rawdb.HashScheme: + //cannot prune snapshot in hbss, because it will used for trie prune, but it's ok in pbss. + case rawdb.PathScheme: + val := rawdb.ReadStorageSnapshot(reader, accountHash, storageHash) + if len(val) == 0 { + log.Debug("cannot find source snapshot?", "addr", accountHash, "key", storageHash, "epoch", epoch) + return 0, nil + } + valWithEpoch := NewValueWithEpoch(epoch, nil) + enc, err := EncodeValueToRLPBytes(valWithEpoch) + if err != nil { + return 0, err + } + rawdb.WriteStorageSnapshot(writer, accountHash, storageHash, enc) + shrinkSize := len(val) - len(enc) + if shrinkSize < 0 { + shrinkSize = 0 + } + return int64(shrinkSize), nil + } + return 0, nil +} diff --git a/core/state/snapshot/snapshot_expire_test.go b/core/state/snapshot/snapshot_expire_test.go new file mode 100644 index 0000000000..0bb2dd762c --- /dev/null +++ b/core/state/snapshot/snapshot_expire_test.go @@ -0,0 +1,31 @@ +package snapshot + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +var ( + accountHash = common.HexToHash("0x31b67165f56d0ac50814cafa06748fb3b8fccd3c611a8117350e7a49b44ce130") + storageHash1 = common.HexToHash("0x0bb2f3e66816c6fd12513f053d5ee034b1fa2d448a1dc8ee7f56e4c87d6c53fe") +) + +func TestShrinkExpiredLeaf(t *testing.T) { + db := memorydb.New() + rawdb.WriteStorageSnapshot(db, accountHash, storageHash1, encodeSnapVal(NewRawValue([]byte("val1")))) + + _, err := ShrinkExpiredLeaf(db, db, accountHash, storageHash1, types.StateEpoch0, rawdb.PathScheme) + assert.NoError(t, err) + + assert.Equal(t, encodeSnapVal(NewValueWithEpoch(types.StateEpoch0, nil)), rawdb.ReadStorageSnapshot(db, accountHash, storageHash1)) +} + +func encodeSnapVal(val SnapValue) []byte { + enc, _ := EncodeValueToRLPBytes(val) + return enc +} diff --git a/core/state/snapshot/snapshot_value.go b/core/state/snapshot/snapshot_value.go new file mode 100644 index 0000000000..b9086476d2 --- /dev/null +++ b/core/state/snapshot/snapshot_value.go @@ -0,0 +1,158 @@ +package snapshot + +import ( + "bytes" + "errors" + + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + RawValueType = iota // simple value, cannot exceed 32 bytes + ValueWithEpochType // value add epoch meta +) + +var ( + ErrSnapValueNotSupport = errors.New("the snapshot type not support now") +) + +type SnapValue interface { + GetType() byte + GetEpoch() types.StateEpoch + GetVal() []byte // may cannot provide val in some value types + EncodeToRLPBytes(buf *rlp.EncoderBuffer) +} + +type RawValue []byte + +func NewRawValue(val []byte) SnapValue { + value := RawValue(val) + return &value +} + +func (v *RawValue) GetType() byte { + return RawValueType +} + +func (v *RawValue) GetEpoch() types.StateEpoch { + return types.StateEpoch0 +} + +func (v *RawValue) GetVal() []byte { + return *v +} + +func (v *RawValue) EncodeToRLPBytes(buf *rlp.EncoderBuffer) { + buf.WriteBytes(*v) +} + +type ValueWithEpoch struct { + Epoch types.StateEpoch // kv's epoch meta + Val []byte // if val is empty hash, just encode as empty string in RLP +} + +func NewValueWithEpoch(epoch types.StateEpoch, val []byte) SnapValue { + if epoch == types.StateEpoch0 { + return NewRawValue(val) + } + return &ValueWithEpoch{ + Epoch: epoch, + Val: val, + } +} + +func (v *ValueWithEpoch) GetType() byte { + return ValueWithEpochType +} + +func (v *ValueWithEpoch) GetEpoch() types.StateEpoch { + return v.Epoch +} + +func (v *ValueWithEpoch) GetVal() []byte { + return v.Val +} + +func (v *ValueWithEpoch) EncodeToRLPBytes(buf *rlp.EncoderBuffer) { + offset := buf.List() + buf.WriteUint64(uint64(v.Epoch)) + if len(v.Val) == 0 { + buf.Write(rlp.EmptyString) + } else { + buf.WriteBytes(v.Val) + } + buf.ListEnd(offset) +} + +func EncodeValueToRLPBytes(val SnapValue) ([]byte, error) { + switch raw := val.(type) { + case *RawValue: + return rlp.EncodeToBytes(raw) + default: + return encodeTypedVal(val) + } +} + +func DecodeValueFromRLPBytes(b []byte) (SnapValue, error) { + if len(b) == 0 { + return &RawValue{}, nil + } + if len(b) == 1 || b[0] > 0x7f { + var data RawValue + _, data, _, err := rlp.Split(b) + if err != nil { + return nil, err + } + return &data, nil + } + return decodeTypedVal(b) +} + +func decodeTypedVal(b []byte) (SnapValue, error) { + switch b[0] { + case ValueWithEpochType: + var data ValueWithEpoch + if err := decodeValueWithEpoch(b[1:], &data); err != nil { + return nil, err + } + return &data, nil + default: + return nil, ErrSnapValueNotSupport + } +} + +func decodeValueWithEpoch(data []byte, v *ValueWithEpoch) error { + elems, _, err := rlp.SplitList(data) + if err != nil { + return err + } + + epoch, left, err := rlp.SplitUint64(elems) + if err != nil { + return err + } + v.Epoch = types.StateEpoch(epoch) + + val, _, err := rlp.SplitString(left) + if err != nil { + return err + } + if len(val) == 0 { + v.Val = []byte{} + } else { + v.Val = val + } + return nil +} + +func encodeTypedVal(val SnapValue) ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, 40)) + buf.WriteByte(val.GetType()) + encoder := rlp.NewEncoderBuffer(buf) + val.EncodeToRLPBytes(&encoder) + if err := encoder.Flush(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} diff --git a/core/state/snapshot/snapshot_value_test.go b/core/state/snapshot/snapshot_value_test.go new file mode 100644 index 0000000000..cb5088b20b --- /dev/null +++ b/core/state/snapshot/snapshot_value_test.go @@ -0,0 +1,62 @@ +package snapshot + +import ( + "encoding/hex" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rlp" + "github.com/stretchr/testify/assert" +) + +var ( + val, _ = hex.DecodeString("0000f9eef0150e074b32e3b3b6d34d2534222292e3953019a41d714d135763a6") +) + +func TestRawValueEncode(t *testing.T) { + value := NewRawValue(val) + enc1, _ := rlp.EncodeToBytes(value) + buf := rlp.NewEncoderBuffer(nil) + value.EncodeToRLPBytes(&buf) + assert.Equal(t, enc1, buf.ToBytes()) +} + +func TestSnapValEncodeDecode(t *testing.T) { + tests := []struct { + raw SnapValue + }{ + { + raw: NewRawValue(common.FromHex("0x3")), + }, + { + raw: NewRawValue(val), + }, + { + raw: NewValueWithEpoch(types.StateEpoch(0), common.FromHex("0x00")), + }, + { + raw: NewValueWithEpoch(types.StateEpoch(0), common.FromHex("0x3")), + }, + { + raw: NewValueWithEpoch(types.StateEpoch(1), common.FromHex("0x3")), + }, + { + raw: NewValueWithEpoch(types.StateEpoch(0), val), + }, + { + raw: NewValueWithEpoch(types.StateEpoch(1000), val), + }, + { + raw: NewValueWithEpoch(types.StateEpoch(1000), []byte{}), + }, + } + for _, item := range tests { + enc, err := EncodeValueToRLPBytes(item.raw) + assert.NoError(t, err) + t.Log(hex.EncodeToString(enc)) + tmp, err := DecodeValueFromRLPBytes(enc) + assert.NoError(t, err) + assert.Equal(t, item.raw, tmp) + } +} diff --git a/core/state/state_expiry.go b/core/state/state_expiry.go new file mode 100644 index 0000000000..203ed2fa2e --- /dev/null +++ b/core/state/state_expiry.go @@ -0,0 +1,195 @@ +package state + +import ( + "bytes" + "fmt" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/trie" +) + +var ( + reviveTrieTimer = metrics.NewRegisteredTimer("state/revivetrie/rt", nil) + reviveTrieMeter = metrics.NewRegisteredMeter("state/revivetrie", nil) + reviveFromLocalMeter = metrics.NewRegisteredMeter("state/revivetrie/local", nil) + reviveFromRemoteMeter = metrics.NewRegisteredMeter("state/revivetrie/remote", nil) +) + +// stateExpiryMeta it contains all state expiry meta for target block +type stateExpiryMeta struct { + enableStateExpiry bool + enableLocalRevive bool + fullStateDB ethdb.FullStateDB + epoch types.StateEpoch + originalRoot common.Hash + originalHash common.Hash +} + +func defaultStateExpiryMeta() *stateExpiryMeta { + return &stateExpiryMeta{enableStateExpiry: false} +} + +// tryReviveState request expired state from remote full state node; +func tryReviveState(meta *stateExpiryMeta, addr common.Address, root common.Hash, tr Trie, prefixKey []byte, key common.Hash, force bool) (map[string][]byte, error) { + if !meta.enableStateExpiry { + return nil, nil + } + //log.Debug("fetching expired storage from remoteDB", "addr", addr, "prefix", prefixKey, "key", key) + reviveTrieMeter.Mark(1) + if meta.enableLocalRevive && !force { + // if there need revive expired state, try to revive locally, when the node is not being pruned, just renew the epoch + val, err := tr.TryLocalRevive(addr, key.Bytes()) + //log.Debug("tryReviveState TryLocalRevive", "addr", addr, "key", key, "val", val, "err", err) + switch err.(type) { + case *trie.MissingNodeError: + // cannot revive locally, request from remote + case nil: + reviveFromLocalMeter.Mark(1) + return map[string][]byte{key.String(): val}, nil + default: + return nil, err + } + } + + reviveFromRemoteMeter.Mark(1) + // cannot revive locally, fetch remote proof + proofs, err := meta.fullStateDB.GetStorageReviveProof(meta.originalRoot, addr, root, []string{common.Bytes2Hex(prefixKey)}, []string{common.Bytes2Hex(key[:])}) + //log.Debug("tryReviveState GetStorageReviveProof", "addr", addr, "key", key, "proofs", len(proofs), "err", err) + if err != nil { + return nil, err + } + + if len(proofs) == 0 { + log.Error("cannot find any revive proof from remoteDB", "addr", addr, "prefix", prefixKey, "key", key) + return nil, fmt.Errorf("cannot find any revive proof from remoteDB") + } + + return ReviveStorageTrie(addr, tr, proofs[0], key) +} + +// batchFetchExpiredStorageFromRemote request expired state from remote full state node with a list of keys and prefixes. +func batchFetchExpiredFromRemote(expiryMeta *stateExpiryMeta, addr common.Address, root common.Hash, tr Trie, prefixKeys [][]byte, keys []common.Hash) ([]map[string][]byte, error) { + reviveTrieMeter.Mark(int64(len(keys))) + ret := make([]map[string][]byte, len(keys)) + prefixKeysStr := make([]string, len(prefixKeys)) + keysStr := make([]string, len(keys)) + + if expiryMeta.enableLocalRevive { + var expiredKeys []common.Hash + var expiredPrefixKeys [][]byte + for i, key := range keys { + val, err := tr.TryLocalRevive(addr, key.Bytes()) + //log.Debug("tryReviveState TryLocalRevive", "addr", addr, "key", key, "val", val, "err", err) + switch err.(type) { + case *trie.MissingNodeError: + expiredKeys = append(expiredKeys, key) + expiredPrefixKeys = append(expiredPrefixKeys, prefixKeys[i]) + case nil: + kv := make(map[string][]byte, 1) + kv[key.String()] = val + ret = append(ret, kv) + default: + return nil, err + } + } + reviveFromLocalMeter.Mark(int64(len(keys) - len(expiredKeys))) + for i, prefix := range expiredPrefixKeys { + prefixKeysStr[i] = common.Bytes2Hex(prefix) + } + for i, key := range expiredKeys { + keysStr[i] = common.Bytes2Hex(key[:]) + } + } else { + for i, prefix := range prefixKeys { + prefixKeysStr[i] = common.Bytes2Hex(prefix) + } + + for i, key := range keys { + keysStr[i] = common.Bytes2Hex(key[:]) + } + } + if len(prefixKeysStr) == 0 { + return ret, nil + } + + // cannot revive locally, fetch remote proof + reviveFromRemoteMeter.Mark(int64(len(keysStr))) + proofs, err := expiryMeta.fullStateDB.GetStorageReviveProof(expiryMeta.originalRoot, addr, root, prefixKeysStr, keysStr) + //log.Debug("tryReviveState GetStorageReviveProof", "addr", addr, "keys", keysStr, "prefixKeys", prefixKeysStr, "proofs", len(proofs), "err", err) + if err != nil { + return nil, err + } + + if len(proofs) == 0 { + log.Error("cannot find any revive proof from remoteDB", "addr", addr, "keys", keysStr, "prefixKeys", prefixKeysStr) + return nil, fmt.Errorf("cannot find any revive proof from remoteDB") + } + + for i, proof := range proofs { + // kvs, err := ReviveStorageTrie(addr, tr, proof, common.HexToHash(keysStr[i])) // TODO(asyukii): this logically should work but it doesn't because of some reason, will need to investigate + kvs, err := ReviveStorageTrie(addr, tr, proof, common.HexToHash(proof.Key)) + if err != nil { + log.Error("reviveStorageTrie failed", "addr", addr, "key", keys[i], "err", err) + continue + } + ret = append(ret, kvs) + } + + return ret, nil +} + +// ReviveStorageTrie revive trie's expired state from proof +func ReviveStorageTrie(addr common.Address, tr Trie, proof types.ReviveStorageProof, targetKey common.Hash) (map[string][]byte, error) { + defer func(start time.Time) { + reviveTrieTimer.Update(time.Since(start)) + }(time.Now()) + + // Decode keys and proofs + key := common.FromHex(proof.Key) + if !bytes.Equal(targetKey[:], key) { + return nil, fmt.Errorf("revive with wrong key, target: %#x, actual: %#x", targetKey, key) + } + prefixKey := common.FromHex(proof.PrefixKey) + innerProofs := make([][]byte, 0, len(proof.Proof)) + for _, p := range proof.Proof { + innerProofs = append(innerProofs, common.FromHex(p)) + } + + proofCache := trie.MPTProofCache{ + MPTProof: trie.MPTProof{ + RootKeyHex: prefixKey, + Proof: innerProofs, + }, + } + + if err := proofCache.VerifyProof(); err != nil { + return nil, err + } + + nubs, err := tr.TryRevive(key, proofCache.CacheNubs()) + if err != nil { + return nil, err + } + + // check if it could get from trie + if _, err = tr.GetStorage(addr, key); err != nil { + return nil, err + } + + ret := make(map[string][]byte) + for _, nub := range nubs { + kvs, err := nub.ResolveKV() + if err != nil { + return nil, err + } + for k, v := range kvs { + ret[k] = v + } + } + return ret, nil +} diff --git a/core/state/state_object.go b/core/state/state_object.go index ed67fceefb..7146146907 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -24,6 +24,9 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/core/state/snapshot" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" @@ -69,7 +72,7 @@ type stateObject struct { data types.StateAccount // Account data with all mutations applied in the scope of block // Write caches. - trie Trie // storage trie, which becomes non-nil on first access + trie Trie // storage trie, which becomes non-nil on first access, it's committed trie code Code // contract bytecode, which gets set when code is loaded sharedOriginStorage *sync.Map // Point to the entry of the stateObject in sharedPool @@ -77,6 +80,13 @@ type stateObject struct { pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block dirtyStorage Storage // Storage entries that have been modified in the current transaction execution, reset for every transaction + // for state expiry feature + pendingReviveTrie Trie // pendingReviveTrie it contains pending revive trie nodes, could update & commit later + pendingReviveState map[string]common.Hash // pendingReviveState for block, when R&W, access revive state first, saved in hash key + pendingAccessedState map[common.Hash]int // pendingAccessedState record which state is accessed(only read now, update/delete/insert will auto update epoch), it will update epoch index late + originStorageEpoch map[common.Hash]types.StateEpoch // originStorageEpoch record origin state epoch, prevent frequency epoch update + pendingFutureReviveState map[common.Hash]int // pendingFutureReviveState record empty state in snapshot. it should preftech first, and allow check in updateTrie + // Cache flags. dirtyCode bool // true if the code was updated @@ -111,15 +121,19 @@ func newObject(db *StateDB, address common.Address, acct *types.StateAccount) *s } return &stateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - origin: origin, - data: *acct, - sharedOriginStorage: storageMap, - originStorage: make(Storage), - pendingStorage: make(Storage), - dirtyStorage: make(Storage), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + origin: origin, + data: *acct, + sharedOriginStorage: storageMap, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), + pendingReviveState: make(map[string]common.Hash), + pendingAccessedState: make(map[common.Hash]int), + pendingFutureReviveState: make(map[common.Hash]int), + originStorageEpoch: make(map[common.Hash]types.StateEpoch), } } @@ -158,12 +172,26 @@ func (s *stateObject) getTrie() (Trie, error) { if err != nil { return nil, err } + if s.db.EnableExpire() { + tr.SetEpoch(s.db.Epoch()) + } s.trie = tr } } return s.trie, nil } +func (s *stateObject) getPendingReviveTrie() (Trie, error) { + if s.pendingReviveTrie == nil { + src, err := s.getTrie() + if err != nil { + return nil, err + } + s.pendingReviveTrie = s.db.db.CopyTrie(src) + } + return s.pendingReviveTrie, nil +} + // GetState retrieves a value from the account storage trie. func (s *stateObject) GetState(key common.Hash) common.Hash { // If we have a dirty value for this state entry, return it @@ -172,7 +200,11 @@ func (s *stateObject) GetState(key common.Hash) common.Hash { return value } // Otherwise return the entry's original value - return s.GetCommittedState(key) + value = s.GetCommittedState(key) + if value != (common.Hash{}) { + s.accessState(key) + } + return value } func (s *stateObject) getOriginStorage(key common.Hash) (common.Hash, bool) { @@ -201,14 +233,26 @@ func (s *stateObject) setOriginStorage(key common.Hash, value common.Hash) { // GetCommittedState retrieves a value from the committed account storage trie. func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { + getCommittedStorageMeter.Mark(1) // If we have a pending write or clean cached, return that if value, pending := s.pendingStorage[key]; pending { return value } + if s.db.EnableExpire() { + if revived, revive := s.queryFromReviveState(s.pendingReviveState, key); revive { + return revived + } + } + if value, cached := s.getOriginStorage(key); cached { return value } + + if value, cached := s.originStorage[key]; cached { + return value + } + // If the object was destructed in *this* block (and potentially resurrected), // the storage has been cleared out, and we should *not* consult the previous // database about any storage values. The only possible alternatives are: @@ -225,33 +269,66 @@ func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { value common.Hash ) if s.db.snap != nil { + getCommittedStorageSnapMeter.Mark(1) start := time.Now() - enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) + // handle state expiry situation + if s.db.EnableExpire() { + var dbError error + enc, err, dbError = s.getExpirySnapStorage(key) + if dbError != nil { + s.db.setError(fmt.Errorf("state expiry getExpirySnapStorage, contract: %v, key: %v, err: %v", s.address, key, dbError)) + return common.Hash{} + } + if len(enc) > 0 { + value.SetBytes(enc) + } + } else { + enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) + if len(enc) > 0 { + _, content, _, err := rlp.Split(enc) + if err != nil { + s.db.setError(err) + } + value.SetBytes(content) + } + } if metrics.EnabledExpensive { s.db.SnapshotStorageReads += time.Since(start) } - if len(enc) > 0 { - _, content, _, err := rlp.Split(enc) - if err != nil { - s.db.setError(err) - } - value.SetBytes(content) - } } + // If the snapshot is unavailable or reading from it fails, load from the database. if s.db.snap == nil || err != nil { + getCommittedStorageTrieMeter.Mark(1) start := time.Now() - tr, err := s.getTrie() + var tr Trie + if s.db.EnableExpire() { + tr, err = s.getPendingReviveTrie() + } else { + tr, err = s.getTrie() + } if err != nil { - s.db.setError(err) + s.db.setError(fmt.Errorf("state object getTrie err, contract: %v, err: %v", s.address, err)) return common.Hash{} } val, err := tr.GetStorage(s.address, key.Bytes()) if metrics.EnabledExpensive { s.db.StorageReads += time.Since(start) } + // handle state expiry situation + if s.db.EnableExpire() { + if path, ok := trie.ParseExpiredNodeErr(err); ok { + //log.Debug("GetCommittedState expired in trie", "addr", s.address, "key", key, "err", err) + val, err = s.tryReviveState(path, key, false) + getCommittedStorageExpiredMeter.Mark(1) + } else if err != nil { + getCommittedStorageUnexpiredMeter.Mark(1) + // TODO(0xbundler): add epoch record cache for prevent frequency access epoch update, may implement later + //s.originStorageEpoch[key] = epoch + } + } if err != nil { - s.db.setError(err) + s.db.setError(fmt.Errorf("state object get storage err, contract: %v, key: %v, err: %v", s.address, key, err)) return common.Hash{} } value.SetBytes(val) @@ -290,6 +367,30 @@ func (s *stateObject) finalise(prefetch bool) { slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure } } + + // try prefetch future revive states + for key := range s.pendingFutureReviveState { + if val, ok := s.dirtyStorage[key]; ok { + if val != s.originStorage[key] { + continue + } + } + slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure + } + + // try prefetch future update state + for key := range s.pendingAccessedState { + if val, ok := s.dirtyStorage[key]; ok { + if val != s.originStorage[key] { + continue + } + } + if _, ok := s.pendingFutureReviveState[key]; ok { + continue + } + slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure + } + if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash { s.db.prefetcher.prefetch(s.addrHash, s.data.Root, s.address, slotsToPrefetch) } @@ -304,7 +405,7 @@ func (s *stateObject) finalise(prefetch bool) { func (s *stateObject) updateTrie() (Trie, error) { // Make sure all dirty slots are finalized into the pending storage area s.finalise(false) // Don't prefetch anymore, pull directly if need be - if len(s.pendingStorage) == 0 { + if !s.needUpdateTrie() { return s.trie, nil } // Track the amount of time wasted on updating the storage trie @@ -320,10 +421,18 @@ func (s *stateObject) updateTrie() (Trie, error) { storage map[common.Hash][]byte origin map[common.Hash][]byte hasher = crypto.NewKeccakState() + tr Trie + err error ) - tr, err := s.getTrie() + if s.db.EnableExpire() { + // if EnableExpire, just use PendingReviveTrie, but prefetcher.trie is useful too, it warms up the db cache. + // and when no state expired or pruned, it will directly use prefetcher.trie too. + tr, err = s.getPendingReviveTrie() + } else { + tr, err = s.getTrie() + } if err != nil { - s.db.setError(err) + s.db.setError(fmt.Errorf("state object update trie getTrie err, contract: %v, err: %v", s.address, err)) return nil, err } // Insert all the pending updates into the trie @@ -341,25 +450,100 @@ func (s *stateObject) updateTrie() (Trie, error) { } dirtyStorage[key] = v } + + if s.db.EnableExpire() { + // append more access slots to update in db + for key := range s.pendingAccessedState { + if _, ok := dirtyStorage[key]; ok { + continue + } + // it must hit in cache + value := s.GetState(key) + dirtyStorage[key] = common.TrimLeftZeroes(value[:]) + //log.Debug("updateTrie access state", "contract", s.address, "key", key, "epoch", s.db.Epoch()) + } + } + var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() + if s.db.EnableExpire() { + // revive state first, to figure out if there have conflict expiry path or local revive + for key := range s.pendingFutureReviveState { + _, err = tr.GetStorage(s.address, key.Bytes()) + if err == nil { + continue + } + path, ok := trie.ParseExpiredNodeErr(err) + if !ok { + s.db.setError(fmt.Errorf("updateTrie pendingFutureReviveState err, contract: %v, key: %v, err: %v", s.address, key, err)) + //log.Debug("updateTrie pendingFutureReviveState", "contract", s.address, "key", key, "epoch", s.db.Epoch(), "tr.epoch", tr.Epoch(), "tr", fmt.Sprintf("%p", tr), "ins", fmt.Sprintf("%p", s), "err", err) + continue + } + if _, err = tryReviveState(s.db.expiryMeta, s.address, s.data.Root, tr, path, key, false); err != nil { + s.db.setError(fmt.Errorf("updateTrie pendingFutureReviveState tryReviveState err, contract: %v, key: %v, path: %v, err: %v", s.address, key, path, err)) + } + //log.Debug("updateTrie pendingFutureReviveState", "contract", s.address, "key", key, "epoch", s.db.Epoch(), "tr.epoch", tr.Epoch(), "tr", fmt.Sprintf("%p", tr), "ins", fmt.Sprintf("%p", s)) + } + // TODO(0xbundler): find some trie node with wrong epoch, temporary add get op, will fix later + //for key, val := range dirtyStorage { + // _, err = tr.GetStorage(s.address, key.Bytes()) + // if err == nil { + // continue + // } + // log.Error("EnableExpire GetStorage error", "addr", s.address, "key", key, "val", val, "origin", s.originStorage[key], "err", err) + // path, ok := trie.ParseExpiredNodeErr(err) + // if !ok { + // s.db.setError(fmt.Errorf("state object dirtyStorage err, contract: %v, key: %v, err: %v", s.address, key, err)) + // continue + // } + // if _, err = tryReviveState(s.db.expiryMeta, s.address, s.data.Root, tr, path key); err != nil { + // log.Error("EnableExpire GetStorage tryReviveState error", "addr", s.address, "key", key, "val", val, "origin", s.originStorage[key], "err", err) + // s.db.setError(fmt.Errorf("state object dirtyStorage tryReviveState err, contract: %v, key: %v, path: %v, err: %v", s.address, key, enErr.Path, err)) + // } + // //log.Debug("updateTrie dirtyStorage", "contract", s.address, "key", key, "epoch", s.db.Epoch(), "tr.epoch", tr.Epoch(), "tr", fmt.Sprintf("%p", tr), "ins", fmt.Sprintf("%p", s)) + //} + } + touchExpiredStorage := make(map[common.Hash][]byte) for key, value := range dirtyStorage { if len(value) == 0 { - if err := tr.DeleteStorage(s.address, key[:]); err != nil { - s.db.setError(err) + err := tr.DeleteStorage(s.address, key[:]) + if path, ok := trie.ParseExpiredNodeErr(err); ok { + touchExpiredStorage[key] = value + if _, err = tryReviveState(s.db.expiryMeta, s.address, s.data.Root, tr, path, key, true); err != nil { + s.db.setError(fmt.Errorf("updateTrie DeleteStorage tryReviveState err, contract: %v, key: %v, path: %v, err: %v", s.address, key, path, err)) + } + } else if err != nil { + s.db.setError(fmt.Errorf("updateTrie DeleteStorage err, contract: %v, key: %v, err: %v", s.address, key, err)) } + //log.Debug("updateTrie DeleteStorage", "contract", s.address, "key", key, "epoch", s.db.Epoch(), "value", value, "tr.epoch", tr.Epoch(), "err", err, "tr", fmt.Sprintf("%p", tr), "ins", fmt.Sprintf("%p", s)) s.db.StorageDeleted += 1 } else { if err := tr.UpdateStorage(s.address, key[:], value); err != nil { - s.db.setError(err) + s.db.setError(fmt.Errorf("updateTrie UpdateStorage err, contract: %v, key: %v, err: %v", s.address, key, err)) } + //log.Debug("updateTrie UpdateStorage", "contract", s.address, "key", key, "epoch", s.db.Epoch(), "value", value, "tr.epoch", tr.Epoch(), "err", err, "tr", fmt.Sprintf("%p", tr), "ins", fmt.Sprintf("%p", s)) s.db.StorageUpdated += 1 } // Cache the items for preloading usedStorage = append(usedStorage, common.CopyBytes(key[:])) } + + // re-execute touched expired storage + for key, value := range touchExpiredStorage { + if len(value) == 0 { + if err := tr.DeleteStorage(s.address, key[:]); err != nil { + s.db.setError(fmt.Errorf("updateTrie DeleteStorage in touchExpiredStorage err, contract: %v, key: %v, err: %v", s.address, key, err)) + } + //log.Debug("updateTrie DeleteStorage in touchExpiredStorage", "contract", s.address, "key", key, "epoch", s.db.Epoch(), "value", value, "tr.epoch", tr.Epoch(), "err", err, "tr", fmt.Sprintf("%p", tr), "ins", fmt.Sprintf("%p", s)) + } else { + if err := tr.UpdateStorage(s.address, key[:], value); err != nil { + s.db.setError(fmt.Errorf("updateTrie UpdateStorage in touchExpiredStorage err, contract: %v, key: %v, err: %v", s.address, key, err)) + } + //log.Debug("updateTrie UpdateStorage in touchExpiredStorage", "contract", s.address, "key", key, "epoch", s.db.Epoch(), "value", value, "tr.epoch", tr.Epoch(), "err", err, "tr", fmt.Sprintf("%p", tr), "ins", fmt.Sprintf("%p", s)) + } + } }() // If state snapshotting is active, cache the data til commit wg.Add(1) @@ -385,9 +569,15 @@ func (s *stateObject) updateTrie() (Trie, error) { // rlp-encoded value to be used by the snapshot var snapshotVal []byte if len(value) != 0 { - snapshotVal, _ = rlp.EncodeToBytes(value) + // Encoding []byte cannot fail, ok to ignore the error. + if s.db.EnableExpire() { + snapshotVal, _ = snapshot.EncodeValueToRLPBytes(snapshot.NewValueWithEpoch(s.db.Epoch(), value)) + } else { + snapshotVal, _ = rlp.EncodeToBytes(value) + } } storage[khash] = snapshotVal // snapshotVal will be nil if it's deleted + //log.Debug("updateTrie UpdateSnapShot", "contract", s.address, "key", key, "epoch", s.db.Epoch(), "value", snapshotVal, "tr.epoch", tr.Epoch(), "tr", fmt.Sprintf("%p", tr), "ins", fmt.Sprintf("%p", s)) // Track the original value of slot only if it's mutated first time prev := s.originStorage[key] @@ -412,9 +602,39 @@ func (s *stateObject) updateTrie() (Trie, error) { if len(s.pendingStorage) > 0 { s.pendingStorage = make(Storage) } + if s.db.EnableExpire() { + if len(s.pendingReviveState) > 0 { + s.pendingReviveState = make(map[string]common.Hash) + } + if len(s.pendingAccessedState) > 0 { + s.pendingAccessedState = make(map[common.Hash]int) + } + if len(s.pendingFutureReviveState) > 0 { + s.pendingFutureReviveState = make(map[common.Hash]int) + } + if len(s.originStorageEpoch) > 0 { + s.originStorageEpoch = make(map[common.Hash]types.StateEpoch) + } + if s.pendingReviveTrie != nil { + s.pendingReviveTrie = nil + } + // reset trie as pending trie, will commit later + if tr != nil { + s.trie = tr + } + } return tr, nil } +func (s *stateObject) needUpdateTrie() bool { + if !s.db.EnableExpire() { + return len(s.pendingStorage) > 0 + } + + return len(s.pendingStorage) > 0 || len(s.pendingReviveState) > 0 || + len(s.pendingAccessedState) > 0 +} + // UpdateRoot sets the trie root to the current root hash of. An error // will be returned if trie root hash is not computed correctly. func (s *stateObject) updateRoot() { @@ -526,6 +746,28 @@ func (s *stateObject) deepCopy(db *StateDB) *stateObject { obj.selfDestructed = s.selfDestructed obj.dirtyCode = s.dirtyCode obj.deleted = s.deleted + + if s.db.EnableExpire() { + if s.pendingReviveTrie != nil { + obj.pendingReviveTrie = db.db.CopyTrie(s.pendingReviveTrie) + } + obj.pendingReviveState = make(map[string]common.Hash, len(s.pendingReviveState)) + for k, v := range s.pendingReviveState { + obj.pendingReviveState[k] = v + } + obj.pendingAccessedState = make(map[common.Hash]int, len(s.pendingAccessedState)) + for k, v := range s.pendingAccessedState { + obj.pendingAccessedState[k] = v + } + obj.pendingFutureReviveState = make(map[common.Hash]int, len(s.pendingFutureReviveState)) + for k, v := range s.pendingFutureReviveState { + obj.pendingFutureReviveState[k] = v + } + obj.originStorageEpoch = make(map[common.Hash]types.StateEpoch, len(s.originStorageEpoch)) + for k, v := range s.originStorageEpoch { + obj.originStorageEpoch[k] = v + } + } return obj } @@ -610,3 +852,112 @@ func (s *stateObject) Balance() *big.Int { func (s *stateObject) Nonce() uint64 { return s.data.Nonce } + +// accessState record all access states, now in pendingAccessedStateEpoch without consensus +func (s *stateObject) accessState(key common.Hash) { + if !s.db.EnableExpire() { + return + } + + if s.db.Epoch() > s.originStorageEpoch[key] { + count := s.pendingAccessedState[key] + s.pendingAccessedState[key] = count + 1 + } +} + +// futureReviveState record future revive state, it will load on prefetcher or updateTrie +func (s *stateObject) futureReviveState(key common.Hash) { + if !s.db.EnableExpire() { + return + } + + count := s.pendingFutureReviveState[key] + s.pendingFutureReviveState[key] = count + 1 +} + +// TODO(0xbundler): add hash key cache later +func (s *stateObject) queryFromReviveState(reviveState map[string]common.Hash, key common.Hash) (common.Hash, bool) { + val, ok := reviveState[string(crypto.Keccak256(key[:]))] + return val, ok +} + +// tryReviveState request expired state from remote full state node; +func (s *stateObject) tryReviveState(prefixKey []byte, key common.Hash, resolvePath bool) ([]byte, error) { + tr, err := s.getPendingReviveTrie() + if err != nil { + return nil, err + } + + // if no prefix, query from revive trie, got the newest expired info + if resolvePath { + val, err := tr.GetStorage(s.address, key.Bytes()) + if err == nil { + // TODO(asyukii): temporary fix snap expired, but trie not expire, may investigate more later. + s.pendingReviveState[string(crypto.Keccak256(key[:]))] = common.BytesToHash(val) + return val, nil + } + path, ok := trie.ParseExpiredNodeErr(err) + if !ok { + return nil, fmt.Errorf("cannot find expired state from trie, err: %v", err) + } + prefixKey = path + } + + kvs, err := tryReviveState(s.db.expiryMeta, s.address, s.data.Root, tr, prefixKey, key, false) + if err != nil { + return nil, err + } + + for k, v := range kvs { + s.pendingReviveState[k] = common.BytesToHash(v) + } + + getCommittedStorageRemoteMeter.Mark(1) + val := s.pendingReviveState[string(crypto.Keccak256(key[:]))] + return val.Bytes(), nil +} + +func (s *stateObject) getExpirySnapStorage(key common.Hash) ([]byte, error, error) { + enc, err := s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) + if err != nil { + return nil, err, nil + } + var val snapshot.SnapValue + if len(enc) > 0 { + val, err = snapshot.DecodeValueFromRLPBytes(enc) + if err != nil { + return nil, nil, err + } + } + + if val == nil { + // record access empty kv, try touch in updateTrie for duplication + //log.Debug("getExpirySnapStorage nil val", "addr", s.address, "key", key, "val", val) + s.futureReviveState(key) + return nil, nil, nil + } + + s.originStorageEpoch[key] = val.GetEpoch() + if !types.EpochExpired(val.GetEpoch(), s.db.Epoch()) { + getCommittedStorageUnexpiredMeter.Mark(1) + return val.GetVal(), nil, nil + } + + getCommittedStorageExpiredMeter.Mark(1) + // if found value not been pruned, just return, local revive later + if s.db.EnableLocalRevive() && len(val.GetVal()) > 0 { + s.futureReviveState(key) + getCommittedStorageExpiredLocalReviveMeter.Mark(1) + //log.Debug("getExpirySnapStorage GetVal", "addr", s.address, "key", key, "val", hex.EncodeToString(val.GetVal())) + return val.GetVal(), nil, nil + } + + //log.Debug("GetCommittedState expired in snapshot", "addr", s.address, "key", key, "val", val, "enc", enc, "err", err) + // handle from remoteDB, if got err just setError, or return to revert in consensus version. + valRaw, err := s.tryReviveState(nil, key, true) + if err != nil { + return nil, nil, err + } + + return valRaw, nil, nil +} diff --git a/core/state/statedb.go b/core/state/statedb.go index 62397b083e..24afa2256b 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -26,13 +26,14 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/gopool" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/params" @@ -44,6 +45,16 @@ import ( const defaultNumOfSlots = 100 +var ( + getCommittedStorageMeter = metrics.NewRegisteredMeter("state/contract/committed", nil) + getCommittedStorageSnapMeter = metrics.NewRegisteredMeter("state/contract/committed/snap", nil) + getCommittedStorageTrieMeter = metrics.NewRegisteredMeter("state/contract/committed/trie", nil) + getCommittedStorageExpiredMeter = metrics.NewRegisteredMeter("state/contract/committed/expired", nil) + getCommittedStorageExpiredLocalReviveMeter = metrics.NewRegisteredMeter("state/contract/committed/expired/localrevive", nil) + getCommittedStorageUnexpiredMeter = metrics.NewRegisteredMeter("state/contract/committed/unexpired", nil) + getCommittedStorageRemoteMeter = metrics.NewRegisteredMeter("state/contract/committed/remote", nil) +) + type revision struct { id int journalIndex int @@ -141,6 +152,9 @@ type StateDB struct { validRevisions []revision nextRevisionId int + // state expiry feature + expiryMeta *stateExpiryMeta + // Measurements gathered during execution for debugging purposes // MetricsMux should be used in more places, but will affect on performance, so following meteration is not accruate MetricsMux sync.Mutex @@ -193,6 +207,7 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) accessList: newAccessList(), transientStorage: newTransientStorage(), hasher: crypto.NewKeccakState(), + expiryMeta: defaultStateExpiryMeta(), } if sdb.snaps != nil { @@ -232,6 +247,25 @@ func (s *StateDB) TransferPrefetcher(prev *StateDB) { s.prefetcherLock.Unlock() } +// InitStateExpiryFeature it must set in initial, reset later will cause wrong result +// Attention: startAtBlockHash corresponding to stateDB's originalRoot, expectHeight is the epoch indicator. +func (s *StateDB) InitStateExpiryFeature(config *types.StateExpiryConfig, remote ethdb.FullStateDB, startAtBlockHash common.Hash, expectHeight *big.Int) *StateDB { + if config == nil || expectHeight == nil || remote == nil { + panic("cannot init state expiry stateDB with nil config/height/remote") + } + epoch := types.GetStateEpoch(config, expectHeight) + s.expiryMeta = &stateExpiryMeta{ + enableStateExpiry: config.EnableExpiry(), + enableLocalRevive: config.EnableLocalRevive, + fullStateDB: remote, + epoch: epoch, + originalRoot: s.originalRoot, + originalHash: startAtBlockHash, + } + //log.Debug("StateDB enable state expiry feature", "expectHeight", expectHeight, "startAtBlockHash", startAtBlockHash, "epoch", epoch) + return s +} + // StartPrefetcher initializes a new trie prefetcher to pull in nodes from the // state trie concurrently while the state is mutated so that when we reach the // commit phase, most of the needed data is already hot. @@ -252,6 +286,7 @@ func (s *StateDB) StartPrefetcher(namespace string) { } else { s.prefetcher = newTriePrefetcher(s.db, s.originalRoot, common.Hash{}, namespace) } + s.prefetcher.InitStateExpiryFeature(s.expiryMeta) } } @@ -365,7 +400,7 @@ func (s *StateDB) AddLog(log *types.Log) { } // GetLogs returns the logs matching the specified transaction hash, and annotates -// them with the given blockNumber and blockHash. +// them with the given blockNumber and originalHash. func (s *StateDB) GetLogs(hash common.Hash, blockNumber uint64, blockHash common.Hash) []*types.Log { logs := s.logs[hash] for _, l := range logs { @@ -965,6 +1000,9 @@ func (s *StateDB) copyInternal(doPrefetch bool) *StateDB { journal: newJournal(), hasher: crypto.NewKeccakState(), + // state expiry copy + expiryMeta: s.expiryMeta, + // In order for the block producer to be able to use and make additions // to the snapshot tree, we need to copy that as well. Otherwise, any // block mined by ourselves will cause gaps in the tree, and force the @@ -1040,12 +1078,14 @@ func (s *StateDB) copyInternal(doPrefetch bool) *StateDB { state.transientStorage = s.transientStorage.Copy() state.prefetcher = s.prefetcher - if s.prefetcher != nil && !doPrefetch { + if !s.EnableExpire() && s.prefetcher != nil && !doPrefetch { // If there's a prefetcher running, make an inactive copy of it that can // only access data but does not actively preload (since the user will not // know that they need to explicitly terminate an active copy). + // State Expiry cannot use older prefetcher directly. state.prefetcher = state.prefetcher.copy() } + return state } @@ -1354,6 +1394,9 @@ func (s *StateDB) deleteStorage(addr common.Address, addrHash common.Hash, root if _, ok := tr.(*trie.EmptyTrie); ok { return false, nil, nil, nil } + if s.EnableExpire() { + tr.SetEpoch(s.Epoch()) + } it, err := tr.NodeIterator(nil) if err != nil { return false, nil, nil, fmt.Errorf("failed to open storage iterator, err: %w", err) @@ -1761,6 +1804,7 @@ func (s *StateDB) Commit(block uint64, failPostCommitFunc func(), postCommitFunc if root == (common.Hash{}) { root = types.EmptyRootHash } + //log.Info("state commit", "nodes", stringfyEpochMeta(nodes.FlattenEpochMeta())) //origin := s.originalRoot //if origin == (common.Hash{}) { // origin = types.EmptyRootHash @@ -1886,6 +1930,18 @@ func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) { } } +func (s *StateDB) EnableExpire() bool { + return s.expiryMeta.enableStateExpiry +} + +func (s *StateDB) Epoch() types.StateEpoch { + return s.expiryMeta.epoch +} + +func (s *StateDB) EnableLocalRevive() bool { + return s.expiryMeta.enableLocalRevive +} + // AddressInAccessList returns true if the given address is in the access list. func (s *StateDB) AddressInAccessList(addr common.Address) bool { if s.accessList == nil { diff --git a/core/state/sync.go b/core/state/sync.go index 61097c6462..1b288b304e 100644 --- a/core/state/sync.go +++ b/core/state/sync.go @@ -55,3 +55,32 @@ func NewStateSync(root common.Hash, database ethdb.KeyValueReader, onLeaf func(k syncer = trie.NewSync(root, database, onAccount, scheme) return syncer } + +func NewStateSyncWithExpiry(root common.Hash, database ethdb.KeyValueReader, onLeaf func(keys [][]byte, leaf []byte) error, scheme string, epoch types.StateEpoch) *trie.Sync { + // Register the storage slot callback if the external callback is specified. + var onSlot func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error + if onLeaf != nil { + onSlot = func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error { + return onLeaf(keys, leaf) + } + } + // Register the account callback to connect the state trie and the storage + // trie belongs to the contract. + var syncer *trie.Sync + onAccount := func(keys [][]byte, path []byte, leaf []byte, parent common.Hash, parentPath []byte) error { + if onLeaf != nil { + if err := onLeaf(keys, leaf); err != nil { + return err + } + } + var obj types.StateAccount + if err := rlp.Decode(bytes.NewReader(leaf), &obj); err != nil { + return err + } + syncer.AddSubTrie(obj.Root, path, parent, parentPath, onSlot) + syncer.AddCodeEntry(common.BytesToHash(obj.CodeHash), path, parent, parentPath) + return nil + } + syncer = trie.NewSyncWithEpoch(root, database, onAccount, scheme, epoch) + return syncer +} diff --git a/core/state/trie_prefetcher.go b/core/state/trie_prefetcher.go index 4184369d9c..ccd62c395b 100644 --- a/core/state/trie_prefetcher.go +++ b/core/state/trie_prefetcher.go @@ -23,6 +23,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" + trie2 "github.com/ethereum/go-ethereum/trie" ) const ( @@ -63,6 +64,9 @@ type triePrefetcher struct { fetchersMutex sync.RWMutex prefetchChan chan *prefetchMsg // no need to wait for return + // state expiry feature + expiryMeta *stateExpiryMeta + deliveryMissMeter metrics.Meter accountLoadMeter metrics.Meter accountDupMeter metrics.Meter @@ -107,6 +111,8 @@ func newTriePrefetcher(db Database, root, rootParent common.Hash, namespace stri accountStaleDupMeter: metrics.GetOrRegisterMeter(prefix+"/accountst/dup", nil), accountStaleSkipMeter: metrics.GetOrRegisterMeter(prefix+"/accountst/skip", nil), accountStaleWasteMeter: metrics.GetOrRegisterMeter(prefix+"/accountst/waste", nil), + + expiryMeta: defaultStateExpiryMeta(), } go p.mainLoop() return p @@ -123,6 +129,10 @@ func (p *triePrefetcher) mainLoop() { fetcher := p.fetchers[id] if fetcher == nil { fetcher = newSubfetcher(p.db, p.root, pMsg.owner, pMsg.root, pMsg.addr) + if p.expiryMeta.enableStateExpiry { + fetcher.initStateExpiryFeature(p.expiryMeta) + } + fetcher.start() p.fetchersMutex.Lock() p.fetchers[id] = fetcher p.fetchersMutex.Unlock() @@ -204,6 +214,11 @@ func (p *triePrefetcher) mainLoop() { } } +// InitStateExpiryFeature it must call in initial period. +func (p *triePrefetcher) InitStateExpiryFeature(expiryMeta *stateExpiryMeta) { + p.expiryMeta = expiryMeta +} + // close iterates over all the subfetchers, aborts any that were left spinning // and reports the stats to the metrics subsystem. func (p *triePrefetcher) close() { @@ -354,6 +369,9 @@ type subfetcher struct { addr common.Address // Address of the account that the trie belongs to trie Trie // Trie being populated with nodes + // state expiry feature + expiryMeta *stateExpiryMeta + tasks [][]byte // Items queued up for retrieval lock sync.Mutex // Lock protecting the task queue @@ -374,21 +392,30 @@ type subfetcher struct { // particular root hash. func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash, addr common.Address) *subfetcher { sf := &subfetcher{ - db: db, - state: state, - owner: owner, - root: root, - addr: addr, - wake: make(chan struct{}, 1), - stop: make(chan struct{}), - term: make(chan struct{}), - copy: make(chan chan Trie), - seen: make(map[string]struct{}), + db: db, + state: state, + owner: owner, + root: root, + addr: addr, + wake: make(chan struct{}, 1), + stop: make(chan struct{}), + term: make(chan struct{}), + copy: make(chan chan Trie), + seen: make(map[string]struct{}), + expiryMeta: defaultStateExpiryMeta(), } - go sf.loop() return sf } +func (sf *subfetcher) start() { + go sf.loop() +} + +// InitStateExpiryFeature it must call in initial period. +func (sf *subfetcher) initStateExpiryFeature(expiryMeta *stateExpiryMeta) { + sf.expiryMeta = expiryMeta +} + // schedule adds a batch of trie keys to the queue to prefetch. func (sf *subfetcher) schedule(keys [][]byte) { atomic.AddUint32(&sf.pendingSize, uint32(len(keys))) @@ -432,6 +459,10 @@ func (sf *subfetcher) scheduleParallel(keys [][]byte) { keysLeftSize := len(keysLeft) for i := 0; i*parallelTriePrefetchCapacity < keysLeftSize; i++ { child := newSubfetcher(sf.db, sf.state, sf.owner, sf.root, sf.addr) + if sf.expiryMeta.enableStateExpiry { + child.initStateExpiryFeature(sf.expiryMeta) + } + child.start() sf.paraChildren = append(sf.paraChildren, child) endIndex := (i + 1) * parallelTriePrefetchCapacity if endIndex >= keysLeftSize { @@ -484,6 +515,9 @@ func (sf *subfetcher) loop() { trie, err = sf.db.OpenTrie(sf.root) } else { trie, err = sf.db.OpenStorageTrie(sf.state, sf.addr, sf.root) + if err == nil && sf.expiryMeta.enableStateExpiry { + trie.SetEpoch(sf.expiryMeta.epoch) + } } if err != nil { log.Debug("Trie prefetcher failed opening trie", "root", sf.root, "err", err) @@ -502,6 +536,9 @@ func (sf *subfetcher) loop() { } else { // address is useless sf.trie, err = sf.db.OpenStorageTrie(sf.state, sf.addr, sf.root) + if err == nil && sf.expiryMeta.enableStateExpiry { + trie.SetEpoch(sf.expiryMeta.epoch) + } } if err != nil { continue @@ -535,7 +572,18 @@ func (sf *subfetcher) loop() { if len(task) == common.AddressLength { sf.trie.GetAccount(common.BytesToAddress(task)) } else { - sf.trie.GetStorage(sf.addr, task) + _, err := sf.trie.GetStorage(sf.addr, task) + // handle expired state + if sf.expiryMeta.enableStateExpiry { + // TODO(0xbundler): revert to single fetch, because tasks is a channel + if path, match := trie2.ParseExpiredNodeErr(err); match { + key := common.BytesToHash(task) + _, err = tryReviveState(sf.expiryMeta, sf.addr, sf.root, sf.trie, path, key, false) + if err != nil { + log.Error("subfetcher tryReviveState err", "addr", sf.addr, "path", path, "err", err) + } + } + } } sf.seen[string(task)] = struct{}{} } diff --git a/core/state_prefetcher.go b/core/state_prefetcher.go index f1bb60febd..f8b7fb5fd5 100644 --- a/core/state_prefetcher.go +++ b/core/state_prefetcher.go @@ -59,7 +59,9 @@ func (p *statePrefetcher) Prefetch(block *types.Block, statedb *state.StateDB, c for i := 0; i < prefetchThread; i++ { go func() { newStatedb := statedb.CopyDoPrefetch() - newStatedb.EnableWriteOnSharedStorage() + if !statedb.EnableExpire() { + newStatedb.EnableWriteOnSharedStorage() + } gaspool := new(GasPool).AddGas(block.GasLimit()) blockContext := NewEVMBlockContext(header, p.bc, nil) evm := vm.NewEVM(blockContext, vm.TxContext{}, statedb, p.config, *cfg) @@ -106,7 +108,10 @@ func (p *statePrefetcher) PrefetchMining(txs TransactionsByPriceAndNonce, header go func(startCh <-chan *types.Transaction, stopCh <-chan struct{}) { idx := 0 newStatedb := statedb.CopyDoPrefetch() - newStatedb.EnableWriteOnSharedStorage() + // TODO(0xbundler): access empty in trie cause shared concurrent bug? opt later + if !statedb.EnableExpire() { + newStatedb.EnableWriteOnSharedStorage() + } gaspool := new(GasPool).AddGas(gasLimit) blockContext := NewEVMBlockContext(header, p.bc, nil) evm := vm.NewEVM(blockContext, vm.TxContext{}, statedb, p.config, cfg) diff --git a/core/txpool/blobpool/blobpool.go b/core/txpool/blobpool/blobpool.go index 71cb2cb53f..b363d7d5a0 100644 --- a/core/txpool/blobpool/blobpool.go +++ b/core/txpool/blobpool/blobpool.go @@ -365,7 +365,7 @@ func (p *BlobPool) Init(gasTip *big.Int, head *types.Header, reserve txpool.Addr return err } } - state, err := p.chain.StateAt(head.Root) + state, err := p.chain.StateAt(head.Root, head.Hash(), head.Number) if err != nil { return err } @@ -746,7 +746,7 @@ func (p *BlobPool) Reset(oldHead, newHead *types.Header) { resettimeHist.Update(time.Since(start).Nanoseconds()) }(time.Now()) - statedb, err := p.chain.StateAt(newHead.Root) + statedb, err := p.chain.StateAt(newHead.Root, newHead.Hash(), newHead.Number) if err != nil { log.Error("Failed to reset blobpool state", "err", err) return diff --git a/core/txpool/blobpool/blobpool_test.go b/core/txpool/blobpool/blobpool_test.go index f8ddcc0c10..dd348f9c2a 100644 --- a/core/txpool/blobpool/blobpool_test.go +++ b/core/txpool/blobpool/blobpool_test.go @@ -158,7 +158,7 @@ func (bt *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block return nil } -func (bc *testBlockChain) StateAt(common.Hash) (*state.StateDB, error) { +func (bc *testBlockChain) StateAt(common.Hash, common.Hash, *big.Int) (*state.StateDB, error) { return bc.statedb, nil } diff --git a/core/txpool/blobpool/interface.go b/core/txpool/blobpool/interface.go index 6f296a54bd..a1852a5bd8 100644 --- a/core/txpool/blobpool/interface.go +++ b/core/txpool/blobpool/interface.go @@ -17,6 +17,8 @@ package blobpool import ( + "math/big" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" @@ -40,5 +42,5 @@ type BlockChain interface { GetBlock(hash common.Hash, number uint64) *types.Block // StateAt returns a state database for a given root hash (generally the head). - StateAt(root common.Hash) (*state.StateDB, error) + StateAt(root common.Hash, blockHash common.Hash, height *big.Int) (*state.StateDB, error) } diff --git a/core/txpool/legacypool/legacypool.go b/core/txpool/legacypool/legacypool.go index 2c23f142b1..88f1c839a2 100644 --- a/core/txpool/legacypool/legacypool.go +++ b/core/txpool/legacypool/legacypool.go @@ -125,7 +125,7 @@ type BlockChain interface { GetBlock(hash common.Hash, number uint64) *types.Block // StateAt returns a state database for a given root hash (generally the head). - StateAt(root common.Hash) (*state.StateDB, error) + StateAt(root common.Hash, blockHash common.Hash, height *big.Int) (*state.StateDB, error) } // Config are the configuration parameters of the transaction pool. @@ -1470,7 +1470,7 @@ func (pool *LegacyPool) reset(oldHead, newHead *types.Header) { if newHead == nil { newHead = pool.chain.CurrentBlock() // Special case during testing } - statedb, err := pool.chain.StateAt(newHead.Root) + statedb, err := pool.chain.StateAt(newHead.Root, newHead.Hash(), newHead.Number) if err != nil { log.Error("Failed to reset txpool state", "err", err) return diff --git a/core/txpool/legacypool/legacypool_test.go b/core/txpool/legacypool/legacypool_test.go index 05ff64aed1..6efcdf8dac 100644 --- a/core/txpool/legacypool/legacypool_test.go +++ b/core/txpool/legacypool/legacypool_test.go @@ -88,7 +88,7 @@ func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block return types.NewBlock(bc.CurrentBlock(), nil, nil, nil, trie.NewStackTrie(nil)) } -func (bc *testBlockChain) StateAt(common.Hash) (*state.StateDB, error) { +func (bc *testBlockChain) StateAt(common.Hash, common.Hash, *big.Int) (*state.StateDB, error) { return bc.statedb, nil } diff --git a/core/types/gen_account_rlp.go b/core/types/gen_account_rlp.go index 5181d88411..9d07200e33 100644 --- a/core/types/gen_account_rlp.go +++ b/core/types/gen_account_rlp.go @@ -5,8 +5,11 @@ package types -import "github.com/ethereum/go-ethereum/rlp" -import "io" +import ( + "io" + + "github.com/ethereum/go-ethereum/rlp" +) func (obj *StateAccount) EncodeRLP(_w io.Writer) error { w := rlp.NewEncoderBuffer(_w) diff --git a/core/types/gen_header_rlp.go b/core/types/gen_header_rlp.go index a5ed5cd150..e05bde09f6 100644 --- a/core/types/gen_header_rlp.go +++ b/core/types/gen_header_rlp.go @@ -5,8 +5,11 @@ package types -import "github.com/ethereum/go-ethereum/rlp" -import "io" +import ( + "io" + + "github.com/ethereum/go-ethereum/rlp" +) func (obj *Header) EncodeRLP(_w io.Writer) error { w := rlp.NewEncoderBuffer(_w) diff --git a/core/types/gen_log_rlp.go b/core/types/gen_log_rlp.go index 4a6c6b0094..78fa783cee 100644 --- a/core/types/gen_log_rlp.go +++ b/core/types/gen_log_rlp.go @@ -5,8 +5,11 @@ package types -import "github.com/ethereum/go-ethereum/rlp" -import "io" +import ( + "io" + + "github.com/ethereum/go-ethereum/rlp" +) func (obj *rlpLog) EncodeRLP(_w io.Writer) error { w := rlp.NewEncoderBuffer(_w) diff --git a/core/types/gen_withdrawal_rlp.go b/core/types/gen_withdrawal_rlp.go index d0b4e0147a..e3fa001eb6 100644 --- a/core/types/gen_withdrawal_rlp.go +++ b/core/types/gen_withdrawal_rlp.go @@ -5,8 +5,11 @@ package types -import "github.com/ethereum/go-ethereum/rlp" -import "io" +import ( + "io" + + "github.com/ethereum/go-ethereum/rlp" +) func (obj *Withdrawal) EncodeRLP(_w io.Writer) error { w := rlp.NewEncoderBuffer(_w) diff --git a/core/types/meta.go b/core/types/meta.go new file mode 100644 index 0000000000..afd6f8655b --- /dev/null +++ b/core/types/meta.go @@ -0,0 +1,61 @@ +package types + +import ( + "errors" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + MetaNoConsensusType = iota +) + +var ( + ErrMetaNotSupport = errors.New("the meta type not support now") + EmptyMetaNoConsensus = MetaNoConsensus(StateEpoch0) +) + +type MetaNoConsensus StateEpoch // Represents the epoch number + +type StateMeta interface { + GetType() byte + Hash() common.Hash + EncodeToRLPBytes() ([]byte, error) +} + +func NewMetaNoConsensus(epoch StateEpoch) StateMeta { + return MetaNoConsensus(epoch) +} + +func (m MetaNoConsensus) GetType() byte { + return MetaNoConsensusType +} + +func (m MetaNoConsensus) Hash() common.Hash { + return common.Hash{} +} + +func (m MetaNoConsensus) Epoch() StateEpoch { + return StateEpoch(m) +} + +func (m MetaNoConsensus) EncodeToRLPBytes() ([]byte, error) { + enc, err := rlp.EncodeToBytes(m) + if err != nil { + return nil, err + } + return enc, nil +} + +func DecodeMetaNoConsensusFromRLPBytes(enc []byte) (MetaNoConsensus, error) { + if len(enc) == 0 { + return EmptyMetaNoConsensus, nil + } + var mc MetaNoConsensus + if err := rlp.DecodeBytes(enc, &mc); err != nil { + return EmptyMetaNoConsensus, err + } + + return mc, nil +} diff --git a/core/types/meta_test.go b/core/types/meta_test.go new file mode 100644 index 0000000000..ac03a5595f --- /dev/null +++ b/core/types/meta_test.go @@ -0,0 +1,26 @@ +package types + +import ( + "encoding/hex" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMetaEncodeDecode(t *testing.T) { + tests := []struct { + data MetaNoConsensus + }{ + {data: EmptyMetaNoConsensus}, + {data: MetaNoConsensus(StateEpoch(10000))}, + } + + for _, item := range tests { + enc, err := item.data.EncodeToRLPBytes() + assert.NoError(t, err) + t.Log(hex.EncodeToString(enc)) + mc, err := DecodeMetaNoConsensusFromRLPBytes(enc) + assert.NoError(t, err) + assert.Equal(t, item.data, mc) + } +} diff --git a/core/types/revive_state.go b/core/types/revive_state.go new file mode 100644 index 0000000000..684f1a06b1 --- /dev/null +++ b/core/types/revive_state.go @@ -0,0 +1,31 @@ +package types + +type ReviveStorageProof struct { + Key string `json:"key"` + PrefixKey string `json:"prefixKey"` + Proof []string `json:"proof"` +} + +type ReviveResult struct { + Err string `json:"err"` + StorageProof []ReviveStorageProof `json:"storageProof"` + BlockNum uint64 `json:"blockNum"` +} + +func NewReviveErrResult(err error, block uint64) *ReviveResult { + var errRet string + if err != nil { + errRet = err.Error() + } + return &ReviveResult{ + Err: errRet, + BlockNum: block, + } +} + +func NewReviveResult(proof []ReviveStorageProof, block uint64) *ReviveResult { + return &ReviveResult{ + StorageProof: proof, + BlockNum: block, + } +} diff --git a/core/types/state_account.go b/core/types/state_account.go index 314f4943ec..0fdc8b8f27 100644 --- a/core/types/state_account.go +++ b/core/types/state_account.go @@ -33,6 +33,7 @@ type StateAccount struct { Balance *big.Int Root common.Hash // merkle root of the storage trie CodeHash []byte + MetaHash common.Hash `rlp:"-"` // TODO (asyukii): handle this } // NewEmptyStateAccount constructs an empty state account. diff --git a/core/types/state_epoch.go b/core/types/state_epoch.go new file mode 100644 index 0000000000..7846b9e051 --- /dev/null +++ b/core/types/state_epoch.go @@ -0,0 +1,61 @@ +package types + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" +) + +const ( + DefaultStateEpochPeriod = uint64(7_008_000) + StateEpoch0 = StateEpoch(0) + StateEpoch1 = StateEpoch(1) + StateEpochKeepLiveNum = StateEpoch(2) +) + +type StateEpoch uint16 + +// GetStateEpoch computes the current state epoch by hard fork and block number +// state epoch will indicate if the state is accessible or expiry. +// Before ClaudeBlock indicates state epoch0. +// ClaudeBlock indicates start state epoch1. +// ElwoodBlock indicates start state epoch2 and start epoch rotate by StateEpochPeriod. +// When N>=2 and epochN started, epoch(N-2)'s state will expire. +func GetStateEpoch(config *StateExpiryConfig, blockNumber *big.Int) StateEpoch { + if blockNumber == nil || config == nil { + return StateEpoch0 + } + epochPeriod := new(big.Int).SetUint64(DefaultStateEpochPeriod) + epoch1Block := epochPeriod + epoch2Block := new(big.Int).Add(epoch1Block, epochPeriod) + + if config != nil { + epochPeriod = new(big.Int).SetUint64(config.StateEpochPeriod) + epoch1Block = new(big.Int).SetUint64(config.StateEpoch1Block) + epoch2Block = new(big.Int).SetUint64(config.StateEpoch2Block) + } + if isBlockReached(blockNumber, epoch2Block) { + ret := new(big.Int).Sub(blockNumber, epoch2Block) + ret.Div(ret, epochPeriod) + ret.Add(ret, common.Big2) + return StateEpoch(ret.Uint64()) + } + if isBlockReached(blockNumber, epoch1Block) { + return 1 + } + + return 0 +} + +// EpochExpired check pre epoch if expired compared to current epoch +func EpochExpired(pre StateEpoch, cur StateEpoch) bool { + return cur > pre && cur-pre >= StateEpochKeepLiveNum +} + +// isBlockReached check if reach expected block number +func isBlockReached(block, expected *big.Int) bool { + if block == nil || expected == nil { + return false + } + return block.Cmp(expected) >= 0 +} diff --git a/core/types/state_expiry.go b/core/types/state_expiry.go new file mode 100644 index 0000000000..7ae6870bba --- /dev/null +++ b/core/types/state_expiry.go @@ -0,0 +1,155 @@ +package types + +import ( + "errors" + "fmt" + "strings" + + "github.com/ethereum/go-ethereum/log" +) + +const ( + StateExpiryPruneLevel0 = iota // StateExpiryPruneLevel0 is for HBSS, in HBSS we cannot prune any expired snapshot, it need rebuild trie for old tire node prune, it also cannot prune any shared trie node too. + StateExpiryPruneLevel1 // StateExpiryPruneLevel1 is the default level, it left some expired snapshot meta for performance friendly. + StateExpiryPruneLevel2 // StateExpiryPruneLevel2 will prune all expired snapshot kvs and trie nodes, but it will access more times in tire when execution. TODO(0xbundler): will support it later +) + +type StateExpiryConfig struct { + Enable bool + FullStateEndpoint string + StateScheme string + PruneLevel uint8 + StateEpoch1Block uint64 + StateEpoch2Block uint64 + StateEpochPeriod uint64 + EnableLocalRevive bool + EnableRemoteMode bool `rlp:"optional"` // when enable remoteDB mode, it will register specific RPC for partial proof and keep sync behind for safety proof + AllowedPeerList []string `rlp:"optional"` // when enable remoteDB mode, it will only delay its sync for the peers in the list +} + +// EnableExpiry when enable remote mode, it just check param +func (s *StateExpiryConfig) EnableExpiry() bool { + if s == nil { + return false + } + return s.Enable && !s.EnableRemoteMode +} + +// EnableRemote when enable remote mode, it just check param +func (s *StateExpiryConfig) EnableRemote() bool { + if s == nil { + return false + } + return s.Enable && s.EnableRemoteMode +} + +func (s *StateExpiryConfig) Validation() error { + if s == nil || !s.Enable { + return nil + } + + s.FullStateEndpoint = strings.TrimSpace(s.FullStateEndpoint) + if s.StateEpoch1Block == 0 || + s.StateEpoch2Block == 0 || + s.StateEpochPeriod == 0 { + return errors.New("StateEpoch1Block or StateEpoch2Block or StateEpochPeriod cannot be 0") + } + + if s.StateEpoch1Block >= s.StateEpoch2Block { + return errors.New("StateEpoch1Block cannot >= StateEpoch2Block") + } + + if s.StateEpochPeriod < DefaultStateEpochPeriod { + log.Warn("The State Expiry state period is too small and may result in frequent expiration affecting performance", + "input", s.StateEpochPeriod, "default", DefaultStateEpochPeriod) + } + + return nil +} + +func (s *StateExpiryConfig) CheckCompatible(newCfg *StateExpiryConfig) error { + if s == nil || newCfg == nil { + return nil + } + + if s.Enable && !newCfg.Enable { + return errors.New("disable state expiry is dangerous after enabled, expired state may pruned") + } + if s.EnableRemoteMode && !newCfg.EnableRemoteMode { + return errors.New("disable state expiry EnableRemoteMode is dangerous after enabled") + } + + if err := s.CheckStateEpochCompatible(newCfg.StateEpoch1Block, newCfg.StateEpoch2Block, newCfg.StateEpochPeriod); err != nil { + return err + } + + if s.StateScheme != newCfg.StateScheme { + return errors.New("StateScheme is incompatible") + } + + if s.PruneLevel != newCfg.PruneLevel { + return errors.New("state expiry PruneLevel is incompatible") + } + + return nil +} + +func (s *StateExpiryConfig) CheckStateEpochCompatible(StateEpoch1Block, StateEpoch2Block, StateEpochPeriod uint64) error { + if s == nil { + return nil + } + + if s.StateEpoch1Block != StateEpoch1Block || + s.StateEpoch2Block != StateEpoch2Block || + s.StateEpochPeriod != StateEpochPeriod { + return fmt.Errorf("state Epoch info is incompatible, StateEpoch1Block: [%v|%v], StateEpoch2Block: [%v|%v], StateEpochPeriod: [%v|%v]", + s.StateEpoch1Block, StateEpoch1Block, s.StateEpoch2Block, StateEpoch2Block, s.StateEpochPeriod, StateEpochPeriod) + } + + return nil +} + +func (s *StateExpiryConfig) String() string { + if !s.Enable { + return "State Expiry Disable." + } + if s.Enable && s.EnableRemoteMode { + return "State Expiry Enable in RemoteMode, it will not expired any state." + } + return fmt.Sprintf("Enable State Expiry, RemoteEndpoint: %v, StateEpoch: [%v|%v|%v], StateScheme: %v, PruneLevel: %v, EnableLocalRevive: %v, AllowedPeerList: %v.", + s.FullStateEndpoint, s.StateEpoch1Block, s.StateEpoch2Block, s.StateEpochPeriod, s.StateScheme, s.PruneLevel, s.EnableLocalRevive, s.AllowedPeerList) +} + +// ShouldKeep1EpochBehind when enable state expiry, keep remoteDB behind the latest only 1 epoch blocks +func (s *StateExpiryConfig) ShouldKeep1EpochBehind(remote uint64, local uint64, peerId string) (bool, uint64) { + + if !s.EnableRemoteMode || remote <= local || remote < s.StateEpoch1Block { + return false, remote + } + + allowed := false + for _, allowPeer := range s.AllowedPeerList { + if allowPeer == peerId { + allowed = true + break + } + } + if !allowed { + return false, remote + } + + // if in epoch1, behind StateEpoch2Block-StateEpoch1Block + if remote < s.StateEpoch2Block { + if remote-(s.StateEpoch2Block-s.StateEpoch1Block) <= local { + return true, 0 + } + return false, remote - (s.StateEpoch2Block - s.StateEpoch1Block) + } + + // if in >= epoch2, behind StateEpochPeriod + if remote-s.StateEpochPeriod <= local { + return true, 0 + } + + return false, remote - s.StateEpochPeriod +} diff --git a/core/types/typed_trie_node.go b/core/types/typed_trie_node.go new file mode 100644 index 0000000000..dedb94eddd --- /dev/null +++ b/core/types/typed_trie_node.go @@ -0,0 +1,144 @@ +package types + +import ( + "errors" + "fmt" + "io" + + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + TrieNodeRawType = iota + TrieBranchNodeWithEpochType +) + +var ( + ErrTypedNodeNotSupport = errors.New("the typed node not support now") +) + +type TypedTrieNode interface { + Type() uint8 + EncodeToRLPBytes(buf *rlp.EncoderBuffer) +} + +type TrieNodeRaw []byte + +func (n TrieNodeRaw) Type() uint8 { + return TrieNodeRawType +} + +func (n TrieNodeRaw) EncodeToRLPBytes(buf *rlp.EncoderBuffer) { +} + +type TrieBranchNodeWithEpoch struct { + EpochMap [16]StateEpoch + Blob []byte +} + +func (n *TrieBranchNodeWithEpoch) Type() uint8 { + return TrieBranchNodeWithEpochType +} + +func (n *TrieBranchNodeWithEpoch) EncodeToRLPBytes(buf *rlp.EncoderBuffer) { + offset := buf.List() + mapOffset := buf.List() + for _, item := range n.EpochMap { + if item == 0 { + buf.Write(rlp.EmptyString) + } else { + buf.WriteUint64(uint64(item)) + } + } + buf.ListEnd(mapOffset) + buf.Write(n.Blob) + buf.ListEnd(offset) +} + +func DecodeTrieBranchNodeWithEpoch(enc []byte) (*TrieBranchNodeWithEpoch, error) { + var n TrieBranchNodeWithEpoch + if len(enc) == 0 { + return nil, io.ErrUnexpectedEOF + } + elems, _, err := rlp.SplitList(enc) + if err != nil { + return nil, fmt.Errorf("decode error: %v", err) + } + + maps, rest, err := rlp.SplitList(elems) + if err != nil { + return nil, fmt.Errorf("decode epochmap error: %v", err) + } + for i := 0; i < len(n.EpochMap); i++ { + var c uint64 + c, maps, err = rlp.SplitUint64(maps) + if err != nil { + return nil, fmt.Errorf("decode epochmap val error: %v", err) + } + n.EpochMap[i] = StateEpoch(c) + } + + k, content, _, err := rlp.Split(rest) + if err != nil { + return nil, fmt.Errorf("decode raw error: %v", err) + } + switch k { + case rlp.String: + n.Blob = content + case rlp.List: + n.Blob = rest + default: + return nil, fmt.Errorf("decode wrong raw type error: %v", err) + } + return &n, nil +} + +func EncodeTypedTrieNode(val TypedTrieNode) []byte { + switch raw := val.(type) { + case TrieNodeRaw: + return raw + case *TrieBranchNodeWithEpoch: + // encode with type prefix + w := rlp.NewEncoderBuffer(nil) + w.Write([]byte{val.Type()}) + val.EncodeToRLPBytes(&w) + result := w.ToBytes() + w.Flush() + return result + } + return nil +} + +func DecodeTypedTrieNode(enc []byte) (TypedTrieNode, error) { + if len(enc) == 0 { + return TrieNodeRaw{}, nil + } + if len(enc) == 1 || enc[0] > 0x7f { + return TrieNodeRaw(enc), nil + } + switch enc[0] { + case TrieBranchNodeWithEpochType: + return DecodeTrieBranchNodeWithEpoch(enc[1:]) + default: + return nil, ErrTypedNodeNotSupport + } +} + +func DecodeTypedTrieNodeRaw(enc []byte) ([]byte, error) { + if len(enc) == 0 { + return enc, nil + } + if len(enc) == 1 || enc[0] > 0x7f { + return enc, nil + } + switch enc[0] { + case TrieBranchNodeWithEpochType: + rn, err := DecodeTrieBranchNodeWithEpoch(enc[1:]) + if err != nil { + return nil, err + } + return rn.Blob, nil + default: + return nil, ErrTypedNodeNotSupport + } +} diff --git a/eth/api_backend.go b/eth/api_backend.go index 3192823148..7e802f2bc1 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -19,6 +19,7 @@ package eth import ( "context" "errors" + "fmt" "math/big" "time" @@ -102,7 +103,7 @@ func (b *EthAPIBackend) HeaderByNumberOrHash(ctx context.Context, blockNrOrHash if hash, ok := blockNrOrHash.Hash(); ok { header := b.eth.blockchain.GetHeaderByHash(hash) if header == nil { - return nil, errors.New("header for hash not found") + return nil, fmt.Errorf("header for hash not found, hash: %#x", hash) } if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { return nil, errors.New("hash is not currently canonical") @@ -169,7 +170,7 @@ func (b *EthAPIBackend) BlockByNumberOrHash(ctx context.Context, blockNrOrHash r if hash, ok := blockNrOrHash.Hash(); ok { header := b.eth.blockchain.GetHeaderByHash(hash) if header == nil { - return nil, errors.New("header for hash not found") + return nil, fmt.Errorf("header for hash not found, hash: %#x", hash) } if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { return nil, errors.New("hash is not currently canonical") @@ -204,7 +205,7 @@ func (b *EthAPIBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.B if header == nil { return nil, nil, errors.New("header not found") } - stateDb, err := b.eth.BlockChain().StateAt(header.Root) + stateDb, err := b.eth.BlockChain().StateAt(header.Root, header.Hash(), header.Number) return stateDb, header, err } @@ -218,17 +219,21 @@ func (b *EthAPIBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockN return nil, nil, err } if header == nil { - return nil, nil, errors.New("header for hash not found") + return nil, nil, fmt.Errorf("header for hash not found, hash: %#x", hash) } if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { return nil, nil, errors.New("hash is not currently canonical") } - stateDb, err := b.eth.BlockChain().StateAt(header.Root) + stateDb, err := b.eth.BlockChain().StateAt(header.Root, header.Hash(), header.Number) return stateDb, header, err } return nil, nil, errors.New("invalid arguments; neither block nor hash specified") } +func (b *EthAPIBackend) StorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash) (state.Trie, error) { + return b.eth.BlockChain().StorageTrie(stateRoot, addr, root) +} + func (b *EthAPIBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) { return b.eth.blockchain.GetReceiptsByHash(hash), nil } diff --git a/eth/api_debug.go b/eth/api_debug.go index 6afa046787..a3868faf0b 100644 --- a/eth/api_debug.go +++ b/eth/api_debug.go @@ -79,7 +79,7 @@ func (api *DebugAPI) DumpBlock(blockNr rpc.BlockNumber) (state.Dump, error) { if header == nil { return state.Dump{}, fmt.Errorf("block #%d not found", blockNr) } - stateDb, err := api.eth.BlockChain().StateAt(header.Root) + stateDb, err := api.eth.BlockChain().StateAt(header.Root, header.Hash(), header.Number) if err != nil { return state.Dump{}, err } @@ -164,7 +164,7 @@ func (api *DebugAPI) AccountRange(blockNrOrHash rpc.BlockNumberOrHash, start hex if header == nil { return state.IteratorDump{}, fmt.Errorf("block #%d not found", number) } - stateDb, err = api.eth.BlockChain().StateAt(header.Root) + stateDb, err = api.eth.BlockChain().StateAt(header.Root, header.Hash(), header.Number) if err != nil { return state.IteratorDump{}, err } @@ -174,7 +174,7 @@ func (api *DebugAPI) AccountRange(blockNrOrHash rpc.BlockNumberOrHash, start hex if block == nil { return state.IteratorDump{}, fmt.Errorf("block %s not found", hash.Hex()) } - stateDb, err = api.eth.BlockChain().StateAt(block.Root()) + stateDb, err = api.eth.BlockChain().StateAt(block.Root(), block.Hash(), block.Number()) if err != nil { return state.IteratorDump{}, err } diff --git a/eth/backend.go b/eth/backend.go index cf593fbb71..879757cf04 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -215,6 +215,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) { Preimages: config.Preimages, StateHistory: config.StateHistory, StateScheme: config.StateScheme, + StateExpiryCfg: config.StateExpiryCfg, } ) bcOps := make([]core.BlockChainOption, 0) @@ -275,6 +276,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) { DirectBroadcast: config.DirectBroadcast, DisablePeerTxBroadcast: config.DisablePeerTxBroadcast, PeerSet: peers, + expiryConfig: config.StateExpiryCfg, }); err != nil { return nil, err } diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 14d68844eb..f82bc7e346 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -103,6 +103,9 @@ type Downloader struct { stateDB ethdb.Database // Database to state sync into (and deduplicate via) + // state expiry + expiryConfig *types.StateExpiryConfig + // Statistics syncStatsChainOrigin uint64 // Origin block number where syncing started at syncStatsChainHeight uint64 // Highest block number known when syncing started @@ -207,6 +210,10 @@ type BlockChain interface { // TrieDB retrieves the low level trie database used for interacting // with trie nodes. TrieDB() *trie.Database + + Config() *params.ChainConfig + + StateExpiryConfig() *types.StateExpiryConfig } type DownloadOption func(downloader *Downloader) *Downloader @@ -235,6 +242,30 @@ func New(stateDb ethdb.Database, mux *event.TypeMux, chain BlockChain, lightchai return dl } +func NewWithExpiry(stateDb ethdb.Database, mux *event.TypeMux, chain BlockChain, lightchain LightChain, expiryConfig *types.StateExpiryConfig, dropPeer peerDropFn, options ...DownloadOption) *Downloader { + if lightchain == nil { + lightchain = chain + } + dl := &Downloader{ + stateDB: stateDb, + mux: mux, + queue: newQueue(blockCacheMaxItems, blockCacheInitialItems), + peers: newPeerSet(), + blockchain: chain, + lightchain: lightchain, + dropPeer: dropPeer, + expiryConfig: expiryConfig, + headerProcCh: make(chan *headerTask, 1), + quitCh: make(chan struct{}), + SnapSyncer: snap.NewSyncerWithStateExpiry(stateDb, chain.TrieDB().Scheme(), expiryConfig.EnableExpiry()), + stateSyncStart: make(chan *stateSync), + syncStartBlock: chain.CurrentSnapBlock().Number.Uint64(), + } + + go dl.stateFetcher() + return dl +} + // Progress retrieves the synchronisation boundaries, specifically the origin // block where synchronisation started at (may have failed/suspended); the block // or header sync is currently at; and the latest known block which the sync targets. @@ -494,6 +525,15 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td, ttd * localHeight = d.lightchain.CurrentHeader().Number.Uint64() } + if d.expiryConfig.EnableRemote() { + var keep bool + keep, remoteHeight = d.expiryConfig.ShouldKeep1EpochBehind(remoteHeight, localHeight, p.id) + log.Debug("EnableRemote wait remote more blocks", "remoteHeight", remoteHeader.Number, "request", remoteHeight, "localHeight", localHeight, "keep", keep, "config", d.expiryConfig) + if keep { + return errCanceled + } + } + origin, err := d.findAncestor(p, localHeight, remoteHeader) if err != nil { return err @@ -581,9 +621,9 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td, ttd * } fetchers := []func() error{ - func() error { return d.fetchHeaders(p, origin+1, remoteHeader.Number.Uint64()) }, // Headers are always retrieved - func() error { return d.fetchBodies(origin+1, beaconMode) }, // Bodies are retrieved during normal and snap sync - func() error { return d.fetchReceipts(origin+1, beaconMode) }, // Receipts are retrieved during snap sync + func() error { return d.fetchHeaders(p, origin+1, remoteHeight) }, // Headers are always retrieved + func() error { return d.fetchBodies(origin+1, beaconMode) }, // Bodies are retrieved during normal and snap sync + func() error { return d.fetchReceipts(origin+1, beaconMode) }, // Receipts are retrieved during snap sync func() error { return d.processHeaders(origin+1, td, ttd, beaconMode) }, } if mode == SnapSync { @@ -1169,6 +1209,15 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, head uint64) e return errCanceled } from += uint64(len(headers)) + // if EnableRemote, just return if ahead the head + if d.expiryConfig.EnableRemote() && from > head { + select { + case d.headerProcCh <- nil: + return nil + case <-d.cancelCh: + return errCanceled + } + } } // If we're still skeleton filling snap sync, check pivot staleness // before continuing to the next skeleton filling @@ -1464,10 +1513,14 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error { func (d *Downloader) processSnapSyncContent() error { // Start syncing state of the reported head block. This should get us most of // the state of the pivot block. + var epoch types.StateEpoch + d.pivotLock.RLock() - sync := d.syncState(d.pivotHeader.Root) + sync := d.syncStateWithEpoch(d.pivotHeader.Root, epoch) d.pivotLock.RUnlock() + epoch = types.GetStateEpoch(d.blockchain.StateExpiryConfig(), new(big.Int).SetUint64(d.pivotHeader.Number.Uint64())) + defer func() { // The `sync` object is replaced every time the pivot moves. We need to // defer close the very last active one, hence the lazy evaluation vs. @@ -1516,11 +1569,12 @@ func (d *Downloader) processSnapSyncContent() error { d.pivotLock.RLock() pivot := d.pivotHeader d.pivotLock.RUnlock() + epoch = types.GetStateEpoch(d.blockchain.StateExpiryConfig(), new(big.Int).SetUint64(pivot.Number.Uint64())) if oldPivot == nil { if pivot.Root != sync.root { sync.Cancel() - sync = d.syncState(pivot.Root) + sync = d.syncStateWithEpoch(pivot.Root, epoch) go closeOnErr(sync) } @@ -1558,7 +1612,7 @@ func (d *Downloader) processSnapSyncContent() error { // If new pivot block found, cancel old state retrieval and restart if oldPivot != P { sync.Cancel() - sync = d.syncState(P.Header.Root) + sync = d.syncStateWithEpoch(P.Header.Root, types.GetStateEpoch(d.blockchain.StateExpiryConfig(), new(big.Int).SetUint64(P.Header.Number.Uint64()))) go closeOnErr(sync) oldPivot = P diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index 501af63ed5..e7d9952ff2 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -20,6 +20,7 @@ import ( "sync" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" ) @@ -40,6 +41,22 @@ func (d *Downloader) syncState(root common.Hash) *stateSync { return s } +func (d *Downloader) syncStateWithEpoch(root common.Hash, epoch types.StateEpoch) *stateSync { + // Create the state sync + s := newStateSyncWithEpoch(d, root, epoch) + select { + case d.stateSyncStart <- s: + // If we tell the statesync to restart with a new root, we also need + // to wait for it to actually also start -- when old requests have timed + // out or been delivered + <-s.started + case <-d.quitCh: + s.err = errCancelStateFetch + close(s.done) + } + return s +} + // stateFetcher manages the active state sync and accepts requests // on its behalf. func (d *Downloader) stateFetcher() { @@ -77,8 +94,10 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync { // stateSync schedules requests for downloading a particular state trie defined // by a given state root. type stateSync struct { - d *Downloader // Downloader instance to access and manage current peerset - root common.Hash // State root currently being synced + d *Downloader // Downloader instance to access and manage current peerset + root common.Hash // State root currently being synced + epoch types.StateEpoch + enableStateExpiry bool started chan struct{} // Started is signalled once the sync loop starts cancel chan struct{} // Channel to signal a termination request @@ -99,11 +118,26 @@ func newStateSync(d *Downloader, root common.Hash) *stateSync { } } +func newStateSyncWithEpoch(d *Downloader, root common.Hash, epoch types.StateEpoch) *stateSync { + return &stateSync{ + d: d, + root: root, + epoch: epoch, + enableStateExpiry: true, + cancel: make(chan struct{}), + done: make(chan struct{}), + started: make(chan struct{}), + } +} + // run starts the task assignment and response processing loop, blocking until // it finishes, and finally notifying any goroutines waiting for the loop to // finish. func (s *stateSync) run() { close(s.started) + if s.enableStateExpiry { + s.d.SnapSyncer.UpdateEpoch(s.epoch) + } s.err = s.d.SnapSyncer.Sync(s.root, s.cancel) close(s.done) } diff --git a/eth/ethconfig/config.go b/eth/ethconfig/config.go index 384996e7fc..a8278c4ebf 100644 --- a/eth/ethconfig/config.go +++ b/eth/ethconfig/config.go @@ -21,6 +21,8 @@ import ( "errors" "time" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus/beacon" @@ -103,6 +105,9 @@ type Config struct { // to turn it on. DisablePeerTxBroadcast bool + // state expiry configs + StateExpiryCfg *types.StateExpiryConfig + // This can be set to list of enrtree:// URLs which will be queried for // for nodes to connect to. EthDiscoveryURLs []string diff --git a/eth/fetcher/block_fetcher.go b/eth/fetcher/block_fetcher.go index b2879cd25f..5d94395f3d 100644 --- a/eth/fetcher/block_fetcher.go +++ b/eth/fetcher/block_fetcher.go @@ -203,6 +203,9 @@ type BlockFetcher struct { fetchingHook func([]common.Hash) // Method to call upon starting a block (eth/61) or header (eth/62) fetch completingHook func([]common.Hash) // Method to call upon starting a block body fetch (eth/62) importedHook func(*types.Header, *types.Block) // Method to call upon successful header or block import (both eth/61 and eth/62) + + // state expiry + expiryConfig *types.StateExpiryConfig } // NewBlockFetcher creates a block fetcher to retrieve blocks based on hash announcements. @@ -383,6 +386,14 @@ func (f *BlockFetcher) loop() { f.forgetBlock(hash) continue } + + if f.expiryConfig.EnableRemote() { + if keep, _ := f.expiryConfig.ShouldKeep1EpochBehind(number, height, op.origin); keep { + log.Debug("BlockFetcher EnableRemote wait remote more blocks", "remoteHeight", number, "localHeight", height, "config", f.expiryConfig) + break + } + } + if f.light { f.importHeaders(op) } else { @@ -994,3 +1005,7 @@ func (f *BlockFetcher) forgetBlock(hash common.Hash) { delete(f.queued, hash) } } + +func (f *BlockFetcher) InitExpiryConfig(config *types.StateExpiryConfig) { + f.expiryConfig = config +} diff --git a/eth/handler.go b/eth/handler.go index d081c76266..f784629117 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -123,6 +123,7 @@ type handlerConfig struct { DirectBroadcast bool DisablePeerTxBroadcast bool PeerSet *peerSet + expiryConfig *types.StateExpiryConfig } type handler struct { @@ -134,6 +135,8 @@ type handler struct { acceptTxs atomic.Bool // Flag whether we're considered synchronised (enables transaction processing) directBroadcast bool + expiryConfig *types.StateExpiryConfig + database ethdb.Database txpool txPool votepool votePool @@ -196,6 +199,7 @@ func newHandler(config *handlerConfig) (*handler, error) { peersPerIP: make(map[string]int), requiredBlocks: config.RequiredBlocks, directBroadcast: config.DirectBroadcast, + expiryConfig: config.expiryConfig, quitSync: make(chan struct{}), handlerDoneCh: make(chan struct{}), handlerStartCh: make(chan struct{}), @@ -249,7 +253,7 @@ func newHandler(config *handlerConfig) (*handler, error) { downloadOptions = append(downloadOptions, success) */ - h.downloader = downloader.New(config.Database, h.eventMux, h.chain, nil, h.removePeer, downloadOptions...) + h.downloader = downloader.NewWithExpiry(config.Database, h.eventMux, h.chain, nil, config.expiryConfig, h.removePeer, downloadOptions...) // Construct the fetcher (short sync) validator := func(header *types.Header) error { @@ -335,6 +339,9 @@ func newHandler(config *handlerConfig) (*handler, error) { } h.blockFetcher = fetcher.NewBlockFetcher(false, nil, h.chain.GetBlockByHash, validator, h.BroadcastBlock, heighter, finalizeHeighter, nil, inserter, h.removePeer) + if config.expiryConfig != nil { + h.blockFetcher.InitExpiryConfig(config.expiryConfig) + } fetchTx := func(peer string, hashes []common.Hash) error { p := h.peers.peer(peer) diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go index f2f8ee2d2b..3d82823406 100644 --- a/eth/protocols/eth/handler_test.go +++ b/eth/protocols/eth/handler_test.go @@ -534,10 +534,11 @@ func testGetNodeData(t *testing.T, protocol uint, drop bool) { // Sanity check whether all state matches. accounts := []common.Address{testAddr, acc1Addr, acc2Addr} for i := uint64(0); i <= backend.chain.CurrentBlock().Number.Uint64(); i++ { - root := backend.chain.GetBlockByNumber(i).Root() + block := backend.chain.GetBlockByNumber(i) + root := block.Root() reconstructed, _ := state.New(root, state.NewDatabase(reconstructDB), nil) for j, acc := range accounts { - state, _ := backend.chain.StateAt(root) + state, _ := backend.chain.StateAt(root, block.Hash(), block.Number()) bw := state.GetBalance(acc) bh := reconstructed.GetBalance(acc) diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index b2fd03766e..e60fc0fafc 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -23,7 +23,10 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" @@ -381,13 +384,26 @@ func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesP storage []*StorageData last common.Hash abort bool + sv snapshot.SnapValue + hash common.Hash + slot []byte + enc []byte ) for it.Next() { if size >= hardLimit { abort = true break } - hash, slot := it.Hash(), common.CopyBytes(it.Slot()) + hash, enc = it.Hash(), common.CopyBytes(it.Slot()) + if len(enc) > 0 { + sv, err = snapshot.DecodeValueFromRLPBytes(enc) + if err != nil || sv == nil { + log.Warn("Failed to decode storage slot", "err", err) + return nil, nil + } + } + + slot = sv.GetVal() // Track the returned interval for the Merkle proofs last = hash @@ -429,13 +445,23 @@ func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesP } proof := light.NewNodeSet() if err := stTrie.Prove(origin[:], proof); err != nil { - log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err) - return nil, nil + if path, ok := trie.ParseExpiredNodeErr(err); ok { + err := reviveAndGetProof(chain.FullStateDB(), stTrie, req.Root, common.BytesToAddress(account[:]), acc.Root, path, origin, proof) + if err != nil { + log.Warn("Failed to prove storage range", "origin", origin, "err", err) + return nil, nil + } + } } if last != (common.Hash{}) { if err := stTrie.Prove(last[:], proof); err != nil { - log.Warn("Failed to prove storage range", "last", last, "err", err) - return nil, nil + if path, ok := trie.ParseExpiredNodeErr(err); ok { + err := reviveAndGetProof(chain.FullStateDB(), stTrie, req.Root, common.BytesToAddress(account[:]), acc.Root, path, last, proof) + if err != nil { + log.Warn("Failed to prove storage range", "origin", origin, "err", err) + return nil, nil + } + } } } for _, blob := range proof.NodeList() { @@ -567,6 +593,24 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s return nodes, nil } +func reviveAndGetProof(fullStateDB ethdb.FullStateDB, tr *trie.StateTrie, stateRoot common.Hash, account common.Address, root common.Hash, prefixKey []byte, key common.Hash, proofDb *light.NodeSet) error { + proofs, err := fullStateDB.GetStorageReviveProof(stateRoot, account, root, []string{common.Bytes2Hex(prefixKey)}, []string{common.Bytes2Hex(key[:])}) + if err != nil || len(proofs) == 0 { + return err + } + + _, err = state.ReviveStorageTrie(account, tr, proofs[0], key) + if err != nil { + return err + } + + if err := tr.Prove(key[:], proofDb); err != nil { + return err + } + + return nil +} + // NodeInfo represents a short summary of the `snap` sub-protocol metadata // known about the host peer. type NodeInfo struct{} diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go index f56a9480b9..3e56bf054e 100644 --- a/eth/protocols/snap/sync.go +++ b/eth/protocols/snap/sync.go @@ -34,6 +34,7 @@ import ( "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" @@ -408,6 +409,7 @@ type SyncPeer interface { // - The peer remains connected, but does not deliver a response in time // - The peer delivers a stale response after a previous timeout // - The peer delivers a refusal to serve the requested state + type Syncer struct { db ethdb.KeyValueStore // Database to store the trie nodes into (and dedup) scheme string // Node scheme used in node database @@ -423,6 +425,9 @@ type Syncer struct { peerDrop *event.Feed // Event feed to react to peers dropping rates *msgrate.Trackers // Message throughput rates for peers + enableStateExpiry bool + epoch types.StateEpoch + // Request tracking during syncing phase statelessPeers map[string]struct{} // Peers that failed to deliver state data accountIdlers map[string]struct{} // Peers that aren't serving account requests @@ -509,6 +514,39 @@ func NewSyncer(db ethdb.KeyValueStore, scheme string) *Syncer { } } +func NewSyncerWithStateExpiry(db ethdb.KeyValueStore, scheme string, enableStateExpiry bool) *Syncer { + return &Syncer{ + db: db, + scheme: scheme, + + peers: make(map[string]SyncPeer), + peerJoin: new(event.Feed), + peerDrop: new(event.Feed), + rates: msgrate.NewTrackers(log.New("proto", "snap")), + update: make(chan struct{}, 1), + + enableStateExpiry: enableStateExpiry, + + accountIdlers: make(map[string]struct{}), + storageIdlers: make(map[string]struct{}), + bytecodeIdlers: make(map[string]struct{}), + + accountReqs: make(map[uint64]*accountRequest), + storageReqs: make(map[uint64]*storageRequest), + bytecodeReqs: make(map[uint64]*bytecodeRequest), + + trienodeHealIdlers: make(map[string]struct{}), + bytecodeHealIdlers: make(map[string]struct{}), + + trienodeHealReqs: make(map[uint64]*trienodeHealRequest), + bytecodeHealReqs: make(map[uint64]*bytecodeHealRequest), + trienodeHealThrottle: maxTrienodeHealThrottle, // Tune downward instead of insta-filling with junk + stateWriter: db.NewBatch(), + + extProgress: new(SyncProgress), + } +} + // Register injects a new data source into the syncer's peerset. func (s *Syncer) Register(peer SyncPeer) error { // Make sure the peer is not registered yet @@ -572,10 +610,16 @@ func (s *Syncer) Unregister(id string) error { func (s *Syncer) Sync(root common.Hash, cancel chan struct{}) error { // Move the trie root from any previous value, revert stateless markers for // any peers and initialize the syncer if it was not yet run + var scheduler *trie.Sync s.lock.Lock() s.root = root + if s.enableStateExpiry { + scheduler = state.NewStateSyncWithExpiry(root, s.db, s.onHealState, s.scheme, s.epoch) + } else { + scheduler = state.NewStateSync(root, s.db, s.onHealState, s.scheme) + } s.healer = &healTask{ - scheduler: state.NewStateSync(root, s.db, s.onHealState, s.scheme), + scheduler: scheduler, trieTasks: make(map[string]common.Hash), codeTasks: make(map[common.Hash]struct{}), } @@ -717,6 +761,10 @@ func (s *Syncer) Sync(root common.Hash, cancel chan struct{}) error { } } +func (s *Syncer) UpdateEpoch(epoch types.StateEpoch) { + s.epoch = epoch +} + // loadSyncStatus retrieves a previously aborted sync status from the database, // or generates a fresh one if none is available. func (s *Syncer) loadSyncStatus() { @@ -2008,14 +2056,24 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { s.storageBytes += common.StorageSize(len(key) + len(value)) }, } + var genTrie *trie.StackTrie + if s.enableStateExpiry { + genTrie = trie.NewStackTrieWithStateExpiry(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) + }, func(owner common.Hash, path []byte, blob []byte) { + rawdb.WriteEpochMetaPlainState(batch, owner, string(path), blob) + }, account, s.epoch) + } else { + genTrie = trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) + }, account) + } tasks = append(tasks, &storageTask{ Next: common.Hash{}, Last: r.End(), root: acc.Root, genBatch: batch, - genTrie: trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { - rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) - }, account), + genTrie: genTrie, }) for r.Next() { batch := ethdb.HookedBatch{ @@ -2024,14 +2082,23 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { s.storageBytes += common.StorageSize(len(key) + len(value)) }, } + if s.enableStateExpiry { + genTrie = trie.NewStackTrieWithStateExpiry(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) + }, func(owner common.Hash, path []byte, blob []byte) { + rawdb.WriteEpochMetaPlainState(batch, owner, string(path), blob) + }, account, s.epoch) + } else { + genTrie = trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) + }, account) + } tasks = append(tasks, &storageTask{ Next: r.Start(), Last: r.End(), root: acc.Root, genBatch: batch, - genTrie: trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { - rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) - }, account), + genTrie: genTrie, }) } for _, task := range tasks { @@ -2076,9 +2143,18 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { slots += len(res.hashes[i]) if i < len(res.hashes)-1 || res.subTask == nil { - tr := trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { - rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) - }, account) + var tr *trie.StackTrie + if s.enableStateExpiry { + tr = trie.NewStackTrieWithStateExpiry(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) + }, func(owner common.Hash, path []byte, blob []byte) { + rawdb.WriteEpochMetaPlainState(batch, owner, string(path), blob) + }, account, s.epoch) + } else { + tr = trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { + rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) + }, account) + } for j := 0; j < len(res.hashes[i]); j++ { tr.Update(res.hashes[i][j][:], res.slots[i][j]) } @@ -2088,7 +2164,13 @@ func (s *Syncer) processStorageResponse(res *storageResponse) { // outdated during the sync, but it can be fixed later during the // snapshot generation. for j := 0; j < len(res.hashes[i]); j++ { - rawdb.WriteStorageSnapshot(batch, account, res.hashes[i][j], res.slots[i][j]) + var snapVal []byte + if s.enableStateExpiry { + snapVal, _ = snapshot.EncodeValueToRLPBytes(snapshot.NewValueWithEpoch(s.epoch, res.slots[i][j])) + } else { + snapVal, _ = rlp.EncodeToBytes(res.slots[i][j]) + } + rawdb.WriteStorageSnapshot(batch, account, res.hashes[i][j], snapVal) // If we're storing large contracts, generate the trie nodes // on the fly to not trash the gluing points diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index 1514ad4e13..25cbef9d28 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -640,6 +640,16 @@ func setupSyncer(scheme string, peers ...*testPeer) *Syncer { return syncer } +func setupSyncerWithExpiry(scheme string, expiry bool, peers ...*testPeer) *Syncer { + stateDb := rawdb.NewMemoryDatabase() + syncer := NewSyncerWithStateExpiry(stateDb, scheme, expiry) + for _, peer := range peers { + syncer.Register(peer) + peer.remote = syncer + } + return syncer +} + // TestSync tests a basic sync with one peer func TestSync(t *testing.T) { t.Parallel() @@ -750,6 +760,7 @@ func TestSyncWithStorage(t *testing.T) { testSyncWithStorage(t, rawdb.HashScheme) testSyncWithStorage(t, rawdb.PathScheme) + testSyncWithStorageStateExpiry(t, rawdb.PathScheme) } func testSyncWithStorage(t *testing.T, scheme string) { @@ -762,7 +773,7 @@ func testSyncWithStorage(t *testing.T, scheme string) { }) } ) - nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 3, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 3, 3000, true, false, false) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) @@ -781,6 +792,35 @@ func testSyncWithStorage(t *testing.T, scheme string) { verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t) } +func testSyncWithStorageStateExpiry(t *testing.T, scheme string) { + var ( + once sync.Once + cancel = make(chan struct{}) + term = func() { + once.Do(func() { + close(cancel) + }) + } + ) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 3, 3000, true, false, true) + mkSource := func(name string) *testPeer { + source := newTestPeer(name, t, term) + source.accountTrie = sourceAccountTrie.Copy() + source.accountValues = elems + source.setStorageTries(storageTries) + source.storageValues = storageElems + return source + } + syncer := setupSyncerWithExpiry(nodeScheme, true, mkSource("sourceA")) + syncer.UpdateEpoch(10) + done := checkStall(t, term) + if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { + t.Fatalf("sync failed: %v", err) + } + close(done) + verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t) +} + // TestMultiSyncManyUseless contains one good peer, and many which doesn't return anything valuable at all func TestMultiSyncManyUseless(t *testing.T) { t.Parallel() @@ -799,7 +839,7 @@ func testMultiSyncManyUseless(t *testing.T, scheme string) { }) } ) - nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false) mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { source := newTestPeer(name, t, term) @@ -853,7 +893,7 @@ func testMultiSyncManyUselessWithLowTimeout(t *testing.T, scheme string) { }) } ) - nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false) mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { source := newTestPeer(name, t, term) @@ -912,7 +952,7 @@ func testMultiSyncManyUnresponsive(t *testing.T, scheme string) { }) } ) - nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false) mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { source := newTestPeer(name, t, term) @@ -1215,7 +1255,7 @@ func testSyncBoundaryStorageTrie(t *testing.T, scheme string) { }) } ) - nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 10, 1000, false, true) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 10, 1000, false, true, false) mkSource := func(name string) *testPeer { source := newTestPeer(name, t, term) @@ -1257,7 +1297,7 @@ func testSyncWithStorageAndOneCappedPeer(t *testing.T, scheme string) { }) } ) - nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 300, 1000, false, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 300, 1000, false, false, false) mkSource := func(name string, slow bool) *testPeer { source := newTestPeer(name, t, term) @@ -1304,7 +1344,7 @@ func testSyncWithStorageAndCorruptPeer(t *testing.T, scheme string) { }) } ) - nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false) mkSource := func(name string, handler storageHandlerFunc) *testPeer { source := newTestPeer(name, t, term) @@ -1348,7 +1388,7 @@ func testSyncWithStorageAndNonProvingPeer(t *testing.T, scheme string) { }) } ) - nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) + nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false) mkSource := func(name string, handler storageHandlerFunc) *testPeer { source := newTestPeer(name, t, term) @@ -1608,16 +1648,22 @@ func makeAccountTrieWithStorageWithUniqueStorage(scheme string, accounts, slots } // makeAccountTrieWithStorage spits out a trie, along with the leafs -func makeAccountTrieWithStorage(scheme string, accounts, slots int, code, boundary bool) (string, *trie.Trie, []*kv, map[common.Hash]*trie.Trie, map[common.Hash][]*kv) { +func makeAccountTrieWithStorage(scheme string, accounts, slots int, code, boundary bool, expiry bool) (string, *trie.Trie, []*kv, map[common.Hash]*trie.Trie, map[common.Hash][]*kv) { var ( - db = trie.NewDatabase(rawdb.NewMemoryDatabase(), newDbConfig(scheme)) - accTrie = trie.NewEmpty(db) + db *trie.Database + accTrie *trie.Trie entries []*kv storageRoots = make(map[common.Hash]common.Hash) storageTries = make(map[common.Hash]*trie.Trie) storageEntries = make(map[common.Hash][]*kv) nodes = trienode.NewMergedNodeSet() ) + if expiry { + db = trie.NewDatabase(rawdb.NewMemoryDatabase(), newDbConfigWithExpiry(scheme, expiry)) + } else { + db = trie.NewDatabase(rawdb.NewMemoryDatabase(), newDbConfig(scheme)) + } + accTrie = trie.NewEmpty(db) // Create n accounts in the trie for i := uint64(1); i <= uint64(accounts); i++ { key := key32(i) @@ -1897,3 +1943,10 @@ func newDbConfig(scheme string) *trie.Config { } return &trie.Config{PathDB: pathdb.Defaults} } + +func newDbConfigWithExpiry(scheme string, expiry bool) *trie.Config { + if scheme == rawdb.HashScheme { + return &trie.Config{} + } + return &trie.Config{PathDB: pathdb.Defaults, EnableStateExpiry: expiry} +} diff --git a/eth/state_accessor.go b/eth/state_accessor.go index 71a38253ef..0a9b9c985f 100644 --- a/eth/state_accessor.go +++ b/eth/state_accessor.go @@ -53,8 +53,8 @@ func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, reexec u // The state is available in live database, create a reference // on top to prevent garbage collection and return a release // function to deref it. - if statedb, err = eth.blockchain.StateAt(block.Root()); err == nil { - eth.blockchain.TrieDB().Reference(block.Root(), common.Hash{}) + if statedb, err = eth.blockchain.StateAt(block.Root(), block.Hash(), block.Number()); err == nil { + statedb.Database().TrieDB().Reference(block.Root(), common.Hash{}) return statedb, func() { eth.blockchain.TrieDB().Dereference(block.Root()) }, nil @@ -71,6 +71,9 @@ func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, reexec u // please re-enable it for better performance. database = state.NewDatabaseWithConfig(eth.chainDb, trie.HashDefaults) if statedb, err = state.New(block.Root(), database, nil); err == nil { + if eth.blockchain.EnableStateExpiry() { + statedb.InitStateExpiryFeature(eth.blockchain.StateExpiryConfig(), eth.blockchain.FullStateDB(), block.Hash(), block.Number()) + } log.Info("Found disk backend for state trie", "root", block.Root(), "number", block.Number()) return statedb, noopReleaser, nil } @@ -98,6 +101,9 @@ func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, reexec u if !readOnly { statedb, err = state.New(current.Root(), database, nil) if err == nil { + if eth.blockchain.EnableStateExpiry() { + statedb.InitStateExpiryFeature(eth.blockchain.StateExpiryConfig(), eth.blockchain.FullStateDB(), current.Hash(), block.Number()) + } return statedb, noopReleaser, nil } } @@ -117,6 +123,9 @@ func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, reexec u statedb, err = state.New(current.Root(), database, nil) if err == nil { + if eth.blockchain.EnableStateExpiry() { + statedb.InitStateExpiryFeature(eth.blockchain.StateExpiryConfig(), eth.blockchain.FullStateDB(), current.Hash(), block.Number()) + } break } } @@ -166,6 +175,9 @@ func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, reexec u if err != nil { return nil, nil, fmt.Errorf("state reset after block %d failed: %v", current.NumberU64(), err) } + if eth.blockchain.EnableStateExpiry() { + statedb.InitStateExpiryFeature(eth.blockchain.StateExpiryConfig(), eth.blockchain.FullStateDB(), current.Hash(), new(big.Int).Add(current.Number(), common.Big1)) + } // Hold the state reference and also drop the parent state // to prevent accumulating too many nodes in memory. triedb.Reference(root, common.Hash{}) @@ -183,7 +195,7 @@ func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, reexec u func (eth *Ethereum) pathState(block *types.Block) (*state.StateDB, func(), error) { // Check if the requested state is available in the live chain. - statedb, err := eth.blockchain.StateAt(block.Root()) + statedb, err := eth.blockchain.StateAt(block.Root(), block.Hash(), block.Number()) if err == nil { return statedb, noopReleaser, nil } diff --git a/eth/tracers/api_test.go b/eth/tracers/api_test.go index c665f8c32b..44b3595215 100644 --- a/eth/tracers/api_test.go +++ b/eth/tracers/api_test.go @@ -141,7 +141,7 @@ func (b *testBackend) teardown() { } func (b *testBackend) StateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, readOnly bool, preferDisk bool) (*state.StateDB, StateReleaseFunc, error) { - statedb, err := b.chain.StateAt(block.Root()) + statedb, err := b.chain.StateAt(block.Root(), block.Hash(), block.Number()) if err != nil { return nil, nil, errStateNotFound } diff --git a/ethclient/gethclient/gethclient.go b/ethclient/gethclient/gethclient.go index c029611678..b219d0a9bc 100644 --- a/ethclient/gethclient/gethclient.go +++ b/ethclient/gethclient/gethclient.go @@ -125,6 +125,25 @@ func (ec *Client) GetProof(ctx context.Context, account common.Address, keys []s return &result, err } +// GetStorageReviveProof returns the proof for the given keys. Prefix keys can be specified to obtain partial proof for a given key. +// Both keys and prefix keys should have the same length. If user wish to obtain full proof for a given key, the corresponding prefix key should be empty string. +func (ec *Client) GetStorageReviveProof(ctx context.Context, stateRoot common.Hash, account common.Address, root common.Hash, keys []string, prefixKeys []string) (*types.ReviveResult, error) { + type reviveResult struct { + StorageProof []types.ReviveStorageProof `json:"storageProof"` + BlockNum uint64 `json:"blockNum"` + } + + var err error + var res reviveResult + + err = ec.c.CallContext(ctx, &res, "eth_getStorageReviveProof", stateRoot, account, root, keys, prefixKeys) + + return &types.ReviveResult{ + StorageProof: res.StorageProof, + BlockNum: res.BlockNum, + }, err +} + // CallContract executes a message call transaction, which is directly executed in the VM // of the node, but never mined into the blockchain. // diff --git a/ethclient/gethclient/gethclient_test.go b/ethclient/gethclient/gethclient_test.go index 65adae8ea7..29b4b95f5f 100644 --- a/ethclient/gethclient/gethclient_test.go +++ b/ethclient/gethclient/gethclient_test.go @@ -93,7 +93,7 @@ func generateTestChain() (*core.Genesis, []*types.Block) { } func TestGethClient(t *testing.T) { - backend, _ := newTestBackend(t) + backend, blocks := newTestBackend(t) client := backend.Attach() defer backend.Close() defer client.Close() @@ -105,6 +105,9 @@ func TestGethClient(t *testing.T) { { "TestGetProof", func(t *testing.T) { testGetProof(t, client) }, + }, { + "TestGetStorageReviveProof", + func(t *testing.T) { testGetStorageReviveProof(t, client, blocks[0]) }, }, { "TestGetProofCanonicalizeKeys", func(t *testing.T) { testGetProofCanonicalizeKeys(t, client) }, @@ -236,6 +239,29 @@ func testGetProof(t *testing.T, client *rpc.Client) { } } +func testGetStorageReviveProof(t *testing.T, client *rpc.Client, block *types.Block) { + ec := New(client) + result, err := ec.GetStorageReviveProof(context.Background(), block.Header().Root, testAddr, block.Header().Root, []string{testSlot.String()}, []string{""}) + proofs := result.StorageProof + + if err != nil { + t.Fatal(err) + } + + // test storage + if len(proofs) != 1 { + t.Fatalf("invalid storage proof, want 1 proof, got %v proof(s)", len(proofs)) + } + + if proofs[0].Key != testSlot.String() { + t.Fatalf("invalid storage proof key, want: %q, got: %q", testSlot.String(), proofs[0].Key) + } + + if proofs[0].PrefixKey != "" { + t.Fatalf("invalid storage proof prefix key, want: %q, got: %q", "", proofs[0].PrefixKey) + } +} + func testGetProofCanonicalizeKeys(t *testing.T, client *rpc.Client) { ec := New(client) diff --git a/ethdb/fullstatedb.go b/ethdb/fullstatedb.go new file mode 100644 index 0000000000..367351f5e4 --- /dev/null +++ b/ethdb/fullstatedb.go @@ -0,0 +1,120 @@ +package ethdb + +import ( + "bytes" + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/rpc" + lru "github.com/hashicorp/golang-lru" +) + +var ( + getProofMeter = metrics.NewRegisteredMeter("ethdb/fullstatedb/getproof", nil) + getProofHitCacheMeter = metrics.NewRegisteredMeter("ethdb/fullstatedb/getproof/cache", nil) + getStorageProofTimer = metrics.NewRegisteredTimer("ethdb/fullstatedb/getproof/rt", nil) +) + +// FullStateDB expired state could fetch from it +type FullStateDB interface { + // GetStorageReviveProof fetch target proof according to specific params + GetStorageReviveProof(stateRoot common.Hash, account common.Address, root common.Hash, prefixKeys, keys []string) ([]types.ReviveStorageProof, error) +} + +type FullStateRPCServer struct { + endpoint string + client *rpc.Client + cache *lru.Cache +} + +func NewFullStateRPCServer(endpoint string) (FullStateDB, error) { + if endpoint == "" { + return nil, errors.New("endpoint must be specified") + } + if strings.HasPrefix(endpoint, "rpc:") || strings.HasPrefix(endpoint, "ipc:") { + // Backwards compatibility with geth < 1.5 which required + // these prefixes. + endpoint = endpoint[4:] + } + // TODO(0xbundler): add more opts, like auth, cache size? + client, err := rpc.DialOptions(context.Background(), endpoint) + if err != nil { + return nil, err + } + + cache, err := lru.New(10000) + if err != nil { + return nil, err + } + return &FullStateRPCServer{ + endpoint: endpoint, + client: client, + cache: cache, + }, nil +} + +func (f *FullStateRPCServer) GetStorageReviveProof(stateRoot common.Hash, account common.Address, root common.Hash, prefixKeys, keys []string) ([]types.ReviveStorageProof, error) { + defer func(start time.Time) { + getStorageProofTimer.Update(time.Since(start)) + }(time.Now()) + + var result types.ReviveResult + + getProofMeter.Mark(int64(len(keys))) + // find from lru cache, now it cache key proof + uncachedPrefixKeys := make([]string, 0, len(prefixKeys)) + uncachedKeys := make([]string, 0, len(keys)) + ret := make([]types.ReviveStorageProof, 0, len(keys)) + for i, key := range keys { + val, ok := f.cache.Get(ProofCacheKey(account, root, prefixKeys[i], key)) + log.Debug("GetStorageReviveProof hit cache", "account", account, "key", key, "ok", ok) + if !ok { + uncachedPrefixKeys = append(uncachedPrefixKeys, prefixKeys[i]) + uncachedKeys = append(uncachedKeys, keys[i]) + continue + } + getProofHitCacheMeter.Mark(1) + ret = append(ret, val.(types.ReviveStorageProof)) + } + if len(uncachedKeys) == 0 { + return ret, nil + } + + // TODO(0xbundler): add timeout in flags? + ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelFunc() + err := f.client.CallContext(ctx, &result, "eth_getStorageReviveProof", stateRoot, account, root, uncachedKeys, uncachedPrefixKeys) + if err != nil { + return nil, fmt.Errorf("failed to get storage revive proof, err: %v, remote's block number: %v", err, result.BlockNum) + } + if len(result.Err) > 0 { + return nil, fmt.Errorf("failed to get storage revive proof, err: %v, remote's block number: %v", result.Err, result.BlockNum) + } + + // add to cache + for _, proof := range result.StorageProof { + f.cache.Add(ProofCacheKey(account, root, proof.PrefixKey, proof.Key), proof) + } + + ret = append(ret, result.StorageProof...) + return ret, err +} + +func ProofCacheKey(account common.Address, root common.Hash, prefix, key string) string { + buf := bytes.NewBuffer(make([]byte, 0, 67+len(prefix)+len(key))) + buf.Write(account[:]) + buf.WriteByte('$') + buf.Write(root[:]) + buf.WriteByte('$') + buf.WriteString(common.No0xPrefix(prefix)) + buf.WriteByte('$') + buf.WriteString(common.No0xPrefix(key)) + return buf.String() +} diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 810e8bba60..496502b14e 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -25,6 +25,8 @@ import ( "strings" "time" + "github.com/ethereum/go-ethereum/metrics" + "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/accounts" @@ -44,15 +46,20 @@ import ( "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/tracers/logger" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" + lru "github.com/hashicorp/golang-lru" "github.com/tyler-smith/go-bip39" ) -const UnHealthyTimeout = 5 * time.Second +const ( + UnHealthyTimeout = 5 * time.Second + APICache = 10000 +) // max is a helper function which returns the larger of the two given integers. func max(a, b int64) int64 { @@ -64,6 +71,11 @@ func max(a, b int64) int64 { // PublicEthereumAPI provides an API to access Ethereum related information. // It offers only methods that operate on public data that is freely available to anyone. +var ( + getStorageProofTimer = metrics.NewRegisteredTimer("ethapi/getstorageproof/rt", nil) +) + +// EthereumAPI provides an API to access Ethereum related information. type EthereumAPI struct { b Backend } @@ -617,12 +629,20 @@ func (s *PersonalAccountAPI) Unpair(ctx context.Context, url string, pin string) // BlockChainAPI provides an API to access Ethereum blockchain data. type BlockChainAPI struct { - b Backend + b Backend + cache *lru.Cache } // NewBlockChainAPI creates a new Ethereum blockchain API. func NewBlockChainAPI(b Backend) *BlockChainAPI { - return &BlockChainAPI{b} + cache, err := lru.New(APICache) + if err != nil { + return nil + } + return &BlockChainAPI{ + b: b, + cache: cache, + } } // ChainId is the EIP-155 replay-protection chain id for the current Ethereum chain config. @@ -757,6 +777,94 @@ func (s *BlockChainAPI) GetProof(ctx context.Context, address common.Address, st }, state.Error() } +// GetStorageReviveProof returns the proof for the given keys. Prefix keys can be specified to obtain partial proof for a given key. +// Both keys and prefix keys should have the same length. If user wish to obtain full proof for a given key, the corresponding prefix key should be empty string. +func (s *BlockChainAPI) GetStorageReviveProof(ctx context.Context, stateRoot common.Hash, address common.Address, root common.Hash, storageKeys []string, storagePrefixKeys []string) (*types.ReviveResult, error) { + defer func(start time.Time) { + getStorageProofTimer.Update(time.Since(start)) + }(time.Now()) + + if len(storageKeys) != len(storagePrefixKeys) { + return nil, errors.New("storageKeys and storagePrefixKeys must be same length") + } + + var ( + blockNum uint64 + keys = make([]common.Hash, len(storageKeys)) + keyLengths = make([]int, len(storageKeys)) + prefixKeys = make([][]byte, len(storagePrefixKeys)) + storageProof = make([]types.ReviveStorageProof, len(storageKeys)) + ) + + // try open target state root trie + storageTrie, err := s.b.StorageTrie(stateRoot, address, root) + if err != nil { + // if there cannot find target state root, try latest trie + stateDb, header, err := s.b.StateAndHeaderByNumber(ctx, rpc.LatestBlockNumber) + if err != nil { + return nil, err + } + blockNum = header.Number.Uint64() + storageTrie, err = stateDb.StorageTrie(address) + if err != nil { + return types.NewReviveErrResult(err, blockNum), nil + } + log.Debug("GetStorageReviveProof from latest block number", "blockNum", blockNum, "blockHash", header.Hash()) + } + + // Deserialize all keys. This prevents state access on invalid input. + for i, hexKey := range storageKeys { + keys[i], keyLengths[i], err = decodeHash(hexKey) + if err != nil { + return types.NewReviveErrResult(err, blockNum), nil + } + } + + // Decode prefix keys + for i, prefixKey := range storagePrefixKeys { + prefixKeys[i], err = hex.DecodeString(prefixKey) + if err != nil { + return types.NewReviveErrResult(err, blockNum), nil + } + } + + // Create the proofs for the storageKeys. + for i, key := range keys { + // Output key encoding is a bit special: if the input was a 32-byte hash, it is + // returned as such. Otherwise, we apply the QUANTITY encoding mandated by the + // JSON-RPC spec for getProof. This behavior exists to preserve backwards + // compatibility with older client versions. + var outputKey string + if keyLengths[i] != 32 { + outputKey = hexutil.EncodeBig(key.Big()) + } else { + outputKey = hexutil.Encode(key[:]) + } + + var proof proofList + prefixKey := prefixKeys[i] + + // Check if request has been cached + val, ok := s.cache.Get(ethdb.ProofCacheKey(address, root, storagePrefixKeys[i], storageKeys[i])) + if ok { + storageProof[i] = val.(types.ReviveStorageProof) + continue + } + + if err := storageTrie.ProveByPath(crypto.Keccak256(key.Bytes()), prefixKey, &proof); err != nil { + return types.NewReviveErrResult(err, blockNum), nil + } + storageProof[i] = types.ReviveStorageProof{ + Key: outputKey, + PrefixKey: storagePrefixKeys[i], + Proof: proof, + } + s.cache.Add(ethdb.ProofCacheKey(address, root, storagePrefixKeys[i], storageKeys[i]), storageProof[i]) + } + + return types.NewReviveResult(storageProof, blockNum), nil +} + // decodeHash parses a hex-encoded 32-byte hash. The input may optionally // be prefixed by 0x and can have a byte length up to 32. func decodeHash(s string) (h common.Hash, inputLength int, err error) { @@ -1385,11 +1493,11 @@ func (s *BlockChainAPI) needToReplay(ctx context.Context, block *types.Block, ac if err != nil { return false, fmt.Errorf("block not found for block number (%d): %v", block.NumberU64()-1, err) } - parentState, err := s.b.Chain().StateAt(parent.Root()) + parentState, err := s.b.Chain().StateAt(parent.Root(), parent.Hash(), parent.Number()) if err != nil { return false, fmt.Errorf("statedb not found for block number (%d): %v", block.NumberU64()-1, err) } - currentState, err := s.b.Chain().StateAt(block.Root()) + currentState, err := s.b.Chain().StateAt(block.Root(), block.Hash(), block.Number()) if err != nil { return false, fmt.Errorf("statedb not found for block number (%d): %v", block.NumberU64(), err) } @@ -1415,7 +1523,7 @@ func (s *BlockChainAPI) replay(ctx context.Context, block *types.Block, accounts if err != nil { return nil, nil, fmt.Errorf("block not found for block number (%d): %v", block.NumberU64()-1, err) } - statedb, err := s.b.Chain().StateAt(parent.Root()) + statedb, err := s.b.Chain().StateAt(parent.Root(), parent.Hash(), block.Number()) if err != nil { return nil, nil, fmt.Errorf("state not found for block number (%d): %v", block.NumberU64()-1, err) } @@ -2195,9 +2303,9 @@ func SubmitTransaction(ctx context.Context, b Backend, tx *types.Transaction) (c if tx.To() == nil { addr := crypto.CreateAddress(from, tx.Nonce()) - log.Info("Submitted contract creation", "hash", tx.Hash().Hex(), "from", from, "nonce", tx.Nonce(), "contract", addr.Hex(), "value", tx.Value(), "x-forward-ip", xForward) + log.Debug("Submitted contract creation", "hash", tx.Hash().Hex(), "from", from, "nonce", tx.Nonce(), "contract", addr.Hex(), "value", tx.Value(), "x-forward-ip", xForward) } else { - log.Info("Submitted transaction", "hash", tx.Hash().Hex(), "from", from, "nonce", tx.Nonce(), "recipient", tx.To(), "value", tx.Value(), "x-forward-ip", xForward) + log.Debug("Submitted transaction", "hash", tx.Hash().Hex(), "from", from, "nonce", tx.Nonce(), "recipient", tx.To(), "value", tx.Value(), "x-forward-ip", xForward) } return tx.Hash(), nil } diff --git a/internal/ethapi/api_test.go b/internal/ethapi/api_test.go index 6d61d6103c..a5d9e332a8 100644 --- a/internal/ethapi/api_test.go +++ b/internal/ethapi/api_test.go @@ -370,6 +370,10 @@ func newTestBackend(t *testing.T, n int, gspec *core.Genesis, generator func(i i return backend } +func (b *testBackend) StorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash) (state.Trie, error) { + panic("not implemented") +} + // nolint:unused func (b *testBackend) setPendingBlock(block *types.Block) { b.pending = block @@ -454,7 +458,7 @@ func (b testBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.Bloc if header == nil { return nil, nil, errors.New("header not found") } - stateDb, err := b.chain.StateAt(header.Root) + stateDb, err := b.chain.StateAt(header.Root, header.Hash(), header.Number) return stateDb, header, err } func (b testBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) { diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index d71d7e8eba..a6f0c31b88 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -67,6 +67,7 @@ type Backend interface { BlockByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Block, error) StateAndHeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*state.StateDB, *types.Header, error) StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) + StorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash) (state.Trie, error) PendingBlockAndReceipts() (*types.Block, types.Receipts) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) GetTd(ctx context.Context, hash common.Hash) *big.Int diff --git a/internal/ethapi/transaction_args_test.go b/internal/ethapi/transaction_args_test.go index fc42df3ddb..a5b53bab6a 100644 --- a/internal/ethapi/transaction_args_test.go +++ b/internal/ethapi/transaction_args_test.go @@ -243,6 +243,10 @@ func newBackendMock() *backendMock { } } +func (b *backendMock) StorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash) (state.Trie, error) { + panic("not implemented") +} + func (b *backendMock) activateLondon() { b.current.Number = big.NewInt(1100) } diff --git a/internal/flags/categories.go b/internal/flags/categories.go index 7a6b8d374c..1627cbf72f 100644 --- a/internal/flags/categories.go +++ b/internal/flags/categories.go @@ -40,6 +40,7 @@ const ( FastNodeCategory = "FAST NODE" FastFinalityCategory = "FAST FINALITY" HistoryCategory = "HISTORY" + StateExpiryCategory = "STATE EXPIRY" ) func init() { diff --git a/les/api_backend.go b/les/api_backend.go index 9ad566f1ad..391f571806 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -149,6 +149,10 @@ func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.B return light.NewState(ctx, header, b.eth.odr), header, nil } +func (b *LesApiBackend) StorageTrie(stateRoot common.Hash, addr common.Address, root common.Hash) (state.Trie, error) { + panic("not implemented") +} + func (b *LesApiBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) { if blockNr, ok := blockNrOrHash.Number(); ok { return b.StateAndHeaderByNumber(ctx, blockNr) diff --git a/light/trie.go b/light/trie.go index 529f1e5d89..81a45b4bfd 100644 --- a/light/trie.go +++ b/light/trie.go @@ -129,6 +129,14 @@ func (t *odrTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { return content, err } +func (t *odrTrie) GetStorageAndUpdateEpoch(_ common.Address, key []byte) ([]byte, error) { + panic("not implemented") +} + +func (t *odrTrie) SetEpoch(epoch types.StateEpoch) { + panic("not implemented") +} + func (t *odrTrie) GetAccount(address common.Address) (*types.StateAccount, error) { var ( enc []byte @@ -212,6 +220,10 @@ func (t *odrTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { return errors.New("not implemented, needs client/server interface split") } +func (t *odrTrie) Epoch() types.StateEpoch { + return types.StateEpoch0 +} + // do tries and retries to execute a function until it returns with no error or // an error type other than MissingNodeError func (t *odrTrie) do(key []byte, fn func() error) error { @@ -240,10 +252,22 @@ func (t *odrTrie) do(key []byte, fn func() error) error { } } -func (db *odrTrie) NoTries() bool { +func (t *odrTrie) NoTries() bool { return false } +func (t *odrTrie) ProveByPath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error { + return errors.New("not implemented, needs client/server interface split") +} + +func (t *odrTrie) TryRevive(key []byte, proof []*trie.MPTProofNub) ([]*trie.MPTProofNub, error) { + return nil, errors.New("not implemented, needs client/server interface split") +} + +func (t *odrTrie) TryLocalRevive(addr common.Address, key []byte) ([]byte, error) { + return nil, errors.New("not implemented, needs client/server interface split") +} + type nodeIterator struct { trie.NodeIterator t *odrTrie diff --git a/miner/miner.go b/miner/miner.go index 4db6140803..176188fd3f 100644 --- a/miner/miner.go +++ b/miner/miner.go @@ -220,7 +220,7 @@ func (miner *Miner) Pending() (*types.Block, *state.StateDB) { if block == nil { return nil, nil } - stateDb, err := miner.worker.chain.StateAt(block.Root) + stateDb, err := miner.worker.chain.StateAt(block.Root, block.Hash(), block.Number) if err != nil { return nil, nil } diff --git a/miner/miner_test.go b/miner/miner_test.go index 489bc46a91..fa8941a076 100644 --- a/miner/miner_test.go +++ b/miner/miner_test.go @@ -85,7 +85,7 @@ func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block return types.NewBlock(bc.CurrentBlock(), nil, nil, nil, trie.NewStackTrie(nil)) } -func (bc *testBlockChain) StateAt(common.Hash) (*state.StateDB, error) { +func (bc *testBlockChain) StateAt(common.Hash, common.Hash, *big.Int) (*state.StateDB, error) { return bc.statedb, nil } diff --git a/miner/worker.go b/miner/worker.go index ffade84a39..6892da7d35 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -24,6 +24,8 @@ import ( "sync/atomic" "time" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus/misc/eip1559" @@ -35,7 +37,6 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/trie" lru "github.com/hashicorp/golang-lru" @@ -650,7 +651,7 @@ func (w *worker) makeEnv(parent *types.Header, header *types.Header, coinbase co prevEnv *environment) (*environment, error) { // Retrieve the parent state to execute on top and start a prefetcher for // the miner to speed block sealing up a bit - state, err := w.chain.StateAtWithSharedPool(parent.Root) + state, err := w.chain.StateAtWithSharedPool(parent.Root, parent.Hash(), header.Number) if err != nil { return nil, err } @@ -1182,7 +1183,6 @@ func (w *worker) commit(env *environment, interval func(), update bool, start ti // Create a local environment copy, avoid the data race with snapshot state. // https://github.com/ethereum/go-ethereum/issues/24299 env := env.copy() - // If we're post merge, just ignore if !w.isTTDReached(block.Header()) { select { diff --git a/p2p/server.go b/p2p/server.go index 0c6e0b7ee0..cd076692a4 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -134,6 +134,10 @@ type Config struct { // allowed to connect, even above the peer limit. TrustedNodes []*enode.Node + // StateExpiryAllowedNodes are used to ensure that the state expiry remoteDb + // will only delay its sync for the peers in the list. + StateExpiryAllowedNodes []*enode.Node + // Connectivity can be restricted to certain IP networks. // If this option is set to a non-nil value, only hosts which match one of the // IP networks contained in the list are considered. diff --git a/trie/committer.go b/trie/committer.go index 4b222f9710..f43f58419b 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -19,6 +19,8 @@ package trie import ( "fmt" + "github.com/ethereum/go-ethereum/trie/epochmeta" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/trie/trienode" ) @@ -27,17 +29,21 @@ import ( // capture all dirty nodes during the commit process and keep them cached in // insertion order. type committer struct { - nodes *trienode.NodeSet - tracer *tracer - collectLeaf bool + nodes *trienode.NodeSet + tracer *tracer + collectLeaf bool + enableStateExpiry bool + enableMetaDB bool } // newCommitter creates a new committer or picks one from the pool. -func newCommitter(nodeset *trienode.NodeSet, tracer *tracer, collectLeaf bool) *committer { +func newCommitter(nodeset *trienode.NodeSet, tracer *tracer, collectLeaf, enableStateExpiry, enableMetaDB bool) *committer { return &committer{ - nodes: nodeset, - tracer: tracer, - collectLeaf: collectLeaf, + nodes: nodeset, + tracer: tracer, + collectLeaf: collectLeaf, + enableStateExpiry: enableStateExpiry, + enableMetaDB: enableMetaDB, } } @@ -139,12 +145,30 @@ func (c *committer) store(path []byte, n node) node { } // Collect the dirty node to nodeset for return. nhash := common.BytesToHash(hash) - c.nodes.AddNode(path, trienode.New(nhash, nodeToBytes(n))) + var blob []byte + if c.enableStateExpiry && !c.enableMetaDB { + blob = nodeToBytesWithEpoch(n) + } else { + blob = nodeToBytes(n) + } + changed := c.tracer.checkNodeChanged(path, blob) + if changed { + c.nodes.AddNode(path, trienode.New(nhash, blob)) + } + if c.enableStateExpiry && c.enableMetaDB { + switch n := n.(type) { + case *fullNode: + metaBlob := epochmeta.BranchMeta2Bytes(epochmeta.NewBranchNodeEpochMeta(n.EpochMap)) + if c.tracer.checkEpochMetaChanged(path, metaBlob) { + c.nodes.AddBranchNodeEpochMeta(path, metaBlob) + } + } + } // Collect the corresponding leaf node if it's required. We don't check // full node since it's impossible to store value in fullNode. The key // length of leaves should be exactly same. - if c.collectLeaf { + if changed && c.collectLeaf { if sn, ok := n.(*shortNode); ok { if val, ok := sn.Val.(valueNode); ok { c.nodes.AddLeaf(nhash, val) diff --git a/trie/database.go b/trie/database.go index 7bad532dde..d41056b687 100644 --- a/trie/database.go +++ b/trie/database.go @@ -18,8 +18,12 @@ package trie import ( "errors" + "fmt" + "math/big" "strings" + "github.com/ethereum/go-ethereum/trie/epochmeta" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/ethdb" @@ -40,6 +44,10 @@ type Config struct { // Testing hooks OnCommit func(states *triestate.Set) // Hook invoked when commit is performed + + // state expiry feature + EnableStateExpiry bool + EpochMeta *epochmeta.Config } // HashDefaults represents a config for using hash-based scheme with @@ -87,6 +95,7 @@ type Database struct { diskdb ethdb.Database // Persistent database to store the snapshot preimages *preimageStore // The store for caching preimages backend backend // The backend for managing trie nodes + snapTree *epochmeta.SnapshotTree } // prepare initializes the database with provided configs, but the @@ -138,10 +147,12 @@ func NewDatabase(diskdb ethdb.Database, config *Config) *Database { * 2. Second, initialize the db according to the scheme already used by db * 3. Last, use the default scheme, namely hash scheme */ + enableEpochMetaDB := false if config.HashDB != nil { if rawdb.ReadStateScheme(diskdb) == rawdb.PathScheme { log.Warn("incompatible state scheme", "old", rawdb.PathScheme, "new", rawdb.HashScheme) } + enableEpochMetaDB = true db.backend = hashdb.New(diskdb, config.HashDB, mptResolver{}) } else if config.PathDB != nil { if rawdb.ReadStateScheme(diskdb) == rawdb.HashScheme { @@ -157,8 +168,16 @@ func NewDatabase(diskdb ethdb.Database, config *Config) *Database { if config.HashDB == nil { config.HashDB = hashdb.Defaults } + enableEpochMetaDB = true db.backend = hashdb.New(diskdb, config.HashDB, mptResolver{}) } + if config.EnableStateExpiry && enableEpochMetaDB { + snapTree, err := epochmeta.NewEpochMetaSnapTree(diskdb, config.EpochMeta) + if err != nil { + panic(fmt.Sprintf("init SnapshotTree err: %v", err)) + } + db.snapTree = snapTree + } return db } @@ -192,7 +211,16 @@ func (db *Database) Update(root common.Hash, parent common.Hash, block uint64, n if db.preimages != nil { db.preimages.commit(false) } - return db.backend.Update(root, parent, block, nodes, states) + if err := db.backend.Update(root, parent, block, nodes, states); err != nil { + return err + } + if db.snapTree != nil { + err := db.snapTree.Update(parent, new(big.Int).SetUint64(block), root, nodes.FlattenEpochMeta()) + if err != nil { + return err + } + } + return nil } // Commit iterates over all the children of a particular node, writes them out @@ -202,7 +230,34 @@ func (db *Database) Commit(root common.Hash, report bool) error { if db.preimages != nil { db.preimages.commit(true) } - return db.backend.Commit(root, report) + if err := db.backend.Commit(root, report); err != nil { + return err + } + return nil +} + +func (db *Database) CommitAll(root common.Hash, report bool) error { + if err := db.Commit(root, report); err != nil { + return err + } + return db.CommitEpochMeta(root) +} + +func (db *Database) CommitEpochMeta(root common.Hash) error { + if db.snapTree != nil { + if err := db.snapTree.Cap(root); err != nil { + return err + } + } + return nil +} + +func (db *Database) EnableExpiry() bool { + if db.config != nil { + return db.config.EnableStateExpiry + } + + return false } // Size returns the storage size of dirty trie nodes in front of the persistent @@ -353,3 +408,7 @@ func (db *Database) SetBufferSize(size int) error { } return pdb.SetBufferSize(size) } + +func (db *Database) EpochMetaSnapTree() *epochmeta.SnapshotTree { + return db.snapTree +} diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index 41478253d9..28aef864cd 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -26,6 +26,10 @@ import ( type EmptyTrie struct{} +func (t *EmptyTrie) Epoch() types.StateEpoch { + return types.StateEpoch0 +} + // NewSecure creates a dummy trie func NewEmptyTrie() *EmptyTrie { return &EmptyTrie{} @@ -81,6 +85,24 @@ func (t *EmptyTrie) NodeIterator(startKey []byte) (NodeIterator, error) { func (t *EmptyTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { return nil } +func (t *EmptyTrie) GetStorageAndUpdateEpoch(addr common.Address, key []byte) ([]byte, error) { + return nil, nil +} + +func (t *EmptyTrie) SetEpoch(epoch types.StateEpoch) { +} + +func (t *EmptyTrie) ProveByPath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error { + return nil +} + +func (t *EmptyTrie) TryRevive(key []byte, proof []*MPTProofNub) ([]*MPTProofNub, error) { + return nil, nil +} + +func (t *EmptyTrie) TryLocalRevive(addr common.Address, key []byte) ([]byte, error) { + return nil, nil +} // Copy returns a copy of SecureTrie. func (t *EmptyTrie) Copy() *EmptyTrie { diff --git a/trie/epochmeta/database.go b/trie/epochmeta/database.go new file mode 100644 index 0000000000..566c33915c --- /dev/null +++ b/trie/epochmeta/database.go @@ -0,0 +1,96 @@ +package epochmeta + +import ( + "bytes" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + + "github.com/ethereum/go-ethereum/rlp" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" +) + +var ( + AccountMetadataPath = []byte("m") + metaAccessMeter = metrics.NewRegisteredMeter("epochmeta/access", nil) + metaHitDiffMeter = metrics.NewRegisteredMeter("epochmeta/access/hit/diff", nil) + metaHitDiskCacheMeter = metrics.NewRegisteredMeter("epochmeta/access/hit/diskcache", nil) + metaHitDiskMeter = metrics.NewRegisteredMeter("epochmeta/access/hit/disk", nil) +) + +type BranchNodeEpochMeta struct { + EpochMap [16]types.StateEpoch +} + +func NewBranchNodeEpochMeta(epochMap [16]types.StateEpoch) *BranchNodeEpochMeta { + return &BranchNodeEpochMeta{EpochMap: epochMap} +} + +func (n *BranchNodeEpochMeta) Encode(w rlp.EncoderBuffer) { + offset := w.List() + for _, e := range n.EpochMap { + w.WriteUint64(uint64(e)) + } + w.ListEnd(offset) +} + +func DecodeFullNodeEpochMeta(enc []byte) (*BranchNodeEpochMeta, error) { + var n BranchNodeEpochMeta + + if err := rlp.DecodeBytes(enc, &n.EpochMap); err != nil { + return nil, err + } + + return &n, nil +} + +type Reader struct { + snap snapshot + tree *SnapshotTree +} + +// NewReader first find snap by blockRoot, if got nil, try using number to instance a read only storage +func NewReader(tree *SnapshotTree, number *big.Int, blockRoot common.Hash) (*Reader, error) { + snap := tree.Snapshot(blockRoot) + if snap == nil { + // try using default snap + if snap = tree.Snapshot(types.EmptyRootHash); snap == nil { + return nil, fmt.Errorf("cannot find target epoch layer %#x", blockRoot) + } + log.Debug("NewReader use default database", "number", number, "root", blockRoot) + } + return &Reader{ + snap: snap, + tree: tree, + }, nil +} + +func (s *Reader) Get(addr common.Hash, path string) ([]byte, error) { + metaAccessMeter.Mark(1) + return s.snap.EpochMeta(addr, path) +} + +func BranchMeta2Bytes(meta *BranchNodeEpochMeta) []byte { + if meta == nil || *meta == (BranchNodeEpochMeta{}) { + return []byte{} + } + buf := rlp.NewEncoderBuffer(nil) + meta.Encode(buf) + return buf.ToBytes() +} + +func AccountMeta2Bytes(meta types.StateMeta) ([]byte, error) { + if meta == nil { + return []byte{}, nil + } + return meta.EncodeToRLPBytes() +} + +// IsEpochMetaPath add some skip hash check rule +func IsEpochMetaPath(path []byte) bool { + return bytes.Equal(AccountMetadataPath, path) +} diff --git a/trie/epochmeta/database_test.go b/trie/epochmeta/database_test.go new file mode 100644 index 0000000000..2bb728c8ae --- /dev/null +++ b/trie/epochmeta/database_test.go @@ -0,0 +1,80 @@ +package epochmeta + +import ( + "bytes" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/core/types" + + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/rlp" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +func makeDiskLayer(diskdb *memorydb.Database, number *big.Int, root common.Hash, addr common.Hash, kv []string) { + if len(kv)%2 != 0 { + panic("wrong kv") + } + meta := epochMetaPlainMeta{ + BlockNumber: number, + BlockRoot: root, + } + enc, _ := rlp.EncodeToBytes(&meta) + rawdb.WriteEpochMetaPlainStateMeta(diskdb, enc) + + for i := 0; i < len(kv); i += 2 { + rawdb.WriteEpochMetaPlainState(diskdb, addr, kv[i], []byte(kv[i+1])) + } +} + +func TestEpochMetaReader(t *testing.T) { + diskdb := memorydb.New() + makeDiskLayer(diskdb, common.Big1, blockRoot1, contract1, []string{"hello", "world"}) + tree, err := NewEpochMetaSnapTree(diskdb, nil) + assert.NoError(t, err) + storageDB, err := NewReader(tree, common.Big1, blockRoot1) + assert.NoError(t, err) + val, err := storageDB.Get(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) +} + +func TestShadowBranchNode_encodeDecode(t *testing.T) { + dt := []struct { + n BranchNodeEpochMeta + }{ + { + n: BranchNodeEpochMeta{ + EpochMap: [16]types.StateEpoch{}, + }, + }, + { + n: BranchNodeEpochMeta{ + EpochMap: [16]types.StateEpoch{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + }, + { + n: BranchNodeEpochMeta{ + EpochMap: [16]types.StateEpoch{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + }, + { + n: BranchNodeEpochMeta{ + EpochMap: [16]types.StateEpoch{}, + }, + }, + } + for _, item := range dt { + buf := rlp.NewEncoderBuffer(bytes.NewBuffer([]byte{})) + item.n.Encode(buf) + enc := buf.ToBytes() + + rn, err := DecodeFullNodeEpochMeta(enc) + assert.NoError(t, err) + assert.Equal(t, &item.n, rn) + } +} diff --git a/trie/epochmeta/difflayer.go b/trie/epochmeta/difflayer.go new file mode 100644 index 0000000000..2e2f73050a --- /dev/null +++ b/trie/epochmeta/difflayer.go @@ -0,0 +1,231 @@ +package epochmeta + +import ( + "bytes" + "encoding/binary" + "errors" + "math" + "math/big" + "math/rand" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" + bloomfilter "github.com/holiman/bloomfilter/v2" +) + +const ( + // MaxEpochMetaDiffDepth default is 128 layers + MaxEpochMetaDiffDepth = 128 + journalVersion uint64 = 1 + enableBloomFilter = false +) + +var ( + // aggregatorMemoryLimit is the maximum size of the bottom-most diff layer + // that aggregates the writes from above until it's flushed into the disk + // layer. + // + // Note, bumping this up might drastically increase the size of the bloom + // filters that's stored in every diff layer. Don't do that without fully + // understanding all the implications. + aggregatorMemoryLimit = uint64(4 * 1024 * 1024) + + // aggregatorItemLimit is an approximate number of items that will end up + // in the agregator layer before it's flushed out to disk. A plain account + // weighs around 14B (+hash), a storage slot 32B (+hash), a deleted slot + // 0B (+hash). Slots are mostly set/unset in lockstep, so that average at + // 16B (+hash). All in all, the average entry seems to be 15+32=47B. Use a + // smaller number to be on the safe side. + aggregatorItemLimit = aggregatorMemoryLimit / 42 + + // bloomTargetError is the target false positive rate when the aggregator + // layer is at its fullest. The actual value will probably move around up + // and down from this number, it's mostly a ballpark figure. + // + // Note, dropping this down might drastically increase the size of the bloom + // filters that's stored in every diff layer. Don't do that without fully + // understanding all the implications. + bloomTargetError = 0.02 + + // bloomSize is the ideal bloom filter size given the maximum number of items + // it's expected to hold and the target false positive error rate. + bloomSize = math.Ceil(float64(aggregatorItemLimit) * math.Log(bloomTargetError) / math.Log(1/math.Pow(2, math.Log(2)))) + + // bloomFuncs is the ideal number of bits a single entry should set in the + // bloom filter to keep its size to a minimum (given it's size and maximum + // entry count). + bloomFuncs = math.Round((bloomSize / float64(aggregatorItemLimit)) * math.Log(2)) + // the bloom offsets are runtime constants which determines which part of the + // account/storage hash the hasher functions looks at, to determine the + // bloom key for an account/slot. This is randomized at init(), so that the + // global population of nodes do not all display the exact same behaviour with + // regards to bloom content + bloomStorageHasherOffset = 0 +) + +func init() { + // Init the bloom offsets in the range [0:24] (requires 8 bytes) + bloomStorageHasherOffset = rand.Intn(25) +} + +// storageBloomHasher is a wrapper around a [2]common.Hash to satisfy the interface +// API requirements of the bloom library used. It's used to convert an account +// hash into a 64 bit mini hash. +type storageBloomHasher struct { + accountHash common.Hash + path string +} + +func (h storageBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") } +func (h storageBloomHasher) Sum(b []byte) []byte { panic("not implemented") } +func (h storageBloomHasher) Reset() { panic("not implemented") } +func (h storageBloomHasher) BlockSize() int { panic("not implemented") } +func (h storageBloomHasher) Size() int { return 8 } +func (h storageBloomHasher) Sum64() uint64 { + if len(h.path) < 8 { + path := [8]byte{} + copy(path[:], h.path) + return binary.BigEndian.Uint64(h.accountHash[bloomStorageHasherOffset:bloomStorageHasherOffset+8]) ^ + binary.BigEndian.Uint64(path[:]) + } + if len(h.path) < bloomStorageHasherOffset+8 { + return binary.BigEndian.Uint64(h.accountHash[bloomStorageHasherOffset:bloomStorageHasherOffset+8]) ^ + binary.BigEndian.Uint64([]byte(h.path[len(h.path)-8:])) + } + return binary.BigEndian.Uint64(h.accountHash[bloomStorageHasherOffset:bloomStorageHasherOffset+8]) ^ + binary.BigEndian.Uint64([]byte(h.path[bloomStorageHasherOffset:bloomStorageHasherOffset+8])) +} + +type diffLayer struct { + blockNumber *big.Int + blockRoot common.Hash + parent snapshot + origin *diskLayer + nodeSet map[common.Hash]map[string][]byte + diffed *bloomfilter.Filter // Bloom filter tracking all the diffed items up to the disk layer + lock sync.RWMutex // lock only protect parent filed change now. +} + +func newEpochMetaDiffLayer(blockNumber *big.Int, blockRoot common.Hash, parent snapshot, nodeSet map[common.Hash]map[string][]byte) *diffLayer { + dl := &diffLayer{ + blockNumber: blockNumber, + blockRoot: blockRoot, + parent: parent, + nodeSet: nodeSet, + } + + if enableBloomFilter { + switch p := parent.(type) { + case *diffLayer: + dl.origin = p.origin + dl.diffed, _ = p.diffed.Copy() + case *diskLayer: + dl.origin = p + dl.diffed, _ = bloomfilter.New(uint64(bloomSize), uint64(bloomFuncs)) + default: + panic("newEpochMetaDiffLayer got wrong snapshot type") + } + // Iterate over all the accounts and storage metas and index them + for accountHash, metas := range dl.nodeSet { + for path := range metas { + dl.diffed.Add(storageBloomHasher{accountHash, path}) + } + } + } + + return dl +} + +func (s *diffLayer) Root() common.Hash { + return s.blockRoot +} + +// EpochMeta find target epoch meta from diff layer or disk layer +func (s *diffLayer) EpochMeta(addrHash common.Hash, path string) ([]byte, error) { + // if the diff chain not contain the meta or staled, try get from disk layer + if s.diffed != nil && !s.diffed.Contains(storageBloomHasher{addrHash, path}) { + return s.origin.EpochMeta(addrHash, path) + } + + cm, exist := s.nodeSet[addrHash] + if exist { + if ret, ok := cm[path]; ok { + metaHitDiffMeter.Mark(1) + return ret, nil + } + } + + s.lock.RLock() + defer s.lock.RUnlock() + return s.parent.EpochMeta(addrHash, path) +} + +func (s *diffLayer) Parent() snapshot { + s.lock.RLock() + defer s.lock.RUnlock() + return s.parent +} + +// Update append new diff layer onto current, nodeChgRecord when val is []byte{}, it delete the kv +func (s *diffLayer) Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (snapshot, error) { + if s.blockNumber.Int64() != 0 && s.blockNumber.Cmp(blockNumber) >= 0 { + return nil, errors.New("update a unordered diff layer in diff layer") + } + return newEpochMetaDiffLayer(blockNumber, blockRoot, s, nodeSet), nil +} + +func (s *diffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + s.lock.RLock() + defer s.lock.RUnlock() + if err := rlp.Encode(buffer, s.blockNumber); err != nil { + return common.Hash{}, err + } + + if s.parent == nil { + return common.Hash{}, errors.New("found nil parent in Journal") + } + + if err := rlp.Encode(buffer, s.parent.Root()); err != nil { + return common.Hash{}, err + } + + if err := rlp.Encode(buffer, s.blockRoot); err != nil { + return common.Hash{}, err + } + storage := make([]journalEpochMeta, 0, len(s.nodeSet)) + for hash, nodes := range s.nodeSet { + keys := make([]string, 0, len(nodes)) + vals := make([][]byte, 0, len(nodes)) + for key, val := range nodes { + keys = append(keys, key) + vals = append(vals, val) + } + storage = append(storage, journalEpochMeta{Hash: hash, Keys: keys, Vals: vals}) + } + if err := rlp.Encode(buffer, storage); err != nil { + return common.Hash{}, err + } + return s.blockRoot, nil +} + +func (s *diffLayer) getNodeSet() map[common.Hash]map[string][]byte { + return s.nodeSet +} + +func (s *diffLayer) resetParent(parent snapshot) { + s.lock.Lock() + defer s.lock.Unlock() + s.parent = parent +} + +type journalEpochMeta struct { + Hash common.Hash + Keys []string + Vals [][]byte +} + +type epochMetaPlainMeta struct { + BlockNumber *big.Int + BlockRoot common.Hash +} diff --git a/trie/epochmeta/difflayer_test.go b/trie/epochmeta/difflayer_test.go new file mode 100644 index 0000000000..ebe1c99818 --- /dev/null +++ b/trie/epochmeta/difflayer_test.go @@ -0,0 +1,156 @@ +package epochmeta + +import ( + "testing" + + "github.com/ethereum/go-ethereum/core/types" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +const hashLen = len(common.Hash{}) + +var ( + blockRoot0 = makeHash("b0") + blockRoot1 = makeHash("b1") + blockRoot2 = makeHash("b2") + contract1 = makeHash("c1") + contract2 = makeHash("c2") + contract3 = makeHash("c3") +) + +func TestEpochMetaDiffLayer_whenGenesis(t *testing.T) { + diskdb := memorydb.New() + // create empty tree + tree, err := NewEpochMetaSnapTree(diskdb, nil) + assert.NoError(t, err) + snap := tree.Snapshot(blockRoot0) + assert.Nil(t, snap) + snap = tree.Snapshot(blockRoot1) + assert.Nil(t, snap) + err = tree.Update(blockRoot0, common.Big1, blockRoot1, makeNodeSet(contract1, []string{"hello", "world"})) + assert.NoError(t, err) + err = tree.Update(blockRoot1, common.Big2, blockRoot2, makeNodeSet(contract1, []string{"hello2", "world2"})) + assert.NoError(t, err) + err = tree.Cap(blockRoot1) + assert.NoError(t, err) + err = tree.Journal() + assert.NoError(t, err) + + // reload + tree, err = NewEpochMetaSnapTree(diskdb, nil) + assert.NoError(t, err) + diskLayer := tree.Snapshot(types.EmptyRootHash) + assert.NotNil(t, diskLayer) + snap = tree.Snapshot(blockRoot0) + assert.Nil(t, snap) + snap1 := tree.Snapshot(blockRoot1) + n, err := snap1.EpochMeta(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), n) + assert.Equal(t, diskLayer, snap1.Parent()) + assert.Equal(t, blockRoot1, snap1.Root()) + + // read from child + snap2 := tree.Snapshot(blockRoot2) + assert.Equal(t, snap1, snap2.Parent()) + assert.Equal(t, blockRoot2, snap2.Root()) + n, err = snap2.EpochMeta(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), n) + n, err = snap2.EpochMeta(contract1, "hello2") + assert.NoError(t, err) + assert.Equal(t, []byte("world2"), n) +} + +func TestEpochMetaDiffLayer_crud(t *testing.T) { + diskdb := memorydb.New() + // create empty tree + tree, err := NewEpochMetaSnapTree(diskdb, nil) + assert.NoError(t, err) + set1 := makeNodeSet(contract1, []string{"hello", "world", "h1", "w1"}) + appendNodeSet(set1, contract3, []string{"h3", "w3"}) + err = tree.Update(blockRoot0, common.Big1, blockRoot1, set1) + assert.NoError(t, err) + set2 := makeNodeSet(contract1, []string{"hello", "", "h1", ""}) + appendNodeSet(set2, contract2, []string{"hello", "", "h2", "w2"}) + err = tree.Update(blockRoot1, common.Big2, blockRoot2, set2) + assert.NoError(t, err) + snap := tree.Snapshot(blockRoot1) + assert.NotNil(t, snap) + val, err := snap.EpochMeta(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) + val, err = snap.EpochMeta(contract1, "h1") + assert.NoError(t, err) + assert.Equal(t, []byte("w1"), val) + val, err = snap.EpochMeta(contract3, "h3") + assert.NoError(t, err) + assert.Equal(t, []byte("w3"), val) + + snap = tree.Snapshot(blockRoot2) + assert.NotNil(t, snap) + val, err = snap.EpochMeta(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + val, err = snap.EpochMeta(contract1, "h1") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + val, err = snap.EpochMeta(contract2, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + val, err = snap.EpochMeta(contract2, "h2") + assert.NoError(t, err) + assert.Equal(t, []byte("w2"), val) + val, err = snap.EpochMeta(contract3, "h3") + assert.NoError(t, err) + assert.Equal(t, []byte("w3"), val) +} + +func makeHash(s string) common.Hash { + var ret common.Hash + if len(s) >= 32 { + copy(ret[:], []byte(s)[:hashLen]) + return ret + } + for i := 0; i < hashLen; i++ { + ret[i] = '0' + } + copy(ret[hashLen-len(s):hashLen], s) + return ret +} + +func makeNodeSet(addr common.Hash, kvs []string) map[common.Hash]map[string][]byte { + if len(kvs)%2 != 0 { + panic("makeNodeSet: wrong params") + } + ret := make(map[common.Hash]map[string][]byte) + ret[addr] = make(map[string][]byte) + for i := 0; i < len(kvs); i += 2 { + if len(kvs) == 0 { + ret[addr][kvs[i]] = nil + continue + } + ret[addr][kvs[i]] = []byte(kvs[i+1]) + } + + return ret +} + +func appendNodeSet(ret map[common.Hash]map[string][]byte, addr common.Hash, kvs []string) { + if len(kvs)%2 != 0 { + panic("makeNodeSet: wrong params") + } + if _, ok := ret[addr]; !ok { + ret[addr] = make(map[string][]byte) + } + for i := 0; i < len(kvs); i += 2 { + if len(kvs) == 0 { + ret[addr][kvs[i]] = nil + continue + } + ret[addr][kvs[i]] = []byte(kvs[i+1]) + } +} diff --git a/trie/epochmeta/disklayer.go b/trie/epochmeta/disklayer.go new file mode 100644 index 0000000000..4d71de4c3b --- /dev/null +++ b/trie/epochmeta/disklayer.go @@ -0,0 +1,152 @@ +package epochmeta + +import ( + "bytes" + "errors" + "math/big" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" + lru "github.com/hashicorp/golang-lru" +) + +const ( + defaultDiskLayerCacheSize = 1024000 +) + +type diskLayer struct { + diskdb ethdb.KeyValueStore + blockNumber *big.Int + blockRoot common.Hash + cache *lru.Cache + lock sync.RWMutex +} + +func newEpochMetaDiskLayer(diskdb ethdb.KeyValueStore, blockNumber *big.Int, blockRoot common.Hash) (*diskLayer, error) { + cache, err := lru.New(defaultDiskLayerCacheSize) + if err != nil { + return nil, err + } + return &diskLayer{ + diskdb: diskdb, + blockNumber: blockNumber, + blockRoot: blockRoot, + cache: cache, + }, nil +} + +func (s *diskLayer) Root() common.Hash { + s.lock.RLock() + defer s.lock.RUnlock() + return s.blockRoot +} + +func (s *diskLayer) EpochMeta(addr common.Hash, path string) ([]byte, error) { + s.lock.RLock() + defer s.lock.RUnlock() + + key := cacheKey(addr, path) + cached, exist := s.cache.Get(key) + if exist { + metaHitDiskCacheMeter.Mark(1) + return cached.([]byte), nil + } + + metaHitDiskMeter.Mark(1) + val := rawdb.ReadEpochMetaPlainState(s.diskdb, addr, path) + s.cache.Add(key, val) + return val, nil +} + +func (s *diskLayer) Parent() snapshot { + return nil +} + +func (s *diskLayer) Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (snapshot, error) { + s.lock.RLock() + if s.blockNumber.Int64() != 0 && s.blockNumber.Cmp(blockNumber) >= 0 { + return nil, errors.New("update a unordered diff layer in disk layer") + } + s.lock.RUnlock() + return newEpochMetaDiffLayer(blockNumber, blockRoot, s, nodeSet), nil +} + +func (s *diskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + return common.Hash{}, nil +} + +func (s *diskLayer) PushDiff(diff *diffLayer) (*diskLayer, error) { + s.lock.Lock() + defer s.lock.Unlock() + + number := diff.blockNumber + if s.blockNumber.Cmp(number) >= 0 { + return nil, errors.New("push a lower block to disk") + } + batch := s.diskdb.NewBatch() + nodeSet := diff.getNodeSet() + if err := s.writeHistory(number, batch, nodeSet); err != nil { + return nil, err + } + + // update meta + meta := epochMetaPlainMeta{ + BlockNumber: number, + BlockRoot: diff.blockRoot, + } + enc, err := rlp.EncodeToBytes(meta) + if err != nil { + return nil, err + } + if err = rawdb.WriteEpochMetaPlainStateMeta(batch, enc); err != nil { + return nil, err + } + + if err = batch.Write(); err != nil { + return nil, err + } + diskLayer := &diskLayer{ + diskdb: s.diskdb, + blockNumber: number, + blockRoot: diff.blockRoot, + cache: s.cache, + } + + // reuse cache + for addr, nodes := range nodeSet { + for path, val := range nodes { + diskLayer.cache.Add(cacheKey(addr, path), val) + } + } + return diskLayer, nil +} + +func (s *diskLayer) writeHistory(number *big.Int, batch ethdb.Batch, nodeSet map[common.Hash]map[string][]byte) error { + for addr, subSet := range nodeSet { + for path, val := range subSet { + // refresh plain state + if len(val) == 0 { + if err := rawdb.DeleteEpochMetaPlainState(batch, addr, path); err != nil { + return err + } + } else { + if err := rawdb.WriteEpochMetaPlainState(batch, addr, path, val); err != nil { + return err + } + } + } + } + log.Debug("shadow node history pruned, only keep plainState", "number", number, "count", len(nodeSet)) + return nil +} + +func cacheKey(addr common.Hash, path string) string { + key := make([]byte, len(addr)+len(path)) + copy(key[:], addr.Bytes()) + copy(key[len(addr):], path) + return string(key) +} diff --git a/trie/epochmeta/snapshot.go b/trie/epochmeta/snapshot.go new file mode 100644 index 0000000000..decbdaaf87 --- /dev/null +++ b/trie/epochmeta/snapshot.go @@ -0,0 +1,347 @@ +package epochmeta + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/big" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" +) + +// snapshot record diff layer and disk layer of shadow nodes, support mini reorg +type snapshot interface { + // Root block state root + Root() common.Hash + + // EpochMeta query shadow node from db, got RLP format + EpochMeta(addrHash common.Hash, path string) ([]byte, error) + + // Parent parent snap + Parent() snapshot + + // Update create a new diff layer from here + Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (snapshot, error) + + // Journal commit self as a journal to buffer + Journal(buffer *bytes.Buffer) (common.Hash, error) +} + +type Config struct { + capLimit int // it indicates how depth diff layer to keep +} + +var Defaults = &Config{ + capLimit: MaxEpochMetaDiffDepth, +} + +// SnapshotTree maintain all diff layers support reorg, will flush to db when MaxEpochMetaDiffDepth reach +// every layer response to a block state change set, there no flatten layers operation. +type SnapshotTree struct { + diskdb ethdb.KeyValueStore + + // diffLayers + diskLayer, disk layer, always not nil + layers map[common.Hash]snapshot + children map[common.Hash][]common.Hash + cfg *Config + + lock sync.RWMutex +} + +func NewEpochMetaSnapTree(diskdb ethdb.KeyValueStore, cfg *Config) (*SnapshotTree, error) { + diskLayer, err := loadDiskLayer(diskdb) + if err != nil { + return nil, err + } + layers, children, err := loadDiffLayers(diskdb, diskLayer) + if err != nil { + return nil, err + } + + layers[diskLayer.blockRoot] = diskLayer + // check if continuously after disk layer + if len(layers) > 1 && len(children[diskLayer.blockRoot]) == 0 { + return nil, errors.New("cannot found any diff layers link to disk layer") + } + + if cfg == nil { + cfg = Defaults + } + return &SnapshotTree{ + diskdb: diskdb, + layers: layers, + children: children, + cfg: cfg, + }, nil +} + +// Cap keep tree depth not greater MaxEpochMetaDiffDepth, all forks parent to disk layer will delete +func (s *SnapshotTree) Cap(blockRoot common.Hash) error { + snap := s.Snapshot(blockRoot) + if snap == nil { + return fmt.Errorf("epoch meta snapshot missing: [%#x]", blockRoot) + } + nextDiff, ok := snap.(*diffLayer) + if !ok { + return nil + } + for i := 0; i < s.cfg.capLimit-1; i++ { + nextDiff, ok = nextDiff.Parent().(*diffLayer) + // if depth less MaxEpochMetaDiffDepth, just return + if !ok { + return nil + } + } + + flatten := make([]snapshot, 0) + parent := nextDiff.Parent() + for parent != nil { + flatten = append(flatten, parent) + parent = parent.Parent() + } + if len(flatten) <= 1 { + return nil + } + + last, ok := flatten[len(flatten)-1].(*diskLayer) + if !ok { + return errors.New("the diff layers not link to disk layer") + } + + s.lock.Lock() + defer s.lock.Unlock() + newDiskLayer, err := s.flattenDiffs2Disk(flatten[:len(flatten)-1], last) + if err != nil { + return err + } + + // clear forks, but keep latest disk forks + for i := len(flatten) - 1; i > 0; i-- { + var childRoot common.Hash + if i > 0 { + childRoot = flatten[i-1].Root() + } else { + childRoot = nextDiff.Root() + } + root := flatten[i].Root() + s.removeSubLayers(s.children[root], &childRoot) + delete(s.layers, root) + delete(s.children, root) + } + + // reset newDiskLayer and children's parent + s.layers[newDiskLayer.Root()] = newDiskLayer + for _, child := range s.children[newDiskLayer.Root()] { + if diff, exist := s.layers[child].(*diffLayer); exist { + diff.resetParent(newDiskLayer) + } + } + log.Debug("epochmeta snap tree cap", "root", blockRoot, "layers", len(s.layers), "flatten", len(flatten)) + return nil +} + +func (s *SnapshotTree) Update(parentRoot common.Hash, blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) error { + // if there are no changes, just skip + if blockRoot == parentRoot { + return nil + } + + // Generate a new snapshot on top of the parent + parent := s.Snapshot(parentRoot) + if parent == nil { + // just point to fake disk layers + parent = s.Snapshot(types.EmptyRootHash) + if parent == nil { + return errors.New("cannot find any suitable parent") + } + parentRoot = parent.Root() + } + snap, err := parent.Update(blockNumber, blockRoot, nodeSet) + if err != nil { + return err + } + + s.lock.Lock() + defer s.lock.Unlock() + + s.layers[blockRoot] = snap + s.children[parentRoot] = append(s.children[parentRoot], blockRoot) + log.Debug("epochmeta snap tree update", "root", blockRoot, "number", blockNumber, "layers", len(s.layers)) + return nil +} + +func (s *SnapshotTree) Snapshot(blockRoot common.Hash) snapshot { + s.lock.RLock() + defer s.lock.RUnlock() + return s.layers[blockRoot] +} + +func (s *SnapshotTree) DB() ethdb.KeyValueStore { + s.lock.RLock() + defer s.lock.RUnlock() + return s.diskdb +} + +func (s *SnapshotTree) Journal() error { + s.lock.Lock() + defer s.lock.Unlock() + + // Firstly write out the metadata of journal + journal := new(bytes.Buffer) + if err := rlp.Encode(journal, journalVersion); err != nil { + return err + } + for _, snap := range s.layers { + if _, err := snap.Journal(journal); err != nil { + return err + } + } + rawdb.WriteEpochMetaSnapshotJournal(s.diskdb, journal.Bytes()) + return nil +} + +func (s *SnapshotTree) removeSubLayers(layers []common.Hash, skip *common.Hash) { + for _, layer := range layers { + if skip != nil && layer == *skip { + continue + } + s.removeSubLayers(s.children[layer], nil) + delete(s.layers, layer) + delete(s.children, layer) + } +} + +// flattenDiffs2Disk delete all flatten and push them to db +func (s *SnapshotTree) flattenDiffs2Disk(flatten []snapshot, diskLayer *diskLayer) (*diskLayer, error) { + var err error + for i := len(flatten) - 1; i >= 0; i-- { + diskLayer, err = diskLayer.PushDiff(flatten[i].(*diffLayer)) + if err != nil { + return nil, err + } + } + + return diskLayer, nil +} + +// loadDiskLayer load from db, could be nil when none in db +func loadDiskLayer(db ethdb.KeyValueStore) (*diskLayer, error) { + val := rawdb.ReadEpochMetaPlainStateMeta(db) + // if there is no disk layer, will construct a fake disk layer + if len(val) == 0 { + diskLayer, err := newEpochMetaDiskLayer(db, common.Big0, types.EmptyRootHash) + if err != nil { + return nil, err + } + return diskLayer, nil + } + var meta epochMetaPlainMeta + if err := rlp.DecodeBytes(val, &meta); err != nil { + return nil, err + } + + layer, err := newEpochMetaDiskLayer(db, meta.BlockNumber, meta.BlockRoot) + if err != nil { + return nil, err + } + return layer, nil +} + +type diffTmp struct { + parent common.Hash + number big.Int + root common.Hash + nodeSet map[common.Hash]map[string][]byte +} + +func loadDiffLayers(db ethdb.KeyValueStore, dl *diskLayer) (map[common.Hash]snapshot, map[common.Hash][]common.Hash, error) { + layers := make(map[common.Hash]snapshot) + children := make(map[common.Hash][]common.Hash) + + journal := rawdb.ReadEpochMetaSnapshotJournal(db) + if len(journal) == 0 { + return layers, children, nil + } + r := rlp.NewStream(bytes.NewReader(journal), 0) + // Firstly, resolve the first element as the journal version + version, err := r.Uint64() + if err != nil { + return nil, nil, errors.New("failed to resolve journal version") + } + if version != journalVersion { + return nil, nil, errors.New("wrong journal version") + } + + diffTmps := make(map[common.Hash]diffTmp) + parents := make(map[common.Hash]common.Hash) + for { + var ( + parent common.Hash + number big.Int + root common.Hash + js []journalEpochMeta + ) + // Read the next diff journal entry + if err := r.Decode(&number); err != nil { + // The first read may fail with EOF, marking the end of the journal + if errors.Is(err, io.EOF) { + break + } + return nil, nil, fmt.Errorf("load diff number: %v", err) + } + if err := r.Decode(&parent); err != nil { + return nil, nil, fmt.Errorf("load diff parent: %v", err) + } + // Read the next diff journal entry + if err := r.Decode(&root); err != nil { + return nil, nil, fmt.Errorf("load diff root: %v", err) + } + if err := r.Decode(&js); err != nil { + return nil, nil, fmt.Errorf("load diff storage: %v", err) + } + + nodeSet := make(map[common.Hash]map[string][]byte) + for _, entry := range js { + nodes := make(map[string][]byte) + for i, key := range entry.Keys { + if len(entry.Vals[i]) > 0 { // RLP loses nil-ness, but `[]byte{}` is not a valid item, so reinterpret that + nodes[key] = entry.Vals[i] + } else { + nodes[key] = nil + } + } + nodeSet[entry.Hash] = nodes + } + + diffTmps[root] = diffTmp{ + parent: parent, + number: number, + root: root, + nodeSet: nodeSet, + } + children[parent] = append(children[parent], root) + + parents[root] = parent + layers[root] = newEpochMetaDiffLayer(&number, root, dl, nodeSet) + } + + // rebuild diff layers from disk layer + rebuildFromParent(dl, children, layers, diffTmps) + return layers, children, nil +} + +func rebuildFromParent(p snapshot, children map[common.Hash][]common.Hash, layers map[common.Hash]snapshot, diffTmps map[common.Hash]diffTmp) { + subs := children[p.Root()] + for _, cur := range subs { + df := diffTmps[cur] + layers[cur] = newEpochMetaDiffLayer(&df.number, df.root, p, df.nodeSet) + rebuildFromParent(layers[cur], children, layers, diffTmps) + } +} diff --git a/trie/epochmeta/snapshot_test.go b/trie/epochmeta/snapshot_test.go new file mode 100644 index 0000000000..0cc4145eee --- /dev/null +++ b/trie/epochmeta/snapshot_test.go @@ -0,0 +1,121 @@ +package epochmeta + +import ( + "math/big" + "strconv" + "testing" + + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +func TestEpochMetaDiffLayer_capDiffLayers(t *testing.T) { + diskdb := memorydb.New() + // create empty tree + tree, err := NewEpochMetaSnapTree(diskdb, nil) + assert.NoError(t, err) + + // push 200 diff layers + count := 1 + for i := 0; i < 200; i++ { + ns := strconv.Itoa(count) + root := makeHash("b" + ns) + parent := makeHash("b" + strconv.Itoa(count-1)) + number := new(big.Int).SetUint64(uint64(count)) + err = tree.Update(parent, number, + root, makeNodeSet(contract1, []string{"hello" + ns, "world" + ns})) + assert.NoError(t, err) + + // add 10 forks + for j := 0; j < 10; j++ { + fs := strconv.Itoa(j) + err = tree.Update(parent, number, + makeHash("b"+ns+"f"+fs), makeNodeSet(contract1, []string{"hello" + ns + "f" + fs, "world" + ns + "f" + fs})) + assert.NoError(t, err) + } + + err = tree.Cap(root) + assert.NoError(t, err) + count++ + } + assert.Equal(t, 1409, len(tree.layers)) + + // push 100 diff layers, and cap + for i := 0; i < 100; i++ { + ns := strconv.Itoa(count) + parent := makeHash("b" + strconv.Itoa(count-1)) + root := makeHash("b" + ns) + number := new(big.Int).SetUint64(uint64(count)) + err = tree.Update(parent, number, root, + makeNodeSet(contract1, []string{"hello" + ns, "world" + ns})) + assert.NoError(t, err) + + // add 20 forks + for j := 0; j < 10; j++ { + fs := strconv.Itoa(j) + err = tree.Update(parent, number, + makeHash("b"+ns+"f"+fs), makeNodeSet(contract1, []string{"hello" + ns + "f" + fs, "world" + ns + "f" + fs})) + assert.NoError(t, err) + } + for j := 0; j < 10; j++ { + fs := strconv.Itoa(j) + err = tree.Update(makeHash("b"+strconv.Itoa(count-1)+"f"+fs), number, + makeHash("b"+ns+"f"+fs), makeNodeSet(contract1, []string{"hello" + ns + "f" + fs, "world" + ns + "f" + fs})) + assert.NoError(t, err) + } + count++ + } + lastRoot := makeHash("b" + strconv.Itoa(count-1)) + err = tree.Cap(lastRoot) + assert.NoError(t, err) + assert.Equal(t, 1409, len(tree.layers)) + + // push 100 diff layers, and cap + for i := 0; i < 129; i++ { + ns := strconv.Itoa(count) + parent := makeHash("b" + strconv.Itoa(count-1)) + root := makeHash("b" + ns) + number := new(big.Int).SetUint64(uint64(count)) + err = tree.Update(parent, number, root, + makeNodeSet(contract1, []string{"hello" + ns, "world" + ns})) + assert.NoError(t, err) + + count++ + } + lastRoot = makeHash("b" + strconv.Itoa(count-1)) + err = tree.Cap(lastRoot) + assert.NoError(t, err) + + assert.Equal(t, 129, len(tree.layers)) + assert.Equal(t, 128, len(tree.children)) + for parent, children := range tree.children { + if tree.layers[parent] == nil { + t.Log(tree.layers[parent]) + } + assert.NotNil(t, tree.layers[parent]) + for _, child := range children { + if tree.layers[child] == nil { + t.Log(tree.layers[child]) + } + assert.NotNil(t, tree.layers[child]) + } + } + + snap := tree.Snapshot(lastRoot) + assert.NotNil(t, snap) + for i := 1; i < count; i++ { + ns := strconv.Itoa(i) + n, err := snap.EpochMeta(contract1, "hello"+ns) + assert.NoError(t, err) + assert.Equal(t, []byte("world"+ns), n) + } + + // store + err = tree.Journal() + assert.NoError(t, err) + + tree, err = NewEpochMetaSnapTree(diskdb, nil) + assert.NoError(t, err) + assert.Equal(t, 129, len(tree.layers)) + assert.Equal(t, 128, len(tree.children)) +} diff --git a/trie/errors.go b/trie/errors.go index 7be7041c7f..0ef3187ebd 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" ) // ErrCommitted is returned when a already committed trie is requested for usage. @@ -50,3 +51,51 @@ func (err *MissingNodeError) Error() string { } return fmt.Sprintf("missing trie node %x (owner %x) (path %x) %v", err.NodeHash, err.Owner, err.Path, err.err) } + +type ReviveNotExpiredError struct { + Path []byte // hex-encoded path to the expired node + Epoch types.StateEpoch +} + +func NewReviveNotExpiredErr(path []byte, epoch types.StateEpoch) error { + return &ReviveNotExpiredError{ + Path: path, + Epoch: epoch, + } +} + +func (e *ReviveNotExpiredError) Error() string { + return fmt.Sprintf("revive not expired kv, path: %v, epoch: %v", e.Path, e.Epoch) +} + +type ExpiredNodeError struct { + Path []byte // hex-encoded path to the expired node + Epoch types.StateEpoch + Node node +} + +func NewExpiredNodeError(path []byte, epoch types.StateEpoch, n node) error { + return &ExpiredNodeError{ + Path: path, + Epoch: epoch, + Node: n, + } +} + +func (err *ExpiredNodeError) Error() string { + return fmt.Sprintf("expired trie node, path: %v, epoch: %v, node: %v", err.Path, err.Epoch, err.Node.fstring("")) +} + +func ParseExpiredNodeErr(err error) ([]byte, bool) { + var path []byte + switch enErr := err.(type) { + case *ExpiredNodeError: + path = enErr.Path + case *MissingNodeError: // when meet MissingNodeError, try revive or fail + path = enErr.Path + default: + return nil, false + } + + return path, true +} diff --git a/trie/inspect_trie.go b/trie/inspect_trie.go new file mode 100644 index 0000000000..5d4cc00d04 --- /dev/null +++ b/trie/inspect_trie.go @@ -0,0 +1,314 @@ +package trie + +import ( + "bytes" + "errors" + "fmt" + "math/big" + "os" + "runtime" + "sort" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/core/types" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" + "github.com/olekukonko/tablewriter" +) + +type Account struct { + Nonce uint64 + Balance *big.Int + Root common.Hash // merkle root of the storage trie + CodeHash []byte +} + +type Inspector struct { + trieDB *Database + trie *Trie // traverse trie + blocknum uint64 + root node // root of triedb + result *TotalTrieTreeStat // inspector result + totalNum uint64 + concurrentQueue chan struct{} + wg sync.WaitGroup +} + +type RWMap struct { + sync.RWMutex + m map[uint64]*TrieTreeStat +} + +// 新建一个RWMap +func NewRWMap() *RWMap { + return &RWMap{ + m: make(map[uint64]*TrieTreeStat, 1), + } +} +func (m *RWMap) Get(k uint64) (*TrieTreeStat, bool) { //从map中读取一个值 + m.RLock() + defer m.RUnlock() + v, existed := m.m[k] // 在锁的保护下从map中读取 + return v, existed +} + +func (m *RWMap) Set(k uint64, v *TrieTreeStat) { // 设置一个键值对 + m.Lock() // 锁保护 + defer m.Unlock() + m.m[k] = v +} + +func (m *RWMap) Delete(k uint64) { //删除一个键 + m.Lock() // 锁保护 + defer m.Unlock() + delete(m.m, k) +} + +func (m *RWMap) Len() int { // map的长度 + m.RLock() // 锁保护 + defer m.RUnlock() + return len(m.m) +} + +func (m *RWMap) Each(f func(k uint64, v *TrieTreeStat) bool) { // 遍历map + m.RLock() //遍历期间一直持有读锁 + defer m.RUnlock() + + for k, v := range m.m { + if !f(k, v) { + return + } + } +} + +type TotalTrieTreeStat struct { + theTrieTreeStats RWMap +} + +type TrieTreeStat struct { + is_account_trie bool + theNodeStatByLevel [15]NodeStat + totalNodeStat NodeStat +} + +type NodeStat struct { + ShortNodeCnt uint64 + FullNodeCnt uint64 + ValueNodeCnt uint64 +} + +func (trieStat *TrieTreeStat) AtomicAdd(theNode node, height uint32) { + switch (theNode).(type) { + case *shortNode: + atomic.AddUint64(&trieStat.totalNodeStat.ShortNodeCnt, 1) + atomic.AddUint64(&(trieStat.theNodeStatByLevel[height].ShortNodeCnt), 1) + case *fullNode: + atomic.AddUint64(&trieStat.totalNodeStat.FullNodeCnt, 1) + atomic.AddUint64(&trieStat.theNodeStatByLevel[height].FullNodeCnt, 1) + case valueNode: + atomic.AddUint64(&trieStat.totalNodeStat.ValueNodeCnt, 1) + atomic.AddUint64(&((trieStat.theNodeStatByLevel[height]).ValueNodeCnt), 1) + default: + panic(errors.New("Invalid node type to statistics")) + } +} + +func (trieStat *TrieTreeStat) Display(rootHash uint64, treeType string) { + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"TrieType", "Level", "ShortNodeCnt", "FullNodeCnt", "ValueNodeCnt"}) + table.SetAlignment(1) + for i := 0; i < len(trieStat.theNodeStatByLevel); i++ { + nodeStat := trieStat.theNodeStatByLevel[i] + if nodeStat.FullNodeCnt == 0 && nodeStat.ShortNodeCnt == 0 && nodeStat.ValueNodeCnt == 0 { + break + } + table.AppendBulk([][]string{ + {"-", strconv.Itoa(i), nodeStat.ShortNodeCount(), nodeStat.FullNodeCount(), nodeStat.ValueNodeCount()}, + }) + } + table.AppendBulk([][]string{ + {fmt.Sprintf("%v-%v", treeType, rootHash), "Total", trieStat.totalNodeStat.ShortNodeCount(), trieStat.totalNodeStat.FullNodeCount(), trieStat.totalNodeStat.ValueNodeCount()}, + }) + table.Render() +} + +func Uint64ToString(cnt uint64) string { + return fmt.Sprintf("%v", cnt) +} + +func (nodeStat *NodeStat) ShortNodeCount() string { + return Uint64ToString(nodeStat.ShortNodeCnt) +} + +func (nodeStat *NodeStat) FullNodeCount() string { + return Uint64ToString(nodeStat.FullNodeCnt) +} +func (nodeStat *NodeStat) ValueNodeCount() string { + return Uint64ToString(nodeStat.ValueNodeCnt) +} + +// NewInspector return a inspector obj +func NewInspector(trieDB *Database, tr *Trie, blocknum uint64, jobnum uint64) (*Inspector, error) { + if tr == nil { + return nil, errors.New("trie is nil") + } + + if tr.root == nil { + return nil, errors.New("trie root is nil") + } + + ins := &Inspector{ + trieDB: trieDB, + trie: tr, + blocknum: blocknum, + root: tr.root, + result: &TotalTrieTreeStat{ + theTrieTreeStats: *NewRWMap(), + }, + totalNum: (uint64)(0), + concurrentQueue: make(chan struct{}, jobnum), + wg: sync.WaitGroup{}, + } + + return ins, nil +} + +// Run statistics, external call +func (inspect *Inspector) Run() { + accountTrieStat := new(TrieTreeStat) + roothash := inspect.trie.Hash().Big().Uint64() + path := make([]byte, 0) + + ticker := time.NewTicker(30 * time.Second) + go func() { + defer ticker.Stop() + for { + select { + case <-ticker.C: + inspect.trieDB.Cap(DEFAULT_TRIEDBCACHE_SIZE) + } + } + }() + + inspect.result.theTrieTreeStats.Set(roothash, accountTrieStat) + log.Info("Find Account Trie Tree, rootHash: ", inspect.trie.Hash().String(), "BlockNum: ", inspect.blocknum) + inspect.ConcurrentTraversal(inspect.trie, accountTrieStat, inspect.root, 0, path) + inspect.wg.Wait() +} + +func (inspect *Inspector) SubConcurrentTraversal(theTrie *Trie, theTrieTreeStat *TrieTreeStat, theNode node, height uint32, path []byte) { + inspect.concurrentQueue <- struct{}{} + inspect.ConcurrentTraversal(theTrie, theTrieTreeStat, theNode, height, path) + <-inspect.concurrentQueue + inspect.wg.Done() +} + +func (inspect *Inspector) ConcurrentTraversal(theTrie *Trie, theTrieTreeStat *TrieTreeStat, theNode node, height uint32, path []byte) { + // print process progress + total_num := atomic.AddUint64(&inspect.totalNum, 1) + if total_num%100000 == 0 { + fmt.Printf("Complete progress: %v, go routines Num: %v, inspect concurrentQueue: %v\n", total_num, runtime.NumGoroutine(), len(inspect.concurrentQueue)) + } + + // nil node + if theNode == nil { + return + } + + switch current := (theNode).(type) { + case *shortNode: + path = append(path, current.Key...) + inspect.ConcurrentTraversal(theTrie, theTrieTreeStat, current.Val, height+1, path) + case *fullNode: + for idx, child := range current.Children { + if child == nil { + continue + } + if len(inspect.concurrentQueue)*2 < cap(inspect.concurrentQueue) { + inspect.wg.Add(1) + go inspect.SubConcurrentTraversal(theTrie, theTrieTreeStat, child, height+1, copy2NewBytes(path, []byte{byte(idx)})) + } else { + inspect.ConcurrentTraversal(theTrie, theTrieTreeStat, child, height+1, append(path, byte(idx))) + } + } + case hashNode: + n, err := theTrie.resolveHash(current, path) + if err != nil { + fmt.Printf("Resolve HashNode error: %v, TrieRoot: %v, Height: %v, Path: %v\n", err, theTrie.Hash().String(), height+1, path) + return + } + inspect.ConcurrentTraversal(theTrie, theTrieTreeStat, n, height, path) + return + case valueNode: + if !hasTerm(path) { + break + } + var account Account + if err := rlp.Decode(bytes.NewReader(current), &account); err != nil { + break + } + if account.Root == (common.Hash{}) || account.Root == types.EmptyRootHash { + break + } + root, _ := theTrie.root.cache() + contractTrie, err := New(StorageTrieID(common.BytesToHash(root), common.BytesToHash(hexToKeybytes(path)), account.Root), inspect.trieDB) + if err != nil { + // fmt.Printf("New contract trie node: %v, error: %v, Height: %v, Path: %v\n", theNode, err, height, path) + break + } + trieStat := new(TrieTreeStat) + trieStat.is_account_trie = false + subRootHash := contractTrie.Hash().Big().Uint64() + inspect.result.theTrieTreeStats.Set(subRootHash, trieStat) + contractPath := make([]byte, 0) + // log.Info("Find Contract Trie Tree, rootHash: ", contractTrie.Hash().String(), "") + inspect.wg.Add(1) + go inspect.SubConcurrentTraversal(contractTrie, trieStat, contractTrie.root, 0, contractPath) + default: + panic(errors.New("Invalid node type to traverse.")) + } + theTrieTreeStat.AtomicAdd(theNode, height) +} + +func (inspect *Inspector) DisplayResult() { + // display root hash + roothash := inspect.trie.Hash().Big().Uint64() + rootStat, _ := inspect.result.theTrieTreeStats.Get(roothash) + rootStat.Display(roothash, "AccountTrie") + + // display contract trie + trieNodeNums := make([][]uint64, 0, inspect.result.theTrieTreeStats.Len()-1) + var totalContactsNodeStat NodeStat + var contractTrieCnt uint64 = 0 + inspect.result.theTrieTreeStats.Each(func(rootHash uint64, stat *TrieTreeStat) bool { + if rootHash == roothash { + return true + } + contractTrieCnt++ + totalContactsNodeStat.ShortNodeCnt += stat.totalNodeStat.ShortNodeCnt + totalContactsNodeStat.FullNodeCnt += stat.totalNodeStat.FullNodeCnt + totalContactsNodeStat.ValueNodeCnt += stat.totalNodeStat.ValueNodeCnt + totalNodeCnt := stat.totalNodeStat.ShortNodeCnt + stat.totalNodeStat.ValueNodeCnt + stat.totalNodeStat.FullNodeCnt + trieNodeNums = append(trieNodeNums, []uint64{totalNodeCnt, rootHash}) + return true + }) + + fmt.Printf("Contract Trie, total trie num: %v, ShortNodeCnt: %v, FullNodeCnt: %v, ValueNodeCnt: %v\n", + contractTrieCnt, totalContactsNodeStat.ShortNodeCnt, totalContactsNodeStat.FullNodeCnt, totalContactsNodeStat.ValueNodeCnt) + sort.Slice(trieNodeNums, func(i, j int) bool { + return trieNodeNums[i][0] > trieNodeNums[j][0] + }) + // only display top 5 + for i, cntHash := range trieNodeNums { + if i > 5 { + break + } + stat, _ := inspect.result.theTrieTreeStats.Get(cntHash[1]) + stat.Display(cntHash[1], "ContractTrie") + } +} diff --git a/trie/iterator.go b/trie/iterator.go index 6f054a7245..4fb626e522 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -735,7 +735,7 @@ func (it *unionIterator) AddResolver(resolver NodeResolver) { // // In the case that descend=false - eg, we're asked to ignore all subnodes of the // current node - we also advance any iterators in the heap that have the current -// path as a prefix. +// Path as a prefix. func (it *unionIterator) Next(descend bool) bool { if len(*it.items) == 0 { return false diff --git a/trie/node.go b/trie/node.go index d78ed5c569..1ec7f00855 100644 --- a/trie/node.go +++ b/trie/node.go @@ -22,26 +22,43 @@ import ( "strings" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" ) +const ( + BranchNodeLength = 17 +) + +const ( + shortNodeType = iota + fullNodeType + hashNodeType + valueNodeType + rawNodeType +) + var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} type node interface { cache() (hashNode, bool) encode(w rlp.EncoderBuffer) fstring(string) string + nodeType() int } type ( fullNode struct { Children [17]node // Actual trie node data to encode/decode (needs custom encoder) flags nodeFlag + EpochMap [16]types.StateEpoch `rlp:"-" json:"-"` + epoch types.StateEpoch `rlp:"-" json:"-"` } shortNode struct { Key []byte Val node flags nodeFlag + epoch types.StateEpoch `rlp:"-" json:"-"` } hashNode []byte valueNode []byte @@ -58,6 +75,48 @@ func (n *fullNode) EncodeRLP(w io.Writer) error { return eb.Flush() } +func (n *fullNode) setEpoch(epoch types.StateEpoch) { + if n.epoch >= epoch { + return + } + n.epoch = epoch +} + +func (n *shortNode) setEpoch(epoch types.StateEpoch) { + if n.epoch >= epoch { + return + } + n.epoch = epoch +} + +func (n *fullNode) GetChildEpoch(index int) types.StateEpoch { + if index < 16 { + return n.EpochMap[index] + } + return n.epoch +} + +func (n *fullNode) UpdateChildEpoch(index int, epoch types.StateEpoch) { + if index < 16 { + n.EpochMap[index] = epoch + } +} + +func (n *fullNode) SetEpochMap(epochMap [16]types.StateEpoch) { + n.EpochMap = epochMap +} + +func (n *fullNode) ChildExpired(prefix []byte, index int, currentEpoch types.StateEpoch) (bool, error) { + childEpoch := n.GetChildEpoch(index) + if types.EpochExpired(childEpoch, currentEpoch) { + return true, &ExpiredNodeError{ + Path: prefix, + Epoch: childEpoch, + } + } + return false, nil +} + func (n *fullNode) copy() *fullNode { copy := *n; return © } func (n *shortNode) copy() *shortNode { copy := *n; return © } @@ -107,6 +166,10 @@ type rawNode []byte func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawNode) nodeType() int { + return rawNodeType +} + func (n rawNode) EncodeRLP(w io.Writer) error { _, err := w.Write(n) return err @@ -117,6 +180,22 @@ func NodeString(hash, buf []byte) string { return node.fstring("NodeString: ") } +func (n *shortNode) nodeType() int { + return shortNodeType +} + +func (n *fullNode) nodeType() int { + return fullNodeType +} + +func (n hashNode) nodeType() int { + return hashNodeType +} + +func (n valueNode) nodeType() int { + return valueNodeType +} + // mustDecodeNode is a wrapper of decodeNode and panic if any error is encountered. func mustDecodeNode(hash, buf []byte) node { n, err := decodeNode(hash, buf) @@ -126,10 +205,10 @@ func mustDecodeNode(hash, buf []byte) node { return n } -// mustDecodeNodeUnsafe is a wrapper of decodeNodeUnsafe and panic if any error is +// mustDecodeNodeUnsafe is a wrapper of decodeTypedNodeUnsafe and panic if any error is // encountered. func mustDecodeNodeUnsafe(hash, buf []byte) node { - n, err := decodeNodeUnsafe(hash, buf) + n, err := decodeTypedNodeUnsafe(hash, buf) if err != nil { panic(fmt.Sprintf("node %x: %v", hash, err)) } @@ -142,7 +221,30 @@ func mustDecodeNodeUnsafe(hash, buf []byte) node { // scenarios with low performance requirements and hard to determine whether the // byte slice be modified or not. func decodeNode(hash, buf []byte) (node, error) { - return decodeNodeUnsafe(hash, common.CopyBytes(buf)) + return decodeTypedNodeUnsafe(hash, common.CopyBytes(buf)) +} + +func decodeTypedNodeUnsafe(hash, buf []byte) (node, error) { + // try decode typed node first + tn, err := types.DecodeTypedTrieNode(buf) + if err != nil { + return nil, err + } + switch tn := tn.(type) { + case types.TrieNodeRaw: + return decodeNodeUnsafe(hash, tn) + case *types.TrieBranchNodeWithEpoch: + rn, err := decodeNodeUnsafe(hash, tn.Blob) + if err != nil { + return nil, err + } + if rn, ok := rn.(*fullNode); ok { + rn.EpochMap = tn.EpochMap + } + return rn, nil + default: + return nil, types.ErrTypedNodeNotSupport + } } // decodeNodeUnsafe parses the RLP encoding of a trie node. The passed byte slice @@ -181,13 +283,13 @@ func decodeShort(hash, elems []byte) (node, error) { if err != nil { return nil, fmt.Errorf("invalid value node: %v", err) } - return &shortNode{key, valueNode(val), flag}, nil + return &shortNode{Key: key, Val: valueNode(val), flags: flag}, nil } r, _, err := decodeRef(rest) if err != nil { return nil, wrapError(err, "val") } - return &shortNode{key, r, flag}, nil + return &shortNode{Key: key, Val: r, flags: flag}, nil } func decodeFull(hash, elems []byte) (*fullNode, error) { diff --git a/trie/node_enc.go b/trie/node_enc.go index 1b2eca682f..98b17bca73 100644 --- a/trie/node_enc.go +++ b/trie/node_enc.go @@ -17,6 +17,7 @@ package trie import ( + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" ) @@ -28,6 +29,39 @@ func nodeToBytes(n node) []byte { return result } +func nodeToBytesWithEpoch(n node) []byte { + switch n := n.(type) { + case *fullNode: + withEpoch := false + for i := 0; i < BranchNodeLength-1; i++ { + if n.EpochMap[i] > 0 { + withEpoch = true + break + } + } + if withEpoch { + w := rlp.NewEncoderBuffer(nil) + w.Write([]byte{types.TrieBranchNodeWithEpochType}) + offset := w.List() + mapOffset := w.List() + for _, item := range n.EpochMap { + if item == 0 { + w.Write(rlp.EmptyString) + } else { + w.WriteUint64(uint64(item)) + } + } + w.ListEnd(mapOffset) + n.encode(w) + w.ListEnd(offset) + result := w.ToBytes() + w.Flush() + return result + } + } + return nodeToBytes(n) +} + func (n *fullNode) encode(w rlp.EncoderBuffer) { offset := w.List() for _, c := range n.Children { diff --git a/trie/proof.go b/trie/proof.go index a463c80b48..5abee92cf9 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -18,10 +18,14 @@ package trie import ( "bytes" + "encoding/hex" "errors" "fmt" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" ) @@ -34,6 +38,8 @@ import ( // nodes of the longest existing prefix of the key (at least the root node), ending // with the node that proves the absence of the key. func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { + var nodeEpoch types.StateEpoch + // Short circuit if the trie is already committed and not usable. if t.committed { return ErrCommitted @@ -45,7 +51,14 @@ func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { tn = t.root ) key = keybytesToHex(key) + + if t.enableExpiry { + nodeEpoch = t.getRootEpoch() + } for len(key) > 0 && tn != nil { + if t.enableExpiry && t.epochExpired(tn, nodeEpoch) { + return NewExpiredNodeError(prefix, nodeEpoch, tn) + } switch n := tn.(type) { case *shortNode: if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { @@ -58,6 +71,9 @@ func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { } nodes = append(nodes, n) case *fullNode: + if t.enableExpiry { + nodeEpoch = n.GetChildEpoch(int(key[0])) + } tn = n.Children[key[0]] prefix = append(prefix, key[0]) key = key[1:] @@ -77,6 +93,12 @@ func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { // clean cache or the database, they are all in their own // copy and safe to use unsafe decoder. tn = mustDecodeNodeUnsafe(n, blob) + + if child, ok := tn.(*fullNode); t.enableExpiry && ok { + if err = t.resolveEpochMeta(child, nodeEpoch, prefix); err != nil { + return err + } + } default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } @@ -111,6 +133,128 @@ func (t *StateTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { return t.trie.Prove(key, proofDb) } +// traverseNodes traverses the trie with the given key starting at the given node. +// If the trie contains the key, the returned node is the node that contains the +// value for the key. If nodes is specified, the traversed nodes are appended to +// it. +func (t *Trie) traverseNodes(tn node, prefixKey, suffixKey []byte, nodes *[]node, epoch types.StateEpoch, updateEpoch bool) (node, error) { + for len(suffixKey) > 0 && tn != nil { + switch n := tn.(type) { + case *shortNode: + if len(suffixKey) >= len(n.Key) && bytes.Equal(n.Key, suffixKey[:len(n.Key)]) { + tn = n.Val + prefixKey = append(prefixKey, n.Key...) + suffixKey = suffixKey[len(n.Key):] + if nodes != nil { + *nodes = append(*nodes, n) + } + continue + } + + tn = nil + if nodes != nil { + *nodes = append(*nodes, n) + } + // if there is a extern node, must put the val + hn, isExternNode := n.Val.(hashNode) + if isExternNode && nodes != nil { + prefixKey = append(prefixKey, n.Key...) + nextBlob, err := t.reader.node(prefixKey, common.BytesToHash(hn)) + if err != nil { + log.Error("Unhandled next trie error in traverseNodes", "err", err) + return nil, err + } + next := mustDecodeNodeUnsafe(hn, nextBlob) + *nodes = append(*nodes, next) + } + case *fullNode: + tn = n.Children[suffixKey[0]] + prefixKey = append(prefixKey, suffixKey[0]) + suffixKey = suffixKey[1:] + if nodes != nil { + *nodes = append(*nodes, n) + } + case hashNode: + // Retrieve the specified node from the underlying node reader. + // trie.resolveAndTrack is not used since in that function the + // loaded blob will be tracked, while it's not required here since + // all loaded nodes won't be linked to trie at all and track nodes + // may lead to out-of-memory issue. + blob, err := t.reader.node(prefixKey, common.BytesToHash(n)) + if err != nil { + log.Error("Unhandled trie error in traverseNodes", "err", err) + return nil, err + } + // The raw-blob format nodes are loaded either from the + // clean cache or the database, they are all in their own + // copy and safe to use unsafe decoder. + tn = mustDecodeNodeUnsafe(n, blob) + if err = t.resolveEpochMeta(tn, epoch, prefixKey); err != nil { + return nil, err + } + default: + panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + } + } + + return tn, nil +} + +func (t *Trie) ProveByPath(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueWriter) error { + if t.committed { + return ErrCommitted + } + + if len(key) == 0 { + return fmt.Errorf("key is empty") + } + + key = keybytesToHex(key) + + // traverse down using the prefixKeyHex + var nodes []node + tn := t.root + startNode, err := t.traverseNodes(tn, nil, prefixKeyHex, nil, 0, false) // obtain the node where the prefixKeyHex leads to + if err != nil { + return err + } + + key = key[len(prefixKeyHex):] // obtain the suffix key + + // traverse through the suffix key + _, err = t.traverseNodes(startNode, prefixKeyHex, key, &nodes, 0, false) + if err != nil { + return err + } + + if len(nodes) == 0 { + log.Error("found nothing....", "prefix", prefixKeyHex, "key", key) + return fmt.Errorf("cannot find target proof, prefix: %#x, suffix: %#x", prefixKeyHex, key) + } + + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + // construct the proof + for _, n := range nodes { + var hn node + n, hn = hasher.proofHash(n) + if hash, ok := hn.(hashNode); ok { + enc := nodeToBytes(n) + if !ok { + hash = hasher.hashData(enc) + } + proofDb.Put(hash, enc) + } + } + + return nil +} + +func (t *StateTrie) ProveByPath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error { + return t.trie.ProveByPath(key, path, proofDb) +} + // VerifyProof checks merkle proofs. The given proof must contain the value for // key in a trie with the given root hash. VerifyProof returns an error if the // proof contains invalid trie nodes or the wrong value. @@ -613,3 +757,299 @@ func get(tn node, key []byte, skipResolved bool) ([]byte, node) { } } } + +type MPTProof struct { + RootKeyHex []byte // prefix key in nibbles format, max 65 bytes. TODO: optimize witness size + Proof [][]byte // list of RLP-encoded nodes +} + +type MPTProofNub struct { + n1PrefixKey []byte // n1's prefix hex key, max 64bytes + n1 node + n2PrefixKey []byte // n2's prefix hex key, max 64bytes + n2 node +} + +// ResolveKV revive state could revive KV from fullNode[0-15] or fullNode[16] or shortNode.Val, return KVs for cache & snap +func (m *MPTProofNub) ResolveKV() (map[string][]byte, error) { + kvMap := make(map[string][]byte) + if err := resolveKV(m.n1, m.n1PrefixKey, kvMap); err != nil { + return nil, err + } + if err := resolveKV(m.n2, m.n2PrefixKey, kvMap); err != nil { + return nil, err + } + + return kvMap, nil +} + +func (m *MPTProofNub) GetValue() []byte { + if val := getNubValue(m.n1, m.n1PrefixKey); val != nil { + return val + } + + if val := getNubValue(m.n2, m.n2PrefixKey); val != nil { + return val + } + + return nil +} + +func (m *MPTProofNub) String() string { + buf := bytes.NewBuffer(nil) + buf.WriteString("n1: ") + buf.WriteString(hex.EncodeToString(m.n1PrefixKey)) + buf.WriteString(", n1proof: ") + if m.n1 != nil { + buf.WriteString(m.n1.fstring("")) + } + buf.WriteString(", n2: ") + buf.WriteString(hex.EncodeToString(m.n2PrefixKey)) + buf.WriteString(", n2proof: ") + if m.n2 != nil { + buf.WriteString(m.n2.fstring("")) + } + return buf.String() +} + +func getNubValue(origin node, prefixKey []byte) []byte { + switch n := origin.(type) { + case nil, hashNode: + return nil + case valueNode: + _, content, _, _ := rlp.Split(n) + return content + case *shortNode: + return getNubValue(n.Val, append(prefixKey, n.Key...)) + case *fullNode: + for i := 0; i < BranchNodeLength-1; i++ { + if val := getNubValue(n.Children[i], append(prefixKey, byte(i))); val != nil { + return val + } + } + return getNubValue(n.Children[BranchNodeLength-1], prefixKey) + default: + panic(fmt.Sprintf("invalid node: %v", origin)) + } +} + +func resolveKV(origin node, prefixKey []byte, kvWriter map[string][]byte) error { + switch n := origin.(type) { + case nil, hashNode: + return nil + case valueNode: + _, content, _, err := rlp.Split(n) + if err != nil { + return err + } + kvWriter[string(hexToKeybytes(prefixKey))] = content + return nil + case *shortNode: + return resolveKV(n.Val, append(prefixKey, n.Key...), kvWriter) + case *fullNode: + for i := 0; i < BranchNodeLength-1; i++ { + if err := resolveKV(n.Children[i], append(prefixKey, byte(i)), kvWriter); err != nil { + return err + } + } + return resolveKV(n.Children[BranchNodeLength-1], prefixKey, kvWriter) + default: + panic(fmt.Sprintf("invalid node: %v", origin)) + } +} + +type MPTProofCache struct { + MPTProof + + cacheHexPath [][]byte // cache path for performance + cacheHashes [][]byte // cache hash for performance + cacheNodes []node // cache node for performance + cacheNubs []*MPTProofNub // cache proof nubs to check revive duplicate +} + +// VerifyProof verify proof in MPT witness +// 1. calculate hash +// 2. decode trie node +// 3. verify partial merkle proof of the witness +// 4. split to partial witness +func (m *MPTProofCache) VerifyProof() error { + m.cacheHashes = make([][]byte, len(m.Proof)) + m.cacheNodes = make([]node, len(m.Proof)) + m.cacheHexPath = make([][]byte, len(m.Proof)) + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + var child []byte + for i := len(m.Proof) - 1; i >= 0; i-- { + m.cacheHashes[i] = hasher.hashData(m.Proof[i]) + n, err := decodeNode(m.cacheHashes[i], m.Proof[i]) + if err != nil { + return err + } + m.cacheNodes[i] = n + + switch t := n.(type) { + case *shortNode: + m.cacheHexPath[i] = t.Key + if err := matchHashNodeInShortNode(child, t); err != nil { + return err + } + case *fullNode: + index, err := matchHashNodeInFullNode(child, t) + if err != nil { + return err + } + if index >= 0 { + m.cacheHexPath[i] = []byte{byte(index)} + } + case valueNode: + if child != nil { + return errors.New("proof wrong child in valueNode") + } + default: + return fmt.Errorf("proof got wrong trie node: %v", t.nodeType()) + } + + child = m.cacheHashes[i] + } + + // cache proof nubs + m.cacheNubs = make([]*MPTProofNub, 0, len(m.Proof)) + prefix := m.RootKeyHex + for i := 0; i < len(m.cacheNodes); i++ { + if i-1 >= 0 { + prefix = copy2NewBytes(prefix, m.cacheHexPath[i-1]) + } + // prefix = append(prefix, m.cacheHexPath[i]...) + n1 := m.cacheNodes[i] + nub := MPTProofNub{ + n1PrefixKey: prefix, + n1: n1, + n2: nil, + n2PrefixKey: nil, + } + + // check if satisfy partial witness rules, + // that short node must with its child, may full node or valueNode + merge, err := mergeNextNode(m.cacheNodes, i) + if err != nil { + return err + } + if merge { + i++ + prefix = copy2NewBytes(prefix, m.cacheHexPath[i-1]) + nub.n2 = m.cacheNodes[i] + nub.n2PrefixKey = prefix + } + m.cacheNubs = append(m.cacheNubs, &nub) + } + + return nil +} + +func copy2NewBytes(s1, s2 []byte) []byte { + ret := make([]byte, len(s1)+len(s2)) + copy(ret, s1) + copy(ret[len(s1):], s2) + return ret +} + +func renewBytes(s []byte) []byte { + ret := make([]byte, len(s)) + copy(ret, s) + return ret +} + +func (m *MPTProofCache) CacheNubs() []*MPTProofNub { + return m.cacheNubs +} + +// mergeNextNode check short node must with child in same nub +func mergeNextNode(nodes []node, i int) (bool, error) { + if i >= len(nodes) { + return false, errors.New("mergeNextNode input outbound index") + } + + n1 := nodes[i] + switch n := n1.(type) { + case *shortNode: + need, err := needNextProofNode(n, n.Val) + if err != nil { + return false, err + } + if need && i+1 >= len(nodes) { + return false, errors.New("mergeNextNode short node must with its child") + } + return need, nil + case valueNode: + return false, errors.New("mergeNextNode value node need merge with prev node") + } + + if i+1 >= len(nodes) { + return false, nil + } + return nodes[i+1].nodeType() == valueNodeType, nil +} + +// needNextProofNode check if node need merge next node into a proofNub, because TrieExtendNode must with its child to revive together +func needNextProofNode(parent, origin node) (bool, error) { + switch n := origin.(type) { + case *fullNode: + for i := 0; i < BranchNodeLength-1; i++ { + need, err := needNextProofNode(n, n.Children[i]) + if err != nil { + return false, err + } + if need { + return true, nil + } + } + return false, nil + case *shortNode: + if parent.nodeType() == shortNodeType { + return false, errors.New("needNextProofNode cannot short node's child is short node") + } + return needNextProofNode(n, n.Val) + case valueNode: + return false, nil + case hashNode: + if parent.nodeType() == fullNodeType { + return false, nil + } + return true, nil + default: + return false, errors.New("needNextProofNode unsupported node") + } +} + +func matchHashNodeInFullNode(child []byte, n *fullNode) (int, error) { + if child == nil { + return -1, nil + } + + for i := 0; i < BranchNodeLength-1; i++ { + switch v := n.Children[i].(type) { + case hashNode: + if bytes.Equal(child, v) { + return i, nil + } + } + } + return -1, errors.New("proof cannot find target child in fullNode") +} + +func matchHashNodeInShortNode(child []byte, n *shortNode) error { + if child == nil { + return nil + } + + switch v := n.Val.(type) { + case hashNode: + if !bytes.Equal(child, v) { + return errors.New("proof wrong child in shortNode") + } + default: + return errors.New("proof must hashNode when meet shortNode") + } + return nil +} diff --git a/trie/proof_test.go b/trie/proof_test.go index fc2de62649..62f337f87b 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -901,6 +901,17 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) { } } +type proofList [][]byte + +func (n *proofList) Put(key []byte, value []byte) error { + *n = append(*n, value) + return nil +} + +func (n *proofList) Delete(key []byte) error { + panic("not supported") +} + // mutateByte changes one byte in b. func mutateByte(b []byte) { for r := mrand.Intn(len(b)); ; { @@ -1071,6 +1082,26 @@ func nonRandomTrie(n int) (*Trie, map[string]*kv) { return trie, vals } +func nonRandomTrieWithExpiry(n int) (*Trie, map[string]*kv) { + db := NewDatabase(rawdb.NewMemoryDatabase(), nil) + trie := NewEmpty(db) + trie.currentEpoch = 10 + trie.rootEpoch = 10 + trie.enableExpiry = true + vals := make(map[string]*kv) + max := uint64(0xffffffffffffffff) + for i := uint64(0); i < uint64(n); i++ { + value := make([]byte, 32) + key := make([]byte, 32) + binary.LittleEndian.PutUint64(key, i) + binary.LittleEndian.PutUint64(value, i-max) + elem := &kv{key, value, false} + trie.MustUpdate(elem.k, elem.v) + vals[string(elem.k)] = elem + } + return trie, vals +} + func TestRangeProofKeysWithSharedPrefix(t *testing.T) { keys := [][]byte{ common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"), diff --git a/trie/secure_trie.go b/trie/secure_trie.go index ffe006c1ff..e665c8c58d 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -94,6 +94,15 @@ func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { return content, err } +func (t *StateTrie) GetStorageAndUpdateEpoch(addr common.Address, key []byte) ([]byte, error) { + enc, err := t.trie.GetAndUpdateEpoch(t.hashKey(key)) + if err != nil || len(enc) == 0 { + return nil, err + } + _, content, _, err := rlp.Split(enc) + return content, err +} + // GetAccount attempts to retrieve an account with provided account address. // If the specified account is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. @@ -245,6 +254,14 @@ func (t *StateTrie) Hash() common.Hash { return t.trie.Hash() } +func (t *StateTrie) SetEpoch(epoch types.StateEpoch) { + t.trie.SetEpoch(epoch) +} + +func (t *StateTrie) Epoch() types.StateEpoch { + return t.trie.currentEpoch +} + // Copy returns a copy of StateTrie. func (t *StateTrie) Copy() *StateTrie { return &StateTrie{ @@ -302,3 +319,13 @@ func (t *StateTrie) getSecKeyCache() map[string][]byte { } return t.secKeyCache } + +func (t *StateTrie) TryRevive(key []byte, proof []*MPTProofNub) ([]*MPTProofNub, error) { + key = t.hashKey(key) + return t.trie.TryRevive(key, proof) +} + +func (t *StateTrie) TryLocalRevive(_ common.Address, key []byte) ([]byte, error) { + key = t.hashKey(key) + return t.trie.TryLocalRevive(key) +} diff --git a/trie/stacktrie.go b/trie/stacktrie.go index ee1ce28291..aed56a2241 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -27,6 +27,8 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie/epochmeta" ) var ErrCommitDisabled = errors.New("no database for committing") @@ -40,6 +42,7 @@ var stPool = sync.Pool{ // NodeWriteFunc is used to provide all information of a dirty node for committing // so that callers can flush nodes into database with desired scheme. type NodeWriteFunc = func(owner common.Hash, path []byte, hash common.Hash, blob []byte) +type NodeMetaWriteFunc = func(owner common.Hash, path []byte, blob []byte) func stackTrieFromPool(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { st := stPool.Get().(*StackTrie) @@ -48,6 +51,16 @@ func stackTrieFromPool(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { return st } +func stackTrieFromPoolWithExpiry(writeFn NodeWriteFunc, writeMetaFn NodeMetaWriteFunc, owner common.Hash, epoch types.StateEpoch) *StackTrie { + st := stPool.Get().(*StackTrie) + st.owner = owner + st.writeFn = writeFn + st.writeMetaFn = writeMetaFn + st.epoch = epoch + st.enableStateExpiry = true + return st +} + func returnToPool(st *StackTrie) { st.Reset() stPool.Put(st) @@ -57,12 +70,16 @@ func returnToPool(st *StackTrie) { // in order. Once it determines that a subtree will no longer be inserted // into, it will hash it and free up the memory it uses. type StackTrie struct { - owner common.Hash // the owner of the trie - nodeType uint8 // node type (as in branch, ext, leaf) - val []byte // value contained by this node if it's a leaf - key []byte // key chunk covered by this (leaf|ext) node - children [16]*StackTrie // list of children (for branch and exts) - writeFn NodeWriteFunc // function for committing nodes, can be nil + owner common.Hash // the owner of the trie + nodeType uint8 // node type (as in branch, ext, leaf) + val []byte // value contained by this node if it's a leaf + key []byte // key chunk covered by this (leaf|ext) node + children [16]*StackTrie // list of children (for branch and exts) + enableStateExpiry bool // whether to enable state expiry + epoch types.StateEpoch + epochMap [16]types.StateEpoch + writeFn NodeWriteFunc // function for committing nodes, can be nil + writeMetaFn NodeMetaWriteFunc // function for committing epoch metadata, can be nil } // NewStackTrie allocates and initializes an empty trie. @@ -83,6 +100,17 @@ func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie } } +func NewStackTrieWithStateExpiry(writeFn NodeWriteFunc, writeMetaFn NodeMetaWriteFunc, owner common.Hash, epoch types.StateEpoch) *StackTrie { + return &StackTrie{ + owner: owner, + nodeType: emptyNode, + epoch: epoch, + enableStateExpiry: true, + writeFn: writeFn, + writeMetaFn: writeMetaFn, + } +} + // NewFromBinary initialises a serialized stacktrie with the given db. func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) { var st StackTrie @@ -208,6 +236,10 @@ func (st *StackTrie) Update(key, value []byte) error { if len(value) == 0 { panic("deletion not supported") } + if st.enableStateExpiry { + st.insertWithEpoch(k[:len(k)-1], value, nil, st.epoch) + return nil + } st.insert(k[:len(k)-1], value, nil) return nil } @@ -387,6 +419,154 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) { } } +func (st *StackTrie) insertWithEpoch(key, value []byte, prefix []byte, epoch types.StateEpoch) { + switch st.nodeType { + case branchNode: /* Branch */ + idx := int(key[0]) + + // Unresolve elder siblings + for i := idx - 1; i >= 0; i-- { + if st.children[i] != nil { + if st.children[i].nodeType != hashedNode { + st.children[i].hash(append(prefix, byte(i))) + } + break + } + } + + // Add new child + if st.children[idx] == nil { + st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn) + } else { + st.children[idx].insertWithEpoch(key[1:], value, append(prefix, key[0]), epoch) + } + st.epochMap[idx] = epoch + + case extNode: /* Ext */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Check if chunks are identical. If so, recurse into + // the child node. Otherwise, the key has to be split + // into 1) an optional common prefix, 2) the fullnode + // representing the two differing path, and 3) a leaf + // for each of the differentiated subtrees. + if diffidx == len(st.key) { + // Ext key and key segment are identical, recurse into + // the child node. + st.children[0].insertWithEpoch(key[diffidx:], value, append(prefix, key[:diffidx]...), epoch) + return + } + // Save the original part. Depending if the break is + // at the extension's last byte or not, create an + // intermediate extension or use the extension's child + // node directly. + var n *StackTrie + if diffidx < len(st.key)-1 { + // Break on the non-last byte, insert an intermediate + // extension. The path prefix of the newly-inserted + // extension should also contain the different byte. + n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn) + n.hash(append(prefix, st.key[:diffidx+1]...)) + } else { + // Break on the last byte, no need to insert + // an extension node: reuse the current node. + // The path prefix of the original part should + // still be same. + n = st.children[0] + n.hash(append(prefix, st.key...)) + } + var p *StackTrie + if diffidx == 0 { + // the break is on the first byte, so + // the current node is converted into + // a branch node. + st.children[0] = nil + p = st + st.nodeType = branchNode + } else { + // the common prefix is at least one byte + // long, insert a new intermediate branch + // node. + st.children[0] = stackTrieFromPoolWithExpiry(st.writeFn, st.writeMetaFn, st.owner, st.epoch) + st.children[0].nodeType = branchNode + p = st.children[0] + } + // Create a leaf for the inserted part + o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + + // Insert both child leaves where they belong: + origIdx := st.key[diffidx] + newIdx := key[diffidx] + p.children[origIdx] = n + p.children[newIdx] = o + st.key = st.key[:diffidx] + p.epochMap[origIdx] = epoch + p.epochMap[newIdx] = epoch + + case leafNode: /* Leaf */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Overwriting a key isn't supported, which means that + // the current leaf is expected to be split into 1) an + // optional extension for the common prefix of these 2 + // keys, 2) a fullnode selecting the path on which the + // keys differ, and 3) one leaf for the differentiated + // component of each key. + if diffidx >= len(st.key) { + panic("Trying to insert into existing key") + } + + // Check if the split occurs at the first nibble of the + // chunk. In that case, no prefix extnode is necessary. + // Otherwise, create that + var p *StackTrie + if diffidx == 0 { + // Convert current leaf into a branch + st.nodeType = branchNode + p = st + st.children[0] = nil + } else { + // Convert current node into an ext, + // and insert a child branch node. + st.nodeType = extNode + st.children[0] = NewStackTrieWithStateExpiry(st.writeFn, st.writeMetaFn, st.owner, st.epoch) + st.children[0].nodeType = branchNode + p = st.children[0] + } + + // Create the two child leaves: one containing the original + // value and another containing the new value. The child leaf + // is hashed directly in order to free up some memory. + origIdx := st.key[diffidx] + p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn) + p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...)) + + newIdx := key[diffidx] + p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + + p.epochMap[origIdx] = epoch + p.epochMap[newIdx] = epoch + + // Finally, cut off the key part that has been passed + // over to the children. + st.key = st.key[:diffidx] + st.val = nil + + case emptyNode: /* Empty */ + st.nodeType = leafNode + st.key = key + st.val = value + + case hashedNode: + panic("trying to insert into hash") + + default: + panic("invalid type") + } +} + // hash converts st into a 'hashedNode', if possible. Possible outcomes: // // 1. The rlp-encoded value was >= 32 bytes: @@ -469,6 +649,7 @@ func (st *StackTrie) hashRec(hasher *hasher, path []byte) { panic("invalid node type") } + prevNodeType := st.nodeType st.nodeType = hashedNode st.key = st.key[:0] if len(encodedNode) < 32 { @@ -481,6 +662,12 @@ func (st *StackTrie) hashRec(hasher *hasher, path []byte) { st.val = hasher.hashData(encodedNode) if st.writeFn != nil { st.writeFn(st.owner, path, common.BytesToHash(st.val), encodedNode) + if st.enableStateExpiry && prevNodeType == branchNode && st.writeMetaFn != nil { + epochMeta := epochmeta.NewBranchNodeEpochMeta(st.epochMap) + buf := rlp.NewEncoderBuffer(nil) + epochMeta.Encode(buf) + st.writeMetaFn(st.owner, path, buf.ToBytes()) + } } } @@ -514,6 +701,9 @@ func (st *StackTrie) Commit() (h common.Hash, err error) { if st.writeFn == nil { return common.Hash{}, ErrCommitDisabled } + if st.enableStateExpiry && st.writeMetaFn == nil { + return common.Hash{}, ErrCommitDisabled + } hasher := newHasher(false) defer returnHasherToPool(hasher) @@ -529,6 +719,12 @@ func (st *StackTrie) Commit() (h common.Hash, err error) { hasher.sha.Write(st.val) hasher.sha.Read(h[:]) - st.writeFn(st.owner, nil, h, st.val) + st.writeFn(st.owner, nil, h, st.val) // func(owner common.Hash, path []byte, hash common.Hash, blob []byte) + if st.enableStateExpiry && st.nodeType == branchNode && st.writeMetaFn != nil { + epochMeta := epochmeta.NewBranchNodeEpochMeta(st.epochMap) + buf := rlp.NewEncoderBuffer(nil) + epochMeta.Encode(buf) + st.writeMetaFn(st.owner, nil, buf.ToBytes()) + } return h, nil } diff --git a/trie/sync.go b/trie/sync.go index 4f55845991..12f5a73ed3 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -27,6 +27,8 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie/epochmeta" ) // ErrNotRequested is returned by the trie sync when it's requested to process a @@ -61,7 +63,7 @@ const maxFetchesPerDepth = 16384 // - Path 0x012345678901234567890123456789010123456789012345678901234567890199 -> {0x0123456789012345678901234567890101234567890123456789012345678901, 0x0099} type SyncPath [][]byte -// NewSyncPath converts an expanded trie path from nibble form into a compact +// NewSyncPath converts an expanded trie Path from nibble form into a compact // version that can be sent over the network. func NewSyncPath(path []byte) SyncPath { // If the hash is from the account trie, append a single item, if it @@ -82,8 +84,8 @@ func NewSyncPath(path []byte) SyncPath { // trie (account) or a layered trie (account -> storage). Each key in the tuple // is in the raw format(32 bytes). // -// The path is a composite hexary path identifying the trie node. All the key -// bytes are converted to the hexary nibbles and composited with the parent path +// The Path is a composite hexary Path identifying the trie node. All the key +// bytes are converted to the hexary nibbles and composited with the parent Path // if the trie node is in a layered trie. // // It's used by state sync and commit to allow handling external references @@ -93,9 +95,10 @@ type LeafCallback func(keys [][]byte, path []byte, leaf []byte, parent common.Ha // nodeRequest represents a scheduled or already in-flight trie node retrieval request. type nodeRequest struct { - hash common.Hash // Hash of the trie node to retrieve - path []byte // Merkle path leading to this node for prioritization - data []byte // Data content of the node, cached until all subtrees complete + hash common.Hash // Hash of the trie node to retrieve + path []byte // Merkle path leading to this node for prioritization + data []byte // Data content of the node, cached until all subtrees complete + epochMap [16]types.StateEpoch parent *nodeRequest // Parent state node referencing this entry deps int // Number of dependencies before allowed to commit this node @@ -125,22 +128,24 @@ type CodeSyncResult struct { // syncMemBatch is an in-memory buffer of successfully downloaded but not yet // persisted data items. type syncMemBatch struct { - nodes map[string][]byte // In-memory membatch of recently completed nodes - hashes map[string]common.Hash // Hashes of recently completed nodes - codes map[common.Hash][]byte // In-memory membatch of recently completed codes - size uint64 // Estimated batch-size of in-memory data. + nodes map[string][]byte // In-memory membatch of recently completed nodes + hashes map[string]common.Hash // Hashes of recently completed nodes + epochMaps map[string][16]types.StateEpoch + codes map[common.Hash][]byte // In-memory membatch of recently completed codes + size uint64 // Estimated batch-size of in-memory data. } // newSyncMemBatch allocates a new memory-buffer for not-yet persisted trie nodes. func newSyncMemBatch() *syncMemBatch { return &syncMemBatch{ - nodes: make(map[string][]byte), - hashes: make(map[string]common.Hash), - codes: make(map[common.Hash][]byte), + nodes: make(map[string][]byte), + hashes: make(map[string]common.Hash), + codes: make(map[common.Hash][]byte), + epochMaps: make(map[string][16]types.StateEpoch), } } -// hasNode reports the trie node with specific path is already cached. +// hasNode reports the trie node with specific Path is already cached. func (batch *syncMemBatch) hasNode(path []byte) bool { _, ok := batch.nodes[string(path)] return ok @@ -156,13 +161,15 @@ func (batch *syncMemBatch) hasCode(hash common.Hash) bool { // unknown trie hashes to retrieve, accepts node data associated with said hashes // and reconstructs the trie step by step until all is done. type Sync struct { - scheme string // Node scheme descriptor used in database. - database ethdb.KeyValueReader // Persistent database to check for existing entries - membatch *syncMemBatch // Memory buffer to avoid frequent database writes - nodeReqs map[string]*nodeRequest // Pending requests pertaining to a trie node path - codeReqs map[common.Hash]*codeRequest // Pending requests pertaining to a code hash - queue *prque.Prque[int64, any] // Priority queue with the pending requests - fetches map[int]int // Number of active fetches per trie node depth + scheme string // Node scheme descriptor used in database. + database ethdb.KeyValueReader // Persistent database to check for existing entries + membatch *syncMemBatch // Memory buffer to avoid frequent database writes + nodeReqs map[string]*nodeRequest // Pending requests pertaining to a trie node path + codeReqs map[common.Hash]*codeRequest // Pending requests pertaining to a code hash + queue *prque.Prque[int64, any] // Priority queue with the pending requests + fetches map[int]int // Number of active fetches per trie node depth + enableStateExpiry bool + epoch types.StateEpoch } // NewSync creates a new trie data download scheduler. @@ -180,6 +187,22 @@ func NewSync(root common.Hash, database ethdb.KeyValueReader, callback LeafCallb return ts } +func NewSyncWithEpoch(root common.Hash, database ethdb.KeyValueReader, callback LeafCallback, scheme string, epoch types.StateEpoch) *Sync { + ts := &Sync{ + scheme: scheme, + database: database, + membatch: newSyncMemBatch(), + nodeReqs: make(map[string]*nodeRequest), + codeReqs: make(map[common.Hash]*codeRequest), + queue: prque.New[int64, any](nil), // Ugh, can contain both string and hash, whyyy + fetches: make(map[int]int), + epoch: epoch, + enableStateExpiry: true, + } + ts.AddSubTrie(root, nil, common.Hash{}, nil, callback) + return ts +} + // AddSubTrie registers a new trie to the sync code, rooted at the designated // parent for completion tracking. The given path is a unique node path in // hex format and contain all the parent path if it's layered trie node. @@ -328,6 +351,14 @@ func (s *Sync) ProcessNode(result NodeSyncResult) error { } req.data = result.Data + if fn, ok := node.(*fullNode); s.enableStateExpiry && ok { + for i := 0; i < 16; i++ { + if fn.Children[i] != nil { + req.epochMap[i] = s.epoch + } + } + } + // Create and schedule a request for all the children nodes requests, err := s.children(req, node) if err != nil { @@ -351,6 +382,14 @@ func (s *Sync) Commit(dbw ethdb.Batch) error { for path, value := range s.membatch.nodes { owner, inner := ResolvePath([]byte(path)) rawdb.WriteTrieNode(dbw, owner, inner, s.membatch.hashes[path], value, s.scheme) + if s.enableStateExpiry { + if s.membatch.epochMaps[path] != [16]types.StateEpoch{} { + epochMeta := epochmeta.NewBranchNodeEpochMeta(s.membatch.epochMaps[path]) + buf := rlp.NewEncoderBuffer(nil) + epochMeta.Encode(buf) + rawdb.WriteEpochMetaPlainState(dbw, owner, string(inner), buf.ToBytes()) + } + } } for hash, value := range s.membatch.codes { rawdb.WriteCode(dbw, hash, value) @@ -509,10 +548,18 @@ func (s *Sync) commitNodeRequest(req *nodeRequest) error { // Write the node content to the membatch s.membatch.nodes[string(req.path)] = req.data s.membatch.hashes[string(req.path)] = req.hash + if req.epochMap != [16]types.StateEpoch{} { + s.membatch.epochMaps[string(req.path)] = req.epochMap + } // The size tracking refers to the db-batch, not the in-memory data. - // Therefore, we ignore the req.path, and account only for the hash+data + // Therefore, we ignore the req.Path, and account only for the hash+data // which eventually is written to db. s.membatch.size += common.HashLength + uint64(len(req.data)) + for _, epoch := range req.epochMap { + if epoch != 0 { + s.membatch.size += 16 + } + } delete(s.nodeReqs, string(req.path)) s.fetches[len(req.path)]-- diff --git a/trie/tracer.go b/trie/tracer.go index 5786af4d3e..a4321e5345 100644 --- a/trie/tracer.go +++ b/trie/tracer.go @@ -17,6 +17,8 @@ package trie import ( + "bytes" + "github.com/ethereum/go-ethereum/common" ) @@ -40,20 +42,29 @@ import ( // Note tracer is not thread-safe, callers should be responsible for handling // the concurrency issues by themselves. type tracer struct { - inserts map[string]struct{} - deletes map[string]struct{} - accessList map[string][]byte + inserts map[string]struct{} + deletes map[string]struct{} + deleteEpochMetas map[string]struct{} // record for epoch meta + accessList map[string][]byte + accessEpochMetaList map[string][]byte + tagEpochMeta bool } // newTracer initializes the tracer for capturing trie changes. func newTracer() *tracer { return &tracer{ - inserts: make(map[string]struct{}), - deletes: make(map[string]struct{}), - accessList: make(map[string][]byte), + inserts: make(map[string]struct{}), + deletes: make(map[string]struct{}), + deleteEpochMetas: make(map[string]struct{}), + accessList: make(map[string][]byte), + accessEpochMetaList: make(map[string][]byte), } } +func (t *tracer) enableTagEpochMeta() { + t.tagEpochMeta = true +} + // onRead tracks the newly loaded trie node and caches the rlp-encoded // blob internally. Don't change the value outside of function since // it's not deep-copied. @@ -61,6 +72,14 @@ func (t *tracer) onRead(path []byte, val []byte) { t.accessList[string(path)] = val } +// onReadEpochMeta tracks the newly loaded trie epoch meta +func (t *tracer) onReadEpochMeta(path []byte, val []byte) { + if !t.tagEpochMeta { + return + } + t.accessEpochMetaList[string(path)] = val +} + // onInsert tracks the newly inserted trie node. If it's already // in the deletion set (resurrected node), then just wipe it from // the deletion set as it's "untouched". @@ -72,6 +91,14 @@ func (t *tracer) onInsert(path []byte) { t.inserts[string(path)] = struct{}{} } +// onExpandToBranchNode tracks the newly inserted trie branch node. +func (t *tracer) onExpandToBranchNode(path []byte) { + if !t.tagEpochMeta { + return + } + delete(t.deleteEpochMetas, string(path)) +} + // onDelete tracks the newly deleted trie node. If it's already // in the addition set, then just wipe it from the addition set // as it's untouched. @@ -83,19 +110,31 @@ func (t *tracer) onDelete(path []byte) { t.deletes[string(path)] = struct{}{} } +// onDeleteBranchNode tracks the newly deleted trie branch node. +func (t *tracer) onDeleteBranchNode(path []byte) { + if !t.tagEpochMeta { + return + } + t.deleteEpochMetas[string(path)] = struct{}{} +} + // reset clears the content tracked by tracer. func (t *tracer) reset() { t.inserts = make(map[string]struct{}) t.deletes = make(map[string]struct{}) + t.deleteEpochMetas = make(map[string]struct{}) t.accessList = make(map[string][]byte) + t.accessEpochMetaList = make(map[string][]byte) } // copy returns a deep copied tracer instance. func (t *tracer) copy() *tracer { var ( - inserts = make(map[string]struct{}) - deletes = make(map[string]struct{}) - accessList = make(map[string][]byte) + inserts = make(map[string]struct{}) + deletes = make(map[string]struct{}) + deleteBranchNodes = make(map[string]struct{}) + accessList = make(map[string][]byte) + accessEpochMetaList = make(map[string][]byte) ) for path := range t.inserts { inserts[path] = struct{}{} @@ -103,13 +142,22 @@ func (t *tracer) copy() *tracer { for path := range t.deletes { deletes[path] = struct{}{} } + for path := range t.deleteEpochMetas { + deleteBranchNodes[path] = struct{}{} + } for path, blob := range t.accessList { accessList[path] = common.CopyBytes(blob) } + for path, blob := range t.accessEpochMetaList { + accessEpochMetaList[path] = common.CopyBytes(blob) + } return &tracer{ - inserts: inserts, - deletes: deletes, - accessList: accessList, + inserts: inserts, + deletes: deletes, + deleteEpochMetas: deleteBranchNodes, + accessList: accessList, + accessEpochMetaList: accessEpochMetaList, + tagEpochMeta: t.tagEpochMeta, } } @@ -128,3 +176,42 @@ func (t *tracer) deletedNodes() []string { } return paths } + +// deletedBranchNodes returns a list of branch node paths which are deleted from the trie. +func (t *tracer) deletedBranchNodes() []string { + var paths []string + for path := range t.deleteEpochMetas { + _, ok := t.accessEpochMetaList[path] + if !ok { + continue + } + paths = append(paths, path) + } + return paths +} + +// cached check if cache the node. +func (t *tracer) cached(path []byte) ([]byte, bool) { + val, ok := t.accessList[string(path)] + return val, ok +} + +// checkNodeChanged check if change for node. +func (t *tracer) checkNodeChanged(path []byte, blob []byte) bool { + val, ok := t.accessList[string(path)] + if !ok { + return len(blob) > 0 + } + + return !bytes.Equal(val, blob) +} + +// checkEpochMetaChanged check if change for epochMeta. +func (t *tracer) checkEpochMetaChanged(path []byte, blob []byte) bool { + val, ok := t.accessEpochMetaList[string(path)] + if !ok { + return len(blob) > 0 + } + + return !bytes.Equal(val, blob) +} diff --git a/trie/trie.go b/trie/trie.go index d19cb31063..0e5928673d 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -21,13 +21,25 @@ import ( "bytes" "errors" "fmt" + "runtime" + "sync" + "sync/atomic" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/trie/epochmeta" "github.com/ethereum/go-ethereum/trie/trienode" ) +var ( + reviveMeter = metrics.NewRegisteredMeter("trie/revive", nil) + reviveNotExpiredMeter = metrics.NewRegisteredMeter("trie/revive/noexpired", nil) + reviveErrMeter = metrics.NewRegisteredMeter("trie/revive/err", nil) +) + // Trie is a Merkle Patricia Trie. Use New to create a trie that sits on // top of a database. Whenever trie performs a commit operation, the generated // nodes will be gathered and returned in a set. Once the trie is committed, @@ -37,7 +49,7 @@ import ( // Trie is not safe for concurrent use. type Trie struct { root node - owner common.Hash + owner common.Hash // Can be used to identify account vs storage trie // Flag whether the commit operation is already performed. If so the // trie is not usable(latest states is invisible). @@ -54,6 +66,12 @@ type Trie struct { // tracer is the tool to track the trie changes. // It will be reset after each commit operation. tracer *tracer + + // fields for state expiry + currentEpoch types.StateEpoch + rootEpoch types.StateEpoch + enableExpiry bool + enableMetaDB bool } // newFlag returns the cache flag value for a newly created node. @@ -64,12 +82,15 @@ func (t *Trie) newFlag() nodeFlag { // Copy returns a copy of Trie. func (t *Trie) Copy() *Trie { return &Trie{ - root: t.root, - owner: t.owner, - committed: t.committed, - unhashed: t.unhashed, - reader: t.reader, - tracer: t.tracer.copy(), + root: t.root, + owner: t.owner, + committed: t.committed, + unhashed: t.unhashed, + reader: t.reader, + tracer: t.tracer.copy(), + rootEpoch: t.rootEpoch, + currentEpoch: t.currentEpoch, + enableExpiry: t.enableExpiry, } } @@ -85,10 +106,28 @@ func New(id *ID, db *Database) (*Trie, error) { return nil, err } trie := &Trie{ - owner: id.Owner, - reader: reader, - tracer: newTracer(), + owner: id.Owner, + reader: reader, + tracer: newTracer(), + enableExpiry: enableStateExpiry(id, db), + enableMetaDB: reader.emReader != nil, + } + // resolve root epoch + if trie.enableExpiry { + if trie.enableMetaDB { + trie.tracer.enableTagEpochMeta() + } + if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash { + trie.root = hashNode(id.Root[:]) + meta, err := trie.resolveAccountMetaAndTrack() + if err != nil { + return nil, err + } + trie.rootEpoch = meta.Epoch() + } + return trie, nil } + if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash { rootnode, err := trie.resolveAndTrack(id.Root[:], nil) if err != nil { @@ -96,12 +135,22 @@ func New(id *ID, db *Database) (*Trie, error) { } trie.root = rootnode } + return trie, nil } +func enableStateExpiry(id *ID, db *Database) bool { + if id.Owner == (common.Hash{}) { + return false + } + + return db.EnableExpiry() +} + // NewEmpty is a shortcut to create empty tree. It's mostly used in tests. func NewEmpty(db *Database) *Trie { tr, _ := New(TrieID(types.EmptyRootHash), db) + tr.enableExpiry = db.EnableExpiry() return tr } @@ -140,14 +189,36 @@ func (t *Trie) MustGet(key []byte) []byte { // // If the requested node is not present in trie, no error will be returned. // If the trie is corrupted, a MissingNodeError is returned. -func (t *Trie) Get(key []byte) ([]byte, error) { +func (t *Trie) Get(key []byte) (value []byte, err error) { + var newroot node + var didResolve bool + // Short circuit if the trie is already committed and not usable. if t.committed { return nil, ErrCommitted } - value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0) + + if t.enableExpiry { + value, newroot, didResolve, err = t.getWithEpoch(t.root, keybytesToHex(key), 0, t.getRootEpoch(), false) + } else { + value, newroot, didResolve, err = t.get(t.root, keybytesToHex(key), 0) + } + if err == nil && didResolve { + t.root = newroot + } + return value, err +} + +func (t *Trie) GetAndUpdateEpoch(key []byte) (value []byte, err error) { + if !t.enableExpiry { + return nil, errors.New("expiry is not enabled") + } + + value, newroot, didResolve, err := t.getWithEpoch(t.root, keybytesToHex(key), 0, t.getRootEpoch(), true) + if err == nil && didResolve { t.root = newroot + t.rootEpoch = t.currentEpoch } return value, err } @@ -188,6 +259,90 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode no } } +func (t *Trie) getWithEpoch(origNode node, key []byte, pos int, epoch types.StateEpoch, updateEpoch bool) (value []byte, newnode node, didResolve bool, err error) { + if t.epochExpired(origNode, epoch) { + return nil, nil, false, NewExpiredNodeError(key[:pos], epoch, origNode) + } + switch n := (origNode).(type) { + case nil: + return nil, nil, false, nil + case valueNode: + return n, n, false, nil + case *shortNode: + if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + // key not found in trie + return nil, n, false, nil + } + value, newnode, didResolve, err = t.getWithEpoch(n.Val, key, pos+len(n.Key), epoch, updateEpoch) + if err == nil && t.renewNode(epoch, didResolve, updateEpoch) { + n = n.copy() + n.Val = newnode + if updateEpoch { + n.setEpoch(t.currentEpoch) + } + n.flags = t.newFlag() + didResolve = true + } + return value, n, didResolve, err + case *fullNode: + value, newnode, didResolve, err = t.getWithEpoch(n.Children[key[pos]], key, pos+1, n.GetChildEpoch(int(key[pos])), updateEpoch) + if err == nil && t.renewNode(epoch, didResolve, updateEpoch) { + n = n.copy() + n.Children[key[pos]] = newnode + if updateEpoch { + n.setEpoch(t.currentEpoch) + } + if updateEpoch && newnode != nil { + n.UpdateChildEpoch(int(key[pos]), t.currentEpoch) + } + n.flags = t.newFlag() + didResolve = true + } + return value, n, didResolve, err + case hashNode: + child, err := t.resolveAndTrack(n, key[:pos]) + if err != nil { + return nil, n, true, err + } + + if err = t.resolveEpochMetaAndTrack(child, epoch, key[:pos]); err != nil { + return nil, n, true, err + } + value, newnode, _, err := t.getWithEpoch(child, key, pos, epoch, updateEpoch) + return value, newnode, true, err + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) + } +} + +// refreshNubEpoch traverses the trie and update the node epoch for each accessed trie node. +// Under an expiry scheme where a hash node is accessed, its parent node's epoch will not be updated. +func refreshNubEpoch(origNode node, epoch types.StateEpoch) node { + switch n := (origNode).(type) { + case nil: + return nil + case valueNode: + return n + case *shortNode: + n.Val = refreshNubEpoch(n.Val, epoch) + n.setEpoch(epoch) + n.flags = nodeFlag{dirty: true} + return n + case *fullNode: + for i := 0; i < BranchNodeLength-1; i++ { + n.Children[i] = refreshNubEpoch(n.Children[i], epoch) + n.UpdateChildEpoch(i, epoch) + } + n.setEpoch(epoch) + n.flags = nodeFlag{dirty: true} + return n + case hashNode: + return n + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) + } +} + // MustGetNode is a wrapper of GetNode and will omit any encountered error but // just print out an error message. func (t *Trie) MustGetNode(path []byte) ([]byte, int) { @@ -304,6 +459,10 @@ func (t *Trie) Update(key, value []byte) error { if t.committed { return ErrCommitted } + + if t.enableExpiry { + return t.updateWithEpoch(key, value, t.getRootEpoch()) + } return t.update(key, value) } @@ -326,6 +485,26 @@ func (t *Trie) update(key, value []byte) error { return nil } +func (t *Trie) updateWithEpoch(key, value []byte, epoch types.StateEpoch) error { + t.unhashed++ + k := keybytesToHex(key) + if len(value) != 0 { + _, n, err := t.insertWithEpoch(t.root, nil, k, valueNode(value), epoch) + if err != nil { + return err + } + t.root = n + } else { + _, n, err := t.deleteWithEpoch(t.root, nil, k, epoch) + if err != nil { + return err + } + t.root = n + } + t.rootEpoch = t.currentEpoch + return nil +} + func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) { if len(key) == 0 { if v, ok := n.(valueNode); ok { @@ -343,7 +522,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error if !dirty || err != nil { return false, n, err } - return true, &shortNode{n.Key, nn, t.newFlag()}, nil + return true, &shortNode{Key: n.Key, Val: nn, flags: t.newFlag()}, nil } // Otherwise branch out at the index where they differ. branch := &fullNode{flags: t.newFlag()} @@ -358,15 +537,17 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error } // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { + t.tracer.onExpandToBranchNode(prefix) return true, branch, nil } // New branch node is created as a child of the original short node. // Track the newly inserted node in the tracer. The node identifier // passed is the path from the root node. t.tracer.onInsert(append(prefix, key[:matchlen]...)) + t.tracer.onExpandToBranchNode(append(prefix, key[:matchlen]...)) // Replace it with a short node leading up to the branch. - return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil + return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag()}, nil case *fullNode: dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) @@ -384,7 +565,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error // since it's always embedded in its parent. t.tracer.onInsert(prefix) - return true, &shortNode{key, value, t.newFlag()}, nil + return true, &shortNode{Key: key, Val: value, flags: t.newFlag()}, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load @@ -405,6 +586,103 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error } } +func (t *Trie) insertWithEpoch(n node, prefix, key []byte, value node, epoch types.StateEpoch) (bool, node, error) { + if t.epochExpired(n, epoch) { + return false, nil, NewExpiredNodeError(prefix, epoch, n) + } + + if len(key) == 0 { + if v, ok := n.(valueNode); ok { + return !bytes.Equal(v, value.(valueNode)), value, nil + } + return true, value, nil + } + switch n := n.(type) { + case *shortNode: + matchlen := prefixLen(key, n.Key) + // If the whole key matches, keep this short node as is + // and only update the value. + if matchlen == len(n.Key) { + dirty, nn, err := t.insertWithEpoch(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value, epoch) + if !t.renewNode(epoch, dirty, true) || err != nil { + return false, n, err + } + return true, &shortNode{Key: n.Key, Val: nn, flags: t.newFlag(), epoch: t.currentEpoch}, nil + } + // Otherwise branch out at the index where they differ. + branch := &fullNode{flags: t.newFlag(), epoch: t.currentEpoch} + var err error + _, branch.Children[n.Key[matchlen]], err = t.insertWithEpoch(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val, t.currentEpoch) + if err != nil { + return false, nil, err + } + branch.UpdateChildEpoch(int(n.Key[matchlen]), t.currentEpoch) + + _, branch.Children[key[matchlen]], err = t.insertWithEpoch(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value, t.currentEpoch) + if err != nil { + return false, nil, err + } + branch.UpdateChildEpoch(int(key[matchlen]), t.currentEpoch) + + // Replace this shortNode with the branch if it occurs at index 0. + if matchlen == 0 { + t.tracer.onExpandToBranchNode(prefix) + return true, branch, nil + } + // New branch node is created as a child of the original short node. + // Track the newly inserted node in the tracer. The node identifier + // passed is the path from the root node. + t.tracer.onInsert(append(prefix, key[:matchlen]...)) + t.tracer.onExpandToBranchNode(append(prefix, key[:matchlen]...)) + + // Replace it with a short node leading up to the branch. + return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag(), epoch: t.currentEpoch}, nil + + case *fullNode: + dirty, nn, err := t.insertWithEpoch(n.Children[key[0]], append(prefix, key[0]), key[1:], value, n.GetChildEpoch(int(key[0]))) + if !t.renewNode(epoch, dirty, true) || err != nil { + return false, n, err + } + n = n.copy() + n.flags = t.newFlag() + n.Children[key[0]] = nn + n.setEpoch(t.currentEpoch) + n.UpdateChildEpoch(int(key[0]), t.currentEpoch) + + return true, n, nil + + case nil: + // New short node is created and track it in the tracer. The node identifier + // passed is the path from the root node. Note the valueNode won't be tracked + // since it's always embedded in its parent. + t.tracer.onInsert(prefix) + + return true, &shortNode{Key: key, Val: value, flags: t.newFlag(), epoch: t.currentEpoch}, nil + + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and insert into it. This leaves all child nodes on + // the path to the value in the trie. + rn, err := t.resolveAndTrack(n, prefix) + if err != nil { + return false, nil, err + } + + if err = t.resolveEpochMetaAndTrack(rn, epoch, prefix); err != nil { + return false, nil, err + } + + dirty, nn, err := t.insertWithEpoch(rn, prefix, key, value, epoch) + if !t.renewNode(epoch, dirty, true) || err != nil { + return false, rn, err + } + return true, nn, nil + + default: + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) + } +} + // MustDelete is a wrapper of Delete and will omit any encountered error but // just print out an error message. func (t *Trie) MustDelete(key []byte) { @@ -418,13 +696,22 @@ func (t *Trie) MustDelete(key []byte) { // If the requested node is not present in trie, no error will be returned. // If the trie is corrupted, a MissingNodeError is returned. func (t *Trie) Delete(key []byte) error { + var n node + var err error // Short circuit if the trie is already committed and not usable. if t.committed { return ErrCommitted } t.unhashed++ k := keybytesToHex(key) - _, n, err := t.delete(t.root, nil, k) + + if t.enableExpiry { + _, n, err = t.deleteWithEpoch(t.root, nil, k, t.getRootEpoch()) + t.rootEpoch = t.currentEpoch + } else { + _, n, err = t.delete(t.root, nil, k) + } + if err != nil { return err } @@ -470,9 +757,9 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. - return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil + return true, &shortNode{Key: concat(n.Key, child.Key...), Val: child.Val, flags: t.newFlag()}, nil default: - return true, &shortNode{n.Key, child, t.newFlag()}, nil + return true, &shortNode{Key: n.Key, Val: child, flags: t.newFlag()}, nil } case *fullNode: @@ -520,7 +807,7 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // shortNode{..., shortNode{...}}. Since the entry // might not be loaded yet, resolve it just for this // check. - cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos))) + cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos)), n.GetChildEpoch(pos)) if err != nil { return false, nil, err } @@ -529,14 +816,16 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // Mark the original short node as deleted since the // value is embedded into the parent now. t.tracer.onDelete(append(prefix, byte(pos))) + t.tracer.onDeleteBranchNode(prefix) k := append([]byte{byte(pos)}, cnode.Key...) - return true, &shortNode{k, cnode.Val, t.newFlag()}, nil + return true, &shortNode{Key: k, Val: cnode.Val, flags: t.newFlag()}, nil } } // Otherwise, n is replaced by a one-nibble short node // containing the child. - return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil + t.tracer.onDeleteBranchNode(prefix) + return true, &shortNode{Key: []byte{byte(pos)}, Val: n.Children[pos], flags: t.newFlag()}, nil } // n still contains at least two values and cannot be reduced. return true, n, nil @@ -566,75 +855,346 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { } } -func concat(s1 []byte, s2 ...byte) []byte { - r := make([]byte, len(s1)+len(s2)) - copy(r, s1) - copy(r[len(s1):], s2) - return r -} - -func (t *Trie) resolve(n node, prefix []byte) (node, error) { - if n, ok := n.(hashNode); ok { - return t.resolveAndTrack(n, prefix) - } - return n, nil -} - -// resolveAndTrack loads node from the underlying store with the given node hash -// and path prefix and also tracks the loaded node blob in tracer treated as the -// node's original value. The rlp-encoded blob is preferred to be loaded from -// database because it's easy to decode node while complex to encode node to blob. -func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { - blob, err := t.reader.node(prefix, common.BytesToHash(n)) - if err != nil { - return nil, err +func (t *Trie) deleteWithEpoch(n node, prefix, key []byte, epoch types.StateEpoch) (bool, node, error) { + if t.epochExpired(n, epoch) { + return false, nil, NewExpiredNodeError(prefix, epoch, n) } - t.tracer.onRead(prefix, blob) - return mustDecodeNode(n, blob), nil -} -// Hash returns the root hash of the trie. It does not write to the -// database and can be used even if the trie doesn't have one. -func (t *Trie) Hash() common.Hash { - hash, cached := t.hashRoot() - t.root = cached - return common.BytesToHash(hash.(hashNode)) -} + switch n := n.(type) { + case *shortNode: + matchlen := prefixLen(key, n.Key) + if matchlen < len(n.Key) { + return false, n, nil // don't replace n on mismatch + } + if matchlen == len(key) { + // The matched short node is deleted entirely and track + // it in the deletion set. The same the valueNode doesn't + // need to be tracked at all since it's always embedded. + t.tracer.onDelete(prefix) -// Commit collects all dirty nodes in the trie and replaces them with the -// corresponding node hash. All collected nodes (including dirty leaves if -// collectLeaf is true) will be encapsulated into a nodeset for return. -// The returned nodeset can be nil if the trie is clean (nothing to commit). -// Once the trie is committed, it's not usable anymore. A new trie must -// be created with new root and updated trie database for following usage -func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) { - defer t.tracer.reset() - defer func() { - // StateDB will cache the trie and reuse it to read and write, - // the committed flag is true will prevent the cache trie access - // the trie node. - t.committed = false - }() - // Trie is empty and can be classified into two types of situations: - // (a) The trie was empty and no update happens => return nil - // (b) The trie was non-empty and all nodes are dropped => return - // the node set includes all deleted nodes - if t.root == nil { - paths := t.tracer.deletedNodes() - if len(paths) == 0 { - return types.EmptyRootHash, nil, nil // case (a) + return true, nil, nil // remove n entirely for whole matches } - nodes := trienode.NewNodeSet(t.owner) - for _, path := range paths { - nodes.AddNode([]byte(path), trienode.NewDeleted()) + // The key is longer than n.Key. Remove the remaining suffix + // from the subtrie. Child can never be nil here since the + // subtrie must contain at least two other values with keys + // longer than n.Key. + dirty, child, err := t.deleteWithEpoch(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):], epoch) + if !dirty || err != nil { + return false, n, err } - return types.EmptyRootHash, nodes, nil // case (b) - } - // Derive the hash for all dirty nodes first. We hold the assumption - // in the following procedure that all nodes are hashed. - rootHash := t.Hash() + switch child := child.(type) { + case *shortNode: + // The child shortNode is merged into its parent, track + // is deleted as well. + t.tracer.onDelete(append(prefix, n.Key...)) - // Do a quick check if we really need to commit. This can happen e.g. + // Deleting from the subtrie reduced it to another + // short node. Merge the nodes to avoid creating a + // shortNode{..., shortNode{...}}. Use concat (which + // always creates a new slice) instead of append to + // avoid modifying n.Key since it might be shared with + // other nodes. + return true, &shortNode{Key: concat(n.Key, child.Key...), Val: child.Val, flags: t.newFlag(), epoch: t.currentEpoch}, nil + default: + return true, &shortNode{Key: n.Key, Val: child, flags: t.newFlag(), epoch: t.currentEpoch}, nil + } + + case *fullNode: + dirty, nn, err := t.deleteWithEpoch(n.Children[key[0]], append(prefix, key[0]), key[1:], n.GetChildEpoch(int(key[0]))) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = t.newFlag() + n.Children[key[0]] = nn + n.setEpoch(t.currentEpoch) + n.UpdateChildEpoch(int(key[0]), t.currentEpoch) + + // Because n is a full node, it must've contained at least two children + // before the delete operation. If the new child value is non-nil, n still + // has at least two children after the deletion, and cannot be reduced to + // a short node. + if nn != nil { + return true, n, nil + } + // Reduction: + // Check how many non-nil entries are left after deleting and + // reduce the full node to a short node if only one entry is + // left. Since n must've contained at least two children + // before deletion (otherwise it would not be a full node) n + // can never be reduced to nil. + // + // When the loop is done, pos contains the index of the single + // value that is left in n or -2 if n contains at least two + // values. + pos := -1 + for i, cld := range &n.Children { + if cld != nil { + if pos == -1 { + pos = i + } else { + pos = -2 + break + } + } + } + if pos >= 0 { + if pos != 16 { + // If the remaining entry is a short node, it replaces + // n and its key gets the missing nibble tacked to the + // front. This avoids creating an invalid + // shortNode{..., shortNode{...}}. Since the entry + // might not be loaded yet, resolve it just for this + // check. + // Attention: if Children[pos] has pruned, just fetch from remote + cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos)), n.GetChildEpoch(pos)) + if err != nil { + return false, nil, err + } + if cnode, ok := cnode.(*shortNode); ok { + // Replace the entire full node with the short node. + // Mark the original short node as deleted since the + // value is embedded into the parent now. + t.tracer.onDelete(append(prefix, byte(pos))) + t.tracer.onDeleteBranchNode(prefix) + + k := append([]byte{byte(pos)}, cnode.Key...) + return true, &shortNode{Key: k, Val: cnode.Val, flags: t.newFlag(), epoch: t.currentEpoch}, nil + } + } + // Otherwise, n is replaced by a one-nibble short node + // containing the child. + t.tracer.onDeleteBranchNode(prefix) + return true, &shortNode{Key: []byte{byte(pos)}, Val: n.Children[pos], flags: t.newFlag(), epoch: t.currentEpoch}, nil + } + // n still contains at least two values and cannot be reduced. + return true, n, nil + + case valueNode: + return true, nil, nil + + case nil: + return false, nil, nil + + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and delete from it. This leaves all child nodes on + // the path to the value in the trie. + rn, err := t.resolveAndTrack(n, prefix) + if err != nil { + return false, nil, err + } + + if err = t.resolveEpochMetaAndTrack(rn, epoch, prefix); err != nil { + return false, nil, err + } + + dirty, nn, err := t.deleteWithEpoch(rn, prefix, key, epoch) + if !dirty || err != nil { + return false, rn, err + } + return true, nn, nil + + default: + panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) + } +} + +func concat(s1 []byte, s2 ...byte) []byte { + r := make([]byte, len(s1)+len(s2)) + copy(r, s1) + copy(r[len(s1):], s2) + return r +} + +func (t *Trie) resolve(n node, prefix []byte, epoch types.StateEpoch) (node, error) { + if n, ok := n.(hashNode); ok { + n, err := t.resolveAndTrack(n, prefix) + if err != nil { + return nil, err + } + if err = t.resolveEpochMetaAndTrack(n, epoch, prefix); err != nil { + return nil, err + } + return n, nil + } + return n, nil +} + +// resolveAndTrack loads node from the underlying store with the given node hash +// and path prefix and also tracks the loaded node blob in tracer treated as the +// node's original value. The rlp-encoded blob is preferred to be loaded from +// database because it's easy to decode node while complex to encode node to blob. +func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { + if t.enableExpiry { + // when meet expired, the trie will skip the resolve path, but cache in tracer + blob, ok := t.tracer.cached(prefix) + if ok { + return mustDecodeNode(n, blob), nil + } + } + blob, err := t.reader.node(prefix, common.BytesToHash(n)) + if err != nil { + return nil, err + } + t.tracer.onRead(prefix, blob) + return mustDecodeNode(n, blob), nil +} + +func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { + blob, err := t.reader.node(prefix, common.BytesToHash(n)) + if err != nil { + return nil, err + } + return mustDecodeNode(n, blob), nil +} + +// resolveEpochMeta resolve full node's epoch map. +func (t *Trie) resolveEpochMeta(n node, epoch types.StateEpoch, prefix []byte) error { + if !t.enableExpiry { + return nil + } + + switch n := n.(type) { + case *shortNode: + n.setEpoch(epoch) + return nil + case *fullNode: + n.setEpoch(epoch) + if t.enableMetaDB { + enc, err := t.reader.epochMeta(prefix) + if err != nil { + return err + } + if len(enc) > 0 { + meta, err := epochmeta.DecodeFullNodeEpochMeta(enc) + if err != nil { + return err + } + n.EpochMap = meta.EpochMap + } + } + return nil + case valueNode, hashNode, nil: + // just skip + return nil + default: + return errors.New("resolveShadowNode unsupported node type") + } +} + +// resolveEpochMetaAndTrack resolve full node's epoch map. +func (t *Trie) resolveEpochMetaAndTrack(n node, epoch types.StateEpoch, prefix []byte) error { + if !t.enableExpiry { + return nil + } + // 1. Check if the node is a full node + switch n := n.(type) { + case *shortNode: + n.setEpoch(epoch) + return nil + case *fullNode: + n.setEpoch(epoch) + if t.enableMetaDB { + enc, err := t.reader.epochMeta(prefix) + if err != nil { + return err + } + t.tracer.onReadEpochMeta(prefix, enc) + if len(enc) > 0 { + meta, err := epochmeta.DecodeFullNodeEpochMeta(enc) + if err != nil { + return err + } + n.EpochMap = meta.EpochMap + } + } + return nil + case valueNode, hashNode, nil: + // just skip + return nil + default: + return errors.New("resolveShadowNode unsupported node type") + } +} + +// resolveAccountMetaAndTrack resolve account's epoch map. +func (t *Trie) resolveAccountMetaAndTrack() (types.MetaNoConsensus, error) { + if !t.enableExpiry { + return types.EmptyMetaNoConsensus, nil + } + var ( + enc []byte + err error + ) + + if t.enableMetaDB { + enc, err = t.reader.accountMeta() + if err != nil { + return types.EmptyMetaNoConsensus, err + } + t.tracer.onReadEpochMeta(epochmeta.AccountMetadataPath, enc) + } else { + enc, err = t.reader.node(epochmeta.AccountMetadataPath, types.EmptyRootHash) + if err != nil { + return types.EmptyMetaNoConsensus, err + } + t.tracer.onRead(epochmeta.AccountMetadataPath, enc) + } + + if len(enc) > 0 { + return types.DecodeMetaNoConsensusFromRLPBytes(enc) + } + return types.EmptyMetaNoConsensus, nil +} + +// Hash returns the root hash of the trie. It does not write to the +// database and can be used even if the trie doesn't have one. +func (t *Trie) Hash() common.Hash { + hash, cached := t.hashRoot() + t.root = cached + return common.BytesToHash(hash.(hashNode)) +} + +// Commit collects all dirty nodes in the trie and replaces them with the +// corresponding node hash. All collected nodes (including dirty leaves if +// collectLeaf is true) will be encapsulated into a nodeset for return. +// The returned nodeset can be nil if the trie is clean (nothing to commit). +// Once the trie is committed, it's not usable anymore. A new trie must +// be created with new root and updated trie database for following usage +func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) { + defer t.tracer.reset() + defer func() { + // StateDB will cache the trie and reuse it to read and write, + // the committed flag is true will prevent the cache trie access + // the trie node. + t.committed = false + }() + // Trie is empty and can be classified into two types of situations: + // (a) The trie was empty and no update happens => return nil + // (b) The trie was non-empty and all nodes are dropped => return + // the node set includes all deleted nodes + if t.root == nil { + paths := t.tracer.deletedNodes() + if len(paths) == 0 { + return types.EmptyRootHash, nil, nil // case (a) + } + nodes := trienode.NewNodeSet(t.owner) + for _, path := range paths { + nodes.AddNode([]byte(path), trienode.NewDeleted()) + } + if t.enableExpiry && t.enableMetaDB { + for _, path := range t.tracer.deletedBranchNodes() { + nodes.AddBranchNodeEpochMeta([]byte(path), nil) + } + } + return types.EmptyRootHash, nodes, nil // case (b) + } + // Derive the hash for all dirty nodes first. We hold the assumption + // in the following procedure that all nodes are hashed. + rootHash := t.Hash() + + // Do a quick check if we really need to commit. This can happen e.g. // if we load a trie for reading storage values, but don't write to it. if hashedNode, dirty := t.root.cache(); !dirty { // Replace the root node with the origin hash in order to @@ -646,7 +1206,27 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet, error) for _, path := range t.tracer.deletedNodes() { nodes.AddNode([]byte(path), trienode.NewDeleted()) } - t.root = newCommitter(nodes, t.tracer, collectLeaf).Commit(t.root) + // store state expiry account meta + if t.enableExpiry { + blob, err := epochmeta.AccountMeta2Bytes(types.NewMetaNoConsensus(t.rootEpoch)) + if err != nil { + return common.Hash{}, nil, err + } + if t.enableMetaDB { + for _, path := range t.tracer.deletedBranchNodes() { + nodes.AddBranchNodeEpochMeta([]byte(path), nil) + } + if t.rootEpoch > types.StateEpoch0 && t.tracer.checkEpochMetaChanged(epochmeta.AccountMetadataPath, blob) { + nodes.AddAccountMeta(blob) + } + } else { + // TODO(0xbundler): the account meta life cycle is same as account data, when delete?. + if t.rootEpoch > types.StateEpoch0 && t.tracer.checkNodeChanged(epochmeta.AccountMetadataPath, blob) { + nodes.AddNode(epochmeta.AccountMetadataPath, trienode.New(types.EmptyRootHash, blob)) + } + } + } + t.root = newCommitter(nodes, t.tracer, collectLeaf, t.enableExpiry, t.enableMetaDB).Commit(t.root) return rootHash, nodes, nil } @@ -682,3 +1262,510 @@ func (t *Trie) Size() int { func (t *Trie) Owner() common.Hash { return t.owner } + +// TryRevive attempts to revive a trie from a list of MPTProofNubs. +// ReviveTrie performs full or partial revive and returns a list of successful +// nubs. ReviveTrie does not guarantee that a value will be revived completely, +// if the proof is not fully valid. +func (t *Trie) TryRevive(key []byte, proof []*MPTProofNub) ([]*MPTProofNub, error) { + key = keybytesToHex(key) + successNubs := make([]*MPTProofNub, 0, len(proof)) + reviveMeter.Mark(int64(len(proof))) + // Revive trie with each proof nub + for _, nub := range proof { + rootExpired := types.EpochExpired(t.getRootEpoch(), t.currentEpoch) + newNode, didRevive, err := t.tryRevive(t.root, key, nub.n1PrefixKey, *nub, 0, t.currentEpoch, rootExpired) + //log.Debug("tryRevive", "key", key, "nub.n1PrefixKey", nub.n1PrefixKey, "nub", nub, "err", err) + if _, ok := err.(*ReviveNotExpiredError); ok { + reviveNotExpiredMeter.Mark(1) + continue + } + if err != nil { + reviveErrMeter.Mark(1) + return nil, err + } + if didRevive { + successNubs = append(successNubs, nub) + t.root = newNode + t.rootEpoch = t.currentEpoch + } + } + return successNubs, nil +} + +// tryRevive it just revive from targetPrefixKey +func (t *Trie) tryRevive(n node, key []byte, targetPrefixKey []byte, nub MPTProofNub, pos int, epoch types.StateEpoch, isExpired bool) (node, bool, error) { + if pos > len(targetPrefixKey) { + return nil, false, fmt.Errorf("target revive node not found") + } + + if pos == len(targetPrefixKey) { + if !t.isExpiredNode(n, targetPrefixKey, epoch, isExpired) { + return nil, false, NewReviveNotExpiredErr(targetPrefixKey[:pos], epoch) + } + hn, ok := n.(hashNode) + if !ok { + return nil, false, fmt.Errorf("not match hashNode stub") + } + + cachedHash, _ := nub.n1.cache() + if !bytes.Equal(cachedHash, hn) { + return nil, false, fmt.Errorf("hash values does not match") + } + + if nub.n2 != nil { + n1, ok := nub.n1.(*shortNode) + if !ok { + return nil, false, fmt.Errorf("invalid node type") + } + n1.Val = nub.n2 + return refreshNubEpoch(n1, t.currentEpoch), true, nil + } + return refreshNubEpoch(nub.n1, t.currentEpoch), true, nil + } + + if isExpired { + return nil, false, NewExpiredNodeError(targetPrefixKey[:pos], epoch, n) + } + + switch n := n.(type) { + case *shortNode: + if len(targetPrefixKey)-pos < len(n.Key) || !bytes.Equal(n.Key, targetPrefixKey[pos:pos+len(n.Key)]) { + return nil, false, fmt.Errorf("key %v not found", targetPrefixKey) + } + newNode, didRevive, err := t.tryRevive(n.Val, key, targetPrefixKey, nub, pos+len(n.Key), epoch, isExpired) + if didRevive && err == nil { + n = n.copy() + n.Val = newNode + n.setEpoch(t.currentEpoch) + n.flags = t.newFlag() + } + return n, didRevive, err + case *fullNode: + childIndex := int(targetPrefixKey[pos]) + isExpired, _ := n.ChildExpired(nil, childIndex, t.currentEpoch) + newNode, didRevive, err := t.tryRevive(n.Children[childIndex], key, targetPrefixKey, nub, pos+1, epoch, isExpired) + if didRevive && err == nil { + n = n.copy() + n.Children[childIndex] = newNode + n.setEpoch(t.currentEpoch) + n.UpdateChildEpoch(childIndex, t.currentEpoch) + n.flags = t.newFlag() + } + + if e, ok := err.(*ExpiredNodeError); ok { + e.Epoch = n.GetChildEpoch(childIndex) + return n, didRevive, e + } + + return n, didRevive, err + case hashNode: + child, err := t.resolveAndTrack(n, targetPrefixKey[:pos]) + if err != nil { + return nil, false, err + } + if err = t.resolveEpochMetaAndTrack(child, epoch, targetPrefixKey[:pos]); err != nil { + return nil, false, err + } + + newNode, _, err := t.tryRevive(child, key, targetPrefixKey, nub, pos, epoch, isExpired) + return newNode, true, err + case nil: + return nil, false, nil + default: + panic(fmt.Sprintf("invalid node: %T", n)) + } +} + +// ExpireByPrefix is used to simulate the expiration of a trie by prefix key. +// It is not used in the actual trie implementation. ExpireByPrefix makes sure +// only a child node of a full node is expired, if not an error is returned. +func (t *Trie) ExpireByPrefix(prefixKeyHex []byte) error { + hn, _, err := t.expireByPrefix(t.root, prefixKeyHex) + if len(prefixKeyHex) == 0 && hn != nil { // whole trie is expired + t.root = hn + t.rootEpoch = 0 + } + if err != nil { + return err + } + return nil +} + +func (t *Trie) expireByPrefix(n node, prefixKeyHex []byte) (node, bool, error) { + // Loop through prefix key + // When prefix key is empty, generate the hash node of the current node + // Replace current node with the hash node + + // If length of prefix key is empty + if len(prefixKeyHex) == 0 { + hasher := newHasher(false) + defer returnHasherToPool(hasher) + var hn node + _, hn = hasher.proofHash(n) + if _, ok := hn.(hashNode); ok { + return hn, false, nil + } + + return nil, true, nil + } + + switch n := n.(type) { + case *shortNode: + matchLen := prefixLen(prefixKeyHex, n.Key) + hn, didUpdateEpoch, err := t.expireByPrefix(n.Val, prefixKeyHex[matchLen:]) + if err != nil { + return nil, didUpdateEpoch, err + } + + if hn != nil { + return nil, didUpdateEpoch, fmt.Errorf("can only expire child short node") + } + + return nil, didUpdateEpoch, err + case *fullNode: + childIndex := int(prefixKeyHex[0]) + hn, didUpdateEpoch, err := t.expireByPrefix(n.Children[childIndex], prefixKeyHex[1:]) + if err != nil { + return nil, didUpdateEpoch, err + } + + // Replace child node with hash node + if hn != nil { + n.Children[prefixKeyHex[0]] = hn + } + + // Update the epoch so that it is expired + if !didUpdateEpoch { + n.UpdateChildEpoch(childIndex, 0) + didUpdateEpoch = true + } + + return nil, didUpdateEpoch, err + default: + return nil, false, fmt.Errorf("invalid node type") + } +} + +func (t *Trie) getRootEpoch() types.StateEpoch { + return t.rootEpoch +} + +// renewNode check if renew node, according to trie node epoch and childDirty, +// childDirty or updateEpoch need copy for prevent reuse trie cache +func (t *Trie) renewNode(epoch types.StateEpoch, childDirty bool, updateEpoch bool) bool { + // when !updateEpoch, it same as !t.withEpochMeta + if !t.enableExpiry || !updateEpoch { + return childDirty + } + + // node need update epoch, just renew + if t.currentEpoch > epoch { + return true + } + + // when no epoch update, same as before + return childDirty +} + +func (t *Trie) epochExpired(n node, epoch types.StateEpoch) bool { + // when node is nil, skip epoch check + if !t.enableExpiry || n == nil { + return false + } + return types.EpochExpired(epoch, t.currentEpoch) +} + +func (t *Trie) SetEpoch(epoch types.StateEpoch) { + t.currentEpoch = epoch +} + +type NodeInfo struct { + Addr common.Hash + Path []byte + Hash common.Hash + Epoch types.StateEpoch + Key common.Hash // only leaf has right Key. + IsLeaf bool + IsBranch bool +} + +type ScanTask struct { + itemCh chan *NodeInfo + unexpiredStat *atomic.Uint64 + expiredStat *atomic.Uint64 + findExpired bool + routineCh chan struct{} + wg *sync.WaitGroup + reportDone chan struct{} +} + +func NewScanTask(itemCh chan *NodeInfo, maxThreads uint64, findExpired bool) *ScanTask { + return &ScanTask{ + itemCh: itemCh, + unexpiredStat: &atomic.Uint64{}, + expiredStat: &atomic.Uint64{}, + findExpired: findExpired, + routineCh: make(chan struct{}, maxThreads), + wg: &sync.WaitGroup{}, + reportDone: make(chan struct{}), + } +} + +func (st *ScanTask) Stat(expired bool) { + if expired { + st.expiredStat.Add(1) + } else { + st.unexpiredStat.Add(1) + } +} + +func (st *ScanTask) ExpiredStat() uint64 { + return st.expiredStat.Load() +} + +func (st *ScanTask) UnexpiredStat() uint64 { + return st.unexpiredStat.Load() +} + +func (st *ScanTask) WaitThreads() { + st.wg.Wait() + close(st.reportDone) +} + +func (st *ScanTask) Report(d time.Duration) { + start := time.Now() + timer := time.NewTimer(d) + defer timer.Stop() + for { + select { + case <-timer.C: + log.Info("Scan trie stats", "total", st.TotalScan(), "unexpired", st.UnexpiredStat(), "expired", st.ExpiredStat(), "go routine", runtime.NumGoroutine(), "elapsed", common.PrettyDuration(time.Since(start))) + timer.Reset(d) + case <-st.reportDone: + log.Info("Scan trie done", "total", st.TotalScan(), "unexpired", st.UnexpiredStat(), "expired", st.ExpiredStat(), "elapsed", common.PrettyDuration(time.Since(start))) + return + } + } +} + +func (st *ScanTask) Schedule(f func()) { + log.Debug("schedule", "total", st.TotalScan(), "go routine", runtime.NumGoroutine()) + st.wg.Add(1) + go func() { + defer func() { + st.wg.Done() + select { + case <-st.routineCh: + log.Debug("task Schedule done", "routine", len(st.routineCh)) + default: + } + }() + f() + }() +} + +func (st *ScanTask) TotalScan() uint64 { + return st.expiredStat.Load() + st.unexpiredStat.Load() +} + +func (st *ScanTask) MoreThread() bool { + select { + case st.routineCh <- struct{}{}: + return true + default: + return false + } +} + +// ScanForPrune traverses the storage trie and prunes all expired or unexpired nodes. +func (t *Trie) ScanForPrune(st *ScanTask) error { + if !t.enableExpiry { + return nil + } + + if t.owner == (common.Hash{}) { + return fmt.Errorf("cannot prune account trie") + } + + err := t.findExpiredSubTree(t.root, nil, t.getRootEpoch(), func(n node, path []byte, epoch types.StateEpoch) { + if pruneErr := t.recursePruneExpiredNode(n, path, epoch, st); pruneErr != nil { + log.Error("recursePruneExpiredNode err", "Path", path, "err", pruneErr) + } + }, st) + if err != nil { + return err + } + + return nil +} + +func (t *Trie) findExpiredSubTree(n node, path []byte, epoch types.StateEpoch, pruner func(n node, path []byte, epoch types.StateEpoch), st *ScanTask) error { + // Upon reaching expired node, it will recursively traverse downwards to all the child nodes + // and collect their hashes. Then, the corresponding key-value pairs will be deleted from the + // database by batches. + if t.epochExpired(n, epoch) { + if st.findExpired { + pruner(n, path, epoch) + } + return nil + } + + switch n := n.(type) { + case *shortNode: + st.Stat(false) + if !st.findExpired { + st.itemCh <- &NodeInfo{ + Hash: common.BytesToHash(n.flags.hash), + } + } + err := t.findExpiredSubTree(n.Val, append(path, n.Key...), epoch, pruner, st) + if err != nil { + return err + } + return nil + case *fullNode: + st.Stat(false) + if !st.findExpired { + st.itemCh <- &NodeInfo{ + Hash: common.BytesToHash(n.flags.hash), + } + } + var err error + // Go through every child and recursively delete expired nodes + for i, child := range n.Children { + err = t.findExpiredSubTree(child, append(path, byte(i)), n.GetChildEpoch(i), pruner, st) + if err != nil { + return err + } + } + return nil + + case hashNode: + resolve, err := t.resolveHash(n, path) + if err != nil { + return err + } + if err = t.resolveEpochMeta(resolve, epoch, path); err != nil { + return err + } + if st.TotalScan()%1000 == 0 && st.MoreThread() { + path := common.CopyBytes(path) + st.Schedule(func() { + if err := t.findExpiredSubTree(resolve, path, epoch, pruner, st); err != nil { + log.Error("recursePruneExpiredNode err", "addr", t.owner, "path", path, "epoch", epoch, "err", err) + } + }) + return nil + } + return t.findExpiredSubTree(resolve, path, epoch, pruner, st) + case valueNode: + return nil + case nil: + return nil + default: + panic(fmt.Sprintf("invalid node type: %T", n)) + } +} + +func (t *Trie) recursePruneExpiredNode(n node, path []byte, epoch types.StateEpoch, st *ScanTask) error { + switch n := n.(type) { + case *shortNode: + st.Stat(true) + subPath := append(path, n.Key...) + key := common.Hash{} + _, isLeaf := n.Val.(valueNode) + if isLeaf { + key = common.BytesToHash(hexToKeybytes(subPath)) + } + if st.findExpired { + st.itemCh <- &NodeInfo{ + Addr: t.owner, + Hash: common.BytesToHash(n.flags.hash), + Path: renewBytes(path), + Key: key, + Epoch: epoch, + IsLeaf: isLeaf, + } + } + + err := t.recursePruneExpiredNode(n.Val, subPath, epoch, st) + if err != nil { + return err + } + return nil + case *fullNode: + st.Stat(true) + if st.findExpired { + st.itemCh <- &NodeInfo{ + Addr: t.owner, + Hash: common.BytesToHash(n.flags.hash), + Path: renewBytes(path), + Epoch: epoch, + IsBranch: true, + } + } + // recurse child, and except valueNode + for i := 0; i < BranchNodeLength-1; i++ { + err := t.recursePruneExpiredNode(n.Children[i], append(path, byte(i)), n.EpochMap[i], st) + if err != nil { + return err + } + } + return nil + case hashNode: + // hashNode is a index of trie node storage, need not prune. + rn, err := t.resolveHash(n, path) + // if touch miss node, just skip + if _, ok := err.(*MissingNodeError); ok { + return nil + } + if err != nil { + return err + } + if err = t.resolveEpochMeta(rn, epoch, path); err != nil { + return err + } + + if st.TotalScan()%1000 == 0 && st.MoreThread() { + path := common.CopyBytes(path) + st.Schedule(func() { + if err := t.recursePruneExpiredNode(rn, path, epoch, st); err != nil { + log.Error("recursePruneExpiredNode err", "addr", t.owner, "path", path, "epoch", epoch, "err", err) + } + }) + return nil + } + return t.recursePruneExpiredNode(rn, path, epoch, st) + case valueNode: + // value node is not a single storage uint, so pass to prune. + return nil + case nil: + return nil + default: + panic(fmt.Sprintf("invalid node type: %T", n)) + } +} + +func (t *Trie) UpdateRootEpoch(epoch types.StateEpoch) { + t.rootEpoch = epoch +} + +func (t *Trie) UpdateCurrentEpoch(epoch types.StateEpoch) { + t.currentEpoch = epoch +} + +// isExpiredNode check if expired or missed, it may prune by old state snap +func (t *Trie) isExpiredNode(n node, targetPrefixKey []byte, epoch types.StateEpoch, expired bool) bool { + if expired { + return true + } + + // check if there miss the trie node + _, err := t.resolve(n, targetPrefixKey, epoch) + if _, ok := err.(*MissingNodeError); ok { + return true + } + + return false +} diff --git a/trie/trie_expiry.go b/trie/trie_expiry.go new file mode 100644 index 0000000000..8c7c8e32f4 --- /dev/null +++ b/trie/trie_expiry.go @@ -0,0 +1,79 @@ +package trie + +import ( + "bytes" + "fmt" + + "github.com/ethereum/go-ethereum/core/types" +) + +func (t *Trie) TryLocalRevive(key []byte) ([]byte, error) { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return nil, ErrCommitted + } + + key = keybytesToHex(key) + val, newroot, didResolve, err := t.tryLocalRevive(t.root, key, 0, t.getRootEpoch()) + if err == nil && didResolve { + t.root = newroot + t.rootEpoch = t.currentEpoch + } + return val, err +} + +func (t *Trie) tryLocalRevive(origNode node, key []byte, pos int, epoch types.StateEpoch) ([]byte, node, bool, error) { + expired := t.epochExpired(origNode, epoch) + switch n := (origNode).(type) { + case nil: + return nil, nil, false, nil + case valueNode: + return n, n, expired, nil + case *shortNode: + if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + // key not found in trie, but just revive for expand + if t.renewNode(epoch, false, expired) { + n = n.copy() + n.setEpoch(t.currentEpoch) + n.flags = t.newFlag() + return nil, n, true, nil + } + return nil, n, false, nil + } + value, newnode, didResolve, err := t.tryLocalRevive(n.Val, key, pos+len(n.Key), epoch) + if err == nil && t.renewNode(epoch, didResolve, expired) { + n = n.copy() + n.Val = newnode + n.setEpoch(t.currentEpoch) + n.flags = t.newFlag() + didResolve = true + } + return value, n, didResolve, err + case *fullNode: + value, newnode, didResolve, err := t.tryLocalRevive(n.Children[key[pos]], key, pos+1, n.GetChildEpoch(int(key[pos]))) + if err == nil && t.renewNode(epoch, didResolve, expired) { + n = n.copy() + n.Children[key[pos]] = newnode + n.setEpoch(t.currentEpoch) + if newnode != nil { + n.UpdateChildEpoch(int(key[pos]), t.currentEpoch) + } + n.flags = t.newFlag() + didResolve = true + } + return value, n, didResolve, err + case hashNode: + child, err := t.resolveAndTrack(n, key[:pos]) + if err != nil { + return nil, n, true, err + } + + if err = t.resolveEpochMetaAndTrack(child, epoch, key[:pos]); err != nil { + return nil, n, true, err + } + value, newnode, _, err := t.tryLocalRevive(child, key, pos, epoch) + return value, newnode, true, err + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) + } +} diff --git a/trie/trie_reader.go b/trie/trie_reader.go index 4215964559..3cad34c8de 100644 --- a/trie/trie_reader.go +++ b/trie/trie_reader.go @@ -17,12 +17,25 @@ package trie import ( + "errors" + "fmt" + "math/big" + "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/trie/epochmeta" "github.com/ethereum/go-ethereum/trie/triestate" ) +var ( + accountMetaTimer = metrics.NewRegisteredTimer("trie/reader/accountmeta", nil) + epochMetaTimer = metrics.NewRegisteredTimer("trie/reader/epochmeta", nil) + nodeTimer = metrics.NewRegisteredTimer("trie/reader/node", nil) +) + // Reader wraps the Node method of a backing trie store. type Reader interface { // Node retrieves the trie node blob with the provided trie identifier, node path and @@ -39,24 +52,41 @@ type Reader interface { // trieReader is a wrapper of the underlying node reader. It's not safe // for concurrent usage. type trieReader struct { - owner common.Hash - reader Reader - banned map[string]struct{} // Marker to prevent node from being accessed, for tests + owner common.Hash + reader Reader + emReader *epochmeta.Reader + banned map[string]struct{} // Marker to prevent node from being accessed, for tests } // newTrieReader initializes the trie reader with the given node reader. func newTrieReader(stateRoot, owner common.Hash, db *Database) (*trieReader, error) { + var err error + if stateRoot == (common.Hash{}) || stateRoot == types.EmptyRootHash { if stateRoot == (common.Hash{}) { log.Error("Zero state root hash!") } - return &trieReader{owner: owner}, nil + tr := &trieReader{owner: owner} + if db.snapTree != nil { + tr.emReader, err = epochmeta.NewReader(db.snapTree, new(big.Int), stateRoot) + if err != nil { + return nil, err + } + } + return tr, nil } reader, err := db.Reader(stateRoot) if err != nil { return nil, &MissingNodeError{Owner: owner, NodeHash: stateRoot, err: err} } - return &trieReader{owner: owner, reader: reader}, nil + tr := trieReader{owner: owner, reader: reader} + if db.snapTree != nil { + tr.emReader, err = epochmeta.NewReader(db.snapTree, new(big.Int), stateRoot) + if err != nil { + return nil, err + } + } + return &tr, nil } // newEmptyReader initializes the pure in-memory reader. All read operations @@ -69,6 +99,9 @@ func newEmptyReader() *trieReader { // information. An MissingNodeError will be returned in case the node is // not found or any error is encountered. func (r *trieReader) node(path []byte, hash common.Hash) ([]byte, error) { + defer func(start time.Time) { + nodeTimer.Update(time.Since(start)) + }(time.Now()) // Perform the logics in tests for preventing trie node access. if r.banned != nil { if _, ok := r.banned[string(path)]; ok { @@ -79,7 +112,7 @@ func (r *trieReader) node(path []byte, hash common.Hash) ([]byte, error) { return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} } blob, err := r.reader.Node(r.owner, path, hash) - if err != nil || len(blob) == 0 { + if err != nil || (!epochmeta.IsEpochMetaPath(path) && len(blob) == 0) { return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err} } return blob, nil @@ -99,3 +132,36 @@ func (l *trieLoader) OpenTrie(root common.Hash) (triestate.Trie, error) { func (l *trieLoader) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (triestate.Trie, error) { return New(StorageTrieID(stateRoot, addrHash, root), l.db) } + +// epochMeta resolve from epoch meta storage +func (r *trieReader) epochMeta(path []byte) ([]byte, error) { + defer func(start time.Time) { + epochMetaTimer.Update(time.Since(start)) + }(time.Now()) + if r.emReader == nil { + return nil, fmt.Errorf("cannot resolve epochmeta without db, path: %#x", path) + } + + // epoch meta cloud be empty, because epoch0 or delete? + blob, err := r.emReader.Get(r.owner, string(path)) + if err != nil { + return nil, fmt.Errorf("resolve epoch meta err, path: %#x, err: %v", path, err) + } + return blob, nil +} + +// accountMeta resolve account metadata +func (r *trieReader) accountMeta() ([]byte, error) { + defer func(start time.Time) { + accountMetaTimer.Update(time.Since(start)) + }(time.Now()) + if r.emReader == nil { + return nil, errors.New("cannot resolve epoch meta without db for account") + } + + blob, err := r.emReader.Get(r.owner, string(epochmeta.AccountMetadataPath)) + if err != nil { + return nil, fmt.Errorf("resolve epoch meta err for account, err: %v", err) + } + return blob, nil +} diff --git a/trie/trie_test.go b/trie/trie_test.go index 35ccc77201..cb80fa8af2 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -36,6 +36,7 @@ import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie/trienode" + "github.com/stretchr/testify/assert" "golang.org/x/crypto/sha3" ) @@ -995,6 +996,262 @@ func TestCommitSequenceSmallRoot(t *testing.T) { } } +func TestRevive(t *testing.T) { + trie, vals := nonRandomTrieWithExpiry(100) + + oriRootHash := trie.Hash() + + for _, kv := range vals { + key := kv.k + val := kv.v + prefixKeys := getFullNodePrefixKeys(trie, key) + for _, prefixKey := range prefixKeys { + // Generate proof + var proof proofList + err := trie.ProveByPath(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + trie.ExpireByPrefix(prefixKey) + + proofCache := makeRawMPTProofCache(prefixKey, proof) + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + _, err = trie.TryRevive(key, proofCache.CacheNubs()) + assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + + // Verifiy value exists after revive + v, err := trie.Get(key) + assert.NoError(t, err, "Get failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + assert.Equal(t, val, v, "value mismatch, got %x, exp %x, key %x, prefixKey %x", v, val, key, prefixKey) + + // Verify root hash + currRootHash := trie.Hash() + assert.Equal(t, oriRootHash, currRootHash, "root hash mismatch, got %x, exp %x, key %x, prefixKey %x", currRootHash, oriRootHash, key, prefixKey) + + // Reset trie + trie, _ = nonRandomTrieWithExpiry(100) + } + } +} + +func TestReviveCustom(t *testing.T) { + data := map[string]string{ + "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", + "defg": "E", "defh": "F", "degh": "G", "degi": "H", + } + + trie := createCustomTrie(data, 10) + + oriRootHash := trie.Hash() + + for k, v := range data { + key := []byte(k) + val := []byte(v) + prefixKeys := getFullNodePrefixKeys(trie, key) + for _, prefixKey := range prefixKeys { + var proofList proofList + err := trie.ProveByPath(key, prefixKey, &proofList) + assert.NoError(t, err) + + trie.ExpireByPrefix(prefixKey) + + proofCache := makeRawMPTProofCache(prefixKey, proofList) + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + _, err = trie.TryRevive(key, proofCache.cacheNubs) + assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + + res, err := trie.Get(key) + assert.NoError(t, err, "Get failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + assert.Equal(t, val, res, "value mismatch, got %x, exp %x, key %x, prefixKey %x", res, val, key, prefixKey) + + // Verify root hash + currRootHash := trie.Hash() + assert.Equal(t, oriRootHash, currRootHash, "root hash mismatch, got %x, exp %x, key %x, prefixKey %x", currRootHash, oriRootHash, key, prefixKey) + + // Reset trie + trie = createCustomTrie(data, 10) + } + } +} + +// TestReviveBadProof tests that a trie cannot be revived from a bad proof +func TestReviveBadProof(t *testing.T) { + dataA := map[string]string{ + "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", + "defg": "E", "defh": "F", "degh": "G", "degi": "H", + } + + dataB := map[string]string{ + "qwer": "A", "qwet": "B", "qwrt": "C", "qwry": "D", + "abcd": "E", "abce": "F", "abde": "G", "abdf": "H", + } + + trieA := createCustomTrie(dataA, 0) + trieB := createCustomTrie(dataB, 0) + + var proofB proofList + + err := trieB.ProveByPath([]byte("abcd"), nil, &proofB) + assert.NoError(t, err) + + // Expire trie A + trieA.ExpireByPrefix(nil) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(nil, proofB) + + // VerifyProof + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + _, err = trieA.TryRevive([]byte("abcd"), proofCache.cacheNubs) + assert.Error(t, err) + + // Verify value does exists after revive + val, err := trieA.Get([]byte("abcd")) + assert.Error(t, err, "Get failed, key %x, val %x", []byte("abcd"), val) + assert.NotEqual(t, []byte("A"), val) +} + +// TestReviveBadProofAfterUpdate tests that after reviving a path and +// then update the value, old proof should be invalid +func TestReviveBadProofAfterUpdate(t *testing.T) { + trie, vals := nonRandomTrieWithExpiry(100) + + for _, kv := range vals { + key := kv.k + val := kv.v + prefixKeys := getFullNodePrefixKeys(trie, key) + for _, prefixKey := range prefixKeys { + // Generate proof + var proof proofList + err := trie.ProveByPath(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + trie.ExpireByPrefix(prefixKey) + + proofCache := makeRawMPTProofCache(prefixKey, proof) + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + _, err = trie.TryRevive(key, proofCache.CacheNubs()) + assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + + // Verify value exists after revive + v, err := trie.Get(key) + assert.NoError(t, err, "Get failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + assert.Equal(t, val, v, "value mismatch, got %x, exp %x, key %x, prefixKey %x", v, val, key, prefixKey) + + trie.Update(key, []byte("new value")) + v, err = trie.Get(key) + assert.NoError(t, err, "Get failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + assert.Equal(t, []byte("new value"), v, "value mismatch, got %x, exp %x, key %x, prefixKey %x", v, val, key, prefixKey) + + _, err = trie.TryRevive(key, proofCache.CacheNubs()) + assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + + v, err = trie.Get(key) + assert.NoError(t, err, "Get failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + assert.Equal(t, []byte("new value"), v, "value mismatch, got %x, exp %x, key %x, prefixKey %x", v, val, key, prefixKey) + + // Reset trie + trie, _ = nonRandomTrieWithExpiry(100) + } + } +} + +func TestPartialReviveFullProof(t *testing.T) { + data := map[string]string{ + "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", + "defg": "E", "defh": "F", "degh": "G", "degi": "H", + } + + trie := createCustomTrie(data, 10) + key := []byte("abcd") + val := []byte("A") + + // Get proof + var proof proofList + err := trie.ProveByPath(key, nil, &proof) + assert.NoError(t, err) + + // Expire trie + err = trie.ExpireByPrefix([]byte{6, 1}) + assert.NoError(t, err) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(nil, proof) + + // Verify proof + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + _, err = trie.TryRevive(key, proofCache.cacheNubs) + assert.NoError(t, err) + + // Validate trie + resVal, err := trie.Get(key) + assert.NoError(t, err) + assert.Equal(t, val, resVal) +} + +func createCustomTrie(data map[string]string, epoch types.StateEpoch) *Trie { + db := NewDatabase(rawdb.NewMemoryDatabase(), nil) + trie := NewEmpty(db) + trie.rootEpoch = epoch + trie.currentEpoch = epoch + trie.enableExpiry = true + for k, v := range data { + trie.MustUpdate([]byte(k), []byte(v)) + } + + return trie +} + +func getFullNodePrefixKeys(t *Trie, key []byte) [][]byte { + var prefixKeys [][]byte + key = keybytesToHex(key) + tn := t.root + currPath := []byte{} + for len(key) > 0 && tn != nil { + switch n := tn.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + // The trie doesn't contain the key. + tn = nil + } else { + tn = n.Val + prefixKeys = append(prefixKeys, currPath) + currPath = append(currPath, n.Key...) + key = key[len(n.Key):] + } + case *fullNode: + tn = n.Children[key[0]] + currPath = append(currPath, key[0]) + key = key[1:] + default: + return nil + } + } + + // // Remove the first item in prefixKeys, which is the empty key + // if len(prefixKeys) > 0 { + // prefixKeys = prefixKeys[1:] + // } + + return prefixKeys +} + // BenchmarkCommitAfterHashFixedSize benchmarks the Commit (after Hash) of a fixed number of updates to a trie. // This benchmark is meant to capture the difference on efficiency of small versus large changes. Typically, // storage tries are small (a couple of entries), whereas the full post-block account trie update is large (a couple @@ -1181,3 +1438,12 @@ func TestDecodeNode(t *testing.T) { decodeNode(hash, elems) } } + +func makeRawMPTProofCache(rootKeyHex []byte, proof [][]byte) MPTProofCache { + return MPTProofCache{ + MPTProof: MPTProof{ + RootKeyHex: rootKeyHex, + Proof: proof, + }, + } +} diff --git a/trie/triedb/pathdb/difflayer.go b/trie/triedb/pathdb/difflayer.go index d25ac1c601..74ff69fedd 100644 --- a/trie/triedb/pathdb/difflayer.go +++ b/trie/triedb/pathdb/difflayer.go @@ -20,6 +20,8 @@ import ( "fmt" "sync" + "github.com/ethereum/go-ethereum/trie/epochmeta" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/trie/trienode" @@ -111,9 +113,9 @@ func (dl *diffLayer) node(owner common.Hash, path []byte, hash common.Hash, dept if ok { // If the trie node is not hash matched, or marked as removed, // bubble up an error here. It shouldn't happen at all. - if n.Hash != hash { + if !epochmeta.IsEpochMetaPath(path) && n.Hash != hash { dirtyFalseMeter.Mark(1) - log.Error("Unexpected trie node in diff layer", "owner", owner, "path", path, "expect", hash, "got", n.Hash) + log.Debug("Unexpected trie node in diff layer", "root", dl.root, "owner", owner, "path", path, "expect", hash, "got", n.Hash) return nil, newUnexpectedNodeError("diff", hash, n.Hash, owner, path) } dirtyHitMeter.Mark(1) diff --git a/trie/triedb/pathdb/disklayer.go b/trie/triedb/pathdb/disklayer.go index 87718290f9..145f8a5e4f 100644 --- a/trie/triedb/pathdb/disklayer.go +++ b/trie/triedb/pathdb/disklayer.go @@ -21,6 +21,9 @@ import ( "fmt" "sync" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/trie/epochmeta" + "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" @@ -126,14 +129,15 @@ func (dl *diskLayer) Node(owner common.Hash, path []byte, hash common.Hash) ([]b h := newHasher() defer h.release() - got := h.hash(blob) - if got == hash { + raw, _ := types.DecodeTypedTrieNodeRaw(blob) + got := h.hash(raw) + if epochmeta.IsEpochMetaPath(path) || got == hash { cleanHitMeter.Mark(1) cleanReadMeter.Mark(int64(len(blob))) return blob, nil } cleanFalseMeter.Mark(1) - log.Error("Unexpected trie node in clean cache", "owner", owner, "path", path, "expect", hash, "got", got) + log.Debug("Unexpected trie node in clean cache", "owner", owner, "path", path, "expect", hash, "got", got) } cleanMissMeter.Mark(1) } @@ -147,9 +151,9 @@ func (dl *diskLayer) Node(owner common.Hash, path []byte, hash common.Hash) ([]b } else { nBlob, nHash = rawdb.ReadStorageTrieNode(dl.db.diskdb, owner, path) } - if nHash != hash { + if !epochmeta.IsEpochMetaPath(path) && nHash != hash { diskFalseMeter.Mark(1) - log.Error("Unexpected trie node in disk", "owner", owner, "path", path, "expect", hash, "got", nHash) + log.Debug("Unexpected trie node in disk", "owner", owner, "path", path, "expect", hash, "got", nHash) return nil, newUnexpectedNodeError("disk", hash, nHash, owner, path) } if dl.cleans != nil && len(nBlob) > 0 { diff --git a/trie/triedb/pathdb/journal.go b/trie/triedb/pathdb/journal.go index d8c7d39fb9..bc5838cecf 100644 --- a/trie/triedb/pathdb/journal.go +++ b/trie/triedb/pathdb/journal.go @@ -161,7 +161,11 @@ func (db *Database) loadDiskLayer(r *rlp.Stream) (layer, error) { subset := make(map[string]*trienode.Node) for _, n := range entry.Nodes { if len(n.Blob) > 0 { - subset[string(n.Path)] = trienode.New(crypto.Keccak256Hash(n.Blob), n.Blob) + raw, err := types.DecodeTypedTrieNodeRaw(n.Blob) + if err != nil { + return nil, err + } + subset[string(n.Path)] = trienode.New(crypto.Keccak256Hash(raw), n.Blob) } else { subset[string(n.Path)] = trienode.NewDeleted() } @@ -199,7 +203,11 @@ func (db *Database) loadDiffLayer(parent layer, r *rlp.Stream) (layer, error) { subset := make(map[string]*trienode.Node) for _, n := range entry.Nodes { if len(n.Blob) > 0 { - subset[string(n.Path)] = trienode.New(crypto.Keccak256Hash(n.Blob), n.Blob) + raw, err := types.DecodeTypedTrieNodeRaw(n.Blob) + if err != nil { + return nil, err + } + subset[string(n.Path)] = trienode.New(crypto.Keccak256Hash(raw), n.Blob) } else { subset[string(n.Path)] = trienode.NewDeleted() } diff --git a/trie/triedb/pathdb/layertree.go b/trie/triedb/pathdb/layertree.go index d314779910..3590691d4c 100644 --- a/trie/triedb/pathdb/layertree.go +++ b/trie/triedb/pathdb/layertree.go @@ -21,6 +21,8 @@ import ( "fmt" "sync" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/trie/trienode" @@ -105,6 +107,7 @@ func (tree *layerTree) add(root common.Hash, parentRoot common.Hash, block uint6 tree.lock.Lock() tree.layers[l.rootHash()] = l + log.Debug("pathdb snap tree update", "root", root, "number", block, "layers", len(tree.layers)) tree.lock.Unlock() return nil } @@ -190,6 +193,7 @@ func (tree *layerTree) cap(root common.Hash, layers int) error { remove(root) } } + log.Debug("pathdb snap tree cap", "root", root, "layers", len(tree.layers)) return nil } diff --git a/trie/triedb/pathdb/nodebuffer.go b/trie/triedb/pathdb/nodebuffer.go index 67de225b04..3f6a17c078 100644 --- a/trie/triedb/pathdb/nodebuffer.go +++ b/trie/triedb/pathdb/nodebuffer.go @@ -20,6 +20,8 @@ import ( "fmt" "time" + "github.com/ethereum/go-ethereum/trie/epochmeta" + "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" @@ -68,9 +70,9 @@ func (b *nodebuffer) node(owner common.Hash, path []byte, hash common.Hash) (*tr if !ok { return nil, nil } - if n.Hash != hash { + if !epochmeta.IsEpochMetaPath(path) && n.Hash != hash { dirtyFalseMeter.Mark(1) - log.Error("Unexpected trie node in node buffer", "owner", owner, "path", path, "expect", hash, "got", n.Hash) + log.Debug("Unexpected trie node in node buffer", "owner", owner, "path", path, "expect", hash, "got", n.Hash) return nil, newUnexpectedNodeError("dirty", hash, n.Hash, owner, path) } return n, nil diff --git a/trie/triedb/pathdb/testutils.go b/trie/triedb/pathdb/testutils.go index d6fdacb421..070e13dea3 100644 --- a/trie/triedb/pathdb/testutils.go +++ b/trie/triedb/pathdb/testutils.go @@ -20,6 +20,8 @@ import ( "bytes" "fmt" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" @@ -130,6 +132,7 @@ func hash(states map[common.Hash][]byte) (common.Hash, []byte) { if len(input) == 0 { return types.EmptyRootHash, nil } + input, _ = rlp.EncodeToBytes(input) return crypto.Keccak256Hash(input), input } diff --git a/trie/trienode/node.go b/trie/trienode/node.go index 98d5588b6d..38bd67cfe5 100644 --- a/trie/trienode/node.go +++ b/trie/trienode/node.go @@ -21,6 +21,8 @@ import ( "sort" "strings" + "github.com/ethereum/go-ethereum/trie/epochmeta" + "github.com/ethereum/go-ethereum/common" ) @@ -59,19 +61,21 @@ type leaf struct { // NodeSet contains a set of nodes collected during the commit operation. // Each node is keyed by path. It's not thread-safe to use. type NodeSet struct { - Owner common.Hash - Leaves []*leaf - Nodes map[string]*Node - updates int // the count of updated and inserted nodes - deletes int // the count of deleted nodes + Owner common.Hash + Leaves []*leaf + Nodes map[string]*Node + EpochMetas map[string][]byte + updates int // the count of updated and inserted nodes + deletes int // the count of deleted nodes } // NewNodeSet initializes a node set. The owner is zero for the account trie and // the owning account address hash for storage tries. func NewNodeSet(owner common.Hash) *NodeSet { return &NodeSet{ - Owner: owner, - Nodes: make(map[string]*Node), + Owner: owner, + Nodes: make(map[string]*Node), + EpochMetas: make(map[string][]byte), } } @@ -99,8 +103,18 @@ func (set *NodeSet) AddNode(path []byte, n *Node) { set.Nodes[string(path)] = n } +// AddBranchNodeEpochMeta adds the provided epoch meta into set. +func (set *NodeSet) AddBranchNodeEpochMeta(path []byte, blob []byte) { + set.EpochMetas[string(path)] = blob +} + +// AddAccountMeta adds the provided account into set. +func (set *NodeSet) AddAccountMeta(blob []byte) { + set.EpochMetas[string(epochmeta.AccountMetadataPath)] = blob +} + // Merge adds a set of nodes into the set. -func (set *NodeSet) Merge(owner common.Hash, nodes map[string]*Node) error { +func (set *NodeSet) Merge(owner common.Hash, nodes map[string]*Node, metas map[string][]byte) error { if set.Owner != owner { return fmt.Errorf("nodesets belong to different owner are not mergeable %x-%x", set.Owner, owner) } @@ -116,6 +130,9 @@ func (set *NodeSet) Merge(owner common.Hash, nodes map[string]*Node) error { } set.AddNode([]byte(path), node) } + for path, meta := range metas { + set.EpochMetas[path] = meta + } return nil } @@ -183,7 +200,7 @@ func NewWithNodeSet(set *NodeSet) *MergedNodeSet { func (set *MergedNodeSet) Merge(other *NodeSet) error { subset, present := set.Sets[other.Owner] if present { - return subset.Merge(other.Owner, other.Nodes) + return subset.Merge(other.Owner, other.Nodes, other.EpochMetas) } set.Sets[other.Owner] = other return nil @@ -197,3 +214,11 @@ func (set *MergedNodeSet) Flatten() map[common.Hash]map[string]*Node { } return nodes } + +func (set *MergedNodeSet) FlattenEpochMeta() map[common.Hash]map[string][]byte { + nodes := make(map[common.Hash]map[string][]byte) + for owner, set := range set.Sets { + nodes[owner] = set.EpochMetas + } + return nodes +} diff --git a/trie/typed_trie_node_test.go b/trie/typed_trie_node_test.go new file mode 100644 index 0000000000..490b101c4d --- /dev/null +++ b/trie/typed_trie_node_test.go @@ -0,0 +1,146 @@ +package trie + +import ( + "math/rand" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/stretchr/testify/assert" +) + +var ( + fullNode1 = fullNode{ + EpochMap: randomEpochMap(), + Children: [17]node{ + &shortNode{ + Key: common.FromHex("0x2e2"), + Val: valueNode(common.FromHex("0x1")), + }, + &shortNode{ + Key: common.FromHex("0x31f"), + Val: valueNode(common.FromHex("0x2")), + }, + hashNode(common.FromHex("0x1dce34c5cc509511f743349d758b8c38af8ac831432dbbfd989436acd3dbdeb8")), + hashNode(common.FromHex("0x8bf421d69d8aacac46f15f0abd517e61e7ffe6b314a15a4fbce3e2a54323fa81")), + }, + } + fullNode2 = fullNode{ + EpochMap: randomEpochMap(), + Children: [17]node{ + hashNode(common.FromHex("0xac51f786e6cee2f4575d19789c1e7ae91da54f2138f415c0f95f127c2893eff9")), + hashNode(common.FromHex("0x83254958a3640af7a740dfcb32a02edfa1224e0ef65c28b1ff60c0b17eacb5d1")), + hashNode(common.FromHex("0xc5f95b4bdbd1a17736a9162cd551d60c60252ea22d5016198ee6e5a5d04ac03a")), + hashNode(common.FromHex("0xfe0654cc989b62dec1758daf6c4a29997f1f618d456981dd1d32f73c74c75151")), + }, + } + shortNode1 = shortNode{ + Key: common.FromHex("0xdf21"), + Val: hashNode(common.FromHex("0x1dce34c5cc509511f743349d758b8c38af8ac831432dbbfd989436acd3dbdeb8")), + } + shortNode2 = shortNode{ + Key: common.FromHex("0xdf21"), + Val: valueNode(common.FromHex("0x1af23")), + } +) + +func TestSimpleTypedNode_Encode_Decode(t *testing.T) { + tests := []struct { + n types.TypedTrieNode + err bool + }{ + { + n: types.TrieNodeRaw{}, + }, + { + n: types.TrieNodeRaw(common.FromHex("0x2465176C461AfB316ebc773C61fAEe85A6515DAA")), + err: true, + }, + { + n: types.TrieNodeRaw(nodeToBytes(&shortNode1)), + }, + { + n: types.TrieNodeRaw(nodeToBytes(&shortNode2)), + }, + { + n: types.TrieNodeRaw(nodeToBytes(&fullNode1)), + }, + { + n: types.TrieNodeRaw(nodeToBytes(&fullNode2)), + }, + { + n: &types.TrieBranchNodeWithEpoch{ + EpochMap: fullNode1.EpochMap, + Blob: nodeToBytes(&fullNode1), + }, + }, + { + n: &types.TrieBranchNodeWithEpoch{ + EpochMap: fullNode2.EpochMap, + Blob: nodeToBytes(&fullNode2), + }, + }, + { + n: &types.TrieBranchNodeWithEpoch{ + EpochMap: randomEpochMap(), + Blob: nodeToBytes(&shortNode1), + }, + }, + { + n: &types.TrieBranchNodeWithEpoch{ + EpochMap: randomEpochMap(), + Blob: nodeToBytes(&shortNode2), + }, + }, + } + + for i, item := range tests { + enc := types.EncodeTypedTrieNode(item.n) + t.Log(common.Bytes2Hex(enc)) + rn, err := types.DecodeTypedTrieNode(enc) + if item.err { + assert.Error(t, err, i) + continue + } + assert.NoError(t, err, i) + assert.Equal(t, item.n, rn, i) + } +} + +func TestNode2Bytes_Encode(t *testing.T) { + tests := []struct { + tn types.TypedTrieNode + n node + err bool + }{ + { + tn: &types.TrieBranchNodeWithEpoch{ + EpochMap: fullNode1.EpochMap, + Blob: nodeToBytes(&fullNode1), + }, + n: &fullNode1, + }, + { + tn: &types.TrieBranchNodeWithEpoch{ + EpochMap: fullNode2.EpochMap, + Blob: nodeToBytes(&fullNode2), + }, + n: &fullNode2, + }, + } + + for i, item := range tests { + enc1 := nodeToBytesWithEpoch(item.n) + enc2 := types.EncodeTypedTrieNode(item.tn) + t.Log(common.Bytes2Hex(enc1), common.Bytes2Hex(enc2)) + assert.Equal(t, enc2, enc1, i) + } +} + +func randomEpochMap() [16]types.StateEpoch { + var ret [16]types.StateEpoch + for i := range ret { + ret[i] = types.StateEpoch(rand.Int() % 10000) + } + return ret +}