From 8871fccaf40f7bec23543d3da725f9180c27b8ef Mon Sep 17 00:00:00 2001
From: Nikil Ravi <nravi@stanford.edu>
Date: Tue, 15 Oct 2024 01:27:50 -0700
Subject: [PATCH 1/3] changes in how parquet is read

---
 src/levanter/data/sharded_datasource.py | 60 +++++++++++++++++++++----
 1 file changed, 52 insertions(+), 8 deletions(-)

diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py
index 208116ca6..e43b12fde 100644
--- a/src/levanter/data/sharded_datasource.py
+++ b/src/levanter/data/sharded_datasource.py
@@ -244,10 +244,31 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
                     for doc in data[row:]:
                         yield doc[self.text_key]
                 case ".parquet":
-                    table = pq.read_table(f)
-                    sliced_table = table.slice(row)
-                    for record in sliced_table.to_pylist():
-                        yield record[self.text_key]  # assumes text_key is in record
+                    parquet_file = pq.ParquetFile(f)
+                    total_rows = parquet_file.metadata.num_rows
+                    if row >= total_rows:
+                        return iter([])
+
+                    # Compute cumulative row counts
+                    row_counts = [rg.num_rows for rg in parquet_file.metadata.row_groups]
+                    cumulative_rows = [0]
+                    for count in row_counts:
+                        cumulative_rows.append(cumulative_rows[-1] + count)
+
+                    # Find the starting row group and row within it
+                    for idx, cum_row in enumerate(cumulative_rows):
+                        if cum_row > row:
+                            row_group_index = idx - 1
+                            start_row_in_group = row - cumulative_rows[row_group_index]
+                            break
+
+                    # Read from the starting row group onwards
+                    for rg_idx in range(row_group_index, parquet_file.num_row_groups):
+                        table = parquet_file.read_row_group(rg_idx, columns=[self.text_key])
+                        if rg_idx == row_group_index:
+                            table = table.slice(start_row_in_group)
+                        for record in table.to_pylist():
+                            yield record[self.text_key]
                 case _:
                     raise ValueError(f"Unknown format {format}")
 
@@ -439,10 +460,33 @@ def shard_names(self) -> Sequence[str]:
     def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
         url = self._shard_name_to_url_mapping[shard_name]
         with fsspec.open(url, "rb", compression="infer") as f:
-            table = pq.read_table(f)
-            sliced_table = table.slice(row)  # zero-copy slicing
-            for record in sliced_table.to_pylist():
-                yield record
+            parquet_file = pq.ParquetFile(f)
+            total_rows = parquet_file.metadata.num_rows
+            if row >= total_rows:
+                return iter([])
+
+            # compute cumulative row counts
+            row_counts = [rg_meta.num_rows for rg_meta in parquet_file.metadata.row_groups]
+            cumulative_rows = [0]
+            for count in row_counts:
+                cumulative_rows.append(cumulative_rows[-1] + count)
+
+            # find starting row group and also find the row within it
+            for idx, cum_row in enumerate(cumulative_rows):
+                if cum_row > row:
+                    row_group_index = idx - 1
+                    start_row_in_group = row - cumulative_rows[row_group_index]
+                    break
+
+            # read from the starting row group onwards
+            for rg_idx in range(row_group_index, parquet_file.num_row_groups):
+                table = parquet_file.read_row_group(rg_idx)
+
+                # if we're in the row group we want, slice the table at/from the row we want
+                if rg_idx == row_group_index:
+                    table = table.slice(start_row_in_group)
+
+                yield from table.to_pylist()
 
 
 def _mk_shard_name_mapping(urls):

From 698e98f5345e0b7d31ec17e2acf8de9006d57bb6 Mon Sep 17 00:00:00 2001
From: Nikil Ravi <nravi@stanford.edu>
Date: Tue, 15 Oct 2024 01:42:18 -0700
Subject: [PATCH 2/3] fixes, and another test

---
 src/levanter/data/sharded_datasource.py | 10 +++++---
 tests/test_sharded_dataset.py           | 33 ++++++++++++++++++++++++-
 2 files changed, 39 insertions(+), 4 deletions(-)

diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py
index e43b12fde..2c7eb42d9 100644
--- a/src/levanter/data/sharded_datasource.py
+++ b/src/levanter/data/sharded_datasource.py
@@ -249,8 +249,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
                     if row >= total_rows:
                         return iter([])
 
+                    num_row_groups = parquet_file.metadata.num_row_groups
+
                     # Compute cumulative row counts
-                    row_counts = [rg.num_rows for rg in parquet_file.metadata.row_groups]
+                    row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)]
                     cumulative_rows = [0]
                     for count in row_counts:
                         cumulative_rows.append(cumulative_rows[-1] + count)
@@ -465,8 +467,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
             if row >= total_rows:
                 return iter([])
 
-            # compute cumulative row counts
-            row_counts = [rg_meta.num_rows for rg_meta in parquet_file.metadata.row_groups]
+            num_row_groups = parquet_file.metadata.num_row_groups
+
+            # Compute cumulative row counts
+            row_counts = [parquet_file.metadata.row_group(i).num_rows for i in range(num_row_groups)]
             cumulative_rows = [0]
             for count in row_counts:
                 cumulative_rows.append(cumulative_rows[-1] + count)
diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py
index 90ab6c34b..d93793bff 100644
--- a/tests/test_sharded_dataset.py
+++ b/tests/test_sharded_dataset.py
@@ -1,7 +1,12 @@
 import os
 import tempfile
 
-from levanter.data.sharded_datasource import AudioTextUrlDataSource, ParquetDataSource, _sniff_format_for_dataset
+from levanter.data.sharded_datasource import (
+    AudioTextUrlDataSource,
+    ParquetDataSource,
+    _sniff_format_for_dataset,
+    TextUrlDataSource,
+)
 from test_utils import skip_if_no_soundlibs
 
 
@@ -68,3 +73,29 @@ def test_basic_parquet_datasource_read_row():
         assert row_data[0]["column2"] == 20
         assert row_data[1]["column1"] == "value3"
         assert row_data[1]["column2"] == 30
+
+
+def test_text_url_data_source_parquet():
+    import pyarrow as pa
+    import pyarrow.parquet as pq
+
+    with tempfile.NamedTemporaryFile(suffix=".parquet", delete=True) as f:
+        data = {
+            "text": ["line1", "line2", "line3", "line4", "line5", "line6"],
+            "column2": [10, 20, 30, 40, 50, 60],
+        }
+        table = pa.Table.from_pydict(data)
+        pq.write_table(table, f.name)
+
+        datasource = TextUrlDataSource([os.path.abspath(f.name)], text_key="text")
+
+        assert len(datasource.shard_names) == 1, "Expected only one shard"
+        shard_name = datasource.shard_names[0]
+
+        # Read data starting from row 2
+        row_data = list(datasource.open_shard_at_row(shard_name=shard_name, row=2))
+
+        # Verify the output
+        expected_texts = ["line3", "line4", "line5", "line6"]
+        assert len(row_data) == len(expected_texts), f"Expected {len(expected_texts)} rows starting from index 2"
+        assert row_data == expected_texts, f"Expected texts {expected_texts}, got {row_data}"

From 38a7e688b1f3fb61ea77a2f98060a8e99071e3b5 Mon Sep 17 00:00:00 2001
From: Nikil Ravi <nravi@stanford.edu>
Date: Tue, 15 Oct 2024 01:57:18 -0700
Subject: [PATCH 3/3] open in binary mode

---
 src/levanter/data/sharded_datasource.py | 22 +++++++++++++---------
 tests/test_sharded_dataset.py           |  2 +-
 2 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py
index 2c7eb42d9..80dd1b4b2 100644
--- a/src/levanter/data/sharded_datasource.py
+++ b/src/levanter/data/sharded_datasource.py
@@ -224,26 +224,30 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
         compression = "infer"
         if url.endswith(".zstd"):  # hacky way to detect zstd
             compression = "zstd"
-        with fsspec.open(url, "r", compression=compression) as f:
-            format = _sniff_format_for_dataset(url)
-            match format:
-                case ".jsonl":
+
+        format = _sniff_format_for_dataset(url)
+        match format:
+            case ".jsonl":
+                with fsspec.open(url, "r", compression=compression) as f:
                     # TODO: would be nice if we could seek faster than this. Right now, all we do is skip json parsing
                     # which is not nothing, but not ideal.
                     for line in f:
                         if i >= row:
                             yield json.loads(line)[self.text_key]
                         i += 1
-                case ".txt":
+            case ".txt":
+                with fsspec.open(url, "r", compression=compression) as f:
                     for line in f:
                         if i >= row:
                             yield line
                         i += 1
-                case ".json":
+            case ".json":
+                with fsspec.open(url, "r", compression=compression) as f:
                     data = json.load(f)
                     for doc in data[row:]:
                         yield doc[self.text_key]
-                case ".parquet":
+            case ".parquet":
+                with fsspec.open(url, "rb", compression=compression) as f:
                     parquet_file = pq.ParquetFile(f)
                     total_rows = parquet_file.metadata.num_rows
                     if row >= total_rows:
@@ -271,8 +275,8 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]:
                             table = table.slice(start_row_in_group)
                         for record in table.to_pylist():
                             yield record[self.text_key]
-                case _:
-                    raise ValueError(f"Unknown format {format}")
+            case _:
+                raise ValueError(f"Unknown format {format}")
 
 
 class AudioTextUrlDataSource(ShardedDataSource[Tuple[np.ndarray, int, str]]):
diff --git a/tests/test_sharded_dataset.py b/tests/test_sharded_dataset.py
index d93793bff..a375d9344 100644
--- a/tests/test_sharded_dataset.py
+++ b/tests/test_sharded_dataset.py
@@ -4,8 +4,8 @@
 from levanter.data.sharded_datasource import (
     AudioTextUrlDataSource,
     ParquetDataSource,
-    _sniff_format_for_dataset,
     TextUrlDataSource,
+    _sniff_format_for_dataset,
 )
 from test_utils import skip_if_no_soundlibs