From 6934fc2893b673b7aa6fe72d868db72a2e2820b7 Mon Sep 17 00:00:00 2001 From: Aleksandr Logunov Date: Thu, 14 Nov 2024 01:30:41 +0100 Subject: [PATCH] feat: resharding proof (#12418) Generate proof for resharding (retain_multi_range) and check that it is valid against existing retain tests, already including memtrie and disktrie. Surprisingly, #12390 allows to make the change almost invisible. The actual work here is to move `&mut TrieRecorder` inside `TrieChangesTracker` which is inside `MemTrieUpdate`. Reasoning: when we work with `MemTrieUpdate`, it is in fact unique owner of the logic to record proof (what I called "2nd stage" in the previous PR, after "1st stage" of trie lookups). And now, if we are sure that proof doesn't need values, `MemTrieUpdate` can fully hold the `&mut`, because memtrie has all the necessary information, and we can record nodes directly on node access instead of doing it retroactively (which was really bad). `TrieRecorder` now is passed as a separate mode to process memtrie updates. We indeed need three modes - on loading we don't record anything, for non-validators we save trie changes, for validators we also save proofs. And after that, all we need to do for `retain_split_shard` is to create `TrieRecorder` and in the end get recorded storage from it. Next step will be to validate this proof in the actual state witness processing... --- chain/chain/src/resharding/manager.rs | 20 +-- core/store/src/trie/mem/loading.rs | 3 +- core/store/src/trie/mem/mem_trie_update.rs | 196 +++++++++++---------- core/store/src/trie/mem/mem_tries.rs | 17 +- core/store/src/trie/mem/resharding.rs | 32 ++-- core/store/src/trie/mod.rs | 32 ++-- 6 files changed, 149 insertions(+), 151 deletions(-) diff --git a/chain/chain/src/resharding/manager.rs b/chain/chain/src/resharding/manager.rs index a0ee43cf482..2771b377b21 100644 --- a/chain/chain/src/resharding/manager.rs +++ b/chain/chain/src/resharding/manager.rs @@ -5,6 +5,7 @@ use super::event_type::{ReshardingEventType, ReshardingSplitShardParams}; use super::types::ReshardingSender; use crate::flat_storage_resharder::{FlatStorageResharder, FlatStorageResharderController}; use crate::types::RuntimeAdapter; +use crate::ChainStoreUpdate; use near_chain_configs::{MutableConfigValue, ReshardingConfig, ReshardingHandle}; use near_chain_primitives::Error; use near_epoch_manager::EpochManagerAdapter; @@ -14,10 +15,10 @@ use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::{get_block_shard_uid, ShardLayout}; use near_primitives::types::chunk_extra::ChunkExtra; use near_store::adapter::{StoreAdapter, StoreUpdateAdapter}; +use near_store::trie::mem::mem_trie_update::TrackingMode; use near_store::trie::ops::resharding::RetainMode; -use near_store::{DBCol, PartialStorage, ShardTries, ShardUId, Store}; - -use crate::ChainStoreUpdate; +use near_store::trie::TrieRecorder; +use near_store::{DBCol, ShardTries, ShardUId, Store}; pub struct ReshardingManager { store: Store, @@ -201,16 +202,15 @@ impl ReshardingManager { "Creating child memtrie by retaining nodes in parent memtrie..." ); let mut mem_tries = mem_tries.write().unwrap(); - let mem_trie_update = mem_tries.update(*chunk_extra.state_root(), true)?; + let mut trie_recorder = TrieRecorder::new(); + let mode = TrackingMode::RefcountsAndAccesses(&mut trie_recorder); + let mem_trie_update = mem_tries.update(*chunk_extra.state_root(), mode)?; - let (trie_changes, _) = - mem_trie_update.retain_split_shard(&boundary_account, retain_mode); - // TODO(#12019): proof generation - let partial_state = PartialState::default(); - let partial_state_len = match &partial_state { + let trie_changes = mem_trie_update.retain_split_shard(&boundary_account, retain_mode); + let partial_storage = trie_recorder.recorded_storage(); + let partial_state_len = match &partial_storage.nodes { PartialState::TrieValues(values) => values.len(), }; - let partial_storage = PartialStorage { nodes: partial_state }; let mem_changes = trie_changes.mem_trie_changes.as_ref().unwrap(); let new_state_root = mem_tries.apply_memtrie_changes(block_height, mem_changes); // TODO(resharding): set all fields of `ChunkExtra`. Consider stronger diff --git a/core/store/src/trie/mem/loading.rs b/core/store/src/trie/mem/loading.rs index ce7068b3e4e..2004f75b3c7 100644 --- a/core/store/src/trie/mem/loading.rs +++ b/core/store/src/trie/mem/loading.rs @@ -5,6 +5,7 @@ use crate::adapter::StoreAdapter; use crate::flat::FlatStorageStatus; use crate::trie::mem::arena::Arena; use crate::trie::mem::construction::TrieConstructor; +use crate::trie::mem::mem_trie_update::TrackingMode; use crate::trie::mem::parallel_loader::load_memtrie_in_parallel; use crate::trie::ops::insert_delete::GenericTrieUpdateInsertDelete; use crate::{DBCol, NibbleSlice, Store}; @@ -155,7 +156,7 @@ pub fn load_trie_from_flat_state_and_delta( let old_state_root = get_state_root(store, prev_hash, shard_uid)?; let new_state_root = get_state_root(store, hash, shard_uid)?; - let mut trie_update = mem_tries.update(old_state_root, false)?; + let mut trie_update = mem_tries.update(old_state_root, TrackingMode::None)?; for (key, value) in changes.0 { match value { Some(value) => { diff --git a/core/store/src/trie/mem/mem_trie_update.rs b/core/store/src/trie/mem/mem_trie_update.rs index c39acf1634d..2257e5d0c21 100644 --- a/core/store/src/trie/mem/mem_trie_update.rs +++ b/core/store/src/trie/mem/mem_trie_update.rs @@ -1,5 +1,4 @@ use std::collections::{BTreeMap, HashMap}; -use std::sync::Arc; use near_primitives::errors::StorageError; use near_primitives::hash::{hash, CryptoHash}; @@ -10,6 +9,7 @@ use crate::trie::ops::interface::{ GenericNodeOrIndex, GenericTrieUpdate, GenericTrieValue, GenericUpdatedNodeId, GenericUpdatedTrieNode, GenericUpdatedTrieNodeWithSize, }; +use crate::trie::trie_recording::TrieRecorder; use crate::trie::{Children, MemTrieChanges, TrieRefcountDeltaMap}; use crate::{RawTrieNode, RawTrieNodeWithSize, TrieChanges}; @@ -67,15 +67,25 @@ impl UpdatedMemTrieNodeWithSize { } } -/// Keeps hashes and encoded trie nodes accessed on updating memtrie. -pub struct TrieAccesses { - /// Hashes and encoded trie nodes. - pub nodes: HashMap>, +/// Allows using in-memory tries to construct the trie node changes entirely +/// (for both in-memory and on-disk updates) because it's much faster. +pub enum TrackingMode<'a> { + /// Don't track any nodes. + None, + /// Track disk refcount changes for trie nodes. + Refcounts, + /// Track disk refcount changes and record all accessed trie nodes. + /// The latter one is needed to record storage proof which is handled by + /// `TrieRecorder`. + /// The main case why recording is needed is a branch with two children, + /// one of which got removed. In this case we need to read another child + /// and squash it together with parent. + RefcountsAndAccesses(&'a mut TrieRecorder), } /// Tracks intermediate trie changes, final version of which is to be committed /// to disk after finishing trie update. -struct TrieChangesTracker { +struct TrieChangesTracker<'a> { /// Counts hashes deleted so far. /// Includes hashes of both trie nodes and state values! refcount_deleted_hashes: BTreeMap, @@ -83,17 +93,33 @@ struct TrieChangesTracker { /// Separated from `refcount_deleted_hashes` to postpone hash computation /// as far as possible. refcount_inserted_values: BTreeMap, u32>, - /// All observed internal nodes. - /// Needed to prepare recorded storage. + /// Recorder for observed internal nodes. /// Note that negative `refcount_deleted_hashes` does not fully cover it, /// as node or value of the same hash can be removed and inserted for the /// same update in different parts of trie! - accesses: TrieAccesses, + recorder: Option<&'a mut TrieRecorder>, } -impl TrieChangesTracker { +impl<'a> TrieChangesTracker<'a> { + fn with_recorder(recorder: Option<&'a mut TrieRecorder>) -> Self { + Self { + refcount_deleted_hashes: BTreeMap::new(), + refcount_inserted_values: BTreeMap::new(), + recorder, + } + } + + fn record(&mut self, node: &MemTrieNodeView<'a, M>) { + let node_hash = node.node_hash(); + let raw_node_serialized = borsh::to_vec(&node.to_raw_trie_node_with_size()).unwrap(); + *self.refcount_deleted_hashes.entry(node_hash).or_default() += 1; + if let Some(recorder) = self.recorder.as_mut() { + recorder.record(&node_hash, raw_node_serialized.into()); + } + } + /// Prepare final refcount difference and also return all trie accesses. - fn finalize(self) -> (TrieRefcountDeltaMap, TrieAccesses) { + fn finalize(self) -> TrieRefcountDeltaMap { let mut refcount_delta_map = TrieRefcountDeltaMap::new(); for (value, rc) in self.refcount_inserted_values { refcount_delta_map.add(hash(&value), value, rc); @@ -101,7 +127,7 @@ impl TrieChangesTracker { for (hash, rc) in self.refcount_deleted_hashes { refcount_delta_map.subtract(hash, rc); } - (refcount_delta_map, self.accesses) + refcount_delta_map } } @@ -117,7 +143,7 @@ pub struct MemTrieUpdate<'a, M: ArenaMemory> { pub updated_nodes: Vec>, /// Tracks trie changes necessary to make on-disk updates and recorded /// storage. - tracked_trie_changes: Option, + nodes_tracker: Option>, } impl<'a, M: ArenaMemory> GenericTrieUpdate<'a, MemTrieNodeId, FlatStateValue> @@ -153,38 +179,34 @@ impl<'a, M: ArenaMemory> GenericTrieUpdate<'a, MemTrieNodeId, FlatStateValue> } fn store_value(&mut self, value: GenericTrieValue) -> FlatStateValue { - // First, set the value which will be stored in memtrie. - let flat_value = match &value { - GenericTrieValue::MemtrieOnly(value) => return value.clone(), - GenericTrieValue::MemtrieAndDisk(value) => FlatStateValue::on_disk(value.as_slice()), + let (flat_value, full_value) = match value { + // If value is provided only for memtrie, it is flat, so we can't + // record nodes. Just return flat value back. + // TODO: check consistency with trie recorder setup. + // `GenericTrieValue::MemtrieOnly` must not be used if + // `nodes_tracker` is set and vice versa. + GenericTrieValue::MemtrieOnly(flat_value) => return flat_value, + GenericTrieValue::MemtrieAndDisk(full_value) => { + (FlatStateValue::on_disk(full_value.as_slice()), full_value) + } }; - // Then, record disk changes if needed. - let Some(tracked_node_changes) = self.tracked_trie_changes.as_mut() else { - return flat_value; - }; - let GenericTrieValue::MemtrieAndDisk(value) = value else { + // Otherwise, record disk changes if needed. + let Some(nodes_tracker) = self.nodes_tracker.as_mut() else { return flat_value; }; - tracked_node_changes - .refcount_inserted_values - .entry(value) - .and_modify(|rc| *rc += 1) - .or_insert(1); + *nodes_tracker.refcount_inserted_values.entry(full_value).or_default() += 1; flat_value } fn delete_value(&mut self, value: FlatStateValue) -> Result<(), StorageError> { - if let Some(tracked_node_changes) = self.tracked_trie_changes.as_mut() { - let hash = value.to_value_ref().hash; - tracked_node_changes - .refcount_deleted_hashes - .entry(hash) - .and_modify(|rc| *rc += 1) - .or_insert(1); - } + let Some(nodes_tracker) = self.nodes_tracker.as_mut() else { + return Ok(()); + }; + let hash = value.to_value_ref().hash; + *nodes_tracker.refcount_deleted_hashes.entry(hash).or_default() += 1; Ok(()) } } @@ -194,23 +216,17 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { root: Option, memory: &'a M, shard_uid: String, - track_trie_changes: bool, + mode: TrackingMode<'a>, ) -> Self { - let mut trie_update = Self { - root, - memory, - shard_uid, - updated_nodes: vec![], - tracked_trie_changes: if track_trie_changes { - Some(TrieChangesTracker { - refcount_inserted_values: BTreeMap::new(), - refcount_deleted_hashes: BTreeMap::new(), - accesses: TrieAccesses { nodes: HashMap::new() }, - }) - } else { - None - }, + let nodes_tracker = match mode { + TrackingMode::None => None, + TrackingMode::Refcounts => Some(TrieChangesTracker::with_recorder(None)), + TrackingMode::RefcountsAndAccesses(recorder) => { + Some(TrieChangesTracker::with_recorder(Some(recorder))) + } }; + let mut trie_update = + Self { root, memory, shard_uid, updated_nodes: vec![], nodes_tracker }; assert_eq!(trie_update.convert_existing_to_updated(root), 0usize); trie_update } @@ -230,29 +246,14 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// If the original node is None, it is a marker for the root of an empty /// trie. fn convert_existing_to_updated(&mut self, node: Option) -> UpdatedMemTrieNodeId { - match node { - None => self.new_updated_node(UpdatedMemTrieNodeWithSize::empty()), - Some(node) => { - if let Some(tracked_trie_changes) = self.tracked_trie_changes.as_mut() { - let node_view = node.as_ptr(self.memory).view(); - let node_hash = node_view.node_hash(); - let raw_node_serialized = - borsh::to_vec(&node_view.to_raw_trie_node_with_size()).unwrap(); - tracked_trie_changes - .accesses - .nodes - .insert(node_hash, raw_node_serialized.into()); - tracked_trie_changes - .refcount_deleted_hashes - .entry(node_hash) - .and_modify(|rc| *rc += 1) - .or_insert(1); - } - self.new_updated_node(UpdatedMemTrieNodeWithSize::from_existing_node_view( - node.as_ptr(self.memory).view(), - )) - } + let Some(node) = node else { + return self.new_updated_node(UpdatedMemTrieNodeWithSize::empty()); + }; + let node_view = node.as_ptr(self.memory).view(); + if let Some(tracked_trie_changes) = self.nodes_tracker.as_mut() { + tracked_trie_changes.record(&node_view); } + self.new_updated_node(UpdatedMemTrieNodeWithSize::from_existing_node_view(node_view)) } /// Inserts the given key value pair into the trie. @@ -419,11 +420,11 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } /// Converts the updates to trie changes as well as memtrie changes. - pub(crate) fn to_trie_changes(mut self) -> (TrieChanges, TrieAccesses) { + pub(crate) fn to_trie_changes(mut self) -> TrieChanges { let old_root = self.root.map(|root| root.as_ptr(self.memory).view().node_hash()).unwrap_or_default(); - let (mut refcount_changes, accesses) = self - .tracked_trie_changes + let mut refcount_changes = self + .nodes_tracker .take() .expect("Cannot to_trie_changes for memtrie changes only") .finalize(); @@ -436,20 +437,17 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } let (insertions, deletions) = refcount_changes.into_changes(); - ( - TrieChanges { - old_root, - new_root: mem_trie_changes - .node_ids_with_hashes - .last() - .map(|(_, hash)| *hash) - .unwrap_or_default(), - insertions, - deletions, - mem_trie_changes: Some(mem_trie_changes), - }, - accesses, - ) + TrieChanges { + old_root, + new_root: mem_trie_changes + .node_ids_with_hashes + .last() + .map(|(_, hash)| *hash) + .unwrap_or_default(), + insertions, + deletions, + mem_trie_changes: Some(mem_trie_changes), + } } } @@ -522,6 +520,8 @@ mod tests { use rand::Rng; use std::collections::{HashMap, HashSet}; + use super::TrackingMode; + struct TestTries { mem: MemTries, disk: ShardTries, @@ -544,9 +544,10 @@ mod tests { } fn make_all_changes(&mut self, changes: Vec<(Vec, Option>)>) -> TrieChanges { - let mut update = self.mem.update(self.state_root, true).unwrap_or_else(|_| { - panic!("Trying to update root {:?} but it's not in memtries", self.state_root) - }); + let mut update = + self.mem.update(self.state_root, TrackingMode::Refcounts).unwrap_or_else(|_| { + panic!("Trying to update root {:?} but it's not in memtries", self.state_root) + }); for (key, value) in changes { if let Some(value) = value { update.insert(&key, value).unwrap(); @@ -554,16 +555,17 @@ mod tests { update.generic_delete(0, &key).unwrap(); } } - update.to_trie_changes().0 + update.to_trie_changes() } fn make_memtrie_changes_only( &mut self, changes: Vec<(Vec, Option>)>, ) -> MemTrieChanges { - let mut update = self.mem.update(self.state_root, false).unwrap_or_else(|_| { - panic!("Trying to update root {:?} but it's not in memtries", self.state_root) - }); + let mut update = + self.mem.update(self.state_root, TrackingMode::None).unwrap_or_else(|_| { + panic!("Trying to update root {:?} but it's not in memtries", self.state_root) + }); for (key, value) in changes { if let Some(value) = value { update.insert_memtrie_only(&key, FlatStateValue::on_disk(&value)).unwrap(); @@ -942,7 +944,7 @@ mod tests { changes: &str, ) -> CryptoHash { let changes = parse_changes(changes); - let mut update = memtrie.update(prev_state_root, false).unwrap(); + let mut update = memtrie.update(prev_state_root, TrackingMode::None).unwrap(); for (key, value) in changes { if let Some(value) = value { diff --git a/core/store/src/trie/mem/mem_tries.rs b/core/store/src/trie/mem/mem_tries.rs index b74cae0a67f..5ad105d0c1b 100644 --- a/core/store/src/trie/mem/mem_tries.rs +++ b/core/store/src/trie/mem/mem_tries.rs @@ -18,7 +18,7 @@ use super::arena::FrozenArena; use super::flexible_data::value::ValueView; use super::iter::STMemTrieIterator; use super::lookup::memtrie_lookup; -use super::mem_trie_update::{construct_root_from_changes, MemTrieUpdate}; +use super::mem_trie_update::{construct_root_from_changes, MemTrieUpdate, TrackingMode}; use super::node::{MemTrieNodeId, MemTrieNodePtr}; /// `MemTries` (logically) owns the memory of multiple tries. @@ -177,19 +177,14 @@ impl MemTries { .set(self.roots.len() as i64); } - pub fn update( - &self, + pub fn update<'a>( + &'a self, root: CryptoHash, - track_trie_changes: bool, - ) -> Result, StorageError> { + mode: TrackingMode<'a>, + ) -> Result, StorageError> { let root_id = if root == CryptoHash::default() { None } else { Some(self.get_root(&root)?.id()) }; - Ok(MemTrieUpdate::new( - root_id, - &self.arena.memory(), - self.shard_uid.to_string(), - track_trie_changes, - )) + Ok(MemTrieUpdate::new(root_id, &self.arena.memory(), self.shard_uid.to_string(), mode)) } /// Returns an iterator over the memtrie for the given trie root. diff --git a/core/store/src/trie/mem/resharding.rs b/core/store/src/trie/mem/resharding.rs index a0c14e785e7..624616a6f7b 100644 --- a/core/store/src/trie/mem/resharding.rs +++ b/core/store/src/trie/mem/resharding.rs @@ -7,7 +7,7 @@ use crate::trie::trie_storage_update::TrieStorageUpdate; use crate::{Trie, TrieChanges}; use super::arena::ArenaMemory; -use super::mem_trie_update::{MemTrieUpdate, TrieAccesses}; +use super::mem_trie_update::MemTrieUpdate; use near_primitives::errors::StorageError; use near_primitives::types::{AccountId, StateRoot}; use std::ops::Range; @@ -23,7 +23,7 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { self, boundary_account: &AccountId, retain_mode: RetainMode, - ) -> (TrieChanges, TrieAccesses) { + ) -> TrieChanges { let intervals = boundary_account_to_intervals(boundary_account, retain_mode); self.retain_multi_range(&intervals) } @@ -33,12 +33,11 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// /// Returns changes to be applied to in-memory trie and proof of the /// retain operation. - fn retain_multi_range(mut self, intervals: &[Range>]) -> (TrieChanges, TrieAccesses) { + fn retain_multi_range(mut self, intervals: &[Range>]) -> TrieChanges { debug_assert!(intervals.iter().all(|range| range.start < range.end)); let intervals_nibbles = intervals_to_nibbles(intervals); - // TODO(#12074): consider handling the case when no changes are made. - // TODO(#12074): restore proof as well. + // TODO(resharding): consider handling the case when no changes are made. self.retain_multi_range_recursive(0, vec![], &intervals_nibbles).unwrap(); self.to_trie_changes() } @@ -72,10 +71,7 @@ impl Trie { } } -// TODO(#12074): tests for -// - `retain_split_shard` API -// - checking not accessing not-inlined values -// - proof correctness +// TODO(resharding): consider adding tests for `retain_split_shard` API. #[cfg(test)] mod tests { use rand::rngs::StdRng; @@ -89,10 +85,12 @@ mod tests { use crate::test_utils::TestTriesBuilder; use crate::trie::mem::iter::MemTrieIterator; + use crate::trie::mem::mem_trie_update::TrackingMode; use crate::trie::mem::mem_tries::MemTries; use crate::trie::mem::nibbles_utils::{ all_two_nibble_nibbles, hex_to_nibbles, multi_hex_to_nibbles, }; + use crate::trie::trie_recording::TrieRecorder; use crate::trie::trie_storage::TrieMemoryPartialStorage; use crate::trie::Trie; @@ -135,17 +133,24 @@ mod tests { let expected_disk_state_root = trie.retain_multi_range(&retain_multi_ranges).unwrap(); let mut memtries = MemTries::new(ShardUId::single_shard()); - let mut update = memtries.update(Trie::EMPTY_ROOT, false).unwrap(); + let mut update = memtries.update(Trie::EMPTY_ROOT, TrackingMode::None).unwrap(); for (key, value) in initial_entries { update.insert(&key, value).unwrap(); } let memtrie_changes = update.to_mem_trie_changes_only(); let state_root = memtries.apply_memtrie_changes(0, &memtrie_changes); - let update = memtries.update(state_root, true).unwrap(); - let (mut trie_changes, _) = update.retain_multi_range(&retain_multi_ranges); + let mut trie_recorder = TrieRecorder::new(); + let mode = TrackingMode::RefcountsAndAccesses(&mut trie_recorder); + let update = memtries.update(state_root, mode).unwrap(); + let mut trie_changes = update.retain_multi_range(&retain_multi_ranges); let memtrie_changes = trie_changes.mem_trie_changes.take().unwrap(); let mem_state_root = memtries.apply_memtrie_changes(1, &memtrie_changes); + let proof = trie_recorder.recorded_storage(); + + let partial_trie = Trie::from_recorded_storage(proof, state_root, false); + let expected_proof_based_state_root = + partial_trie.retain_multi_range(&retain_multi_ranges).unwrap(); let entries = if mem_state_root != StateRoot::default() { let state_root_ptr = memtries.get_root(&mem_state_root).unwrap(); @@ -164,6 +169,9 @@ mod tests { // Check state root with disk-trie state root. assert_eq!(mem_state_root, expected_disk_state_root); + + // Check state root resulting by retain based on partial storage. + assert_eq!(mem_state_root, expected_proof_based_state_root); } #[test] diff --git a/core/store/src/trie/mod.rs b/core/store/src/trie/mod.rs index 20fce75f144..1ec918c85a8 100644 --- a/core/store/src/trie/mod.rs +++ b/core/store/src/trie/mod.rs @@ -1,7 +1,6 @@ use self::accounting_cache::TrieAccountingCache; use self::iterator::DiskTrieIterator; use self::mem::flexible_data::value::ValueView; -use self::trie_recording::TrieRecorder; use self::trie_storage::TrieMemoryPartialStorage; use crate::flat::{FlatStateChanges, FlatStorageChunkView}; pub use crate::trie::config::TrieConfig; @@ -19,7 +18,7 @@ pub use crate::trie::trie_storage::{TrieCache, TrieCachingStorage, TrieDBStorage use crate::StorageError; use borsh::{BorshDeserialize, BorshSerialize}; pub use from_flat::construct_trie_from_flat; -use mem::mem_trie_update::{UpdatedMemTrieNodeId, UpdatedMemTrieNodeWithSize}; +use mem::mem_trie_update::{TrackingMode, UpdatedMemTrieNodeId, UpdatedMemTrieNodeWithSize}; use mem::mem_tries::MemTries; use near_primitives::challenge::PartialState; use near_primitives::hash::{hash, CryptoHash}; @@ -40,9 +39,10 @@ use std::cell::RefCell; use std::collections::{BTreeMap, HashSet}; use std::fmt::Write; use std::hash::Hash; +use std::ops::DerefMut; use std::str; use std::sync::{Arc, RwLock, RwLockReadGuard}; -pub use trie_recording::{SubtreeSize, TrieRecorderStats}; +pub use trie_recording::{SubtreeSize, TrieRecorder, TrieRecorderStats}; #[cfg(test)] use trie_storage_update::UpdatedTrieStorageNode; use trie_storage_update::{TrieStorageUpdate, UpdatedTrieStorageNodeWithSize}; @@ -1638,30 +1638,22 @@ impl Trie { match &self.memtries { Some(memtries) => { - // If we have in-memory tries, use it to construct the changes entirely (for - // both in-memory and on-disk updates) because it's much faster. let guard = memtries.read().unwrap(); - let mut trie_update = guard.update(self.root, true)?; + let mut recorder = self.recorder.as_ref().map(|recorder| recorder.borrow_mut()); + let tracking_mode = match &mut recorder { + Some(recorder) => TrackingMode::RefcountsAndAccesses(recorder.deref_mut()), + None => TrackingMode::Refcounts, + }; + + let mut trie_update = guard.update(self.root, tracking_mode)?; for (key, value) in changes { match value { Some(arr) => trie_update.insert(&key, arr)?, None => trie_update.generic_delete(0, &key)?, } } - let (trie_changes, trie_accesses) = trie_update.to_trie_changes(); - - // Retroactively record all accessed trie items which are - // required to process trie update but were not recorded at - // processing lookups. - // The main case is a branch with two children, one of which - // got removed, so we need to read another one and squash it - // together with parent. - if let Some(recorder) = &self.recorder { - for (node_hash, serialized_node) in trie_accesses.nodes { - recorder.borrow_mut().record(&node_hash, serialized_node); - } - } - Ok(trie_changes) + + Ok(trie_update.to_trie_changes()) } None => { let mut trie_update = TrieStorageUpdate::new(&self);