diff --git a/tools/fork-network/src/cli.rs b/tools/fork-network/src/cli.rs index f4f3a10fff5..4db225fb334 100644 --- a/tools/fork-network/src/cli.rs +++ b/tools/fork-network/src/cli.rs @@ -6,8 +6,7 @@ use near_chain::types::{RuntimeAdapter, Tip}; use near_chain::{ChainStore, ChainStoreAccess}; use near_chain_configs::{Genesis, GenesisConfig, GenesisValidationMode, NEAR_BASE}; use near_crypto::PublicKey; -use near_epoch_manager::shard_assignment::shard_id_to_uid; -use near_epoch_manager::{EpochManager, EpochManagerAdapter, EpochManagerHandle}; +use near_epoch_manager::{EpochManager, EpochManagerAdapter}; use near_mirror::key_mapping::{map_account, map_key}; use near_o11y::default_subscriber_with_opentelemetry; use near_o11y::env_filter::make_env_filter; @@ -16,7 +15,6 @@ use near_primitives::account::id::AccountType; use near_primitives::account::{AccessKey, AccessKeyPermission, Account, AccountContract}; use near_primitives::borsh; use near_primitives::epoch_manager::{EpochConfig, EpochConfigStore}; -use near_primitives::hash::CryptoHash; use near_primitives::serialize::dec_format; use near_primitives::shard_layout::{ShardLayout, ShardUId}; use near_primitives::state::FlatStateValue; @@ -24,7 +22,7 @@ use near_primitives::state_record::StateRecord; use near_primitives::trie_key::col; use near_primitives::trie_key::trie_key_parsers::parse_account_id_from_account_key; use near_primitives::types::{ - AccountId, AccountInfo, Balance, BlockHeight, EpochId, NumBlocks, NumSeats, ShardId, StateRoot, + AccountId, AccountInfo, Balance, BlockHeight, EpochId, NumBlocks, NumSeats, StateRoot, }; use near_primitives::version::{ProtocolVersion, PROTOCOL_VERSION}; use near_store::adapter::StoreAdapter; @@ -37,11 +35,10 @@ use near_store::{ use nearcore::{load_config, open_storage, NearConfig, NightshadeRuntime, NightshadeRuntimeExt}; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::Deserialize; -use std::collections::{BTreeMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; -use std::str::FromStr; use std::sync::Arc; use strum::IntoEnumIterator; @@ -84,7 +81,11 @@ enum SubCommand { } #[derive(clap::Parser)] -struct InitCmd; +struct InitCmd { + /// If given, the shard layout in this file will be used to generate the forked genesis state + #[arg(long)] + pub shard_layout_file: Option, +} #[derive(clap::Parser)] struct FinalizeCmd; @@ -128,18 +129,20 @@ struct SetValidatorsCmd { pub num_seats: Option, } -const FORKED_ROOTS_KEY_PREFIX: &str = "FORK_TOOL_SHARD_ID:"; +const FORKED_ROOTS_KEY_PREFIX: &[u8; 20] = b"FORK_TOOL_SHARD_UID:"; -fn parse_state_roots_key(key: &[u8]) -> anyhow::Result { - let key = std::str::from_utf8(key)?; +fn parse_state_roots_key(key: &[u8]) -> anyhow::Result { // Sanity check assertion since we should be iterating based on this prefix assert!(key.starts_with(FORKED_ROOTS_KEY_PREFIX)); - let int_part = &key[FORKED_ROOTS_KEY_PREFIX.len()..]; - ShardId::from_str(int_part).with_context(|| format!("Failed parsing ShardId from {}", int_part)) + let shard_uid_part = &key[FORKED_ROOTS_KEY_PREFIX.len()..]; + borsh::from_slice(shard_uid_part) + .with_context(|| format!("Failed parsing ShardUId from fork tool key {:?}", key)) } -fn make_state_roots_key(shard_id: ShardId) -> Vec { - format!("{FORKED_ROOTS_KEY_PREFIX}{shard_id}").into_bytes() +pub(crate) fn make_state_roots_key(shard_uid: ShardUId) -> Vec { + let mut key = FORKED_ROOTS_KEY_PREFIX.to_vec(); + key.append(&mut borsh::to_vec(&shard_uid).unwrap()); + key } /// The minimum set of columns that will be needed to start a node after the `finalize` command runs @@ -209,8 +212,8 @@ impl ForkNetworkCommand { near_config.config.store.state_snapshot_enabled = false; match &self.command { - SubCommand::Init(InitCmd) => { - self.init(near_config, home_dir)?; + SubCommand::Init(InitCmd { shard_layout_file }) => { + self.init(near_config, home_dir, shard_layout_file.as_deref())?; } SubCommand::AmendAccessKeys(AmendAccessKeysCmd { batch_size }) => { self.amend_access_keys(*batch_size, near_config, home_dir)?; @@ -272,7 +275,12 @@ impl ForkNetworkCommand { // Snapshots the DB. // Determines parameters that will be used to initialize the new chain. // After this completes, almost every DB column can be removed, however this command doesn't delete anything itself. - fn write_fork_info(&self, near_config: &mut NearConfig, home_dir: &Path) -> anyhow::Result<()> { + fn write_fork_info( + &self, + near_config: &mut NearConfig, + home_dir: &Path, + shard_layout_file: Option<&Path>, + ) -> anyhow::Result<()> { // Open storage with migration let storage = open_storage(&home_dir, near_config).unwrap(); let store = storage.get_hot_store(); @@ -286,6 +294,19 @@ impl ForkNetworkCommand { let head = store.get_ser::(DBCol::BlockMisc, FINAL_HEAD_KEY)?.unwrap(); let shard_layout = epoch_manager.get_shard_layout(&head.epoch_id)?; let all_shard_uids: Vec<_> = shard_layout.shard_uids().collect(); + + let target_shard_layout = match shard_layout_file { + Some(shard_layout_file) => { + let layout = std::fs::read_to_string(shard_layout_file).with_context(|| { + format!("failed reading shard layout file at {}", shard_layout_file.display()) + })?; + serde_json::from_str(&layout).with_context(|| { + format!("failed parsing shard layout file at {}", shard_layout_file.display()) + })? + } + None => shard_layout, + }; + // Flat state can be at different heights for different shards. // That is fine, we'll simply lookup state root for each . let fork_heads = get_fork_heads(&all_shard_uids, store.clone())?; @@ -298,54 +319,55 @@ impl ForkNetworkCommand { ); // Move flat storage to the max height for consistency across shards. - let (block_height, desired_block_hash) = - fork_heads.iter().map(|head| (head.height, head.hash)).max().unwrap(); + let desired_flat_head = fork_heads.iter().max_by_key(|b| b.height).unwrap(); - let desired_block_header = chain.get_block_header(&desired_block_hash)?; + let desired_block_header = chain.get_block_header(&desired_flat_head.hash)?; let epoch_id = desired_block_header.epoch_id(); let flat_storage_manager = FlatStorageManager::new(store.flat_store()); // Advance flat heads to the same (max) block height to ensure // consistency of state across the shards. - let state_roots: Vec<(ShardId, StateRoot)> = shard_layout - .shard_ids() - .map(|shard_id| { - let shard_uid = - shard_id_to_uid(epoch_manager.as_ref(), shard_id, epoch_id).unwrap(); + let state_roots: Vec<(ShardUId, StateRoot)> = all_shard_uids + .into_iter() + .map(|shard_uid| { flat_storage_manager.create_flat_storage_for_shard(shard_uid).unwrap(); let flat_storage = flat_storage_manager.get_flat_storage_for_shard(shard_uid).unwrap(); - flat_storage.update_flat_head(&desired_block_hash).unwrap(); - let chunk_extra = chain.get_chunk_extra(&desired_block_hash, &shard_uid).unwrap(); + flat_storage.update_flat_head(&desired_flat_head.hash).unwrap(); + let chunk_extra = + chain.get_chunk_extra(&desired_flat_head.hash, &shard_uid).unwrap(); let state_root = chunk_extra.state_root(); - tracing::info!(?shard_id, ?epoch_id, ?state_root); - (shard_id, *state_root) + tracing::info!(?shard_uid, ?epoch_id, ?state_root); + (shard_uid, *state_root) }) .collect(); // Increment height to represent that some changes were made to the original state. tracing::info!( - block_height, - ?desired_block_hash, + ?desired_flat_head, ?state_roots, ?epoch_id, "Moved flat heads to a common block" ); - let block_height = block_height + 1; let mut store_update = store.store_update(); store_update.set_ser(DBCol::Misc, b"FORK_TOOL_EPOCH_ID", epoch_id)?; - store_update.set_ser(DBCol::Misc, b"FORK_TOOL_BLOCK_HASH", &desired_block_hash)?; - store_update.set(DBCol::Misc, b"FORK_TOOL_BLOCK_HEIGHT", &block_height.to_le_bytes()); - for (shard_id, state_root) in state_roots.iter() { - store_update.set_ser(DBCol::Misc, &make_state_roots_key(*shard_id), state_root)?; + store_update.set_ser(DBCol::Misc, b"FORK_TOOL_FLAT_HEAD", &desired_flat_head)?; + store_update.set_ser(DBCol::Misc, b"FORK_TOOL_SHARD_LAYOUT", &target_shard_layout)?; + for (shard_uid, state_root) in state_roots.iter() { + store_update.set_ser(DBCol::Misc, &make_state_roots_key(*shard_uid), state_root)?; } store_update.commit()?; Ok(()) } - fn init(&self, near_config: &mut NearConfig, home_dir: &Path) -> anyhow::Result<()> { - self.write_fork_info(near_config, home_dir)?; + fn init( + &self, + near_config: &mut NearConfig, + home_dir: &Path, + shard_layout_file: Option<&Path>, + ) -> anyhow::Result<()> { + self.write_fork_info(near_config, home_dir, shard_layout_file)?; let mut unwanted_cols = Vec::new(); for col in DBCol::iter() { if !COLUMNS_TO_KEEP.contains(&col) && !SETUP_COLUMNS_TO_KEEP.contains(&col) { @@ -380,32 +402,35 @@ impl ForkNetworkCommand { &near_config.genesis.config, Some(home_dir), ); - let (prev_state_roots, prev_hash, epoch_id, _block_height) = - self.get_state_roots_and_hash(epoch_manager.as_ref(), store.clone())?; - tracing::info!(?prev_state_roots, ?epoch_id, ?prev_hash); + let (prev_state_roots, flat_head, epoch_id, target_shard_layout) = + self.get_state_roots_and_hash(store.clone())?; + tracing::info!(?prev_state_roots, ?epoch_id, ?flat_head); - let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; - let all_shard_uids = shard_layout.shard_uids().collect::>(); + let source_shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; + let all_shard_uids = source_shard_layout.shard_uids().collect::>(); assert_eq!(all_shard_uids.len(), prev_state_roots.len()); let runtime = NightshadeRuntime::from_config(home_dir, store.clone(), &near_config, epoch_manager) .context("could not create the transaction runtime")?; + // TODO: add an option to not load them all at once. As is, this takes an insane amount of memory for mainnet state. runtime .get_tries() .load_memtries_for_enabled_shards(&all_shard_uids, &[].into(), true) .unwrap(); - let runtime2 = runtime.clone(); - let shard_layout2 = shard_layout.clone(); + let shard_tries = runtime.get_tries(); + let target_shard_layout2 = target_shard_layout.clone(); let make_storage_mutator: MakeSingleShardStorageMutatorFn = Arc::new(move |update_state| { - StorageMutator::new(&runtime2.clone(), update_state, shard_layout2.clone()) + StorageMutator::new(shard_tries.clone(), update_state, target_shard_layout2.clone()) }); let new_state_roots = self.prepare_state( batch_size, store, - shard_layout, + source_shard_layout, + target_shard_layout, + flat_head, prev_state_roots, make_storage_mutator.clone(), runtime, @@ -439,39 +464,42 @@ impl ForkNetworkCommand { Some(home_dir), ); - let (prev_state_roots, _prev_hash, epoch_id, block_height) = - self.get_state_roots_and_hash(epoch_manager.as_ref(), store.clone())?; + let (prev_state_roots, flat_head, _epoch_id, target_shard_layout) = + self.get_state_roots_and_hash(store.clone())?; - let runtime = NightshadeRuntime::from_config( - home_dir, - store.clone(), - &near_config, - epoch_manager.clone(), - ) - .context("could not create the transaction runtime")?; + let runtime = + NightshadeRuntime::from_config(home_dir, store.clone(), &near_config, epoch_manager) + .context("could not create the transaction runtime")?; let runtime_config_store = RuntimeConfigStore::new(None); let runtime_config = runtime_config_store.get_config(PROTOCOL_VERSION); - let shard_layout = epoch_manager - .get_shard_layout(&epoch_id) - .with_context(|| format!("Failed getting shard layout for epoch {}", &epoch_id.0))?; - let shard_uids = shard_layout.shard_uids().collect::>(); - assert_eq!(shard_uids.len(), prev_state_roots.len()); + let shard_uids = target_shard_layout.shard_uids().collect::>(); + assert_eq!( + shard_uids.iter().collect::>(), + prev_state_roots.iter().map(|(k, _v)| k).collect::>() + ); let flat_store = store.flat_store(); - let mut update_state = Vec::new(); - for (shard_uid, prev_state_root) in shard_uids.iter().zip(prev_state_roots.into_iter()) { - update_state.push(ShardUpdateState::new(&flat_store, *shard_uid, prev_state_root)?); - } - - let storage_mutator = - StorageMutator::new(&runtime, update_state.clone(), shard_layout.clone())?; + // Here we use the same shard layout for source and target, because we assume that amend-access-keys has already + // written the new shard layout to the FORK_TOOL_SHARD_LAYOUT key, and in any case we're not mapping things from + // source shards to target shards in this function + let update_state = ShardUpdateState::new_update_state( + &flat_store, + &target_shard_layout, + &target_shard_layout, + &prev_state_roots, + )?; + let storage_mutator = StorageMutator::new( + runtime.get_tries(), + update_state.clone(), + target_shard_layout.clone(), + )?; let new_validator_accounts = self.add_validator_accounts( validators, runtime_config, home_dir, - &shard_layout, + &target_shard_layout, storage_mutator, )?; let new_state_roots = update_state.into_iter().map(|u| u.state_root()).collect::>(); @@ -480,9 +508,10 @@ impl ForkNetworkCommand { self.make_and_write_genesis( genesis_time, protocol_version, + target_shard_layout, epoch_length, num_seats, - block_height, + flat_head.height + 1, chain_id_suffix, chain_id, new_state_roots.clone(), @@ -517,31 +546,21 @@ impl ForkNetworkCommand { // The Vec returned is in ShardIndex order fn get_state_roots_and_hash( &self, - epoch_manager: &EpochManagerHandle, store: Store, - ) -> anyhow::Result<(Vec, CryptoHash, EpochId, BlockHeight)> { + ) -> anyhow::Result<(HashMap, BlockInfo, EpochId, ShardLayout)> { let epoch_id = EpochId(store.get_ser(DBCol::Misc, b"FORK_TOOL_EPOCH_ID")?.unwrap()); - let block_hash = store.get_ser(DBCol::Misc, b"FORK_TOOL_BLOCK_HASH")?.unwrap(); - let block_height = store.get(DBCol::Misc, b"FORK_TOOL_BLOCK_HEIGHT")?.unwrap(); - let block_height = u64::from_le_bytes(block_height.as_slice().try_into().unwrap()); - let shard_layout = epoch_manager - .get_shard_layout(&epoch_id) - .with_context(|| format!("Failed getting shard layout for epoch {}", &epoch_id.0))?; - let mut state_roots = vec![None; shard_layout.shard_ids().count()]; - for item in store.iter_prefix(DBCol::Misc, FORKED_ROOTS_KEY_PREFIX.as_bytes()) { + let flat_head = store.get_ser(DBCol::Misc, b"FORK_TOOL_FLAT_HEAD")?.unwrap(); + let shard_layout = store.get_ser(DBCol::Misc, b"FORK_TOOL_SHARD_LAYOUT")?.unwrap(); + let mut state_roots = HashMap::new(); + for item in store.iter_prefix(DBCol::Misc, FORKED_ROOTS_KEY_PREFIX) { let (key, value) = item?; - let shard_id = parse_state_roots_key(&key)?; - let shard_index = shard_layout - .get_shard_index(shard_id) - .with_context(|| format!("Failed finding shard index for {}", shard_id))?; + let shard_uid = parse_state_roots_key(&key)?; let state_root: StateRoot = borsh::from_slice(&value)?; - assert!(state_roots[shard_index].is_none()); - state_roots[shard_index] = Some(state_root); + state_roots.insert(shard_uid, state_root); } - let state_roots = state_roots.into_iter().map(|s| s.unwrap()).collect(); - tracing::info!(?state_roots, ?block_hash, ?epoch_id, block_height); - Ok((state_roots, block_hash, epoch_id, block_height)) + tracing::info!(?state_roots, ?flat_head, ?epoch_id); + Ok((state_roots, flat_head, epoch_id, shard_layout)) } /// Checks that `~/.near/data/fork-snapshot/data` exists. @@ -618,19 +637,18 @@ impl ForkNetworkCommand { Ok(first_config) } + /// Returns info on delayed receipts mapped from this shard, and this shard's state root after + /// all updates are applied. fn prepare_shard_state( &self, batch_size: u64, - shard_layout: ShardLayout, + source_shard_layout: ShardLayout, + target_shard_layout: ShardLayout, shard_uid: ShardUId, store: Store, make_storage_mutator: MakeSingleShardStorageMutatorFn, update_state: Vec, ) -> anyhow::Result { - // Doesn't support secrets. - tracing::info!(?shard_uid); - let shard_idx = shard_layout.get_shard_index(shard_uid.shard_id()).unwrap(); - let mut storage_mutator: StorageMutator = make_storage_mutator(update_state.clone())?; // TODO: allow mutating the state with a secret, so this can be used to prepare a public test network @@ -641,19 +659,12 @@ impl ForkNetworkCommand { let trie_storage = TrieDBStorage::new(store.trie_store(), shard_uid); let mut receipts_tracker = - DelayedReceiptTracker::new(shard_uid, shard_layout.shard_ids().count()); + DelayedReceiptTracker::new(shard_uid, target_shard_layout.shard_ids().count()); // Iterate over the whole flat storage and do the necessary changes to have access to all accounts. - let mut index_delayed_receipt = 0; let mut ref_keys_retrieved = 0; let mut records_not_parsed = 0; let mut records_parsed = 0; - let mut access_keys_updated = 0; - let mut accounts_implicit_updated = 0; - let mut contract_data_updated = 0; - let mut contract_code_updated = 0; - let mut postponed_receipts_updated = 0; - let mut received_data_updated = 0; for item in store.flat_store().iter(shard_uid) { let (key, value) = match item { @@ -675,98 +686,57 @@ impl ForkNetworkCommand { } let new_account_id = map_account(&account_id, None); let replacement = map_key(&public_key, None); - let new_shard_id = shard_layout.account_id_to_shard_id(&new_account_id); - let new_shard_idx = shard_layout.get_shard_index(new_shard_id).unwrap(); + let new_shard_id = + target_shard_layout.account_id_to_shard_id(&new_account_id); + let new_shard_idx = + target_shard_layout.get_shard_index(new_shard_id).unwrap(); - storage_mutator.remove(shard_idx, key)?; + storage_mutator.remove_access_key(shard_uid, account_id, public_key)?; storage_mutator.set_access_key( new_shard_idx, new_account_id, replacement.public_key(), access_key.clone(), )?; - access_keys_updated += 1; } - StateRecord::Account { account_id, account } => { - // TODO(eth-implicit) Change back to is_implicit() when ETH-implicit accounts are supported. - if account_id.get_account_type() == AccountType::NearImplicitAccount { - let new_account_id = map_account(&account_id, None); - let new_shard_id = shard_layout.account_id_to_shard_id(&new_account_id); - let new_shard_idx = shard_layout.get_shard_index(new_shard_id).unwrap(); - - storage_mutator.remove(shard_idx, key)?; - storage_mutator.set_account(new_shard_idx, new_account_id, account)?; - accounts_implicit_updated += 1; - } + storage_mutator.map_account(shard_uid, account_id, account)?; } StateRecord::Data { account_id, data_key, value } => { - // TODO(eth-implicit) Change back to is_implicit() when ETH-implicit accounts are supported. - if account_id.get_account_type() == AccountType::NearImplicitAccount { - let new_account_id = map_account(&account_id, None); - let new_shard_id = shard_layout.account_id_to_shard_id(&new_account_id); - let new_shard_idx = shard_layout.get_shard_index(new_shard_id).unwrap(); - - storage_mutator.remove(shard_idx, key)?; - storage_mutator.set_data( - new_shard_idx, - new_account_id, - &data_key, - value, - )?; - contract_data_updated += 1; - } + storage_mutator.map_data(shard_uid, account_id, &data_key, value)?; } StateRecord::Contract { account_id, code } => { - // TODO(eth-implicit) Change back to is_implicit() when ETH-implicit accounts are supported. - if account_id.get_account_type() == AccountType::NearImplicitAccount { - let new_account_id = map_account(&account_id, None); - let new_shard_id = shard_layout.account_id_to_shard_id(&new_account_id); - let new_shard_idx = shard_layout.get_shard_index(new_shard_id).unwrap(); - - storage_mutator.remove(shard_idx, key)?; - storage_mutator.set_code(new_shard_idx, new_account_id, code)?; - contract_code_updated += 1; - } + storage_mutator.map_code(shard_uid, account_id, code)?; } StateRecord::PostponedReceipt(mut receipt) => { - storage_mutator.remove(shard_idx, key)?; + storage_mutator.remove_postponed_receipt( + shard_uid, + receipt.receiver_id().clone(), + *receipt.receipt_id(), + )?; near_mirror::genesis::map_receipt(&mut receipt, None, &default_key); let new_shard_id = - shard_layout.account_id_to_shard_id(receipt.receiver_id()); - let new_shard_idx = shard_layout.get_shard_index(new_shard_id).unwrap(); + target_shard_layout.account_id_to_shard_id(receipt.receiver_id()); + let new_shard_idx = + target_shard_layout.get_shard_index(new_shard_id).unwrap(); storage_mutator.set_postponed_receipt(new_shard_idx, &receipt)?; - postponed_receipts_updated += 1; } StateRecord::ReceivedData { account_id, data_id, data } => { - // TODO(eth-implicit) Change back to is_implicit() when ETH-implicit accounts are supported. - if account_id.get_account_type() == AccountType::NearImplicitAccount { - let new_account_id = map_account(&account_id, None); - let new_shard_id = shard_layout.account_id_to_shard_id(&new_account_id); - let new_shard_idx = shard_layout.get_shard_index(new_shard_id).unwrap(); - - storage_mutator.remove(shard_idx, key)?; - storage_mutator.set_received_data( - new_shard_idx, - new_account_id, - data_id, - &data, - )?; - received_data_updated += 1; - } + storage_mutator.map_received_data(shard_uid, account_id, data_id, &data)?; } StateRecord::DelayedReceipt(receipt) => { let new_account_id = map_account(receipt.receipt.receiver_id(), None); - let new_shard_id = shard_layout.account_id_to_shard_id(&new_account_id); - let new_shard_idx = shard_layout.get_shard_index(new_shard_id).unwrap(); + let new_shard_id = + target_shard_layout.account_id_to_shard_id(&new_account_id); + let new_shard_idx = + target_shard_layout.get_shard_index(new_shard_id).unwrap(); // The index is guaranteed to be set when iterating over the trie rather than reading // serialized StateRecords let index = receipt.index.unwrap(); receipts_tracker.push(new_shard_idx, index); - index_delayed_receipt += 1; } } records_parsed += 1; @@ -774,18 +744,7 @@ impl ForkNetworkCommand { records_not_parsed += 1; } if storage_mutator.should_commit(batch_size) { - tracing::info!( - ?shard_uid, - ref_keys_retrieved, - records_parsed, - updated = access_keys_updated - + accounts_implicit_updated - + contract_data_updated - + contract_code_updated - + postponed_receipts_updated - + index_delayed_receipt - + received_data_updated, - ); + tracing::info!(?shard_uid, ref_keys_retrieved, records_parsed,); storage_mutator.commit()?; storage_mutator = make_storage_mutator(update_state.clone())?; } @@ -793,18 +752,7 @@ impl ForkNetworkCommand { // Commit the remaining updates. if storage_mutator.should_commit(1) { - tracing::info!( - ?shard_uid, - ref_keys_retrieved, - records_parsed, - updated = access_keys_updated - + accounts_implicit_updated - + contract_data_updated - + contract_code_updated - + postponed_receipts_updated - + index_delayed_receipt - + received_data_updated, - ); + tracing::info!(?shard_uid, ref_keys_retrieved, records_parsed,); storage_mutator.commit()?; storage_mutator = make_storage_mutator(update_state.clone())?; } @@ -814,13 +762,6 @@ impl ForkNetworkCommand { ref_keys_retrieved, records_parsed, records_not_parsed, - accounts_implicit_updated, - access_keys_updated, - contract_code_updated, - contract_data_updated, - postponed_receipts_updated, - delayed_receipts_updated = index_delayed_receipt, - received_data_updated, num_has_full_key = has_full_key.len(), "Pass 1 done" ); @@ -851,14 +792,14 @@ impl ForkNetworkCommand { { continue; } - let shard_id = shard_layout.account_id_to_shard_id(&account_id); + let shard_id = source_shard_layout.account_id_to_shard_id(&account_id); if shard_id != shard_uid.shard_id() { tracing::warn!( "Account {} belongs to shard {} but was found in flat storage for shard {}", &account_id, shard_id, shard_uid.shard_id(), ); } - let shard_idx = shard_layout.get_shard_index(shard_id).unwrap(); + let shard_idx = source_shard_layout.get_shard_index(shard_id).unwrap(); storage_mutator.set_access_key( shard_idx, account_id, @@ -878,23 +819,44 @@ impl ForkNetworkCommand { Ok(receipts_tracker) } + // TODO: instead of calling this every time, this could be integrated into StorageMutator or something + fn update_source_state_roots( + source_state_roots: &mut HashMap, + target_shard_layout: &ShardLayout, + update_state: &[ShardUpdateState], + ) { + for (shard_uid, state_root) in source_state_roots.iter_mut() { + if target_shard_layout.shard_uids().any(|s| s == *shard_uid) { + let shard_idx = target_shard_layout.get_shard_index(shard_uid.shard_id()).unwrap(); + *state_root = update_state[shard_idx].state_root(); + } + } + } + fn prepare_state( &self, batch_size: u64, store: Store, - shard_layout: ShardLayout, - prev_state_roots: Vec, + source_shard_layout: ShardLayout, + target_shard_layout: ShardLayout, + flat_head: BlockInfo, + mut source_state_roots: HashMap, make_storage_mutator: MakeSingleShardStorageMutatorFn, runtime: Arc, ) -> anyhow::Result> { - let shard_uids = shard_layout.shard_uids().collect::>(); - assert_eq!(shard_uids.len(), prev_state_roots.len()); + let shard_uids = source_shard_layout.shard_uids().collect::>(); + assert_eq!( + shard_uids.iter().collect::>(), + source_state_roots.iter().map(|(k, _v)| k).collect::>() + ); let flat_store = store.flat_store(); - let mut update_state = Vec::new(); - for (shard_uid, prev_state_root) in shard_uids.iter().zip(prev_state_roots.into_iter()) { - update_state.push(ShardUpdateState::new(&flat_store, *shard_uid, prev_state_root)?); - } + let update_state = ShardUpdateState::new_update_state( + &flat_store, + &source_shard_layout, + &target_shard_layout, + &source_state_roots, + )?; // the try_fold().try_reduce() will give a Vec<> of the return values and return early if one fails let receipt_trackers = shard_uids @@ -904,7 +866,8 @@ impl ForkNetworkCommand { |mut trackers, shard_uid| { let t = self.prepare_shard_state( batch_size, - shard_layout.clone(), + source_shard_layout.clone(), + target_shard_layout.clone(), shard_uid, store.clone(), make_storage_mutator.clone(), @@ -922,14 +885,40 @@ impl ForkNetworkCommand { }, )?; + Self::update_source_state_roots( + &mut source_state_roots, + &target_shard_layout, + &update_state, + ); + let shard_tries = runtime.get_tries(); + crate::storage_mutator::write_bandwidth_scheduler_state( + &shard_tries, + &source_shard_layout, + &target_shard_layout, + &source_state_roots, + &update_state, + )?; + let default_key = near_mirror::key_mapping::default_extra_key(None).public_key(); + Self::update_source_state_roots( + &mut source_state_roots, + &target_shard_layout, + &update_state, + ); crate::delayed_receipts::write_delayed_receipts( - runtime.as_ref(), + &shard_tries, &update_state, receipt_trackers, - &shard_layout, + &source_state_roots, + &target_shard_layout, &default_key, )?; + crate::storage_mutator::finalize_state( + &shard_tries, + &source_shard_layout, + &target_shard_layout, + flat_head, + )?; let state_roots = update_state.into_iter().map(|u| u.state_root()).collect(); tracing::info!(?state_roots, "All done"); @@ -994,6 +983,7 @@ impl ForkNetworkCommand { &self, genesis_time: DateTime, protocol_version: Option, + shard_layout: ShardLayout, epoch_length: u64, num_seats: &Option, height: BlockHeight, @@ -1025,14 +1015,29 @@ impl ForkNetworkCommand { let original_config = near_config.genesis.config.clone(); + // TODO: consider doing something smarter with these two + let num_block_producer_seats_per_shard = vec![ + original_config + .num_block_producer_seats_per_shard[0]; + shard_layout.num_shards() as usize + ]; + let avg_hidden_validator_seats_per_shard = + if original_config.avg_hidden_validator_seats_per_shard.is_empty() { + Vec::new() + } else { + vec![ + original_config.avg_hidden_validator_seats_per_shard[0]; + shard_layout.num_shards() as usize + ] + }; let new_config = GenesisConfig { chain_id: new_chain_id, genesis_height: height, genesis_time, epoch_length, num_block_producer_seats: epoch_config.num_block_producer_seats, - num_block_producer_seats_per_shard: epoch_config.num_block_producer_seats_per_shard, - avg_hidden_validator_seats_per_shard: epoch_config.avg_hidden_validator_seats_per_shard, + num_block_producer_seats_per_shard, + avg_hidden_validator_seats_per_shard, block_producer_kickout_threshold: 0, chunk_producer_kickout_threshold: 0, chunk_validator_only_kickout_threshold: 0, @@ -1043,7 +1048,7 @@ impl ForkNetworkCommand { fishermen_threshold: epoch_config.fishermen_threshold, minimum_stake_divisor: epoch_config.minimum_stake_divisor, protocol_upgrade_stake_threshold: epoch_config.protocol_upgrade_stake_threshold, - shard_layout: epoch_config.shard_layout.clone(), + shard_layout, num_chunk_only_producer_seats: epoch_config.num_chunk_only_producer_seats, minimum_validators_per_shard: epoch_config.minimum_validators_per_shard, minimum_stake_ratio: epoch_config.minimum_stake_ratio, diff --git a/tools/fork-network/src/delayed_receipts.rs b/tools/fork-network/src/delayed_receipts.rs index ea1272bdf34..69831b48130 100644 --- a/tools/fork-network/src/delayed_receipts.rs +++ b/tools/fork-network/src/delayed_receipts.rs @@ -1,14 +1,12 @@ use crate::storage_mutator::ShardUpdateState; -use near_chain::types::RuntimeAdapter; use near_crypto::PublicKey; use near_primitives::borsh; use near_primitives::receipt::{Receipt, ReceiptOrStateStoredReceipt, TrieQueueIndices}; use near_primitives::shard_layout::{ShardLayout, ShardUId}; use near_primitives::trie_key::TrieKey; -use near_primitives::types::{ShardId, ShardIndex}; -use near_store::Trie; -use nearcore::NightshadeRuntime; +use near_primitives::types::{ShardIndex, StateRoot}; +use near_store::{ShardTries, Trie}; use anyhow::Context; use std::borrow::Cow; @@ -42,17 +40,24 @@ impl DelayedReceiptTracker { } } -fn remove_source_receipt_index(trie_updates: &mut HashMap, Option>>, index: u64) { - let key = TrieKey::DelayedReceipt { index }; - - if let Entry::Vacant(e) = trie_updates.entry(key.to_vec()) { +fn remove_source_receipt_index( + trie_updates: &mut [HashMap>>], + source_shard_uid: ShardUId, + target_shard_layout: &ShardLayout, + index: u64, +) { + if !target_shard_layout.shard_uids().any(|s| s == source_shard_uid) { + return; + } + let shard_idx = target_shard_layout.get_shard_index(source_shard_uid.shard_id()).unwrap(); + if let Entry::Vacant(e) = trie_updates[shard_idx].entry(index) { e.insert(None); } } fn read_delayed_receipt( trie: &Trie, - source_shard_id: ShardId, + source_shard_uid: ShardUId, index: u64, ) -> anyhow::Result> { let key = TrieKey::DelayedReceipt { index }; @@ -60,7 +65,7 @@ fn read_delayed_receipt( near_store::get_pure::(trie, &key).with_context(|| { format!( "failed reading delayed receipt idx {} from shard {} trie", - index, source_shard_id, + index, source_shard_uid, ) })?; Ok(match value { @@ -69,7 +74,7 @@ fn read_delayed_receipt( tracing::warn!( "Expected delayed receipt with index {} in shard {} not found", index, - source_shard_id, + source_shard_uid, ); None } @@ -77,48 +82,42 @@ fn read_delayed_receipt( } fn set_target_delayed_receipt( - trie_updates: &mut HashMap, Option>>, + trie_updates: &mut HashMap>>, target_index: &mut u64, mut receipt: Receipt, default_key: &PublicKey, ) { - let target_key = TrieKey::DelayedReceipt { index: *target_index }; - let target_key = target_key.to_vec(); - *target_index += 1; - near_mirror::genesis::map_receipt(&mut receipt, None, default_key); let value = ReceiptOrStateStoredReceipt::Receipt(Cow::Owned(receipt)); let value = borsh::to_vec(&value).unwrap(); - trie_updates.insert(target_key, Some(value)); + trie_updates.insert(*target_index, Some(value)); + *target_index += 1; } // This should be called after push() has been called on each DelayedReceiptTracker in `trackers` // for each receipt in its shard. This reads and maps the accounts and keys in all the receipts and // writes them to the right shards. pub(crate) fn write_delayed_receipts( - runtime: &NightshadeRuntime, + shard_tries: &ShardTries, update_state: &[ShardUpdateState], trackers: Vec, - shard_layout: &ShardLayout, + source_state_roots: &HashMap, + target_shard_layout: &ShardLayout, default_key: &PublicKey, ) -> anyhow::Result<()> { - assert_eq!(update_state.len(), trackers.len()); for t in trackers.iter() { assert_eq!(update_state.len(), t.indices.len()); } - let shard_tries = runtime.get_tries(); - let tries = update_state + let tries = trackers .iter() - .enumerate() - .map(|(shard_index, update)| { - let state_root = update.state_root(); - let shard_id = shard_layout.get_shard_id(shard_index).unwrap(); - let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, shard_layout); - shard_tries.get_trie_for_shard(shard_uid, state_root) + .map(|tracker| { + let state_root = source_state_roots.get(&tracker.source_shard_uid).unwrap(); + let trie = shard_tries.get_trie_for_shard(tracker.source_shard_uid, *state_root); + (tracker.source_shard_uid, trie) }) - .collect::>(); + .collect::>(); // TODO: commit these updates periodically so we don't read everything to memory, which might be too much. let mut trie_updates = vec![HashMap::new(); update_state.len()]; @@ -128,19 +127,23 @@ pub(crate) fn write_delayed_receipts( // changing this to try to be somewhat fair and take from other shards // before taking twice from the same shard - for (source_shard_idx, target_shard_idx, index) in - trackers.into_iter().enumerate().flat_map(|(source_shard_idx, tracker)| { - tracker.indices.into_iter().enumerate().flat_map(move |(target_shard_idx, indices)| { - indices.into_iter().map(move |index| (source_shard_idx, target_shard_idx, index)) - }) + for (source_shard_uid, target_shard_idx, index) in trackers.into_iter().flat_map(|tracker| { + tracker.indices.into_iter().enumerate().flat_map(move |(target_shard_idx, indices)| { + indices + .into_iter() + .map(move |index| (tracker.source_shard_uid, target_shard_idx, index)) }) - { - let source_shard_id = shard_layout.get_shard_id(source_shard_idx).unwrap(); - - remove_source_receipt_index(&mut trie_updates[source_shard_idx], index); + }) { + let trie = tries.get(&source_shard_uid).unwrap(); + + remove_source_receipt_index( + &mut trie_updates, + source_shard_uid, + target_shard_layout, + index, + ); - let Some(receipt) = read_delayed_receipt(&tries[source_shard_idx], source_shard_id, index)? - else { + let Some(receipt) = read_delayed_receipt(trie, source_shard_uid, index)? else { continue; }; @@ -157,16 +160,18 @@ pub(crate) fn write_delayed_receipts( for (shard_idx, (updates, update_state)) in trie_updates.into_iter().zip(update_state.iter()).enumerate() { - let shard_id = shard_layout.get_shard_id(shard_idx).unwrap(); - let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, shard_layout); + let shard_id = target_shard_layout.get_shard_id(shard_idx).unwrap(); + let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, target_shard_layout); - let mut updates = updates.into_iter().collect::>(); + let mut updates = updates + .into_iter() + .map(|(index, value)| (TrieKey::DelayedReceipt { index }, value)) + .collect::>(); let next_available_index = next_index[shard_idx]; - let key = TrieKey::DelayedReceiptIndices.to_vec(); let indices = TrieQueueIndices { first_index: 0, next_available_index }; let value = borsh::to_vec(&indices).unwrap(); - updates.push((key, Some(value))); + updates.push((TrieKey::DelayedReceiptIndices, Some(value))); crate::storage_mutator::commit_shard(shard_uid, &shard_tries, update_state, updates) .context("failed committing trie changes")? } diff --git a/tools/fork-network/src/storage_mutator.rs b/tools/fork-network/src/storage_mutator.rs index c6720ea8a8b..b03d642088b 100644 --- a/tools/fork-network/src/storage_mutator.rs +++ b/tools/fork-network/src/storage_mutator.rs @@ -1,19 +1,25 @@ -use near_chain::types::RuntimeAdapter; use near_crypto::PublicKey; +use near_mirror::key_mapping::map_account; use near_primitives::account::{AccessKey, Account}; +use near_primitives::bandwidth_scheduler::{ + BandwidthSchedulerState, BandwidthSchedulerStateV1, LinkAllowance, +}; use near_primitives::borsh; use near_primitives::hash::CryptoHash; use near_primitives::receipt::Receipt; use near_primitives::shard_layout::{ShardLayout, ShardUId}; use near_primitives::trie_key::TrieKey; -use near_primitives::types::{AccountId, BlockHeight, ShardIndex, StateRoot, StoreKey, StoreValue}; +use near_primitives::types::{ + AccountId, BlockHeight, ShardIndex, StateChangeCause, StateRoot, StoreKey, StoreValue, +}; use near_store::adapter::flat_store::FlatStoreAdapter; use near_store::adapter::StoreUpdateAdapter; -use near_store::flat::{FlatStateChanges, FlatStorageStatus}; +use near_store::flat::{BlockInfo, FlatStateChanges, FlatStorageReadyStatus, FlatStorageStatus}; +use near_store::trie::update::TrieUpdateResult; use near_store::{DBCol, ShardTries}; -use nearcore::NightshadeRuntime; use anyhow::Context; +use std::collections::{HashMap, HashSet}; use std::sync::{Arc, Mutex}; /// Stores the state root and next height we want to pass to apply_memtrie_changes() and delete_until_height() @@ -28,14 +34,14 @@ struct InProgressRoot { #[derive(Clone)] pub(crate) struct ShardUpdateState { - root: Arc>, + root: Arc>>, } impl ShardUpdateState { // here we set the given state root as the one we start with, and we set Self::update_height to be // one bigger than the highest block height we have flat state for. The reason for this is that the // memtries will initially be loaded with nodes referenced by each block height we have deltas for. - pub(crate) fn new( + fn new( flat_store: &FlatStoreAdapter, shard_uid: ShardUId, state_root: CryptoHash, @@ -61,52 +67,105 @@ impl ShardUpdateState { } }; Ok(Self { - root: Arc::new(Mutex::new(InProgressRoot { + root: Arc::new(Mutex::new(Some(InProgressRoot { state_root, update_height: max_delta_height + 1, - })), + }))), }) } + fn new_empty() -> Self { + Self { root: Arc::new(Mutex::new(None)) } + } + + /// Returns a vec of length equal to the number of shards in `target_shard_layout`, + /// indexed by ShardIndex. + /// `state_roots` should have ShardUIds belonging to source_shard_layout + pub(crate) fn new_update_state( + flat_store: &FlatStoreAdapter, + source_shard_layout: &ShardLayout, + target_shard_layout: &ShardLayout, + state_roots: &HashMap, + ) -> anyhow::Result> { + let source_shards = source_shard_layout.shard_uids().collect::>(); + assert_eq!(&source_shards, &state_roots.iter().map(|(k, _v)| *k).collect::>()); + let target_shards = target_shard_layout.shard_uids().collect::>(); + let mut update_state = vec![None; target_shards.len()]; + for (shard_uid, state_root) in state_roots.iter() { + if !target_shards.contains(shard_uid) { + continue; + } + + let state = Self::new(&flat_store, *shard_uid, *state_root)?; + + let shard_idx = target_shard_layout.get_shard_index(shard_uid.shard_id()).unwrap(); + assert!(update_state[shard_idx].is_none()); + update_state[shard_idx] = Some(state); + } + for shard_uid in target_shards { + if source_shards.contains(&shard_uid) { + continue; + } + let state = Self::new_empty(); + + let shard_idx = target_shard_layout.get_shard_index(shard_uid.shard_id()).unwrap(); + assert!(update_state[shard_idx].is_none()); + update_state[shard_idx] = Some(state); + } + + Ok(update_state.into_iter().map(|s| s.unwrap()).collect()) + } + pub(crate) fn state_root(&self) -> CryptoHash { - self.root.lock().unwrap().state_root + self.root.lock().unwrap().as_ref().map_or_else(CryptoHash::default, |s| s.state_root) } } struct ShardUpdates { update_state: ShardUpdateState, - updates: Vec<(Vec, Option>)>, + updates: Vec<(TrieKey, Option>)>, } impl ShardUpdates { fn set(&mut self, key: TrieKey, value: Vec) { - self.updates.push((key.to_vec(), Some(value))); + self.updates.push((key, Some(value))); } - fn remove(&mut self, key: Vec) { + fn remove(&mut self, key: TrieKey) { self.updates.push((key, None)); } } /// Object that updates the existing state. Combines all changes, commits them /// and returns new state roots. +// TODO: add stats on how many keys are updated/removed/left in place and log it in commit() pub(crate) struct StorageMutator { updates: Vec, shard_tries: ShardTries, - shard_layout: ShardLayout, + target_shard_layout: ShardLayout, + // For efficiency/convenience + target_shards: HashSet, +} + +struct MappedAccountId { + new_account_id: AccountId, + source: Option, + target: ShardIndex, + need_rewrite: bool, } impl StorageMutator { pub(crate) fn new( - runtime: &NightshadeRuntime, + shard_tries: ShardTries, update_state: Vec, - shard_layout: ShardLayout, + target_shard_layout: ShardLayout, ) -> anyhow::Result { let updates = update_state .into_iter() .map(|update_state| ShardUpdates { update_state, updates: Vec::new() }) .collect(); - Ok(Self { updates, shard_tries: runtime.get_tries(), shard_layout }) + let target_shards = target_shard_layout.shard_uids().collect(); + Ok(Self { updates, shard_tries, target_shard_layout, target_shards }) } fn set(&mut self, shard_idx: ShardIndex, key: TrieKey, value: Vec) -> anyhow::Result<()> { @@ -114,7 +173,7 @@ impl StorageMutator { Ok(()) } - pub(crate) fn remove(&mut self, shard_idx: ShardIndex, key: Vec) -> anyhow::Result<()> { + pub(crate) fn remove(&mut self, shard_idx: ShardIndex, key: TrieKey) -> anyhow::Result<()> { self.updates[shard_idx].remove(key); Ok(()) } @@ -128,6 +187,58 @@ impl StorageMutator { self.set(shard_idx, TrieKey::Account { account_id }, borsh::to_vec(&value)?) } + fn mapped_account_id( + &self, + source_shard_uid: ShardUId, + account_id: &AccountId, + ) -> MappedAccountId { + let new_account_id = map_account(&account_id, None); + let target_shard_id = self.target_shard_layout.account_id_to_shard_id(&new_account_id); + let target = self.target_shard_layout.get_shard_index(target_shard_id).unwrap(); + let source = if self.target_shards.contains(&source_shard_uid) { + Some(self.target_shard_layout.get_shard_index(source_shard_uid.shard_id()).unwrap()) + } else { + None + }; + let need_rewrite = account_id != &new_account_id || source != Some(target); + MappedAccountId { new_account_id, source, target, need_rewrite } + } + + pub(crate) fn map_account( + &mut self, + source_shard_uid: ShardUId, + account_id: AccountId, + account: Account, + ) -> anyhow::Result<()> { + let mapped = self.mapped_account_id(source_shard_uid, &account_id); + + if mapped.need_rewrite { + if let Some(source_shard_idx) = mapped.source { + self.remove(source_shard_idx, TrieKey::Account { account_id })?; + } + self.set( + mapped.target, + TrieKey::Account { account_id: mapped.new_account_id }, + borsh::to_vec(&account).unwrap(), + )?; + } + Ok(()) + } + + pub(crate) fn remove_access_key( + &mut self, + source_shard_uid: ShardUId, + account_id: AccountId, + public_key: PublicKey, + ) -> anyhow::Result<()> { + if self.target_shards.contains(&source_shard_uid) { + let shard_idx = + self.target_shard_layout.get_shard_index(source_shard_uid.shard_id()).unwrap(); + self.remove(shard_idx, TrieKey::AccessKey { account_id, public_key })?; + } + Ok(()) + } + pub(crate) fn set_access_key( &mut self, shard_idx: ShardIndex, @@ -142,27 +253,64 @@ impl StorageMutator { ) } - pub(crate) fn set_data( + pub(crate) fn map_data( &mut self, - shard_idx: ShardIndex, + source_shard_uid: ShardUId, account_id: AccountId, data_key: &StoreKey, value: StoreValue, ) -> anyhow::Result<()> { - self.set( - shard_idx, - TrieKey::ContractData { account_id, key: data_key.to_vec() }, - borsh::to_vec(&value)?, - ) + let mapped = self.mapped_account_id(source_shard_uid, &account_id); + + if mapped.need_rewrite { + if let Some(source_shard_idx) = mapped.source { + self.remove( + source_shard_idx, + TrieKey::ContractData { account_id, key: data_key.to_vec() }, + )?; + } + self.set( + mapped.target, + TrieKey::ContractData { account_id: mapped.new_account_id, key: data_key.to_vec() }, + borsh::to_vec(&value)?, + )?; + } + Ok(()) } - pub(crate) fn set_code( + pub(crate) fn map_code( &mut self, - shard_idx: ShardIndex, + source_shard_uid: ShardUId, account_id: AccountId, value: Vec, ) -> anyhow::Result<()> { - self.set(shard_idx, TrieKey::ContractCode { account_id }, value) + let mapped = self.mapped_account_id(source_shard_uid, &account_id); + + if mapped.need_rewrite { + if let Some(source_shard_idx) = mapped.source { + self.remove(source_shard_idx, TrieKey::ContractCode { account_id })?; + } + self.set( + mapped.target, + TrieKey::ContractCode { account_id: mapped.new_account_id }, + value, + )?; + } + Ok(()) + } + + pub(crate) fn remove_postponed_receipt( + &mut self, + source_shard_uid: ShardUId, + receiver_id: AccountId, + receipt_id: CryptoHash, + ) -> anyhow::Result<()> { + if self.target_shards.contains(&source_shard_uid) { + let shard_idx = + self.target_shard_layout.get_shard_index(source_shard_uid.shard_id()).unwrap(); + self.remove(shard_idx, TrieKey::PostponedReceipt { receiver_id, receipt_id })?; + } + Ok(()) } pub(crate) fn set_postponed_receipt( @@ -180,18 +328,37 @@ impl StorageMutator { ) } - pub(crate) fn set_received_data( + pub(crate) fn map_received_data( &mut self, - shard_idx: ShardIndex, + source_shard_uid: ShardUId, account_id: AccountId, data_id: CryptoHash, data: &Option>, ) -> anyhow::Result<()> { - self.set( - shard_idx, - TrieKey::ReceivedData { receiver_id: account_id, data_id }, - borsh::to_vec(data)?, - ) + let mapped = self.mapped_account_id(source_shard_uid, &account_id); + + if mapped.need_rewrite { + if let Some(source_shard_idx) = mapped.source { + self.remove( + source_shard_idx, + TrieKey::ReceivedData { receiver_id: account_id, data_id }, + )?; + } + self.set( + mapped.target, + TrieKey::ReceivedData { receiver_id: mapped.new_account_id, data_id }, + borsh::to_vec(data)?, + )?; + } + Ok(()) + } + + fn set_bandwidth_scheduler_state( + &mut self, + shard_idx: ShardIndex, + state: BandwidthSchedulerState, + ) -> anyhow::Result<()> { + self.set(shard_idx, TrieKey::BandwidthSchedulerState, borsh::to_vec(&state)?) } pub(crate) fn should_commit(&self, batch_size: u64) -> bool { @@ -200,36 +367,36 @@ impl StorageMutator { /// Commits any pending trie changes for all shards pub(crate) fn commit(self) -> anyhow::Result<()> { - let Self { updates, shard_tries, shard_layout } = self; + let Self { updates, shard_tries, target_shard_layout, .. } = self; for (shard_index, update) in updates.into_iter().enumerate() { - let shard_id = shard_layout.get_shard_id(shard_index).unwrap(); - let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, &shard_layout); + let shard_id = target_shard_layout.get_shard_id(shard_index).unwrap(); + let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, &target_shard_layout); commit_shard(shard_uid, &shard_tries, &update.update_state, update.updates)?; } Ok(()) } } -pub(crate) fn commit_shard( - shard_uid: ShardUId, +fn commit_to_existing_state( shard_tries: &ShardTries, - update_state: &ShardUpdateState, - updates: Vec<(Vec, Option>)>, + shard_uid: ShardUId, + root: &mut InProgressRoot, + updates: Vec<(TrieKey, Option>)>, ) -> anyhow::Result<()> { - if updates.is_empty() { - return Ok(()); - } + let updates = + updates.into_iter().map(|(trie_key, value)| (trie_key.to_vec(), value)).collect::>(); - let mut root = update_state.root.lock().unwrap(); let num_updates = updates.len(); tracing::info!(?shard_uid, num_updates, "commit"); let flat_state_changes = FlatStateChanges::from_raw_key_value(&updates); let mut update = shard_tries.store_update(); flat_state_changes.apply_to_flat_state(&mut update.flat_store_update(), shard_uid); - let trie_changes = - shard_tries.get_trie_for_shard(shard_uid, root.state_root).update(updates)?; + let trie_changes = shard_tries + .get_trie_for_shard(shard_uid, root.state_root) + .update(updates) + .with_context(|| format!("failed updating trie for shard {}", shard_uid))?; tracing::info!( ?shard_uid, num_trie_node_insertions = trie_changes.insertions().len(), @@ -245,13 +412,145 @@ pub(crate) fn commit_shard( root.state_root = state_root; tracing::info!(?shard_uid, num_updates, "committing"); - update.store_update().set_ser( - DBCol::Misc, - format!("FORK_TOOL_SHARD_ID:{}", shard_uid.shard_id).as_bytes(), - &state_root, - )?; + let key = crate::cli::make_state_roots_key(shard_uid); + update.store_update().set_ser(DBCol::Misc, &key, &state_root)?; update.commit()?; tracing::info!(?shard_uid, ?state_root, "Commit is done"); Ok(()) } + +fn commit_to_new_state( + shard_tries: &ShardTries, + shard_uid: ShardUId, + updates: Vec<(TrieKey, Option>)>, +) -> anyhow::Result { + let num_updates = updates.len(); + tracing::info!(?shard_uid, num_updates, "commit new"); + + let mut trie_update = shard_tries.new_trie_update(shard_uid, StateRoot::default()); + for (key, value) in updates { + match value { + Some(value) => trie_update.set(key, value), + None => trie_update.remove(key), + } + } + trie_update.commit(StateChangeCause::InitialState); + let TrieUpdateResult { trie_changes, state_changes, .. } = + trie_update.finalize().with_context(|| { + format!("Initial trie update finalization failed for shard {}", shard_uid) + })?; + let mut store_update = shard_tries.store_update(); + let state_root = shard_tries.apply_all(&trie_changes, shard_uid, &mut store_update); + FlatStateChanges::from_state_changes(&state_changes) + .apply_to_flat_state(&mut store_update.flat_store_update(), shard_uid); + let key = crate::cli::make_state_roots_key(shard_uid); + store_update.store_update().set_ser(DBCol::Misc, &key, &state_root)?; + tracing::info!(?shard_uid, "committing initial state to new shard"); + store_update + .commit() + .with_context(|| format!("Initial flat storage commit failed for shard {}", shard_uid))?; + + Ok(state_root) +} + +pub(crate) fn commit_shard( + shard_uid: ShardUId, + // TODO: Don't create the Trie object every time + shard_tries: &ShardTries, + update_state: &ShardUpdateState, + updates: Vec<(TrieKey, Option>)>, +) -> anyhow::Result<()> { + if updates.is_empty() { + return Ok(()); + } + + let mut root = update_state.root.lock().unwrap(); + + match root.as_mut() { + Some(root) => commit_to_existing_state(shard_tries, shard_uid, root, updates)?, + None => { + let state_root = commit_to_new_state(shard_tries, shard_uid, updates)?; + // TODO: load memtrie + *root = Some(InProgressRoot { state_root, update_height: 1 }); + } + }; + + Ok(()) +} + +pub(crate) fn write_bandwidth_scheduler_state( + shard_tries: &ShardTries, + source_shard_layout: &ShardLayout, + target_shard_layout: &ShardLayout, + state_roots: &HashMap, + update_state: &[ShardUpdateState], +) -> anyhow::Result<()> { + let source_shards = source_shard_layout.shard_uids().collect::>(); + let target_shards = target_shard_layout.shard_uids().collect::>(); + + if source_shards == target_shards { + return Ok(()); + } + let (shard_uid, state_root) = state_roots.iter().next().unwrap(); + let trie = shard_tries.get_trie_for_shard(*shard_uid, *state_root); + let Some(BandwidthSchedulerState::V1(state)) = near_store::get_bandwidth_scheduler_state(&trie) + .with_context(|| format!("failed getting bandwidth scheduler state for {}", shard_uid))? + else { + return Ok(()); + }; + + let mut link_allowances = Vec::new(); + // TODO: maybe do something other than this + let allowance = state.link_allowances[0].allowance; + for sender in target_shard_layout.shard_ids() { + for receiver in target_shard_layout.shard_ids() { + link_allowances.push(LinkAllowance { sender, receiver, allowance }) + } + } + let new_state = BandwidthSchedulerState::V1(BandwidthSchedulerStateV1 { + link_allowances, + sanity_check_hash: state.sanity_check_hash, + }); + let mut mutator = StorageMutator::new( + shard_tries.clone(), + update_state.to_vec(), + target_shard_layout.clone(), + )?; + for shard_idx in target_shard_layout.shard_indexes() { + mutator.set_bandwidth_scheduler_state(shard_idx, new_state.clone())?; + } + mutator.commit() +} + +// After we rewrite everything in the trie to the target shards, write flat storage statuses for new shards +// TODO: remove all state that belongs to source shards not in the target shard layout +pub(crate) fn finalize_state( + shard_tries: &ShardTries, + source_shard_layout: &ShardLayout, + target_shard_layout: &ShardLayout, + flat_head: BlockInfo, +) -> anyhow::Result<()> { + let source_shards = source_shard_layout.shard_uids().collect::>(); + + for shard_uid in target_shard_layout.shard_uids() { + if source_shards.contains(&shard_uid) { + continue; + } + + let mut trie_update = shard_tries.store_update(); + let store_update = trie_update.store_update(); + store_update + .set_ser( + DBCol::FlatStorageStatus, + &shard_uid.to_bytes(), + &FlatStorageStatus::Ready(FlatStorageReadyStatus { flat_head }), + ) + .unwrap(); + trie_update + .commit() + .with_context(|| format!("failed writing flat storage status for {}", shard_uid))?; + tracing::info!(?shard_uid, "wrote flat storage status for new shard"); + } + Ok(()) +}