Skip to content

Commit

Permalink
add parquet support
Browse files Browse the repository at this point in the history
  • Loading branch information
nikil-ravi committed Oct 13, 2024
1 parent 944a19f commit 52bff4f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
31 changes: 30 additions & 1 deletion src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import datasets
import fsspec
import numpy as np
import pyarrow.parquet as pq
import pandas as pd

from levanter.utils import fsspec_utils

Expand Down Expand Up @@ -149,6 +151,10 @@ def datasource_from_json(urls_or_paths: Sequence[str]) -> ShardedDataSource[dict
return JsonDataSource(urls_or_paths)


def datasource_from_parquet(urls_or_paths: Sequence[str]) -> ShardedDataSource[dict]:
return ParquetDataSource(urls_or_paths)


class WrappedHFDataSource(ShardedDataSource[dict]):
"""
This class is responsible for loading a dataset from HuggingFace Datasets and returning the shards.
Expand Down Expand Up @@ -238,6 +244,11 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
data = json.load(f)
for doc in data[row:]:
yield doc[self.text_key]
case ".parquet":
table = pq.read_table(f)
sliced_table = table.slice(row)
for record in sliced_table.to_pylist():
yield record[self.text_key] # assumes text_key is in record
case _:
raise ValueError(f"Unknown format {format}")

Expand Down Expand Up @@ -313,7 +324,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[Tuple[np.ndar


def _sniff_format_for_dataset(url):
good_formats = [".jsonl", ".txt", ".json"]
good_formats = [".jsonl", ".txt", ".json", ".parquet"]
format_from_url = None
# try both with and without compression (could be gz, bz2, etc, so look at the "first" extension)
extensions = [os.path.splitext(url)[1], os.path.splitext(os.path.splitext(url)[0])[1]]
Expand Down Expand Up @@ -417,6 +428,24 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
return iter(data[row:])


class ParquetDataSource(ShardedDataSource[dict]):
def __init__(self, urls):
self.urls = urls
self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls)

@property
def shard_names(self) -> Sequence[str]:
return list(self._shard_name_to_url_mapping.keys())

def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
url = self._shard_name_to_url_mapping[shard_name]
with fsspec.open(url, "r", compression="infer") as f:
table = pq.read_table(f)
sliced_table = table.slice(row) # zero-copy slicing
for record in sliced_table.to_pylist():
yield record


def _mk_shard_name_mapping(urls):
_shard_name_to_url_mapping = {}
# remove common prefix
Expand Down
43 changes: 42 additions & 1 deletion tests/test_sharded_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tempfile

from levanter.data.sharded_datasource import AudioTextUrlDataSource, _sniff_format_for_dataset
from levanter.data.sharded_datasource import AudioTextUrlDataSource, _sniff_format_for_dataset, ParquetDataSource
from test_utils import skip_if_no_soundlibs


Expand All @@ -24,6 +24,47 @@ def test_sniff_format_for_json():
assert _sniff_format_for_dataset(f.name) == ".json"


def test_sniff_format_for_parquet():

import pyarrow as pa
import pyarrow.parquet as pq

with tempfile.NamedTemporaryFile(suffix=".parquet") as f:
table = pa.table({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']})
pq.write_table(table, f.name)
f.flush()

assert _sniff_format_for_dataset(f.name) == ".parquet"


@skip_if_no_soundlibs
def test_resolve_audio_pointer():
AudioTextUrlDataSource.resolve_audio_pointer("https://ccrma.stanford.edu/~jos/mp3/trumpet.mp3", 16_000)


def test_basic_parquet_datasource_read_row():

import pyarrow as pa
import pyarrow.parquet as pq

with tempfile.NamedTemporaryFile(suffix=".parquet") as f:
# Create a simple dataset
data = {
"column1": ["value1", "value2", "value3"],
"column2": [10, 20, 30]
}
table = pa.Table.from_pydict(data)
pq.write_table(table, f.name)

# Instantiate the ParquetDataSource
datasource = ParquetDataSource([f.name])

# sanity check: Read data starting from row 1
row_data = list(datasource.open_shard_at_row(shard_name=f.name.replace(".", "_"), row=1))

# Verify the output
assert len(row_data) == 2 # We expect 2 rows starting from index 1
assert row_data[0]["column1"] == "value2"
assert row_data[0]["column2"] == 20
assert row_data[1]["column1"] == "value3"
assert row_data[1]["column2"] == 30

0 comments on commit 52bff4f

Please sign in to comment.