From 61f294fe277c47290d022f02526768869cb56d58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Mon, 4 Dec 2023 09:14:50 +0100 Subject: [PATCH] feat: add edsnlp.data support for parquet files with parallel reading / writing --- edsnlp/data/parquet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/edsnlp/data/parquet.py b/edsnlp/data/parquet.py index 81d45c685..2ee794f07 100644 --- a/edsnlp/data/parquet.py +++ b/edsnlp/data/parquet.py @@ -81,11 +81,11 @@ def __init__( if isinstance(path, Path) or "://" in path else f"file://{os.path.abspath(path)}" ) - fs, path = pyarrow.fs.FileSystem.from_uri(path) + fs, fs_path = pyarrow.fs.FileSystem.from_uri(path) fs: pyarrow.fs.FileSystem - fs.create_dir(path, recursive=True) + fs.create_dir(fs_path, recursive=True) if overwrite is False: - dataset = pyarrow.dataset.dataset(path, format="parquet", filesystem=fs) + dataset = pyarrow.dataset.dataset(fs_path, format="parquet", filesystem=fs) if len(list(dataset.get_fragments())): raise FileExistsError( f"Directory {path} already exists and is not empty. " @@ -104,7 +104,6 @@ def __init__( def write_worker(self, records, last=False): # Results will contain a batches of samples ready to be written (or None if # write_in_worker is True) and they have already been written. - n_to_fill = self.num_rows_per_file - len(self.batch) results = [] count = 0 @@ -115,9 +114,10 @@ def write_worker(self, records, last=False): # While there is something to write greedy = last or not self.accumulate while len(records) or greedy and len(self.batch): + n_to_fill = self.num_rows_per_file - len(self.batch) self.batch.extend(records[:n_to_fill]) records = records[n_to_fill:] - if greedy or len(self.batch) == self.num_rows_per_file: + if greedy or len(self.batch) >= self.num_rows_per_file: fragment = pyarrow.Table.from_pydict(ld_to_dl(self.batch)) # type: ignore count += len(self.batch) self.batch = []