Skip to content

Commit

Permalink
tracker.finish to deal with subprocess stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 12, 2024
1 parent 0503001 commit b49e69d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def compute_log_probs(model, example):
checkpointer = trainer.config.checkpointer.create(trainer.run_id)
checkpointer.wait_until_finished()

# This isn't necessary except when Levanter is run in a subprocess (as happens w/ ray)
trainer.tracker.finish()


if __name__ == "__main__":
levanter.config.main(main)()
3 changes: 3 additions & 0 deletions src/levanter/tracker/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio
pylogger.exception(f"Error logging artifact {artifact_path} to {log_path}")
return

def finish(self):
self.writer.close()


@TrackerConfig.register_subclass("tensorboard")
@dataclass
Expand Down
22 changes: 22 additions & 0 deletions src/levanter/tracker/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
pass

@abc.abstractmethod
def finish(self):
"""
Finish the tracker. This is called when the tracker is no longer needed. This can, e.g.,
force a commit of all metrics.
"""
pass

def __enter__(self):
import levanter.tracker.tracker_fns as tracker_fns

Expand Down Expand Up @@ -81,6 +89,17 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio
for tracker in self.loggers:
tracker.log_artifact(artifact_path, name=name, type=type)

def finish(self):
excs = []
for tracker in self.loggers:
try:
tracker.finish()
except Exception as e:
excs.append(e)

if excs:
raise RuntimeError("Errors occurred when finishing trackers") from excs[0]


class TrackerConfig(draccus.PluginRegistry, abc.ABC):
discover_packages_path = "levanter.tracker"
Expand Down Expand Up @@ -109,6 +128,9 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
pass

def finish(self):
pass


@TrackerConfig.register_subclass("noop")
@dataclasses.dataclass
Expand Down
4 changes: 4 additions & 0 deletions src/levanter/tracker/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
self.run.log_artifact(artifact_path, name=name, type=type)

def finish(self):
logger.info("Finishing wandb run...")
self.run.finish()


def is_wandb_available():
try:
Expand Down

0 comments on commit b49e69d

Please sign in to comment.