From 47441c0fd12b91eada4e4133a515f007a074de16 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 2 Nov 2024 22:59:13 -0700 Subject: [PATCH] a bit worried the bookkeeping isn't quite right on resume, but we're almost there. --- config/gpt2_small_fast_pile.yaml | 2 +- src/levanter/store/cache.py | 122 ++++++++++++++++++----------- src/levanter/utils/fsspec_utils.py | 11 +-- 3 files changed, 79 insertions(+), 56 deletions(-) diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index 3a21732a7..291213d75 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -1,4 +1,4 @@ -data: !include data/pile_source_old.yaml +data: !include data/pile_mixture.yaml model: type: gpt2 hidden_dim: 768 diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index f01e3b881..cd59aef4c 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -35,8 +35,8 @@ from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource -from ..utils.fsspec_utils import async_remove from ..utils.fsspec_utils import exists as fsspec_exists +from ..utils.fsspec_utils import remove as fsspec_remove from ..utils.ray_utils import ( ExceptionInfo, RefBox, @@ -672,7 +672,13 @@ def __init__( memory=2 * self._options.target_bytes_per_flush, ).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor) - self._tokenize_pbar = tqdm(total=len(source.shard_names), desc="Tokenizing", unit="shard") + self._tokenize_pbar = tqdm( + total=len(source.shard_names), desc=f"{path_for_name}: tokenizing", unit="shard" + ) + self._copy_pbar = tqdm(total=len(source.shard_names), desc=f"{path_for_name}: copying", unit="shard") + self._report_totals = _ProgressReport(0, 0, 0) + self._copy_report_totals = _ProgressReport(0, 0, 0) + self._last_update = time.time() except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here @@ -763,15 +769,39 @@ async def _do_notify_async(): def _report_progress(self, report: "_ProgressReport"): import humanfriendly - self._tokenize_pbar.update(report.total_shards_completed) - mb_str = humanfriendly.format_size(report.total_bytes) - self._tokenize_pbar.set_postfix( - { - "rows": report.total_rows, - "shards": report.total_shards_completed, - "mb": mb_str, - } - ) + self._tokenize_pbar.update(report.new_shards) + self._report_totals.new_shards += report.new_shards + self._report_totals.new_rows += report.new_rows + self._report_totals.new_bytes += report.new_bytes + + if time.time() - self._last_update > 10.0: + self._last_update = time.time() + + mb_str = humanfriendly.format_size(self._report_totals.new_bytes) + self._tokenize_pbar.set_postfix( + { + "rows": self._report_totals.new_rows, + "shards": self._report_totals.new_shards, + "size": mb_str, + } + ) + + def _report_copy_progress(self, report: "_ProgressReport"): + # TODO: log bytes copied + self._copy_pbar.update(report.new_shards) + self._copy_report_totals.new_shards += report.new_shards + self._copy_report_totals.new_rows += report.new_rows + self._copy_report_totals.new_bytes += report.new_bytes + + if time.time() - self._last_update > 10.0: + self._last_update = time.time() + self._copy_pbar.set_postfix( + { + "shards": report.new_shards, + "rows": report.new_rows, + # "size": humanfriendly.format_size(report.new_bytes), + } + ) def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()): @@ -855,12 +885,10 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) - logger.info( + logger.debug( f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}." ) - unit = "shard" if len(shard_groups) == len(source.shard_names) else "shard group" - processor_ref = ray.put(processor) source_ref = ray.put(source) @@ -928,7 +956,6 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): copy_refs: dict[str, ray.ObjectRef] = {} last_ref: ray.ObjectRef | None = None - copying_pbar = tqdm(total=len(shard_groups), desc="Copying", unit=unit, leave=False, position=1) for group in shard_groups: # first make sure it's either done this run or already done @@ -944,7 +971,9 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group]) if shards_copied == len(shard_groups[group]) or group == first_group: assert initial_ledger.total_num_rows >= total_rows_from_caches - copying_pbar.update(1) + parent._report_copy_progress.remote( + _ProgressReport(new_shards=shards_copied, new_rows=initial_ledger.total_num_rows) + ) elif shards_copied > 0: # In theory we can handle this, but it's a bit tricky, so we're going to punt for now @@ -976,8 +1005,11 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): continue if copy_refs.get(group) is not None: - ray.wait([copy_refs[group]], fetch_local=False) - copying_pbar.update(1) + ledger = ray.get(copy_refs[group]) + ledgers[group] = ledger + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows) + ) # refs form a linked list implicitly, so we can just wait on the last one if last_ref is not None: @@ -989,23 +1021,23 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): ledger._serialize_and_commit(cache_dir) parent._notify_updated_ledger.remote(ledger) - # clean up the temporary caches _clean_up_temp_caches(paths, first_group) def _clean_up_temp_caches(paths, first_group): - async def cleanup(): - futures = [] - for group, path in paths.items(): - if group == first_group: - continue - - if fsspec_exists(path): - futures.append(async_remove(path, recursive=True)) - - await asyncio.gather(*futures) + for group, path in paths.items(): + if group == first_group: + continue - asyncio.run(cleanup()) + if fsspec_exists(path): + for i in range(10): + # this is crashy for some reason + try: + fsspec_remove(path, recursive=True) + break + except Exception: + logger.exception(f"Failed to remove {path} on attempt {i}") + time.sleep(1) def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: @@ -1074,6 +1106,9 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R assert not dest_ledger.is_finished parent._notify_updated_ledger.remote(dest_ledger) + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(source_ledger.shard_rows), new_rows=source_ledger.total_num_rows) + ) return dest_ledger @@ -1179,9 +1214,9 @@ def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore @dataclass class _ProgressReport: - total_rows: int - total_bytes: float - total_shards_completed: int + new_rows: int = 0 + new_bytes: float = 0 + new_shards: int = 0 # TODO: other counts @@ -1216,13 +1251,13 @@ def _tokenize_one_shard_group( total_rows = ledger.total_num_rows found_shard_with_rows = False - report = _ProgressReport(total_rows, 0, 0) + if total_rows > 0: + report_fn(_ProgressReport(new_rows=total_rows), ledger) for shard_name in shards: if shard_name in ledger.finished_shards: logger.info(f"Shard {shard_name} already processed.") - report.total_shards_completed += 1 - report_fn(report, ledger) + report_fn(_ProgressReport(new_shards=1), ledger) continue logger.debug(f"Processing {shard_name}.") @@ -1247,6 +1282,7 @@ def _tokenize_one_shard_group( this_batch_size += len(batch) rows_this_shard += len(batch) + total_rows += len(batch) if prepared_batch is None: prepared_batch = this_prepared @@ -1259,10 +1295,7 @@ def _tokenize_one_shard_group( if batch_byte_size > options.target_bytes_per_flush: writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) - report.total_rows += this_batch_size - report.total_bytes += batch_byte_size - - report_fn(report, writer.ledger) + report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) nice_bytes = humanfriendly.format_size(batch_byte_size) logger.debug( @@ -1276,9 +1309,7 @@ def _tokenize_one_shard_group( batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) nice_bytes = humanfriendly.format_size(batch_byte_size) - report.total_rows += this_batch_size - report.total_bytes += batch_byte_size - report_fn(report, writer.ledger) + report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) logger.debug( @@ -1287,15 +1318,14 @@ def _tokenize_one_shard_group( this_batch_size = 0 prepared_batch = None - report.total_shards_completed += 1 writer.finish_shard(shard_name, rows_this_shard) - report_fn(report, writer.ledger) + report_fn(_ProgressReport(new_shards=1), writer.ledger) if not force_unfinalized: writer.finish() - logger.info(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") + logger.debug(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") return writer.ledger diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index cc03c174b..c8d3931fe 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,11 +1,7 @@ -import asyncio - import braceexpand import fsspec from fsspec.asyn import AsyncFileSystem -from levanter.utils.thread_utils import _executor, blocking_wait - def exists(url, **kwargs) -> bool: """Check if a file exists on a remote filesystem.""" @@ -40,10 +36,7 @@ def remove(url, *, recursive=False, **kwargs): # TODO: better to use a STS deletion policy or job for this one. fs, path = fsspec.core.url_to_fs(url, **kwargs) - if isinstance(fs, AsyncFileSystem): - blocking_wait(fs._rm(path, recursive=recursive)) - else: - fs.rm(path, recursive=recursive) + fs.rm(path, recursive=recursive) async def async_remove(url, *, recursive=False, **kwargs): @@ -53,4 +46,4 @@ async def async_remove(url, *, recursive=False, **kwargs): if isinstance(fs, AsyncFileSystem): return await fs._rm(path, recursive=recursive) else: - return await asyncio.wrap_future(_executor.submit(fs.rm, path, recursive=recursive)) + fs.rm(path, recursive=recursive)