diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 38db67f1..24586bdd 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -8,7 +8,7 @@ import argparse from datatrove.executor.local import LocalPipelineExecutor -from datatrove.pipeline.readers import HuggingFaceDatasetReader, JsonlReader +from datatrove.pipeline.readers import HuggingFaceDatasetReader, JsonlReader, ParquetReader from datatrove.pipeline.tokens import DocumentTokenizer @@ -72,6 +72,18 @@ def get_args(): "--glob-pattern", type=str, default=None, help="A glob pattern to filter files to read. Default: None" ) + p3 = sp.add_parser(name="parquet") + p3.add_argument( + "--dataset", + type=str, + required=True, + help="Path to a .paquet file or hf:// path", + ) + p3.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset. Default: text") + p3.add_argument( + "--glob-pattern", type=str, default=None, help="A glob pattern to filter files to read. Default: None" + ) + args = parser.parse_args() return args @@ -85,6 +97,8 @@ def main(args): text_key=args.column, dataset_options={"split": args.split}, ) + elif args.readers == "parquet": + datatrove_reader = ParquetReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) else: datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern)