From b0d53a02d06e82bc7cab65e40461e6ea66ee27bd Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 12 Nov 2024 11:58:39 -0800 Subject: [PATCH] tracker.finish to deal with subprocess stuff (#801) we use subprocess in marin to invoke levanter, but subprocesses don't wait on other subprocesses somehow, and so wandb doesn't get a chance to finish. This solves this --- src/levanter/main/train_lm.py | 3 +++ src/levanter/tracker/tensorboard.py | 3 +++ src/levanter/tracker/tracker.py | 22 ++++++++++++++++++++++ src/levanter/tracker/wandb.py | 4 ++++ 4 files changed, 32 insertions(+) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index b411bd59e..f2ad3e7ce 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -268,6 +268,9 @@ def compute_log_probs(model, example): checkpointer = trainer.config.checkpointer.create(trainer.run_id) checkpointer.wait_until_finished() + # This isn't necessary except when Levanter is run in a subprocess (as happens w/ ray) + trainer.tracker.finish() + if __name__ == "__main__": levanter.config.main(main)() diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py index 360c32171..e819d6459 100644 --- a/src/levanter/tracker/tensorboard.py +++ b/src/levanter/tracker/tensorboard.py @@ -43,6 +43,9 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio pylogger.exception(f"Error logging artifact {artifact_path} to {log_path}") return + def finish(self): + self.writer.close() + @TrackerConfig.register_subclass("tensorboard") @dataclass diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 8b6816f17..99fd217e5 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -46,6 +46,14 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pass + @abc.abstractmethod + def finish(self): + """ + Finish the tracker. This is called when the tracker is no longer needed. This can, e.g., + force a commit of all metrics. + """ + pass + def __enter__(self): import levanter.tracker.tracker_fns as tracker_fns @@ -81,6 +89,17 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio for tracker in self.loggers: tracker.log_artifact(artifact_path, name=name, type=type) + def finish(self): + excs = [] + for tracker in self.loggers: + try: + tracker.finish() + except Exception as e: + excs.append(e) + + if excs: + raise RuntimeError("Errors occurred when finishing trackers") from excs[0] + class TrackerConfig(draccus.PluginRegistry, abc.ABC): discover_packages_path = "levanter.tracker" @@ -109,6 +128,9 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pass + def finish(self): + pass + @TrackerConfig.register_subclass("noop") @dataclasses.dataclass diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 18f0251ec..981bebf83 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -72,6 +72,10 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): self.run.log_artifact(artifact_path, name=name, type=type) + def finish(self): + logger.info("Finishing wandb run...") + self.run.finish() + def is_wandb_available(): try: