diff --git a/core/state/database.go b/core/state/database.go index f64a363c0b..58366be0ad 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -158,17 +158,20 @@ type Trie interface { // with the node that proves the absence of the key. Prove(key []byte, proofDb ethdb.KeyValueWriter) error - // ProvePath generate proof state in trie. - ProvePath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error + // ProveByPath generate proof state in trie. + ProveByPath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error - // ReviveTrie revive expired state from proof. - ReviveTrie(key []byte, proof []*trie.MPTProofNub) ([]*trie.MPTProofNub, error) + // TryRevive revive expired state from proof. + TryRevive(key []byte, proof []*trie.MPTProofNub) ([]*trie.MPTProofNub, error) // SetEpoch set current epoch in trie, it must set in initial period, or it will get error behavior. SetEpoch(types.StateEpoch) // Epoch get current epoch in trie Epoch() types.StateEpoch + + // TryLocalRevive it revive using local non-pruned states + TryLocalRevive(addr common.Address, key []byte) ([]byte, error) } // NewDatabase creates a backing store for state. The returned database is safe for diff --git a/core/state/state_expiry.go b/core/state/state_expiry.go index 6870205fa8..232c4d0582 100644 --- a/core/state/state_expiry.go +++ b/core/state/state_expiry.go @@ -14,12 +14,34 @@ import ( var ( reviveStorageTrieTimer = metrics.NewRegisteredTimer("state/revivetrie/rt", nil) + EnableLocalRevive = false // indicate if using local revive ) // fetchExpiredStorageFromRemote request expired state from remote full state node; func fetchExpiredStorageFromRemote(fullDB ethdb.FullStateDB, stateRoot common.Hash, addr common.Address, root common.Hash, tr Trie, prefixKey []byte, key common.Hash) (map[string][]byte, error) { log.Debug("fetching expired storage from remoteDB", "addr", addr, "prefix", prefixKey, "key", key) + if EnableLocalRevive { + // if there need revive expired state, try to revive locally, when the node is not being pruned, just renew the epoch + val, err := tr.TryLocalRevive(addr, key.Bytes()) + log.Debug("fetchExpiredStorageFromRemote TryLocalRevive", "addr", addr, "key", key, "val", val, "err", err) + if _, ok := err.(*trie.MissingNodeError); !ok { + return nil, err + } + switch err.(type) { + case *trie.MissingNodeError: + // cannot revive locally, request from remote + case nil: + ret := make(map[string][]byte, 1) + ret[key.String()] = val + return ret, nil + default: + return nil, err + } + } + + // cannot revive locally, fetch remote proof proofs, err := fullDB.GetStorageReviveProof(stateRoot, addr, root, []string{common.Bytes2Hex(prefixKey)}, []string{common.Bytes2Hex(key[:])}) + log.Debug("fetchExpiredStorageFromRemote GetStorageReviveProof", "addr", addr, "key", key, "proofs", len(proofs), "err", err) if err != nil { return nil, err } @@ -60,7 +82,7 @@ func reviveStorageTrie(addr common.Address, tr Trie, proof types.ReviveStoragePr return nil, err } - nubs, err := tr.ReviveTrie(key, proofCache.CacheNubs()) + nubs, err := tr.TryRevive(key, proofCache.CacheNubs()) if err != nil { return nil, err } diff --git a/core/state/state_object.go b/core/state/state_object.go index c8d20392c2..53428729ab 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -18,6 +18,7 @@ package state import ( "bytes" + "encoding/hex" "fmt" "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/log" @@ -81,10 +82,11 @@ type stateObject struct { dirtyStorage Storage // Storage entries that have been modified in the current transaction execution, reset for every transaction // for state expiry feature - pendingReviveTrie Trie // pendingReviveTrie it contains pending revive trie nodes, could update & commit later - pendingReviveState map[string]common.Hash // pendingReviveState for block, when R&W, access revive state first, saved in hash key - pendingAccessedState map[common.Hash]int // pendingAccessedState record which state is accessed(only read now, update/delete/insert will auto update epoch), it will update epoch index late - originStorageEpoch map[common.Hash]types.StateEpoch // originStorageEpoch record origin state epoch, prevent frequency epoch update + pendingReviveTrie Trie // pendingReviveTrie it contains pending revive trie nodes, could update & commit later + pendingReviveState map[string]common.Hash // pendingReviveState for block, when R&W, access revive state first, saved in hash key + pendingAccessedState map[common.Hash]int // pendingAccessedState record which state is accessed(only read now, update/delete/insert will auto update epoch), it will update epoch index late + originStorageEpoch map[common.Hash]types.StateEpoch // originStorageEpoch record origin state epoch, prevent frequency epoch update + pendingFutureReviveState map[common.Hash]int // pendingFutureReviveState record empty state in snapshot. it should preftech first, and allow check in updateTrie // Cache flags. dirtyCode bool // true if the code was updated @@ -120,18 +122,19 @@ func newObject(db *StateDB, address common.Address, acct *types.StateAccount) *s } return &stateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - origin: origin, - data: *acct, - sharedOriginStorage: storageMap, - originStorage: make(Storage), - pendingStorage: make(Storage), - dirtyStorage: make(Storage), - pendingReviveState: make(map[string]common.Hash), - pendingAccessedState: make(map[common.Hash]int), - originStorageEpoch: make(map[common.Hash]types.StateEpoch), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + origin: origin, + data: *acct, + sharedOriginStorage: storageMap, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), + pendingReviveState: make(map[string]common.Hash), + pendingAccessedState: make(map[common.Hash]int), + pendingFutureReviveState: make(map[common.Hash]int), + originStorageEpoch: make(map[common.Hash]types.StateEpoch), } } @@ -264,7 +267,6 @@ func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { // If no live objects are available, attempt to use snapshots var ( enc []byte - sv snapshot.SnapValue err error value common.Hash ) @@ -274,15 +276,13 @@ func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { // handle state expiry situation if s.db.EnableExpire() { var dbError error - sv, err, dbError = s.getExpirySnapStorage(key) + enc, err, dbError = s.getExpirySnapStorage(key) if dbError != nil { s.db.setError(fmt.Errorf("state expiry getExpirySnapStorage, contract: %v, key: %v, err: %v", s.address, key, dbError)) return common.Hash{} } - // if query success, just set val, otherwise request from trie - if err == nil && sv != nil { - value.SetBytes(sv.GetVal()) - s.originStorageEpoch[key] = sv.GetEpoch() + if len(enc) > 0 { + value.SetBytes(enc) } } else { enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) @@ -300,7 +300,7 @@ func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { } // If the snapshot is unavailable or reading from it fails, load from the database. - if s.needLoadFromTrie(err, sv) { + if s.db.snap == nil || err != nil { getCommittedStorageTrieMeter.Mark(1) start := time.Now() var tr Trie @@ -383,6 +383,17 @@ func (s *stateObject) finalise(prefetch bool) { slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure } } + + // try prefetch future revive states + for key := range s.pendingFutureReviveState { + if val, ok := s.dirtyStorage[key]; ok { + if val != s.originStorage[key] { + continue + } + } + slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure + } + if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash { s.db.prefetcher.prefetch(s.addrHash, s.data.Root, s.address, slotsToPrefetch) } @@ -417,6 +428,8 @@ func (s *stateObject) updateTrie() (Trie, error) { err error ) if s.db.EnableExpire() { + // if EnableExpire, just use PendingReviveTrie, but prefetcher.trie is useful too, it warms up the db cache. + // and when no state expired or pruned, it will directly use prefetcher.trie too. tr, err = s.getPendingReviveTrie() } else { tr, err = s.getTrie() @@ -458,6 +471,23 @@ func (s *stateObject) updateTrie() (Trie, error) { wg.Add(1) go func() { defer wg.Done() + if s.db.EnableExpire() { + // revive state first, to figure out if there have conflict expiry path or local revive + for key := range s.pendingFutureReviveState { + _, err = tr.GetStorage(s.address, key.Bytes()) + if err == nil { + continue + } + enErr, ok := err.(*trie.ExpiredNodeError) + if !ok { + s.db.setError(fmt.Errorf("state object pendingFutureReviveState err, contract: %v, key: %v, err: %v", s.address, key, err)) + continue + } + if _, err = fetchExpiredStorageFromRemote(s.db.fullStateDB, s.db.originalRoot, s.address, s.data.Root, tr, enErr.Path, key); err != nil { + s.db.setError(fmt.Errorf("state object pendingFutureReviveState fetchExpiredStorageFromRemote err, contract: %v, key: %v, err: %v", s.address, key, err)) + } + } + } for key, value := range dirtyStorage { if len(value) == 0 { if err := tr.DeleteStorage(s.address, key[:]); err != nil { @@ -535,6 +565,9 @@ func (s *stateObject) updateTrie() (Trie, error) { if len(s.pendingAccessedState) > 0 { s.pendingAccessedState = make(map[common.Hash]int) } + if len(s.pendingFutureReviveState) > 0 { + s.pendingFutureReviveState = make(map[common.Hash]int) + } if len(s.originStorageEpoch) > 0 { s.originStorageEpoch = make(map[common.Hash]types.StateEpoch) } @@ -682,6 +715,10 @@ func (s *stateObject) deepCopy(db *StateDB) *stateObject { for k, v := range s.pendingAccessedState { obj.pendingAccessedState[k] = v } + obj.pendingFutureReviveState = make(map[common.Hash]int, len(s.pendingFutureReviveState)) + for k, v := range s.pendingFutureReviveState { + obj.pendingFutureReviveState[k] = v + } obj.originStorageEpoch = make(map[common.Hash]types.StateEpoch, len(s.originStorageEpoch)) for k, v := range s.originStorageEpoch { obj.originStorageEpoch[k] = v @@ -784,6 +821,16 @@ func (s *stateObject) accessState(key common.Hash) { } } +// futureReviveState record future revive state, it will load on prefetcher or updateTrie +func (s *stateObject) futureReviveState(key common.Hash) { + if !s.db.EnableExpire() { + return + } + + count := s.pendingFutureReviveState[key] + s.pendingFutureReviveState[key] = count + 1 +} + // TODO(0xbundler): add hash key cache later func (s *stateObject) queryFromReviveState(reviveState map[string]common.Hash, key common.Hash) (common.Hash, bool) { khash := crypto.HashData(s.db.hasher, key[:]) @@ -814,14 +861,10 @@ func (s *stateObject) fetchExpiredFromRemote(prefixKey []byte, key common.Hash, } kvs, err := fetchExpiredStorageFromRemote(s.db.fullStateDB, s.db.originalRoot, s.address, s.data.Root, tr, prefixKey, key) - if err != nil { - // Keys may not exist in the trie, so they can't be revived. - if _, ok := err.(*trie.KeyDoesNotExistError); ok { - return nil, nil - } - return nil, fmt.Errorf("revive storage trie failed, err: %v", err) + return nil, err } + for k, v := range kvs { s.pendingReviveState[k] = common.BytesToHash(v) } @@ -831,7 +874,7 @@ func (s *stateObject) fetchExpiredFromRemote(prefixKey []byte, key common.Hash, return val.Bytes(), nil } -func (s *stateObject) getExpirySnapStorage(key common.Hash) (snapshot.SnapValue, error, error) { +func (s *stateObject) getExpirySnapStorage(key common.Hash) ([]byte, error, error) { enc, err := s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) if err != nil { return nil, err, nil @@ -845,23 +888,28 @@ func (s *stateObject) getExpirySnapStorage(key common.Hash) (snapshot.SnapValue, } if val == nil { + // record access empty kv, try touch in updateTrie for duplication + s.futureReviveState(key) return nil, nil, nil } + s.originStorageEpoch[key] = val.GetEpoch() if !types.EpochExpired(val.GetEpoch(), s.db.epoch) { - return val, nil, nil + return val.GetVal(), nil, nil } - // TODO(0xbundler): if found value not been pruned, just return - //if len(val.GetVal()) > 0 { - // return val, nil, nil - //} + // if found value not been pruned, just return, local revive later + if EnableLocalRevive && len(val.GetVal()) > 0 { + s.futureReviveState(key) + log.Debug("getExpirySnapStorage GetVal", "addr", s.address, "key", key, "val", hex.EncodeToString(val.GetVal())) + return val.GetVal(), nil, nil + } - // handle from remoteDB, if got err just setError, just return to revert in consensus version. + // handle from remoteDB, if got err just setError, or return to revert in consensus version. valRaw, err := s.fetchExpiredFromRemote(nil, key, true) if err != nil { return nil, nil, err } - return snapshot.NewValueWithEpoch(val.GetEpoch(), valRaw), nil, nil + return valRaw, nil, nil } diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 4c981c1b13..44405c3171 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -845,7 +845,7 @@ func (s *BlockChainAPI) GetStorageReviveProof(ctx context.Context, stateRoot com var proof proofList prefixKey := prefixKeys[i] - if err := storageTrie.ProvePath(crypto.Keccak256(key.Bytes()), prefixKey, &proof); err != nil { + if err := storageTrie.ProveByPath(crypto.Keccak256(key.Bytes()), prefixKey, &proof); err != nil { return nil, err } storageProof[i] = types.ReviveStorageProof{ diff --git a/light/trie.go b/light/trie.go index 3daa7025bd..81a45b4bfd 100644 --- a/light/trie.go +++ b/light/trie.go @@ -115,10 +115,6 @@ type odrTrie struct { trie *trie.Trie } -func (t *odrTrie) ReviveTrie(key []byte, proof []*trie.MPTProofNub) ([]*trie.MPTProofNub, error) { - panic("not implemented") -} - func (t *odrTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { key = crypto.Keccak256(key) var enc []byte @@ -224,10 +220,6 @@ func (t *odrTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { return errors.New("not implemented, needs client/server interface split") } -func (t *odrTrie) ProvePath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error { - return errors.New("not implemented, needs client/server interface split") -} - func (t *odrTrie) Epoch() types.StateEpoch { return types.StateEpoch0 } @@ -260,10 +252,22 @@ func (t *odrTrie) do(key []byte, fn func() error) error { } } -func (db *odrTrie) NoTries() bool { +func (t *odrTrie) NoTries() bool { return false } +func (t *odrTrie) ProveByPath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error { + return errors.New("not implemented, needs client/server interface split") +} + +func (t *odrTrie) TryRevive(key []byte, proof []*trie.MPTProofNub) ([]*trie.MPTProofNub, error) { + return nil, errors.New("not implemented, needs client/server interface split") +} + +func (t *odrTrie) TryLocalRevive(addr common.Address, key []byte) ([]byte, error) { + return nil, errors.New("not implemented, needs client/server interface split") +} + type nodeIterator struct { trie.NodeIterator t *odrTrie diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index 4ee8cce1e6..28aef864cd 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -89,14 +89,18 @@ func (t *EmptyTrie) GetStorageAndUpdateEpoch(addr common.Address, key []byte) ([ return nil, nil } -func (t *EmptyTrie) ProvePath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error { +func (t *EmptyTrie) SetEpoch(epoch types.StateEpoch) { +} + +func (t *EmptyTrie) ProveByPath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error { return nil } -func (t *EmptyTrie) SetEpoch(epoch types.StateEpoch) { +func (t *EmptyTrie) TryRevive(key []byte, proof []*MPTProofNub) ([]*MPTProofNub, error) { + return nil, nil } -func (t *EmptyTrie) ReviveTrie(key []byte, proof []*MPTProofNub) ([]*MPTProofNub, error) { +func (t *EmptyTrie) TryLocalRevive(addr common.Address, key []byte) ([]byte, error) { return nil, nil } diff --git a/trie/errors.go b/trie/errors.go index e6d61f0228..0c1b785bf0 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -83,17 +83,3 @@ func NewExpiredNodeError(path []byte, epoch types.StateEpoch) error { func (err *ExpiredNodeError) Error() string { return fmt.Sprintf("expired trie node, path: %v, epoch: %v", err.Path, err.Epoch) } - -type KeyDoesNotExistError struct { - Key []byte -} - -func NewKeyDoesNotExistError(key []byte) error { - return &KeyDoesNotExistError{ - Key: key, - } -} - -func (err *KeyDoesNotExistError) Error() string { - return "key does not exist" -} diff --git a/trie/proof.go b/trie/proof.go index c67b3d6c3b..7e62c4a8b1 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -181,7 +181,7 @@ func (t *Trie) traverseNodes(tn node, prefixKey, suffixKey []byte, nodes *[]node return tn, nil } -func (t *Trie) ProvePath(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueWriter) error { +func (t *Trie) ProveByPath(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueWriter) error { if t.committed { return ErrCommitted @@ -233,135 +233,8 @@ func (t *Trie) ProvePath(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValue return nil } -// VerifyPathProof reconstructs the trie from the given proof and verifies the root hash. -func VerifyPathProof(keyHex []byte, prefixKeyHex []byte, proofList [][]byte, epoch types.StateEpoch) (node, hashNode, error) { - - if len(proofList) == 0 { - return nil, nil, fmt.Errorf("proof list is empty") - } - - n, err := ConstructTrieFromProof(keyHex, prefixKeyHex, proofList, epoch) - if err != nil { - return nil, nil, err - } - - // hash the root node - hasher := newHasher(false) - defer returnHasherToPool(hasher) - hn, cn := hasher.hash(n, true) - if hash, ok := hn.(hashNode); ok { - return cn, hash, nil - } - - return nil, nil, fmt.Errorf("path proof verification failed") -} - -// ConstructTrieFromProof constructs a trie from the given proof. It returns the root node of the trie. -func ConstructTrieFromProof(keyHex []byte, prefixKeyHex []byte, proofList [][]byte, epoch types.StateEpoch) (node, error) { - if len(proofList) == 0 { - return nil, nil - } - h := newHasher(false) - defer returnHasherToPool(h) - keyHex = keyHex[len(prefixKeyHex):] - - root, err := decodeNode(nil, proofList[0]) - if err != nil { - return nil, fmt.Errorf("decode proof root %#x, err: %v", proofList[0], err) - } - // update epoch - switch n := root.(type) { - case *shortNode: - n.setEpoch(epoch) - case *fullNode: - n.setEpoch(epoch) - } - - parentNode := root - for i := 1; i < len(proofList); i++ { - n, err := decodeNode(nil, proofList[i]) - if err != nil { - return nil, fmt.Errorf("decode proof item %#x, err: %v", proofList[i], err) - } - - // verify proof continuous - keyrest, child := get(parentNode, keyHex, false) - switch cld := child.(type) { - case nil: - return nil, NewKeyDoesNotExistError(keyHex) - case hashNode: - hashed, _ := h.hash(n, false) - if !bytes.Equal(cld, hashed.(hashNode)) { - return nil, fmt.Errorf("the child node of shortNode is not a hashNode or doesn't match the hash in the proof") - } - default: - // proof's child cannot contain valueNode/shortNode/fullNode - return nil, fmt.Errorf("worng proof, got unexpect node, fstr: %v", child.fstring("")) - } - - // update epoch - switch n := n.(type) { - case *shortNode: - n.setEpoch(epoch) - case *fullNode: - n.setEpoch(epoch) - } - - // Link the parent and child. - switch sn := parentNode.(type) { - case *shortNode: - sn.Val = n - case *fullNode: - sn.Children[keyHex[0]] = n - sn.UpdateChildEpoch(int(keyHex[0]), epoch) - } - - // reset - parentNode = n - keyHex = keyrest - } - - return root, nil -} - -// updateEpochInChildNodes traverse down a node and update the epoch of the child nodes -func updateEpochInChildNodes(tn *node, key []byte, epoch types.StateEpoch) error { - - node := *tn - startNode := node - - for len(key) > 0 && node != nil { - switch n := node.(type) { - case *shortNode: - if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { - // The trie doesn't contain the key. - node = nil - } else { - node = n.Val - key = key[len(n.Key):] - } - n.setEpoch(epoch) - case *fullNode: - node = n.Children[key[0]] - n.UpdateChildEpoch(int(key[0]), epoch) - n.setEpoch(epoch) - - key = key[1:] - case nil, hashNode, valueNode: - *tn = startNode - return nil - default: - panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) - } - } - - *tn = startNode - - return nil -} - -func (t *StateTrie) ProvePath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error { - return t.trie.ProvePath(key, path, proofDb) +func (t *StateTrie) ProveByPath(key []byte, path []byte, proofDb ethdb.KeyValueWriter) error { + return t.trie.ProveByPath(key, path, proofDb) } // VerifyProof checks merkle proofs. The given proof must contain the value for diff --git a/trie/proof_test.go b/trie/proof_test.go index 16beee31b7..62f337f87b 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -73,31 +73,6 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database { return provers } -func TestOneElementPathProof(t *testing.T) { - trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), nil)) - updateString(trie, "k", "v") - - var proofList proofList - - trie.Prove([]byte("k"), &proofList) - if proofList == nil { - t.Fatalf("nil proof") - } - - if len(proofList) != 1 { - t.Errorf("proof should have one element") - } - - _, hn, err := VerifyPathProof(keybytesToHex([]byte("k")), nil, proofList, 0) - if err != nil { - t.Fatalf("failed to verify proof: %v\nraw proof: %x", err, proofList) - } - - if common.BytesToHash(hn) != trie.Hash() { - t.Fatalf("verified root mismatch: have %x, want %x", hn, trie.Hash()) - } -} - func TestProof(t *testing.T) { trie, vals := randomTrie(500) root := trie.Hash() diff --git a/trie/secure_trie.go b/trie/secure_trie.go index b695609997..e665c8c58d 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -320,7 +320,12 @@ func (t *StateTrie) getSecKeyCache() map[string][]byte { return t.secKeyCache } -func (t *StateTrie) ReviveTrie(key []byte, proof []*MPTProofNub) ([]*MPTProofNub, error) { +func (t *StateTrie) TryRevive(key []byte, proof []*MPTProofNub) ([]*MPTProofNub, error) { key = t.hashKey(key) - return t.trie.ReviveTrie(key, proof) + return t.trie.TryRevive(key, proof) +} + +func (t *StateTrie) TryLocalRevive(_ common.Address, key []byte) ([]byte, error) { + key = t.hashKey(key) + return t.trie.TryLocalRevive(key) } diff --git a/trie/trie.go b/trie/trie.go index 0fbb8f69ac..6a74636b45 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -1171,16 +1171,12 @@ func (t *Trie) Owner() common.Hash { return t.owner } -// ReviveTrie attempts to revive a trie from a list of MPTProofNubs. +// TryRevive attempts to revive a trie from a list of MPTProofNubs. // ReviveTrie performs full or partial revive and returns a list of successful // nubs. ReviveTrie does not guarantee that a value will be revived completely, // if the proof is not fully valid. -func (t *Trie) ReviveTrie(key []byte, proof []*MPTProofNub) ([]*MPTProofNub, error) { - key = keybytesToHex(key) - return t.TryRevive(key, proof) -} - func (t *Trie) TryRevive(key []byte, proof []*MPTProofNub) ([]*MPTProofNub, error) { + key = keybytesToHex(key) successNubs := make([]*MPTProofNub, 0, len(proof)) reviveMeter.Mark(int64(len(proof))) // Revive trie with each proof nub diff --git a/trie/trie_expiry.go b/trie/trie_expiry.go new file mode 100644 index 0000000000..1dd6f098a1 --- /dev/null +++ b/trie/trie_expiry.go @@ -0,0 +1,72 @@ +package trie + +import ( + "bytes" + "fmt" + "github.com/ethereum/go-ethereum/core/types" +) + +func (t *Trie) TryLocalRevive(key []byte) ([]byte, error) { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return nil, ErrCommitted + } + + key = keybytesToHex(key) + val, newroot, didResolve, err := t.tryLocalRevive(t.root, key, 0, t.getRootEpoch()) + if err == nil && didResolve { + t.root = newroot + t.rootEpoch = t.currentEpoch + } + return val, err +} + +func (t *Trie) tryLocalRevive(origNode node, key []byte, pos int, epoch types.StateEpoch) ([]byte, node, bool, error) { + expired := t.epochExpired(origNode, epoch) + switch n := (origNode).(type) { + case nil: + return nil, nil, false, nil + case valueNode: + return n, n, expired, nil + case *shortNode: + if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + // key not found in trie + return nil, n, false, nil + } + value, newnode, didResolve, err := t.tryLocalRevive(n.Val, key, pos+len(n.Key), epoch) + if err == nil && t.renewNode(epoch, didResolve, expired) { + n = n.copy() + n.Val = newnode + n.setEpoch(t.currentEpoch) + didResolve = true + } + return value, n, didResolve, err + case *fullNode: + value, newnode, didResolve, err := t.tryLocalRevive(n.Children[key[pos]], key, pos+1, n.GetChildEpoch(int(key[pos]))) + if err == nil && t.renewNode(epoch, didResolve, expired) { + n = n.copy() + n.Children[key[pos]] = newnode + n.setEpoch(t.currentEpoch) + if newnode != nil { + n.UpdateChildEpoch(int(key[pos]), t.currentEpoch) + } + didResolve = true + } + return value, n, didResolve, err + case hashNode: + child, err := t.resolveAndTrack(n, key[:pos]) + if err != nil { + return nil, n, true, err + } + + if child, ok := child.(*fullNode); ok { + if err = t.resolveEpochMeta(child, epoch, key[:pos]); err != nil { + return nil, n, true, err + } + } + value, newnode, _, err := t.tryLocalRevive(child, key, pos, epoch) + return value, newnode, true, err + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) + } +} diff --git a/trie/trie_test.go b/trie/trie_test.go index c35f1609ee..b59b0fb3e5 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1008,7 +1008,7 @@ func TestRevive(t *testing.T) { for _, prefixKey := range prefixKeys { // Generate proof var proof proofList - err := trie.ProvePath(key, prefixKey, &proof) + err := trie.ProveByPath(key, prefixKey, &proof) assert.NoError(t, err) // Expire trie @@ -1019,7 +1019,7 @@ func TestRevive(t *testing.T) { assert.NoError(t, err) // Revive trie - _, err = trie.TryRevive(keybytesToHex(key), proofCache.CacheNubs()) + _, err = trie.TryRevive(key, proofCache.CacheNubs()) assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x, val %x", key, prefixKey, val) // Verifiy value exists after revive @@ -1053,7 +1053,7 @@ func TestReviveCustom(t *testing.T) { prefixKeys := getFullNodePrefixKeys(trie, key) for _, prefixKey := range prefixKeys { var proofList proofList - err := trie.ProvePath(key, prefixKey, &proofList) + err := trie.ProveByPath(key, prefixKey, &proofList) assert.NoError(t, err) trie.ExpireByPrefix(prefixKey) @@ -1063,7 +1063,7 @@ func TestReviveCustom(t *testing.T) { assert.NoError(t, err) // Revive trie - _, err = trie.TryRevive(keybytesToHex(key), proofCache.cacheNubs) + _, err = trie.TryRevive(key, proofCache.cacheNubs) assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x, val %x", key, prefixKey, val) res := trie.MustGet(key)