diff --git a/Cargo.toml b/Cargo.toml index d4fb213..f9e98a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ tokio = { version = "1", features = ["full"] } hex = "0.4.3" tar = "0.4.43" smooth-json = "0.2.7" +futures = "0.3.31" [dev-dependencies] sqlx = { version = "0.8", features = [ "runtime-tokio", "tls-native-tls", "sqlite", "migrate", "macros", "derive", "postgres"] } diff --git a/src/json_rescue_v5_load.rs b/src/json_rescue_v5_load.rs index e93ac0e..8581b38 100644 --- a/src/json_rescue_v5_load.rs +++ b/src/json_rescue_v5_load.rs @@ -6,9 +6,13 @@ use crate::{ schema_transaction::WarehouseTxMaster, }; use anyhow::Result; +use futures::{stream, StreamExt}; use log::{error, info}; use neo4rs::Graph; use std::path::Path; +use std::sync::Arc; +use tokio::sync::Semaphore; +use tokio::task; /// How many records to read from the archives before attempting insert static LOAD_QUEUE_SIZE: usize = 1000; @@ -21,6 +25,7 @@ pub async fn decompress_and_extract(tgz_file: &Path, pool: &Graph) -> Result Result Result { + let temppath = decompress_to_temppath(tgz_file)?; + let json_vec = list_all_json_files(temppath.path())?; + + let found_count = Arc::new(tokio::sync::Mutex::new(0u64)); + let created_count = Arc::new(tokio::sync::Mutex::new(0u64)); + + let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_INSERT)); + let parse_semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_PARSE)); + + let tasks = json_vec.into_iter().map(|j| { + let semaphore = Arc::clone(&semaphore); + let parse_semaphore = Arc::clone(&parse_semaphore); + let found_count = Arc::clone(&found_count); + let created_count = Arc::clone(&created_count); + let pool = pool.clone(); + + task::spawn(async move { + let _permit = parse_semaphore.acquire().await.unwrap(); // Control parsing concurrency + if let Ok((mut r, _e)) = extract_v5_json_rescue(&j) { + let drain: Vec = r.drain(..).collect(); + + if !drain.is_empty() { + let _db_permit = semaphore.acquire().await.unwrap(); // Control DB insert concurrency + let res = tx_batch( + &drain, + &pool, + QUERY_BATCH_SIZE, + j.file_name().unwrap().to_str().unwrap(), + ) + .await?; + { + let mut fc = found_count.lock().await; + let mut cc = created_count.lock().await; + *fc += drain.len() as u64; + *cc += res.created_tx as u64; + } + } + } + Ok::<(), anyhow::Error>(()) + }) + }); + + // Collect all results + let results: Vec<_> = futures::future::join_all(tasks).await; + + // Check for errors in tasks + for result in results { + if let Err(e) = result { + error!("Task failed: {:?}", e); + } + } + + let found_count = *found_count.lock().await; + let created_count = *created_count.lock().await; + + info!("V5 transactions found: {}", found_count); + info!("V5 transactions processed: {}", created_count); + if found_count != created_count { + error!("transactions loaded don't match transactions extracted"); + } + + Ok(created_count) +} + +use futures::{stream, StreamExt}; +use tokio::sync::Semaphore; +use std::sync::Arc; + +const MAX_CONCURRENT_PARSE: usize = 4; // Number of concurrent parsing tasks +const MAX_CONCURRENT_INSERT: usize = 2; // Number of concurrent database insert tasks + +pub async fn stream_decompress_and_extract(tgz_file: &Path, pool: &Graph) -> Result { + let temppath = decompress_to_temppath(tgz_file)?; + let json_vec = list_all_json_files(temppath.path())?; + + let found_count = Arc::new(tokio::sync::Mutex::new(0u64)); + let created_count = Arc::new(tokio::sync::Mutex::new(0u64)); + + let parse_semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_PARSE)); + let insert_semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_INSERT)); + + // Stream for parsing JSON files + let parse_stream = stream::iter(json_vec).map(|j| { + let parse_semaphore = Arc::clone(&parse_semaphore); + let insert_semaphore = Arc::clone(&insert_semaphore); + let found_count = Arc::clone(&found_count); + let created_count = Arc::clone(&created_count); + let pool = pool.clone(); + + async move { + let _parse_permit = parse_semaphore.acquire().await.unwrap(); + if let Ok((mut records, _e)) = extract_v5_json_rescue(&j) { + let batch = records.drain(..).collect::>(); + + if !batch.is_empty() { + let _insert_permit = insert_semaphore.acquire().await.unwrap(); + let res = tx_batch( + &batch, + &pool, + QUERY_BATCH_SIZE, + j.file_name().unwrap().to_str().unwrap(), + ) + .await?; + + let mut fc = found_count.lock().await; + let mut cc = created_count.lock().await; + *fc += batch.len() as u64; + *cc += res.created_tx as u64; + } + } + Ok::<(), anyhow::Error>(()) + } + }); + + // Process the stream with controlled concurrency + parse_stream + .buffer_unordered(MAX_CONCURRENT_PARSE) + .for_each(|result| async { + if let Err(e) = result { + error!("Failed to process file: {:?}", e); + } + }) + .await; + + // Gather final counts + let found_count = *found_count.lock().await; + let created_count = *created_count.lock().await; + + info!("V5 transactions found: {}", found_count); + info!("V5 transactions processed: {}", created_count); + if found_count != created_count { + error!("transactions loaded don't match transactions extracted"); + } Ok(created_count) } + pub async fn rip(start_dir: &Path, pool: &Graph) -> Result { let tgz_list = list_all_tgz_archives(start_dir)?; info!("tgz archives found: {}", tgz_list.len()); diff --git a/tests/test_json_rescue_v5_load.rs b/tests/test_json_rescue_v5_load.rs index 1a34667..eb198f7 100644 --- a/tests/test_json_rescue_v5_load.rs +++ b/tests/test_json_rescue_v5_load.rs @@ -30,6 +30,50 @@ async fn test_load_all_tgz() -> anyhow::Result<()> { Ok(()) } +#[tokio::test] +async fn test_concurrent_load_all_tgz() -> anyhow::Result<()> { + libra_forensic_db::log_setup(); + + let c = start_neo4j_container(); + let port = c.get_host_port_ipv4(7687); + let pool = get_neo4j_localhost_pool(port) + .await + .expect("could not get neo4j connection pool"); + maybe_create_indexes(&pool) + .await + .expect("could start index"); + + let path = fixtures::v5_json_tx_path().join("0-99900.tgz"); + + let tx_count = json_rescue_v5_load::concurrent_decompress_and_extract(&path, &pool).await?; + + assert!(tx_count == 5244); + + Ok(()) +} + +#[tokio::test] +async fn test_stream_load_all_tgz() -> anyhow::Result<()> { + libra_forensic_db::log_setup(); + + let c = start_neo4j_container(); + let port = c.get_host_port_ipv4(7687); + let pool = get_neo4j_localhost_pool(port) + .await + .expect("could not get neo4j connection pool"); + maybe_create_indexes(&pool) + .await + .expect("could start index"); + + let path = fixtures::v5_json_tx_path().join("0-99900.tgz"); + + let tx_count = json_rescue_v5_load::stream_decompress_and_extract(&path, &pool).await?; + + assert!(tx_count == 5244); + + Ok(()) +} + #[tokio::test] async fn test_load_entrypoint() -> anyhow::Result<()> { libra_forensic_db::log_setup();