Skip to content

Commit

Permalink
a bit worried the bookkeeping isn't quite right on resume, but we're …
Browse files Browse the repository at this point in the history
…almost there.
  • Loading branch information
dlwh committed Nov 3, 2024
1 parent cb85654 commit 47441c0
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 56 deletions.
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_pile.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
data: !include data/pile_source_old.yaml
data: !include data/pile_mixture.yaml
model:
type: gpt2
hidden_dim: 768
Expand Down
122 changes: 76 additions & 46 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch
from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor
from ..data.sharded_datasource import ShardedDataSource
from ..utils.fsspec_utils import async_remove
from ..utils.fsspec_utils import exists as fsspec_exists
from ..utils.fsspec_utils import remove as fsspec_remove
from ..utils.ray_utils import (
ExceptionInfo,
RefBox,
Expand Down Expand Up @@ -672,7 +672,13 @@ def __init__(
memory=2 * self._options.target_bytes_per_flush,
).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor)

self._tokenize_pbar = tqdm(total=len(source.shard_names), desc="Tokenizing", unit="shard")
self._tokenize_pbar = tqdm(
total=len(source.shard_names), desc=f"{path_for_name}: tokenizing", unit="shard"
)
self._copy_pbar = tqdm(total=len(source.shard_names), desc=f"{path_for_name}: copying", unit="shard")
self._report_totals = _ProgressReport(0, 0, 0)
self._copy_report_totals = _ProgressReport(0, 0, 0)
self._last_update = time.time()

except Exception:
# Ray behaves poorly if the constructor of an actor fails, so we catch and log here
Expand Down Expand Up @@ -763,15 +769,39 @@ async def _do_notify_async():
def _report_progress(self, report: "_ProgressReport"):
import humanfriendly

self._tokenize_pbar.update(report.total_shards_completed)
mb_str = humanfriendly.format_size(report.total_bytes)
self._tokenize_pbar.set_postfix(
{
"rows": report.total_rows,
"shards": report.total_shards_completed,
"mb": mb_str,
}
)
self._tokenize_pbar.update(report.new_shards)
self._report_totals.new_shards += report.new_shards
self._report_totals.new_rows += report.new_rows
self._report_totals.new_bytes += report.new_bytes

if time.time() - self._last_update > 10.0:
self._last_update = time.time()

mb_str = humanfriendly.format_size(self._report_totals.new_bytes)
self._tokenize_pbar.set_postfix(
{
"rows": self._report_totals.new_rows,
"shards": self._report_totals.new_shards,
"size": mb_str,
}
)

def _report_copy_progress(self, report: "_ProgressReport"):
# TODO: log bytes copied
self._copy_pbar.update(report.new_shards)
self._copy_report_totals.new_shards += report.new_shards
self._copy_report_totals.new_rows += report.new_rows
self._copy_report_totals.new_bytes += report.new_bytes

if time.time() - self._last_update > 10.0:
self._last_update = time.time()
self._copy_pbar.set_postfix(
{
"shards": report.new_shards,
"rows": report.new_rows,
# "size": humanfriendly.format_size(report.new_bytes),
}
)


def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()):
Expand Down Expand Up @@ -855,12 +885,10 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):

shard_groups = _assign_shards_to_groups(source, options.num_shard_groups)

logger.info(
logger.debug(
f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}."
)

unit = "shard" if len(shard_groups) == len(source.shard_names) else "shard group"

processor_ref = ray.put(processor)
source_ref = ray.put(source)

Expand Down Expand Up @@ -928,7 +956,6 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):

copy_refs: dict[str, ray.ObjectRef] = {}
last_ref: ray.ObjectRef | None = None
copying_pbar = tqdm(total=len(shard_groups), desc="Copying", unit=unit, leave=False, position=1)

for group in shard_groups:
# first make sure it's either done this run or already done
Expand All @@ -944,7 +971,9 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):
shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group])
if shards_copied == len(shard_groups[group]) or group == first_group:
assert initial_ledger.total_num_rows >= total_rows_from_caches
copying_pbar.update(1)
parent._report_copy_progress.remote(
_ProgressReport(new_shards=shards_copied, new_rows=initial_ledger.total_num_rows)
)

elif shards_copied > 0:
# In theory we can handle this, but it's a bit tricky, so we're going to punt for now
Expand Down Expand Up @@ -976,8 +1005,11 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):
continue

if copy_refs.get(group) is not None:
ray.wait([copy_refs[group]], fetch_local=False)
copying_pbar.update(1)
ledger = ray.get(copy_refs[group])
ledgers[group] = ledger
parent._report_copy_progress.remote(
_ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows)
)

# refs form a linked list implicitly, so we can just wait on the last one
if last_ref is not None:
Expand All @@ -989,23 +1021,23 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):
ledger._serialize_and_commit(cache_dir)
parent._notify_updated_ledger.remote(ledger)

# clean up the temporary caches
_clean_up_temp_caches(paths, first_group)


def _clean_up_temp_caches(paths, first_group):
async def cleanup():
futures = []
for group, path in paths.items():
if group == first_group:
continue

if fsspec_exists(path):
futures.append(async_remove(path, recursive=True))

await asyncio.gather(*futures)
for group, path in paths.items():
if group == first_group:
continue

asyncio.run(cleanup())
if fsspec_exists(path):
for i in range(10):
# this is crashy for some reason
try:
fsspec_remove(path, recursive=True)
break
except Exception:
logger.exception(f"Failed to remove {path} on attempt {i}")
time.sleep(1)


def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]:
Expand Down Expand Up @@ -1074,6 +1106,9 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R
assert not dest_ledger.is_finished

parent._notify_updated_ledger.remote(dest_ledger)
parent._report_copy_progress.remote(
_ProgressReport(new_shards=len(source_ledger.shard_rows), new_rows=source_ledger.total_num_rows)
)

return dest_ledger

Expand Down Expand Up @@ -1179,9 +1214,9 @@ def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore

@dataclass
class _ProgressReport:
total_rows: int
total_bytes: float
total_shards_completed: int
new_rows: int = 0
new_bytes: float = 0
new_shards: int = 0
# TODO: other counts


Expand Down Expand Up @@ -1216,13 +1251,13 @@ def _tokenize_one_shard_group(
total_rows = ledger.total_num_rows
found_shard_with_rows = False

report = _ProgressReport(total_rows, 0, 0)
if total_rows > 0:
report_fn(_ProgressReport(new_rows=total_rows), ledger)

for shard_name in shards:
if shard_name in ledger.finished_shards:
logger.info(f"Shard {shard_name} already processed.")
report.total_shards_completed += 1
report_fn(report, ledger)
report_fn(_ProgressReport(new_shards=1), ledger)
continue

logger.debug(f"Processing {shard_name}.")
Expand All @@ -1247,6 +1282,7 @@ def _tokenize_one_shard_group(

this_batch_size += len(batch)
rows_this_shard += len(batch)
total_rows += len(batch)

if prepared_batch is None:
prepared_batch = this_prepared
Expand All @@ -1259,10 +1295,7 @@ def _tokenize_one_shard_group(

if batch_byte_size > options.target_bytes_per_flush:
writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch)
report.total_rows += this_batch_size
report.total_bytes += batch_byte_size

report_fn(report, writer.ledger)
report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger)

nice_bytes = humanfriendly.format_size(batch_byte_size)
logger.debug(
Expand All @@ -1276,9 +1309,7 @@ def _tokenize_one_shard_group(
batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch))
nice_bytes = humanfriendly.format_size(batch_byte_size)

report.total_rows += this_batch_size
report.total_bytes += batch_byte_size
report_fn(report, writer.ledger)
report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger)

writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch)
logger.debug(
Expand All @@ -1287,15 +1318,14 @@ def _tokenize_one_shard_group(
this_batch_size = 0
prepared_batch = None

report.total_shards_completed += 1
writer.finish_shard(shard_name, rows_this_shard)

report_fn(report, writer.ledger)
report_fn(_ProgressReport(new_shards=1), writer.ledger)

if not force_unfinalized:
writer.finish()

logger.info(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.")
logger.debug(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.")

return writer.ledger

Expand Down
11 changes: 2 additions & 9 deletions src/levanter/utils/fsspec_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import asyncio

import braceexpand
import fsspec
from fsspec.asyn import AsyncFileSystem

from levanter.utils.thread_utils import _executor, blocking_wait


def exists(url, **kwargs) -> bool:
"""Check if a file exists on a remote filesystem."""
Expand Down Expand Up @@ -40,10 +36,7 @@ def remove(url, *, recursive=False, **kwargs):
# TODO: better to use a STS deletion policy or job for this one.
fs, path = fsspec.core.url_to_fs(url, **kwargs)

if isinstance(fs, AsyncFileSystem):
blocking_wait(fs._rm(path, recursive=recursive))
else:
fs.rm(path, recursive=recursive)
fs.rm(path, recursive=recursive)


async def async_remove(url, *, recursive=False, **kwargs):
Expand All @@ -53,4 +46,4 @@ async def async_remove(url, *, recursive=False, **kwargs):
if isinstance(fs, AsyncFileSystem):
return await fs._rm(path, recursive=recursive)
else:
return await asyncio.wrap_future(_executor.submit(fs.rm, path, recursive=recursive))
fs.rm(path, recursive=recursive)

0 comments on commit 47441c0

Please sign in to comment.