Skip to content

Commit

Permalink
fixes, and another test
Browse files Browse the repository at this point in the history
  • Loading branch information
nikil-ravi committed Oct 15, 2024
1 parent 8871fcc commit 698e98f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
10 changes: 7 additions & 3 deletions src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
if row >= total_rows:
return iter([])

num_row_groups = parquet_file.metadata.num_row_groups

# Compute cumulative row counts
row_counts = [rg.num_rows for rg in parquet_file.metadata.row_groups]
row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)]
cumulative_rows = [0]
for count in row_counts:
cumulative_rows.append(cumulative_rows[-1] + count)
Expand Down Expand Up @@ -465,8 +467,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
if row >= total_rows:
return iter([])

# compute cumulative row counts
row_counts = [rg_meta.num_rows for rg_meta in parquet_file.metadata.row_groups]
num_row_groups = parquet_file.metadata.num_row_groups

# Compute cumulative row counts
row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)]
cumulative_rows = [0]
for count in row_counts:
cumulative_rows.append(cumulative_rows[-1] + count)
Expand Down
33 changes: 32 additions & 1 deletion tests/test_sharded_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
import tempfile

from levanter.data.sharded_datasource import AudioTextUrlDataSource, ParquetDataSource, _sniff_format_for_dataset
from levanter.data.sharded_datasource import (
AudioTextUrlDataSource,
ParquetDataSource,
_sniff_format_for_dataset,
TextUrlDataSource,
)
from test_utils import skip_if_no_soundlibs


Expand Down Expand Up @@ -68,3 +73,29 @@ def test_basic_parquet_datasource_read_row():
assert row_data[0]["column2"] == 20
assert row_data[1]["column1"] == "value3"
assert row_data[1]["column2"] == 30


def test_text_url_data_source_parquet():
import pyarrow as pa
import pyarrow.parquet as pq

with tempfile.NamedTemporaryFile(suffix=".parquet", delete=True) as f:
data = {
"text": ["line1", "line2", "line3", "line4", "line5", "line6"],
"column2": [10, 20, 30, 40, 50, 60],
}
table = pa.Table.from_pydict(data)
pq.write_table(table, f.name)

datasource = TextUrlDataSource([os.path.abspath(f.name)], text_key="text")

assert len(datasource.shard_names) == 1, "Expected only one shard"
shard_name = datasource.shard_names[0]

# Read data starting from row 2
row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=2))

# Verify the output
expected_texts = ["line3", "line4", "line5", "line6"]
assert len(row_data) == len(expected_texts), f"Expected {len(expected_texts)} rows starting from index 2"
assert row_data == expected_texts, f"Expected texts {expected_texts}, got {row_data}"

0 comments on commit 698e98f

Please sign in to comment.