Skip to content

Commit

Permalink
tracker.finish to deal with subprocess stuff (#801)
Browse files Browse the repository at this point in the history
we use subprocess in marin to invoke levanter, but subprocesses don't
wait on other subprocesses somehow, and so wandb doesn't get a chance to
finish. This solves this
  • Loading branch information
dlwh authored Nov 12, 2024
1 parent 0503001 commit b0d53a0
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def compute_log_probs(model, example):
checkpointer = trainer.config.checkpointer.create(trainer.run_id)
checkpointer.wait_until_finished()

# This isn't necessary except when Levanter is run in a subprocess (as happens w/ ray)
trainer.tracker.finish()


if __name__ == "__main__":
levanter.config.main(main)()
3 changes: 3 additions & 0 deletions src/levanter/tracker/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio
pylogger.exception(f"Error logging artifact {artifact_path} to {log_path}")
return

def finish(self):
self.writer.close()


@TrackerConfig.register_subclass("tensorboard")
@dataclass
Expand Down
22 changes: 22 additions & 0 deletions src/levanter/tracker/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
pass

@abc.abstractmethod
def finish(self):
"""
Finish the tracker. This is called when the tracker is no longer needed. This can, e.g.,
force a commit of all metrics.
"""
pass

def __enter__(self):
import levanter.tracker.tracker_fns as tracker_fns

Expand Down Expand Up @@ -81,6 +89,17 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio
for tracker in self.loggers:
tracker.log_artifact(artifact_path, name=name, type=type)

def finish(self):
excs = []
for tracker in self.loggers:
try:
tracker.finish()
except Exception as e:
excs.append(e)

if excs:
raise RuntimeError("Errors occurred when finishing trackers") from excs[0]


class TrackerConfig(draccus.PluginRegistry, abc.ABC):
discover_packages_path = "levanter.tracker"
Expand Down Expand Up @@ -109,6 +128,9 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
pass

def finish(self):
pass


@TrackerConfig.register_subclass("noop")
@dataclasses.dataclass
Expand Down
4 changes: 4 additions & 0 deletions src/levanter/tracker/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
self.run.log_artifact(artifact_path, name=name, type=type)

def finish(self):
logger.info("Finishing wandb run...")
self.run.finish()


def is_wandb_available():
try:
Expand Down

0 comments on commit b0d53a0

Please sign in to comment.