Skip to content

Commit

Permalink
refactor out and support ParamWatchCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 26, 2024
1 parent c9e2f97 commit 0a7d78f
Showing 1 changed file with 101 additions and 159 deletions.
260 changes: 101 additions & 159 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@
import copy
import logging as pylogging
import os
import re
import subprocess
import sys
import tempfile
import threading
import time
import warnings
from abc import ABC
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Callable, Generic, Optional, TypeVar

import humanfriendly
import jax
import jax.numpy as jnp
from jaxtyping import PyTree
Expand Down Expand Up @@ -317,100 +312,7 @@ def update_pbar(step: StepInfo):
return update_pbar


def log_memory_usage(sample_interval: float = 1.0, log_individual_devices: bool = False):
"""
Logs memory usage. This runs a loop that samples memory usage every `sample_interval` seconds.
We only log when hooks are invoked, so there's not much point in running this much more frequently than you invoke
the hook.
I think it's a good idea to run this in a separate thread, so that you sample from random points, but I'm not sure.
:param sample_interval:
:return:
"""

directory = "/dev/shm"
# macos doesn't have /dev/shm
if not os.path.exists(directory):
directory = tempfile.gettempdir()

tempfile_name = os.path.join(directory, f"memory_usage_{os.getpid()}.prof")

# a lot of this code is lifted from https://github.com/ayaka14732/jax-smi CC-0

def inner():
import posix
import time

while True:
jax.profiler.save_device_memory_profile(f"{tempfile_name}.new")
posix.rename(f"{tempfile_name}.new", tempfile_name)
time.sleep(sample_interval)

thread = threading.Thread(target=inner, daemon=True)
thread.start()

def log_memory_usage(step: StepInfo):
process = subprocess.run(
args=f"go tool pprof -tags {tempfile_name}".split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
)

if process.returncode != 0:
warnings.warn("failed to run pprof. Is go installed?")
return

output = process.stdout.decode("utf-8")

# output looks like this:
# 2.4MB (12.53%): TFRT_CPU_0
# 2.4MB (12.50%): TFRT_CPU_1
# 2.4MB (12.50%): TFRT_CPU_2
# 2.4MB (12.50%): TFRT_CPU_3
# 2.4MB (12.50%): TFRT_CPU_4
# 2.4MB (12.50%): TFRT_CPU_5
# 2.4MB (12.50%): TFRT_CPU_6
# 2.4MB (12.50%): TFRT_CPU_7
#
# kind: Total 19.5MB
# 18.9MB (97.20%): buffer
# 558.4kB ( 2.80%): executable

# gpus look like this:
# 1.0MB ( 0.00%): gpu:0
per_device, by_kind = output.split("kind: Total ")

# first, get the total memory usage
regex = re.compile(r"^(\d+\.\d+[a-zA-Z]+)")
match = regex.search(by_kind)
if match:
memory_usage = humanfriendly.parse_size(match.group(1))
levanter.tracker.log({"memory/total": memory_usage / 1e6}, step=step.step)

# this works for the "kind" and the individual devices
regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)")

if log_individual_devices:
# now, get the memory usage per device.
# split the output at kind: Total
for match in regex.finditer(per_device):
memory_usage = humanfriendly.parse_size(match.group(1))
device_name = match.group(3)
levanter.tracker.log({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step)

# now, get the memory usage per kind.
# same regex as above
for match in regex.finditer(by_kind):
memory_usage = match.group(1)
memory_usage = humanfriendly.parse_size(memory_usage)
levanter.tracker.log({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step)

return log_memory_usage


def profile(path: str, start_step: int, num_steps: int, create_perfetto_link: bool) -> Callable[[StepInfo], None]:
print(f"create_perfetto_link: {create_perfetto_link}")

def profiler_callback_fn(step: StepInfo):
# -1 b/c step is the finished step
if step.step == start_step - 1:
Expand Down Expand Up @@ -506,7 +408,7 @@ class GradWatchCallback(JitCallback[S, M, dict[str, float | Histogram]]):

def __init__(
self,
prefix: Optional[str] = None,
prefix: str = "grad",
include_histogram: bool = True,
split_scan_layers: bool = True,
):
Expand All @@ -515,70 +417,110 @@ def __init__(
self.split_scan_layers = split_scan_layers

def inside_step(self, state: TrainerState[M], grad: M):
return self._generate_statistics_for("grad", grad)
# levanter.tracker.jit_log(logs, step=state._step)
return summary_statistics_for_tree(self.prefix, grad, self.split_scan_layers, self.include_histogram)

def on_step(self, step_info: StepInfo[S], cb_info: dict[str, float | Histogram]):
levanter.tracker.log(cb_info, step=step_info.step)

def _generate_statistics_for(self, kind: str, tree: M) -> dict[str, float | Histogram]:
if self.split_scan_layers:
is_leaf = lambda n: isinstance(n, haliax.nn.Stacked) or is_named_array(n) # noqa: E731
else:
is_leaf = is_named_array

def _rec_log_magnitudes(norms, hists, path_prefix, tree):
leaf_key_paths = jax_utils.leaf_key_paths(tree, prefix=path_prefix, is_leaf=is_leaf)
del path_prefix
for key_path, g in zip(
jax.tree.leaves(leaf_key_paths, is_leaf=is_leaf),
jax.tree.leaves(tree, is_leaf=is_leaf),
strict=True,
):
if self.split_scan_layers and isinstance(g, haliax.nn.Stacked):
vmapped_norms, vmapped_hists = haliax.vmap(_rec_log_magnitudes, g.Block)({}, {}, "", g.stacked)

for k, v in vmapped_norms.items():
for i in range(g.Block.size):
norms[f"{key_path}/{i}/{k}"] = v[i]

for k, v in vmapped_hists.items():
for i in range(g.Block.size):
hists[f"{key_path}/{i}/{k}"] = jax.tree.map(
lambda x: x[i] if is_jax_array_like(x) else x, v
)

elif isinstance(g, NamedArray):
# TODO: add linalg.norm to Haliax
norms[key_path] = jnp.linalg.norm(g.array)
if self.include_histogram:
hist = Histogram.from_named_array(g)
hists[key_path] = hist
elif is_jax_array_like(g):
norms[key_path] = jnp.linalg.norm(g)

if self.include_histogram:
hist = Histogram.from_array(g)
hists[key_path] = hist

return norms, hists

norms_to_log: dict[str, jax.Array] = {}
hists_to_log: dict[str, Histogram] = {}

_rec_log_magnitudes(norms_to_log, hists_to_log, None, tree)

to_log: dict = {}

if self.prefix is not None:
log_prefix = self.prefix + "/" + kind
else:
log_prefix = kind

for key, value in norms_to_log.items():
to_log[f"{log_prefix}/norm/{key}"] = value
class ParamWatchCallback(JitCallback[S, M, dict[str, float | Histogram]]):
"""
Emulates the behavior of Wandb's PyTorch-only built-in gradient logging (wandb.watch)
Args:
prefix (str): The prefix to use for logging.
include_histogram (bool): Whether to include histograms of the gradients.
split_scan_layers (bool): Whether to split the scan layers into separate histograms/norms
"""

def __init__(
self,
prefix: str = "params",
include_histogram: bool = True,
split_scan_layers: bool = True,
):
self.prefix = prefix
self.include_histogram = include_histogram
self.split_scan_layers = split_scan_layers

def inside_step(self, state: TrainerState[M], grad: M):
return summary_statistics_for_tree(
self.prefix, state.trainable_model, self.split_scan_layers, self.include_histogram
)

def on_step(self, step_info: StepInfo[S], cb_info: dict[str, float | Histogram]):
levanter.tracker.log(cb_info, step=step_info.step)


def summary_statistics_for_tree(
prefix: str, tree: M, split_scan_layers: bool, include_histogram: bool
) -> dict[str, float | Histogram]:
"""
Computes the summary statistics for a tree of (named) arrays.
This function is designed to allow you to emulate the behavior of Wandb's PyTorch-only built-in gradient logging,
but also works for any PyTree. It computes the Froebinius norm of each array,
and optionally the histogram as well.
Args:
prefix: The prefix to use for logging.
tree: The tree of arrays to compute the summary statistics for.
split_scan_layers: Whether to split the scan layers into separate histograms/norms. Recommended.
include_histogram: Whether to include histograms of the gradients. This increases overhead significantly.
Returns:
"""
if split_scan_layers:
is_leaf = lambda n: isinstance(n, haliax.nn.Stacked) or is_named_array(n) # noqa: E731
else:
is_leaf = is_named_array

def _rec_log_magnitudes(norms, hists, path_prefix, tree):
leaf_key_paths = jax_utils.leaf_key_paths(tree, prefix=path_prefix, is_leaf=is_leaf)
del path_prefix
for key_path, g in zip(
jax.tree.leaves(leaf_key_paths, is_leaf=is_leaf),
jax.tree.leaves(tree, is_leaf=is_leaf),
strict=True,
):
if split_scan_layers and isinstance(g, haliax.nn.Stacked):
vmapped_norms, vmapped_hists = haliax.vmap(_rec_log_magnitudes, g.Block)({}, {}, "", g.stacked)

for k, v in vmapped_norms.items():
for i in range(g.Block.size):
norms[f"{key_path}/{i}/{k}"] = v[i]

for k, v in vmapped_hists.items():
for i in range(g.Block.size):
hists[f"{key_path}/{i}/{k}"] = jax.tree.map(lambda x: x[i] if is_jax_array_like(x) else x, v)

elif isinstance(g, NamedArray):
# TODO: add linalg.norm to Haliax
norms[key_path] = jnp.linalg.norm(g.array)
if include_histogram:
hist = Histogram.from_named_array(g)
hists[key_path] = hist
elif is_jax_array_like(g):
norms[key_path] = jnp.linalg.norm(g)

if include_histogram:
hist = Histogram.from_array(g)
hists[key_path] = hist

return norms, hists

norms_to_log: dict[str, jax.Array] = {}
hists_to_log: dict[str, Histogram] = {}

_rec_log_magnitudes(norms_to_log, hists_to_log, None, tree)

to_log: dict = {}

for key, value in norms_to_log.items():
to_log[f"{prefix}/norm/{key}"] = value

for key, value in hists_to_log.items():
to_log[f"{log_prefix}/hist/{key}"] = value
for key, value in hists_to_log.items():
to_log[f"{prefix}/hist/{key}"] = value

return to_log
return to_log

0 comments on commit 0a7d78f

Please sign in to comment.