Skip to content

Commit

Permalink
Merge pull request #14 from allenai/mixer-paths
Browse files Browse the repository at this point in the history
Mixer can use s3 or local paths
  • Loading branch information
rodneykinney authored Jul 11, 2023
2 parents eb57219 + d458c90 commit 4228002
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 292 deletions.
12 changes: 12 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# See https://pyo3.rs/v0.14.2/building_and_distribution.html#macos
[target.x86_64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]

[target.aarch64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]
9 changes: 8 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
[package]
name = "dolma"
version = "0.6.2"
version = "0.6.3"
edition = "2021"
license = "Apache-2.0"


# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "dolma"
Expand All @@ -30,3 +29,4 @@ threadpool = "1.8.1"
tokio = {version = "1.27.0", features = ["full"]}
tokio-util = "0.7.7"
unicode-segmentation = "1.7"
glob = "0.3.1"
13 changes: 6 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,24 @@ setup:
publish:
maturin publish

test: setup develop setup-test test-python test-rust clean-test
test: setup develop setup-test test-python test-rust

test-python:
pytest -vs tests/python

test-rust:
cargo test -- --nocapture

clean-test:
test-rust-clean:
rm -rf tests/work/*
aws s3 rm --recursive s3://ai2-llm/pretraining-data/tests/mixer/

setup-test:
test-rust-setup:
aws s3 cp tests/data/documents.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz
aws s3 cp tests/data/pii-attributes.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/pii/head/0000.json.gz
aws s3 cp tests/data/toxicity-attributes.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/toxicity/head/0000.json.gz
aws s3 cp tests/data/sample-attributes.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/sample/head/0000.json.gz
aws s3 cp tests/data/duplicate-paragraphs.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/duplicate_paragraphs/head/0000.json.gz
aws s3 sync tests/data/expected s3://ai2-llm/pretraining-data/tests/mixer/expected --exclude ".*" --exclude "*/.*"

test-rust: test-rust-clean test-rust-setup
cargo test -- --nocapture

develop:
maturin develop --extras=dev
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dolma"
version = "0.6.2"
version = "0.6.3"
description = "Data filters"
license = {text = "Apache-2.0"}
readme = "README.md"
Expand Down
129 changes: 53 additions & 76 deletions src/deduper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::VecDeque;
use std::fs::OpenOptions;
use std::io;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use std::path::PathBuf;
use std::process;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
Expand All @@ -15,8 +15,8 @@ use threadpool::ThreadPool;

use crate::bloom_filter::BloomFilter;
use crate::s3_util;
use crate::s3_util::{download_to_file, upload_file};
use crate::shard::shard_config::WorkDirConfig;
use crate::shard::FileCache;

use deduper_config::*;

Expand Down Expand Up @@ -74,40 +74,44 @@ pub fn run(config: DeduperConfig) {
// For doc-level deduping, check the Bloom filter for existence of the configured key and set the configured attribute to true.
// For paragraph-level deduping, check the Bloom filter for existence of a paragraph in the text and add a span to the configured attribute.
fn write_attributes(
doc_path: String,
docs_location: String,
work_dirs: WorkDirConfig,
dedupe_config: DedupeConfig,
bloom_filter: Arc<BloomFilter>,
) -> Result<(), io::Error> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();

let s3_client = s3_util::new_client(None)?;

let input_work_dir = Path::new(&work_dirs.input);
let output_work_dir = Path::new(&work_dirs.output);
let cache = FileCache {
s3_client: Box::new(s3_util::new_client(None)?),
work: work_dirs.clone(),
};

let output_path = {
let attrs_location = {
let mut attr_prefix = "/attributes/".to_owned();
attr_prefix.push_str(&dedupe_config.name);
attr_prefix.push_str("/");
doc_path.to_owned().replace("/documents/", &attr_prefix)
docs_location
.to_owned()
.replace("/documents/", &attr_prefix)
};
let local_output = output_work_dir.join(&output_path);
let local_output = cache.prepare_output(&attrs_location)?;
if local_output.exists() {
log::info!("Skipping {:?} because it already exists", output_path);
log::info!("Skipping {:?} because it already exists", attrs_location);
return Ok(());
}
log::info!(
"Writing attributes for {} to {}",
docs_location,
local_output.display()
);

std::fs::create_dir_all(local_output.parent().unwrap())?;

let tmp_output_path = output_work_dir.join(output_path.clone() + ".tmp");
log::info!(
"Writing attributes for {} to {}",
docs_location,
local_output.display()
);
{
let local_input = input_work_dir.join(Path::new(&doc_path));
log::info!("Downloading {} to {}", doc_path, local_input.display());
rt.block_on(download_to_file(&s3_client, &doc_path, &local_input))?;
let local_input = cache.prepare_input(&docs_location)?;
let input_file = OpenOptions::new()
.read(true)
.write(false)
Expand All @@ -120,7 +124,7 @@ fn write_attributes(
.write(true)
.create(true)
.truncate(true)
.open(&tmp_output_path)?;
.open(&local_output)?;

let mut writer = BufWriter::with_capacity(
1024 * 1024,
Expand All @@ -132,7 +136,12 @@ fn write_attributes(
match line {
Ok(_) => {}
Err(e) => {
log::error!("Error reading line {} of {}: {}", line_number, &doc_path, e);
log::error!(
"Error reading line {} of {}: {}",
line_number,
&docs_location,
e
);
break;
}
}
Expand Down Expand Up @@ -223,23 +232,7 @@ fn write_attributes(
}
std::fs::remove_file(local_input)?;
}

log::info!(
"Uploading {} to {}",
&tmp_output_path.display(),
&output_path
);
rt.block_on(upload_file(&s3_client, &output_path, &tmp_output_path))?;

{
// Create empty file to indicate that the shard is done.
OpenOptions::new()
.create(true)
.write(true)
.open(&local_output)?;
std::fs::remove_file(&tmp_output_path)?;
}

cache.finalize_output(&attrs_location)?;
Ok(())
}

Expand Down Expand Up @@ -303,16 +296,14 @@ pub mod deduper_config {
}

#[cfg(test)]
pub mod test {
mod test {
use std::fs::OpenOptions;
use std::io;
use std::io::{BufRead, BufReader};
use std::path::Path;

use flate2::read::MultiGzDecoder;

use crate::s3_util;
use crate::s3_util::download_to_file;

use super::*;

Expand Down Expand Up @@ -352,53 +343,39 @@ pub mod test {
}

#[test]
pub fn test_dedupe_by_url() -> Result<(), io::Error> {
fn test_dedupe_by_url() -> Result<(), io::Error> {
let config = DeduperConfig::read_from_file("tests/config/dedupe-by-url.json").unwrap();
run(config);

let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let s3_client = s3_util::new_client(None)?;

let local_output_file = "tests/work/output/dedupe-by-url.json.gz";
let remote_output_file = "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_by_url/head/0000.json.gz";
rt.block_on(download_to_file(
&s3_client,
remote_output_file,
Path::new(local_output_file),
))?;
run(config.clone());

let cache = FileCache {
s3_client: Box::new(s3_util::new_client(None)?),
work: config.work_dir.clone(),
};

let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_by_url/head/0000.json.gz")?;

compare_contents(
"tests/data/expected/dedupe-by-url.json.gz",
local_output_file,
&local_output_file.display().to_string(),
);
Ok(())
}

#[test]
pub fn test_dedupe_paragraphs() -> Result<(), io::Error> {
fn test_dedupe_paragraphs() -> Result<(), io::Error> {
let config = DeduperConfig::read_from_file("tests/config/dedupe-paragraphs.json").unwrap();
run(config);

let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let s3_client = s3_util::new_client(None)?;

let local_output_file = "tests/work/output/dedupe-paragraphs.json.gz";
let remote_output_file = "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_paragraphs/head/0000.json.gz";
rt.block_on(download_to_file(
&s3_client,
remote_output_file,
Path::new(local_output_file),
))?;
run(config.clone());

let cache = FileCache {
s3_client: Box::new(s3_util::new_client(None)?),
work: config.work_dir.clone(),
};

let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_paragraphs/head/0000.json.gz")?;

compare_contents(
"tests/data/expected/dedupe-paragraphs.json.gz",
local_output_file,
&local_output_file.display().to_string(),
);
Ok(())
}
Expand Down
Loading

0 comments on commit 4228002

Please sign in to comment.