diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 2c7eb42d9..80dd1b4b2 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -224,26 +224,30 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: compression = "infer" if url.endswith(".zstd"): # hacky way to detect zstd compression = "zstd" - with fsspec.open(url, "r", compression=compression) as f: - format = _sniff_format_for_dataset(url) - match format: - case ".jsonl": + + format = _sniff_format_for_dataset(url) + match format: + case ".jsonl": + with fsspec.open(url, "r", compression=compression) as f: # TODO: would be nice if we could seek faster than this. Right now, all we do is skip json parsing # which is not nothing, but not ideal. for line in f: if i >= row: yield json.loads(line)[self.text_key] i += 1 - case ".txt": + case ".txt": + with fsspec.open(url, "r", compression=compression) as f: for line in f: if i >= row: yield line i += 1 - case ".json": + case ".json": + with fsspec.open(url, "r", compression=compression) as f: data = json.load(f) for doc in data[row:]: yield doc[self.text_key] - case ".parquet": + case ".parquet": + with fsspec.open(url, "rb", compression=compression) as f: parquet_file = pq.ParquetFile(f) total_rows = parquet_file.metadata.num_rows if row >= total_rows: @@ -271,8 +275,8 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: table = table.slice(start_row_in_group) for record in table.to_pylist(): yield record[self.text_key] - case _: - raise ValueError(f"Unknown format {format}") + case _: + raise ValueError(f"Unknown format {format}") class AudioTextUrlDataSource(ShardedDataSource[Tuple[np.ndarray, int, str]]): diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py index d93793bff..a375d9344 100644 --- a/tests/test_sharded_dataset.py +++ b/tests/test_sharded_dataset.py @@ -4,8 +4,8 @@ from levanter.data.sharded_datasource import ( AudioTextUrlDataSource, ParquetDataSource, - _sniff_format_for_dataset, TextUrlDataSource, + _sniff_format_for_dataset, ) from test_utils import skip_if_no_soundlibs