diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..43745b78 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,12 @@ +# See https://pyo3.rs/v0.14.2/building_and_distribution.html#macos +[target.x86_64-apple-darwin] +rustflags = [ + "-C", "link-arg=-undefined", + "-C", "link-arg=dynamic_lookup", +] + +[target.aarch64-apple-darwin] +rustflags = [ + "-C", "link-arg=-undefined", + "-C", "link-arg=dynamic_lookup", +] diff --git a/Cargo.lock b/Cargo.lock index d380d608..65b986b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -682,7 +682,7 @@ dependencies = [ [[package]] name = "dolma" -version = "0.6.1" +version = "0.6.2" dependencies = [ "ahash", "aws-config", @@ -691,6 +691,7 @@ dependencies = [ "clap", "env_logger", "flate2", + "glob", "jsonpath-rust", "log", "pyo3", @@ -840,6 +841,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "h2" version = "0.3.20" diff --git a/Cargo.toml b/Cargo.toml index 25d8650c..b72e63d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,9 @@ [package] name = "dolma" -version = "0.6.2" +version = "0.6.3" edition = "2021" license = "Apache-2.0" - # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] name = "dolma" @@ -30,3 +29,4 @@ threadpool = "1.8.1" tokio = {version = "1.27.0", features = ["full"]} tokio-util = "0.7.7" unicode-segmentation = "1.7" +glob = "0.3.1" diff --git a/Makefile b/Makefile index ada3fe2f..27d7aeae 100644 --- a/Makefile +++ b/Makefile @@ -28,25 +28,24 @@ setup: publish: maturin publish -test: setup develop setup-test test-python test-rust clean-test +test: setup develop setup-test test-python test-rust test-python: pytest -vs tests/python -test-rust: - cargo test -- --nocapture - -clean-test: +test-rust-clean: rm -rf tests/work/* aws s3 rm --recursive s3://ai2-llm/pretraining-data/tests/mixer/ -setup-test: +test-rust-setup: aws s3 cp tests/data/documents.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz aws s3 cp tests/data/pii-attributes.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/pii/head/0000.json.gz aws s3 cp tests/data/toxicity-attributes.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/toxicity/head/0000.json.gz aws s3 cp tests/data/sample-attributes.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/sample/head/0000.json.gz aws s3 cp tests/data/duplicate-paragraphs.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/duplicate_paragraphs/head/0000.json.gz - aws s3 sync tests/data/expected s3://ai2-llm/pretraining-data/tests/mixer/expected --exclude ".*" --exclude "*/.*" + +test-rust: test-rust-clean test-rust-setup + cargo test -- --nocapture develop: maturin develop --extras=dev diff --git a/pyproject.toml b/pyproject.toml index e1149f6a..1a3077a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dolma" -version = "0.6.2" +version = "0.6.3" description = "Data filters" license = {text = "Apache-2.0"} readme = "README.md" diff --git a/src/deduper.rs b/src/deduper.rs index 8b6183b7..586ea7d0 100644 --- a/src/deduper.rs +++ b/src/deduper.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; use std::fs::OpenOptions; use std::io; use std::io::{BufRead, BufReader, BufWriter, Write}; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use std::process; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; @@ -15,8 +15,8 @@ use threadpool::ThreadPool; use crate::bloom_filter::BloomFilter; use crate::s3_util; -use crate::s3_util::{download_to_file, upload_file}; use crate::shard::shard_config::WorkDirConfig; +use crate::shard::FileCache; use deduper_config::*; @@ -74,40 +74,44 @@ pub fn run(config: DeduperConfig) { // For doc-level deduping, check the Bloom filter for existence of the configured key and set the configured attribute to true. // For paragraph-level deduping, check the Bloom filter for existence of a paragraph in the text and add a span to the configured attribute. fn write_attributes( - doc_path: String, + docs_location: String, work_dirs: WorkDirConfig, dedupe_config: DedupeConfig, bloom_filter: Arc, ) -> Result<(), io::Error> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - - let s3_client = s3_util::new_client(None)?; - - let input_work_dir = Path::new(&work_dirs.input); - let output_work_dir = Path::new(&work_dirs.output); + let cache = FileCache { + s3_client: Box::new(s3_util::new_client(None)?), + work: work_dirs.clone(), + }; - let output_path = { + let attrs_location = { let mut attr_prefix = "/attributes/".to_owned(); attr_prefix.push_str(&dedupe_config.name); attr_prefix.push_str("/"); - doc_path.to_owned().replace("/documents/", &attr_prefix) + docs_location + .to_owned() + .replace("/documents/", &attr_prefix) }; - let local_output = output_work_dir.join(&output_path); + let local_output = cache.prepare_output(&attrs_location)?; if local_output.exists() { - log::info!("Skipping {:?} because it already exists", output_path); + log::info!("Skipping {:?} because it already exists", attrs_location); return Ok(()); } + log::info!( + "Writing attributes for {} to {}", + docs_location, + local_output.display() + ); std::fs::create_dir_all(local_output.parent().unwrap())?; - let tmp_output_path = output_work_dir.join(output_path.clone() + ".tmp"); + log::info!( + "Writing attributes for {} to {}", + docs_location, + local_output.display() + ); { - let local_input = input_work_dir.join(Path::new(&doc_path)); - log::info!("Downloading {} to {}", doc_path, local_input.display()); - rt.block_on(download_to_file(&s3_client, &doc_path, &local_input))?; + let local_input = cache.prepare_input(&docs_location)?; let input_file = OpenOptions::new() .read(true) .write(false) @@ -120,7 +124,7 @@ fn write_attributes( .write(true) .create(true) .truncate(true) - .open(&tmp_output_path)?; + .open(&local_output)?; let mut writer = BufWriter::with_capacity( 1024 * 1024, @@ -132,7 +136,12 @@ fn write_attributes( match line { Ok(_) => {} Err(e) => { - log::error!("Error reading line {} of {}: {}", line_number, &doc_path, e); + log::error!( + "Error reading line {} of {}: {}", + line_number, + &docs_location, + e + ); break; } } @@ -223,23 +232,7 @@ fn write_attributes( } std::fs::remove_file(local_input)?; } - - log::info!( - "Uploading {} to {}", - &tmp_output_path.display(), - &output_path - ); - rt.block_on(upload_file(&s3_client, &output_path, &tmp_output_path))?; - - { - // Create empty file to indicate that the shard is done. - OpenOptions::new() - .create(true) - .write(true) - .open(&local_output)?; - std::fs::remove_file(&tmp_output_path)?; - } - + cache.finalize_output(&attrs_location)?; Ok(()) } @@ -303,16 +296,14 @@ pub mod deduper_config { } #[cfg(test)] -pub mod test { +mod test { use std::fs::OpenOptions; use std::io; use std::io::{BufRead, BufReader}; - use std::path::Path; use flate2::read::MultiGzDecoder; use crate::s3_util; - use crate::s3_util::download_to_file; use super::*; @@ -352,53 +343,39 @@ pub mod test { } #[test] - pub fn test_dedupe_by_url() -> Result<(), io::Error> { + fn test_dedupe_by_url() -> Result<(), io::Error> { let config = DeduperConfig::read_from_file("tests/config/dedupe-by-url.json").unwrap(); - run(config); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - let s3_client = s3_util::new_client(None)?; - - let local_output_file = "tests/work/output/dedupe-by-url.json.gz"; - let remote_output_file = "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_by_url/head/0000.json.gz"; - rt.block_on(download_to_file( - &s3_client, - remote_output_file, - Path::new(local_output_file), - ))?; + run(config.clone()); + + let cache = FileCache { + s3_client: Box::new(s3_util::new_client(None)?), + work: config.work_dir.clone(), + }; + + let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_by_url/head/0000.json.gz")?; compare_contents( "tests/data/expected/dedupe-by-url.json.gz", - local_output_file, + &local_output_file.display().to_string(), ); Ok(()) } #[test] - pub fn test_dedupe_paragraphs() -> Result<(), io::Error> { + fn test_dedupe_paragraphs() -> Result<(), io::Error> { let config = DeduperConfig::read_from_file("tests/config/dedupe-paragraphs.json").unwrap(); - run(config); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - let s3_client = s3_util::new_client(None)?; - - let local_output_file = "tests/work/output/dedupe-paragraphs.json.gz"; - let remote_output_file = "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_paragraphs/head/0000.json.gz"; - rt.block_on(download_to_file( - &s3_client, - remote_output_file, - Path::new(local_output_file), - ))?; + run(config.clone()); + + let cache = FileCache { + s3_client: Box::new(s3_util::new_client(None)?), + work: config.work_dir.clone(), + }; + + let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_paragraphs/head/0000.json.gz")?; compare_contents( "tests/data/expected/dedupe-paragraphs.json.gz", - local_output_file, + &local_output_file.display().to_string(), ); Ok(()) } diff --git a/src/mixer.rs b/src/mixer.rs index c8afb305..63cf3832 100644 --- a/src/mixer.rs +++ b/src/mixer.rs @@ -54,7 +54,7 @@ pub mod mixer_config { use crate::shard::shard_config::{StreamConfig, WorkDirConfig}; - #[derive(Serialize, Deserialize)] + #[derive(Serialize, Deserialize, Clone)] pub struct MixerConfig { pub streams: Vec, pub processes: usize, @@ -80,12 +80,12 @@ mod test { use std::fs::OpenOptions; use std::io; use std::io::{BufRead, BufReader}; - use std::path::Path; use flate2::read::MultiGzDecoder; use crate::s3_util; - use crate::s3_util::download_to_file; + use crate::shard::FileCache; + use mixer_config::MixerConfig; use super::*; @@ -127,74 +127,83 @@ mod test { #[test] fn test_mixer() -> Result<(), io::Error> { let config = MixerConfig::read_from_file("tests/config/mixer.json")?; - run(config); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - let s3_client = s3_util::new_client(None)?; - - let local_output_file = "tests/work/output/mixer.json.gz"; - let remote_output_file = - "s3://ai2-llm/pretraining-data/tests/mixer/outputs/v1/documents/head/mixer-test-0000.json.gz"; - rt.block_on(download_to_file( - &s3_client, - remote_output_file, - Path::new(local_output_file), - ))?; - - compare_contents("tests/data/expected/mixer.json.gz", local_output_file); + run(config.clone()); + + let cache = FileCache { + s3_client: Box::new(s3_util::new_client(None)?), + work: config.work_dir.clone(), + }; + + let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/outputs/v1/documents/head/mixer-test-0000.json.gz")?; + + compare_contents( + "tests/data/expected/mixer.json.gz", + &local_output_file.display().to_string(), + ); + Ok(()) + } + + #[test] + fn test_mixer_local() -> Result<(), io::Error> { + std::fs::create_dir_all("tests/work/mixer-local/input/documents/mixer-local")?; + std::fs::create_dir_all("tests/work/mixer-local/input/attributes/pii/mixer-local")?; + std::fs::create_dir_all("tests/work/mixer-local/input/attributes/toxicity/mixer-local")?; + std::fs::copy( + "tests/data/documents.json.gz", + "tests/work/mixer-local/input/documents/mixer-local/0000.json.gz", + )?; + std::fs::copy( + "tests/data/pii-attributes.json.gz", + "tests/work/mixer-local/input/attributes/pii/mixer-local/0000.json.gz", + )?; + std::fs::copy( + "tests/data/toxicity-attributes.json.gz", + "tests/work/mixer-local/input/attributes/toxicity/mixer-local/0000.json.gz", + )?; + let config = MixerConfig::read_from_file("tests/config/mixer-local.json")?; + run(config.clone()); + + compare_contents( + "tests/data/expected/mixer.json.gz", + "tests/work/mixer-local/output/mixer-local-test-0000.json.gz", + ); Ok(()) } #[test] fn test_email_span_replacement() -> Result<(), io::Error> { let config = MixerConfig::read_from_file("tests/config/email-spans.json")?; - run(config); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - let s3_client = s3_util::new_client(None)?; - - let local_output_file = "tests/work/output/email-spans.json.gz"; - let remote_output_file = - "s3://ai2-llm/pretraining-data/tests/mixer/outputs/v1/documents/head/email-spans-test-0000.json.gz"; - rt.block_on(download_to_file( - &s3_client, - remote_output_file, - Path::new(local_output_file), - ))?; - - compare_contents("tests/data/expected/email-spans.json.gz", local_output_file); + run(config.clone()); + + let cache = FileCache { + s3_client: Box::new(s3_util::new_client(None)?), + work: config.work_dir.clone(), + }; + + let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/outputs/v1/documents/head/email-spans-test-0000.json.gz")?; + + compare_contents( + "tests/data/expected/email-spans.json.gz", + &local_output_file.display().to_string(), + ); Ok(()) } #[test] fn test_paragraph_removal() -> Result<(), io::Error> { let config = MixerConfig::read_from_file("tests/config/paragraph-spans.json")?; - run(config); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - let s3_client = s3_util::new_client(None)?; - - let local_output_file = "tests/work/output/remove-paragraphs.json.gz"; - let remote_output_file = - "s3://ai2-llm/pretraining-data/tests/mixer/outputs/v1/documents/head/paragraph-spans-test-0000.json.gz"; - rt.block_on(download_to_file( - &s3_client, - remote_output_file, - Path::new(local_output_file), - ))?; + run(config.clone()); + + let cache = FileCache { + s3_client: Box::new(s3_util::new_client(None)?), + work: config.work_dir.clone(), + }; + + let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/outputs/v1/documents/head/paragraph-spans-test-0000.json.gz")?; compare_contents( "tests/data/expected/remove-paragraphs.json.gz", - local_output_file, + &local_output_file.display().to_string(), ); Ok(()) } @@ -202,26 +211,18 @@ mod test { #[test] fn test_filter_by_span() -> Result<(), io::Error> { let config = MixerConfig::read_from_file("tests/config/filter-by-spans.json")?; - run(config); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - let s3_client = s3_util::new_client(None)?; - - let local_output_file = "tests/work/output/filter-by-spans.json.gz"; - let remote_output_file = - "s3://ai2-llm/pretraining-data/tests/mixer/outputs/v1/documents/head/filter-by-spans-test-0000.json.gz"; - rt.block_on(download_to_file( - &s3_client, - remote_output_file, - Path::new(local_output_file), - ))?; + run(config.clone()); + + let cache = FileCache { + s3_client: Box::new(s3_util::new_client(None)?), + work: config.work_dir.clone(), + }; + + let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/outputs/v1/documents/head/filter-by-spans-test-0000.json.gz")?; compare_contents( "tests/data/expected/filter-by-spans.json.gz", - local_output_file, + &local_output_file.display().to_string(), ); Ok(()) } diff --git a/src/s3_util.rs b/src/s3_util.rs index f5352521..30a0d6ed 100644 --- a/src/s3_util.rs +++ b/src/s3_util.rs @@ -1,23 +1,21 @@ use std::io; use std::path::Path; -use std::str::FromStr; use aws_sdk_s3::config::Region; use aws_sdk_s3::error::ProvideErrorMetadata; use aws_sdk_s3::primitives::ByteStream; use aws_sdk_s3::Client as S3Client; -use regex::Regex; use tokio::fs::File as TokioFile; -pub fn split_path(s3_prefix: &str) -> Result<(&str, &str), &'static str> { +// Split an s3:// url into a bucket and key +pub fn split_url(s3_url: &str) -> Result<(&str, &str), &'static str> { // use a regular expression to check if s3_prefix starts with s3:// - let re = Regex::new(r"^s3://").unwrap(); - if !re.is_match(s3_prefix) { + if !s3_url.starts_with("s3://") { return Err("s3_prefix must start with s3://"); } // split the s3_prefix into parts - let parts: Vec<&str> = s3_prefix.splitn(4, '/').collect(); + let parts: Vec<&str> = s3_url.splitn(4, '/').collect(); // if there are less than 3 parts, then the s3_prefix is invalid if parts.len() < 3 { @@ -27,22 +25,21 @@ pub fn split_path(s3_prefix: &str) -> Result<(&str, &str), &'static str> { let bucket = parts[2]; // if there are not 4 parts, then the object path is empty, so we set it to "/" - let object_path = if parts.len() == 4 { parts[3] } else { "/" }; + let key = if parts.len() == 4 { parts[3] } else { "/" }; - Ok((bucket, object_path)) + Ok((bucket, key)) } pub async fn download_to_file( s3_client: &S3Client, - prefix: &str, + bucket: &str, + key: &str, path: &Path, ) -> Result<(), io::Error> { - let (bucket, key) = split_path(prefix).unwrap(); - let result = s3_client .get_object() .bucket(bucket) - .key(key) + .key(key.clone()) .send() .await .map_err(|e| { @@ -64,13 +61,16 @@ pub async fn download_to_file( Ok(()) } -pub async fn upload_file(s3_client: &S3Client, prefix: &str, path: &Path) -> Result<(), io::Error> { - let (bucket, key) = split_path(prefix).unwrap(); - +pub async fn upload_file( + s3_client: &S3Client, + path: &Path, + bucket: &str, + key: &str, +) -> Result<(), io::Error> { s3_client .put_object() .bucket(bucket) - .key(key) + .key(key.clone()) .body(ByteStream::from_path(path).await?) .send() .await @@ -88,8 +88,11 @@ pub async fn upload_file(s3_client: &S3Client, prefix: &str, path: &Path) -> Res Ok(()) } -pub async fn object_size(s3_client: &S3Client, prefix: &str) -> Result { - let (bucket, key) = split_path(prefix).unwrap(); +pub async fn object_size( + s3_client: &S3Client, + bucket: &str, + key: &str, +) -> Result { let resp = s3_client .head_object() .bucket(bucket) @@ -117,15 +120,9 @@ pub fn find_objects_matching_patterns( .unwrap(); let mut stream_inputs: Vec = Vec::new(); - for full_pattern in patterns.iter() { + for pattern in patterns.iter() { let start_size = stream_inputs.len(); - let (bucket, pattern) = split_path(full_pattern).unwrap(); - - let mut output_prefix = String::from_str("s3://").unwrap(); - output_prefix.push_str(bucket); - output_prefix.push_str("/"); - - let mut prefix = pattern.clone().to_string(); + let mut prefix = pattern.clone(); let mut suffix: Option = Some("".to_owned()); let maybe_index = pattern.chars().position(|c| c == '*'); if let Some(index) = maybe_index { @@ -138,12 +135,14 @@ pub fn find_objects_matching_patterns( let mut has_more = true; let mut token: Option = None; while has_more { + let (bucket, key) = split_url(&prefix).unwrap(); let resp = if token.is_some() { + log::info!("Listing objects in bucket={}, prefix={}", bucket, key); rt.block_on( s3_client .list_objects_v2() .bucket(bucket) - .prefix(&prefix) + .prefix(key) .delimiter("/") .continuation_token(token.unwrap()) .send(), @@ -154,16 +153,15 @@ pub fn find_objects_matching_patterns( s3_client .list_objects_v2() .bucket(bucket) - .prefix(&prefix) + .prefix(key) .delimiter("/") .send(), ) .unwrap() }; resp.contents().unwrap_or_default().iter().for_each(|obj| { - let mut full_output_prefix = output_prefix.clone(); - full_output_prefix.push_str(obj.key().unwrap()); - stream_inputs.push(full_output_prefix); + let s3_url = format!("s3://{}/{}", bucket, obj.key().unwrap()); + stream_inputs.push(s3_url); }); suffix.iter().for_each(|s| { resp.common_prefixes() @@ -172,10 +170,8 @@ pub fn find_objects_matching_patterns( .for_each(|sub_folder| { let mut full_path = sub_folder.prefix().unwrap().to_owned(); full_path.push_str(s); - let mut full_output_prefix = output_prefix.clone(); - full_output_prefix.push_str(&full_path); - - stream_inputs.push(full_output_prefix); + let s3_url = format!("s3://{}/{}", bucket, full_path); + stream_inputs.push(s3_url); }); }); token = resp.next_continuation_token().map(String::from); @@ -254,16 +250,16 @@ mod test { } #[test] - fn test_split_path() -> Result<(), ()> { + fn test_split_url() -> Result<(), ()> { // test case when path is correct - let prefix = "s3://my-bucket/my-key"; - let (bucket, key) = split_path(prefix).unwrap(); + let prefix = "s3://my-bucket/my-key-dir/my-key"; + let (bucket, key) = split_url(prefix).unwrap(); assert_eq!(bucket, "my-bucket"); - assert_eq!(key, "my-key"); + assert_eq!(key, "my-key-dir/my-key"); // test case when path is incorrect let prefix = "s3:/my-bucket/my-key"; - let result = split_path(prefix); + let result = split_url(prefix); assert!(result.is_err()); Ok(()) @@ -277,9 +273,8 @@ mod test { .unwrap(); let s3_client = new_client(None)?; - let prefix = - "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz"; - let resp = rt.block_on(object_size(&s3_client, prefix)); + let key = "pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz"; + let resp = rt.block_on(object_size(&s3_client, "ai2-llm", key)); let size = resp.unwrap(); assert_eq!(size, 25985); @@ -296,11 +291,11 @@ mod test { let local_output_file = "tests/work/output/pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz"; - let remote_output_file: &str = - "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz"; + let s3_path: &str = "pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz"; rt.block_on(download_to_file( &s3_client, - remote_output_file, + "ai2-llm", + s3_path, Path::new(local_output_file), ))?; diff --git a/src/shard.rs b/src/shard.rs index 676d4b9d..7729dc88 100644 --- a/src/shard.rs +++ b/src/shard.rs @@ -1,16 +1,17 @@ use std::fs::OpenOptions; use std::io; use std::io::{BufRead, BufReader, BufWriter, Write}; -use std::path::Path; +use std::path::{Path, PathBuf}; +use aws_sdk_s3::Client as S3Client; use flate2::read::MultiGzDecoder; use flate2::write::GzEncoder; use flate2::Compression; +use glob::glob; use rayon::prelude::*; use serde_json::Value; use crate::s3_util; -use crate::s3_util::{download_to_file, object_size, upload_file}; use crate::shard::shard_config::*; // A shard is a unit of work for the mixer. @@ -37,47 +38,30 @@ impl Shard { // since it doesn't account for the size of any attributes to merged, // or documents dropped by the filter. pub fn split_streams(streams: &Vec) -> Result, io::Error> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - let s3_client = s3_util::new_client(None)?; - let mut shards: Vec = Vec::new(); for stream_config in streams { let mut stream_shard_count = 0; log::info!("Computing shards for stream {}...", stream_config.name); - let stream_inputs = - s3_util::find_objects_matching_patterns(&s3_client, &stream_config.documents)?; - - let inputs_with_sizes = stream_inputs - .par_iter() - .map(|input| { - let resp = rt.block_on(object_size(&s3_client, input)); + let stream_inputs = find_objects_matching_patterns(&stream_config.documents)?; + let input_count = stream_inputs.len(); + let input_sizes = get_object_sizes(&stream_inputs)?; + let inputs_with_sizes = std::iter::zip(stream_inputs, input_sizes) + .map(|(input, size)| { let mut attr_paths = Vec::new(); for prefix in stream_config.attributes.iter() { let mut attr_prefix = "/attributes/".to_owned(); attr_prefix.push_str(prefix); attr_prefix.push_str("/"); - let attr_path = input.to_owned().replace("/documents/", &attr_prefix); + let attr_path = input.replace("/documents/", &attr_prefix); attr_paths.push(attr_path); } - match resp { - Ok(size) => ( - DocumentPaths { - doc_path: input.to_owned(), - attribute_paths: attr_paths, - }, - size, - ), - Err(_) => ( - DocumentPaths { - doc_path: input.to_owned(), - attribute_paths: attr_paths, - }, - 0, - ), - } + ( + DocumentPaths { + doc_path: input, + attribute_paths: attr_paths, + }, + size, + ) }) .collect::>(); let mut shard_size = inputs_with_sizes[0].1; @@ -128,7 +112,7 @@ impl Shard { } log::info!( "Splitting {} files for {} into {} shards", - stream_inputs.len(), + input_count, stream_config.name, stream_shard_count ); @@ -144,27 +128,19 @@ impl Shard { // Apply span replacements // Upload the output file to S3. pub fn process(&self, work_dirs: WorkDirConfig) -> Result<(), io::Error> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - - let s3_client = s3_util::new_client(None)?; - - let inputs_dir = Path::new(&work_dirs.input); - let outputs_dir = Path::new(&work_dirs.output); + let cache = FileCache { + s3_client: Box::new(s3_util::new_client(None)?), + work: work_dirs.clone(), + }; - let output_path = outputs_dir.join(self.output.clone()); - std::fs::create_dir_all(output_path.parent().unwrap())?; - - let tmp_output_path = outputs_dir.join(self.output.clone() + ".tmp"); + let output_path = cache.prepare_output(&self.output)?; { let output_file = OpenOptions::new() .read(false) .write(true) .create(true) .truncate(true) - .open(tmp_output_path.clone())?; + .open(output_path.clone())?; let mut writer = BufWriter::with_capacity( 1024 * 1024, @@ -173,37 +149,25 @@ impl Shard { for input_path in self.inputs.iter() { log::info!("Merging {} into {}", input_path.doc_path, self.output); - let local_docs_file = inputs_dir.join(Path::new(&input_path.doc_path)); - log::info!( - "Downloading {} to {}", - input_path.doc_path, - local_docs_file.display() - ); - rt.block_on(download_to_file( - &s3_client, - &input_path.doc_path, - &local_docs_file, - ))?; + let local_docs_file = cache.prepare_input(&input_path.doc_path)?; let mut local_attr_readers = Vec::new(); let mut attr_reader_failure_counts = Vec::new(); for attr in &input_path.attribute_paths { - let local_attr_file = inputs_dir.join(Path::new(&attr)); - log::info!("Downloading {} to {}", attr, local_attr_file.display()); - rt.block_on(download_to_file(&s3_client, &attr, &local_attr_file))?; + let local_attr_file = cache.prepare_input(&attr)?; let f = OpenOptions::new() .read(true) .write(false) .create(false) - .open(local_attr_file.clone())?; + .open(&local_attr_file)?; let attr_reader = BufReader::with_capacity(1024 * 1024, MultiGzDecoder::new(f)); - local_attr_readers.push(attr_reader.lines()); + local_attr_readers.push((local_attr_file, attr_reader.lines())); attr_reader_failure_counts.push(0); } let input_file = OpenOptions::new() .read(true) .write(false) .create(false) - .open(local_docs_file.clone())?; + .open(&local_docs_file)?; let reader = BufReader::with_capacity(1024 * 1024, MultiGzDecoder::new(input_file)); let mut line_number = 0; @@ -226,7 +190,7 @@ impl Shard { let mut data: Value = serde_json::from_str(&line)?; let mut attrs = serde_json::Map::new(); let mut attr_reader_index = 0; - for attr_reader in local_attr_readers.iter_mut() { + for (_, attr_reader) in local_attr_readers.iter_mut() { match attr_reader.next() { Some(Ok(line)) => { let attr_data: Value = serde_json::from_str(&line)?; @@ -374,7 +338,7 @@ impl Shard { writer.write_all(b"\n")?; } } - std::fs::remove_file(local_docs_file)?; + cache.finalize_input(&input_path.doc_path)?; for i in 0..input_path.attribute_paths.len() { if attr_reader_failure_counts[i] > 0 { log::warn!( @@ -383,9 +347,7 @@ impl Shard { attr_reader_failure_counts[i] ); } - std::fs::remove_file( - inputs_dir.join(Path::new(&input_path.attribute_paths[i])), - )?; + cache.finalize_input(&input_path.attribute_paths[i])?; } log::info!( "Dropped {} of {} documents from {}", @@ -395,23 +357,7 @@ impl Shard { ); } } - - log::info!( - "Uploading {} to {}", - &tmp_output_path.display(), - &self.output - ); - rt.block_on(upload_file(&s3_client, &self.output, &tmp_output_path))?; - - { - // Create empty file to indicate that the shard is done. - OpenOptions::new() - .create(true) - .write(true) - .open(&output_path)?; - std::fs::remove_file(&tmp_output_path)?; - } - + cache.finalize_output(&self.output)?; Ok(()) } } @@ -531,3 +477,163 @@ pub mod shard_config { } } } + +// Handles input/output files, including S3 downloads/uploads +pub struct FileCache { + pub s3_client: Box, + pub work: WorkDirConfig, +} + +macro_rules! cached_s3_location { + ($url:expr, $dir:expr) => {{ + let (bucket, key) = s3_util::split_url($url).unwrap(); + (bucket, key.clone(), Path::new($dir).join(key.clone())) + }}; +} + +impl FileCache { + // If "location" is a path to a local file that exists, return it + // If it is an S3 URL, download the contents to the working input directory, and return the path + pub fn prepare_input(&self, location: &str) -> Result { + if location.starts_with("s3://") { + let (bucket, key, path) = cached_s3_location!(location, &self.work.input); + log::info!("Downloading {} to {}", location, path.display()); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + rt.block_on(s3_util::download_to_file( + &self.s3_client, + bucket, + &key, + &path, + ))?; + Ok(path.clone()) + } else { + let path = Path::new(location); + if path.exists() { + Ok(path.to_path_buf()) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!("File not found: {}", location), + )) + } + } + } + + // If input was downloaded from S3, delete the local cache + // Otherwise, do nothing + pub fn finalize_input(&self, location: &str) -> Result<(), io::Error> { + if location.starts_with("s3://") { + let (_, _, path) = cached_s3_location!(location, &self.work.input); + std::fs::remove_file(&path)?; + Ok(()) + } else { + Ok(()) + } + } + + // If output is an S3 URL, return a path to a new temporary location in the working output directory + // If it is a local path, return a ".tmp" path in the same directory + pub fn prepare_output(&self, location: &str) -> Result { + if location.starts_with("s3://") { + let (_, _, path) = cached_s3_location!(location, &self.work.output); + std::fs::create_dir_all(path.parent().unwrap())?; + Ok(path.clone()) + } else { + let tmp_location = location.to_owned() + ".tmp"; + let path = Path::new(tmp_location.as_str()); + std::fs::create_dir_all(path.parent().unwrap())?; + Ok(path.to_path_buf()) + } + } + + // If "output" is an S3 URL, upload contents from the temporary file, + // then replace the temporary file with an empty one as a checkpoint + // If "output" is a local path, rename the ".tmp" file to the original name + pub fn finalize_output(&self, location: &str) -> Result<(), io::Error> { + if location.starts_with("s3://") { + let (bucket, key, path) = cached_s3_location!(location, &self.work.output); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + rt.block_on(s3_util::upload_file(&self.s3_client, &path, &bucket, &key))?; + std::fs::remove_file(&path)?; + { + // Create empty file to indicate that the shard is done. + OpenOptions::new().create(true).write(true).open(&path)?; + } + Ok(()) + } else { + let tmp_path = location.to_owned() + ".tmp"; + let tmp_path = Path::new(tmp_path.as_str()); + std::fs::rename(&tmp_path, &location)?; + Ok(()) + } + } +} + +pub fn find_objects_matching_patterns(patterns: &Vec) -> Result, io::Error> { + let s3_url_count = patterns.iter().filter(|p| p.starts_with("s3://")).count(); + if s3_url_count == 0 { + let mut matches = Vec::new(); + for pattern in patterns.iter() { + for entry in + glob(pattern).expect(format! {"Invalid file pattern: {}", pattern.clone()}.as_str()) + { + matches.push(entry.unwrap().to_str().unwrap().to_owned()); + } + } + Ok(matches) + } else if s3_url_count == patterns.len() { + let s3_client = s3_util::new_client(None)?; + s3_util::find_objects_matching_patterns(&s3_client, patterns) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "Cannot mix S3 and local paths", + )) + } +} + +// Get the size in bytes of a list of objects, either S3 urls or local file paths +pub fn get_object_sizes(locations: &Vec) -> Result, io::Error> { + let s3_url_count = locations.iter().filter(|p| p.starts_with("s3://")).count(); + if s3_url_count == 0 { + let sizes: Vec = locations + .par_iter() + .map(|location| { + let path = Path::new(location); + let metadata = path.metadata().unwrap(); + metadata.len() as usize + }) + .collect(); + Ok(sizes) + } else if s3_url_count == locations.len() { + let s3_client = s3_util::new_client(None)?; + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let sizes = locations + .par_iter() + .map(|location| { + let (bucket, key) = s3_util::split_url(location).unwrap(); + let resp = rt.block_on(s3_util::object_size(&s3_client, &bucket, &key)); + match resp { + Ok(size) => size, + Err(_) => 0, + } + }) + .collect(); + Ok(sizes) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "Cannot mix S3 and local paths", + )) + } +} diff --git a/tests/config/mixer-local.json b/tests/config/mixer-local.json new file mode 100644 index 00000000..1d2d206f --- /dev/null +++ b/tests/config/mixer-local.json @@ -0,0 +1,33 @@ +{ + "streams": [ + { + "name": "mixer-local-test", + "documents": [ + "tests/work/mixer-local/input/documents/*/0000.json.gz" + ], + "output": { + "path": "tests/work/mixer-local/output", + "max_size_in_bytes": 100000 + }, + "attributes": [ + "pii", + "toxicity" + ], + "filter": { + "include": [ + "$.metadata[?(@.length < 10000)]" + ], + "exclude": [ + "$.metadata[?(@.length < 500)]", + "$.attributes[?(@.pii.too_much_pii == true)]", + "$.attributes[?(@.toxicity > 0.8)]" + ] + } + } + ], + "work_dir": { + "input": "tests/work/mixer-local/input", + "output": "tests/work/mixer-local/output" + }, + "processes": 1 +}