Skip to content

Commit

Permalink
feat: resharding proof (#12418)
Browse files Browse the repository at this point in the history
Generate proof for resharding (retain_multi_range) and check that it is
valid against existing retain tests, already including memtrie and
disktrie.

Surprisingly, #12390 allows to make the change almost invisible. The
actual work here is to move `&mut TrieRecorder` inside
`TrieChangesTracker` which is inside `MemTrieUpdate`.

Reasoning: when we work with `MemTrieUpdate`, it is in fact unique owner
of the logic to record proof (what I called "2nd stage" in the previous
PR, after "1st stage" of trie lookups). And now, if we are sure that
proof doesn't need values, `MemTrieUpdate` can fully hold the `&mut`,
because memtrie has all the necessary information, and we can record
nodes directly on node access instead of doing it retroactively (which
was really bad).

`TrieRecorder` now is passed as a separate mode to process memtrie
updates. We indeed need three modes - on loading we don't record
anything, for non-validators we save trie changes, for validators we
also save proofs. And after that, all we need to do for
`retain_split_shard` is to create `TrieRecorder` and in the end get
recorded storage from it.

Next step will be to validate this proof in the actual state witness
processing...
  • Loading branch information
Longarithm authored Nov 14, 2024
1 parent 8d180da commit 6934fc2
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 151 deletions.
20 changes: 10 additions & 10 deletions chain/chain/src/resharding/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::event_type::{ReshardingEventType, ReshardingSplitShardParams};
use super::types::ReshardingSender;
use crate::flat_storage_resharder::{FlatStorageResharder, FlatStorageResharderController};
use crate::types::RuntimeAdapter;
use crate::ChainStoreUpdate;
use near_chain_configs::{MutableConfigValue, ReshardingConfig, ReshardingHandle};
use near_chain_primitives::Error;
use near_epoch_manager::EpochManagerAdapter;
Expand All @@ -14,10 +15,10 @@ use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::{get_block_shard_uid, ShardLayout};
use near_primitives::types::chunk_extra::ChunkExtra;
use near_store::adapter::{StoreAdapter, StoreUpdateAdapter};
use near_store::trie::mem::mem_trie_update::TrackingMode;
use near_store::trie::ops::resharding::RetainMode;
use near_store::{DBCol, PartialStorage, ShardTries, ShardUId, Store};

use crate::ChainStoreUpdate;
use near_store::trie::TrieRecorder;
use near_store::{DBCol, ShardTries, ShardUId, Store};

pub struct ReshardingManager {
store: Store,
Expand Down Expand Up @@ -201,16 +202,15 @@ impl ReshardingManager {
"Creating child memtrie by retaining nodes in parent memtrie..."
);
let mut mem_tries = mem_tries.write().unwrap();
let mem_trie_update = mem_tries.update(*chunk_extra.state_root(), true)?;
let mut trie_recorder = TrieRecorder::new();
let mode = TrackingMode::RefcountsAndAccesses(&mut trie_recorder);
let mem_trie_update = mem_tries.update(*chunk_extra.state_root(), mode)?;

let (trie_changes, _) =
mem_trie_update.retain_split_shard(&boundary_account, retain_mode);
// TODO(#12019): proof generation
let partial_state = PartialState::default();
let partial_state_len = match &partial_state {
let trie_changes = mem_trie_update.retain_split_shard(&boundary_account, retain_mode);
let partial_storage = trie_recorder.recorded_storage();
let partial_state_len = match &partial_storage.nodes {
PartialState::TrieValues(values) => values.len(),
};
let partial_storage = PartialStorage { nodes: partial_state };
let mem_changes = trie_changes.mem_trie_changes.as_ref().unwrap();
let new_state_root = mem_tries.apply_memtrie_changes(block_height, mem_changes);
// TODO(resharding): set all fields of `ChunkExtra`. Consider stronger
Expand Down
3 changes: 2 additions & 1 deletion core/store/src/trie/mem/loading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::adapter::StoreAdapter;
use crate::flat::FlatStorageStatus;
use crate::trie::mem::arena::Arena;
use crate::trie::mem::construction::TrieConstructor;
use crate::trie::mem::mem_trie_update::TrackingMode;
use crate::trie::mem::parallel_loader::load_memtrie_in_parallel;
use crate::trie::ops::insert_delete::GenericTrieUpdateInsertDelete;
use crate::{DBCol, NibbleSlice, Store};
Expand Down Expand Up @@ -155,7 +156,7 @@ pub fn load_trie_from_flat_state_and_delta(
let old_state_root = get_state_root(store, prev_hash, shard_uid)?;
let new_state_root = get_state_root(store, hash, shard_uid)?;

let mut trie_update = mem_tries.update(old_state_root, false)?;
let mut trie_update = mem_tries.update(old_state_root, TrackingMode::None)?;
for (key, value) in changes.0 {
match value {
Some(value) => {
Expand Down
196 changes: 99 additions & 97 deletions core/store/src/trie/mem/mem_trie_update.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;

use near_primitives::errors::StorageError;
use near_primitives::hash::{hash, CryptoHash};
Expand All @@ -10,6 +9,7 @@ use crate::trie::ops::interface::{
GenericNodeOrIndex, GenericTrieUpdate, GenericTrieValue, GenericUpdatedNodeId,
GenericUpdatedTrieNode, GenericUpdatedTrieNodeWithSize,
};
use crate::trie::trie_recording::TrieRecorder;
use crate::trie::{Children, MemTrieChanges, TrieRefcountDeltaMap};
use crate::{RawTrieNode, RawTrieNodeWithSize, TrieChanges};

Expand Down Expand Up @@ -67,41 +67,67 @@ impl UpdatedMemTrieNodeWithSize {
}
}

/// Keeps hashes and encoded trie nodes accessed on updating memtrie.
pub struct TrieAccesses {
/// Hashes and encoded trie nodes.
pub nodes: HashMap<CryptoHash, Arc<[u8]>>,
/// Allows using in-memory tries to construct the trie node changes entirely
/// (for both in-memory and on-disk updates) because it's much faster.
pub enum TrackingMode<'a> {
/// Don't track any nodes.
None,
/// Track disk refcount changes for trie nodes.
Refcounts,
/// Track disk refcount changes and record all accessed trie nodes.
/// The latter one is needed to record storage proof which is handled by
/// `TrieRecorder`.
/// The main case why recording is needed is a branch with two children,
/// one of which got removed. In this case we need to read another child
/// and squash it together with parent.
RefcountsAndAccesses(&'a mut TrieRecorder),
}

/// Tracks intermediate trie changes, final version of which is to be committed
/// to disk after finishing trie update.
struct TrieChangesTracker {
struct TrieChangesTracker<'a> {
/// Counts hashes deleted so far.
/// Includes hashes of both trie nodes and state values!
refcount_deleted_hashes: BTreeMap<CryptoHash, u32>,
/// Counts state values inserted so far.
/// Separated from `refcount_deleted_hashes` to postpone hash computation
/// as far as possible.
refcount_inserted_values: BTreeMap<Vec<u8>, u32>,
/// All observed internal nodes.
/// Needed to prepare recorded storage.
/// Recorder for observed internal nodes.
/// Note that negative `refcount_deleted_hashes` does not fully cover it,
/// as node or value of the same hash can be removed and inserted for the
/// same update in different parts of trie!
accesses: TrieAccesses,
recorder: Option<&'a mut TrieRecorder>,
}

impl TrieChangesTracker {
impl<'a> TrieChangesTracker<'a> {
fn with_recorder(recorder: Option<&'a mut TrieRecorder>) -> Self {
Self {
refcount_deleted_hashes: BTreeMap::new(),
refcount_inserted_values: BTreeMap::new(),
recorder,
}
}

fn record<M: ArenaMemory>(&mut self, node: &MemTrieNodeView<'a, M>) {
let node_hash = node.node_hash();
let raw_node_serialized = borsh::to_vec(&node.to_raw_trie_node_with_size()).unwrap();
*self.refcount_deleted_hashes.entry(node_hash).or_default() += 1;
if let Some(recorder) = self.recorder.as_mut() {
recorder.record(&node_hash, raw_node_serialized.into());
}
}

/// Prepare final refcount difference and also return all trie accesses.
fn finalize(self) -> (TrieRefcountDeltaMap, TrieAccesses) {
fn finalize(self) -> TrieRefcountDeltaMap {
let mut refcount_delta_map = TrieRefcountDeltaMap::new();
for (value, rc) in self.refcount_inserted_values {
refcount_delta_map.add(hash(&value), value, rc);
}
for (hash, rc) in self.refcount_deleted_hashes {
refcount_delta_map.subtract(hash, rc);
}
(refcount_delta_map, self.accesses)
refcount_delta_map
}
}

Expand All @@ -117,7 +143,7 @@ pub struct MemTrieUpdate<'a, M: ArenaMemory> {
pub updated_nodes: Vec<Option<UpdatedMemTrieNodeWithSize>>,
/// Tracks trie changes necessary to make on-disk updates and recorded
/// storage.
tracked_trie_changes: Option<TrieChangesTracker>,
nodes_tracker: Option<TrieChangesTracker<'a>>,
}

impl<'a, M: ArenaMemory> GenericTrieUpdate<'a, MemTrieNodeId, FlatStateValue>
Expand Down Expand Up @@ -153,38 +179,34 @@ impl<'a, M: ArenaMemory> GenericTrieUpdate<'a, MemTrieNodeId, FlatStateValue>
}

fn store_value(&mut self, value: GenericTrieValue) -> FlatStateValue {
// First, set the value which will be stored in memtrie.
let flat_value = match &value {
GenericTrieValue::MemtrieOnly(value) => return value.clone(),
GenericTrieValue::MemtrieAndDisk(value) => FlatStateValue::on_disk(value.as_slice()),
let (flat_value, full_value) = match value {
// If value is provided only for memtrie, it is flat, so we can't
// record nodes. Just return flat value back.
// TODO: check consistency with trie recorder setup.
// `GenericTrieValue::MemtrieOnly` must not be used if
// `nodes_tracker` is set and vice versa.
GenericTrieValue::MemtrieOnly(flat_value) => return flat_value,
GenericTrieValue::MemtrieAndDisk(full_value) => {
(FlatStateValue::on_disk(full_value.as_slice()), full_value)
}
};

// Then, record disk changes if needed.
let Some(tracked_node_changes) = self.tracked_trie_changes.as_mut() else {
return flat_value;
};
let GenericTrieValue::MemtrieAndDisk(value) = value else {
// Otherwise, record disk changes if needed.
let Some(nodes_tracker) = self.nodes_tracker.as_mut() else {
return flat_value;
};
tracked_node_changes
.refcount_inserted_values
.entry(value)
.and_modify(|rc| *rc += 1)
.or_insert(1);
*nodes_tracker.refcount_inserted_values.entry(full_value).or_default() += 1;

flat_value
}

fn delete_value(&mut self, value: FlatStateValue) -> Result<(), StorageError> {
if let Some(tracked_node_changes) = self.tracked_trie_changes.as_mut() {
let hash = value.to_value_ref().hash;
tracked_node_changes
.refcount_deleted_hashes
.entry(hash)
.and_modify(|rc| *rc += 1)
.or_insert(1);
}
let Some(nodes_tracker) = self.nodes_tracker.as_mut() else {
return Ok(());
};

let hash = value.to_value_ref().hash;
*nodes_tracker.refcount_deleted_hashes.entry(hash).or_default() += 1;
Ok(())
}
}
Expand All @@ -194,23 +216,17 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> {
root: Option<MemTrieNodeId>,
memory: &'a M,
shard_uid: String,
track_trie_changes: bool,
mode: TrackingMode<'a>,
) -> Self {
let mut trie_update = Self {
root,
memory,
shard_uid,
updated_nodes: vec![],
tracked_trie_changes: if track_trie_changes {
Some(TrieChangesTracker {
refcount_inserted_values: BTreeMap::new(),
refcount_deleted_hashes: BTreeMap::new(),
accesses: TrieAccesses { nodes: HashMap::new() },
})
} else {
None
},
let nodes_tracker = match mode {
TrackingMode::None => None,
TrackingMode::Refcounts => Some(TrieChangesTracker::with_recorder(None)),
TrackingMode::RefcountsAndAccesses(recorder) => {
Some(TrieChangesTracker::with_recorder(Some(recorder)))
}
};
let mut trie_update =
Self { root, memory, shard_uid, updated_nodes: vec![], nodes_tracker };
assert_eq!(trie_update.convert_existing_to_updated(root), 0usize);
trie_update
}
Expand All @@ -230,29 +246,14 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> {
/// If the original node is None, it is a marker for the root of an empty
/// trie.
fn convert_existing_to_updated(&mut self, node: Option<MemTrieNodeId>) -> UpdatedMemTrieNodeId {
match node {
None => self.new_updated_node(UpdatedMemTrieNodeWithSize::empty()),
Some(node) => {
if let Some(tracked_trie_changes) = self.tracked_trie_changes.as_mut() {
let node_view = node.as_ptr(self.memory).view();
let node_hash = node_view.node_hash();
let raw_node_serialized =
borsh::to_vec(&node_view.to_raw_trie_node_with_size()).unwrap();
tracked_trie_changes
.accesses
.nodes
.insert(node_hash, raw_node_serialized.into());
tracked_trie_changes
.refcount_deleted_hashes
.entry(node_hash)
.and_modify(|rc| *rc += 1)
.or_insert(1);
}
self.new_updated_node(UpdatedMemTrieNodeWithSize::from_existing_node_view(
node.as_ptr(self.memory).view(),
))
}
let Some(node) = node else {
return self.new_updated_node(UpdatedMemTrieNodeWithSize::empty());
};
let node_view = node.as_ptr(self.memory).view();
if let Some(tracked_trie_changes) = self.nodes_tracker.as_mut() {
tracked_trie_changes.record(&node_view);
}
self.new_updated_node(UpdatedMemTrieNodeWithSize::from_existing_node_view(node_view))
}

/// Inserts the given key value pair into the trie.
Expand Down Expand Up @@ -419,11 +420,11 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> {
}

/// Converts the updates to trie changes as well as memtrie changes.
pub(crate) fn to_trie_changes(mut self) -> (TrieChanges, TrieAccesses) {
pub(crate) fn to_trie_changes(mut self) -> TrieChanges {
let old_root =
self.root.map(|root| root.as_ptr(self.memory).view().node_hash()).unwrap_or_default();
let (mut refcount_changes, accesses) = self
.tracked_trie_changes
let mut refcount_changes = self
.nodes_tracker
.take()
.expect("Cannot to_trie_changes for memtrie changes only")
.finalize();
Expand All @@ -436,20 +437,17 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> {
}
let (insertions, deletions) = refcount_changes.into_changes();

(
TrieChanges {
old_root,
new_root: mem_trie_changes
.node_ids_with_hashes
.last()
.map(|(_, hash)| *hash)
.unwrap_or_default(),
insertions,
deletions,
mem_trie_changes: Some(mem_trie_changes),
},
accesses,
)
TrieChanges {
old_root,
new_root: mem_trie_changes
.node_ids_with_hashes
.last()
.map(|(_, hash)| *hash)
.unwrap_or_default(),
insertions,
deletions,
mem_trie_changes: Some(mem_trie_changes),
}
}
}

Expand Down Expand Up @@ -522,6 +520,8 @@ mod tests {
use rand::Rng;
use std::collections::{HashMap, HashSet};

use super::TrackingMode;

struct TestTries {
mem: MemTries,
disk: ShardTries,
Expand All @@ -544,26 +544,28 @@ mod tests {
}

fn make_all_changes(&mut self, changes: Vec<(Vec<u8>, Option<Vec<u8>>)>) -> TrieChanges {
let mut update = self.mem.update(self.state_root, true).unwrap_or_else(|_| {
panic!("Trying to update root {:?} but it's not in memtries", self.state_root)
});
let mut update =
self.mem.update(self.state_root, TrackingMode::Refcounts).unwrap_or_else(|_| {
panic!("Trying to update root {:?} but it's not in memtries", self.state_root)
});
for (key, value) in changes {
if let Some(value) = value {
update.insert(&key, value).unwrap();
} else {
update.generic_delete(0, &key).unwrap();
}
}
update.to_trie_changes().0
update.to_trie_changes()
}

fn make_memtrie_changes_only(
&mut self,
changes: Vec<(Vec<u8>, Option<Vec<u8>>)>,
) -> MemTrieChanges {
let mut update = self.mem.update(self.state_root, false).unwrap_or_else(|_| {
panic!("Trying to update root {:?} but it's not in memtries", self.state_root)
});
let mut update =
self.mem.update(self.state_root, TrackingMode::None).unwrap_or_else(|_| {
panic!("Trying to update root {:?} but it's not in memtries", self.state_root)
});
for (key, value) in changes {
if let Some(value) = value {
update.insert_memtrie_only(&key, FlatStateValue::on_disk(&value)).unwrap();
Expand Down Expand Up @@ -942,7 +944,7 @@ mod tests {
changes: &str,
) -> CryptoHash {
let changes = parse_changes(changes);
let mut update = memtrie.update(prev_state_root, false).unwrap();
let mut update = memtrie.update(prev_state_root, TrackingMode::None).unwrap();

for (key, value) in changes {
if let Some(value) = value {
Expand Down
Loading

0 comments on commit 6934fc2

Please sign in to comment.