Skip to content

Commit

Permalink
backup
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jul 10, 2024
1 parent cb44b19 commit 74bbff2
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 84 deletions.
12 changes: 12 additions & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,15 @@

CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
MODEL_CONFIG_FILE_NAME = "model_config.json"

GLOBAL_STEP = None
LOG_STATE_INTERVAL = 1
IS_RANK_TO_MONITOR = None
CONFIG = None

TRAINING_CONFIG = None


DEBUG_PATH = "./debug/nn_states_with_bs_2_and_transpose_qkv/acts/"

MONITOR_STATE_PATH = "/fsx/phuc/projects/nanotron/debug/runs"
22 changes: 22 additions & 0 deletions src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,28 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
grad = fp32_grad
fp32_param.grad = grad

from nanotron import constants

if (constants.GLOBAL_STEP - 1) % constants.LOG_STATE_INTERVAL == 0 and constants.IS_RANK_TO_MONITOR is True:
import wandb
from nanotron import constants
from nanotron.scaling.monitor import save_tensor

save_tensor(
name=f"{name}.grad",
tensor=fp32_param.grad,
path=f"{constants.MONITOR_STATE_PATH}/{constants.CONFIG.general.run}/{constants.GLOBAL_STEP}/grads",
)

wandb.log(
{
f"{name}:grad:mean": fp32_param.grad.detach().mean().item(),
f"{name}:grad:std": fp32_param.grad.detach().std().item(),
f"{name}:grad:norm": fp32_param.grad.detach().norm().item(),
"iteration_step": constants.GLOBAL_STEP,
}
)

@contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronizations across
Expand Down
147 changes: 93 additions & 54 deletions src/nanotron/scaling/monitor.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,69 @@
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from nanotron import constants
from nanotron.constants import MONITOR_STATE_PATH
from nanotron.models.base import NanotronModel
from nanotron.parallel import ParallelContext
from torch import nn
from torch.distributed import ReduceOp


def track_weight_and_grad_stats(
name: str,
module: nn.Module,
save_weights: bool,
save_grads: bool,
save_acts: bool,
parallel_context: ParallelContext,
):
def compute_stats(tensors, metrics: List[str] = ["amax"]):
NAME_TO_FUNC = {
"mean": lambda x: x.mean().item(),
"std": lambda x: x.std().item(),
"var": lambda x: x.var().item(),
"norm": lambda x: x.norm().item(),
"min": lambda x: x.min().item(),
"max": lambda x: x.max().item(),
"amax": lambda x: x.abs().max().item(),
def save_tensor(name, tensor, path):
if name is None or name == "":
return

# dp_rank = dist.get_rank(group=parallel_context.dp_pg)
# tp_rank = dist.get_rank(group=parallel_context.tp_pg)
# pp_rank = dist.get_rank(group=parallel_context.pp_pg)

os.makedirs(path, exist_ok=True)

# if dp_rank == 0 and tp_rank == 0 and pp_rank == 0:
torch.save(
tensor,
# f"{path}/{name}_dp_rank_{dp_rank}_and_pp_rank_{pp_rank}_and_tp_rank_{tp_rank}.pt"
f"{path}/{name}.pt",
)


def compute_stats(name, tensors):
tensors = {"tensor": tensors} if not isinstance(tensors, dict) else tensors
stats = {}

for key, tensor in tensors.items():
if tensor.dtype == torch.long or tensor.dtype == torch.int or tensor.dtype == torch.bool:
continue

stats[key] = {}
stats[key] = {
# "mean": tensor.cpu().mean().item(),
"mean": tensor.detach().mean().item(),
"std": tensor.detach().std().item(),
# "data": tensor.detach().cpu().tolist()
# "var": tensor.var().item(),
"norm": tensor.detach().norm().item(),
# "min": tensor.min().item(),
# "max": tensor.max().item(),
"amax": tensor.detach().amax().item(),
}
tensors = {"tensor": tensors} if not isinstance(tensors, dict) else tensors
stats = {}

for key, tensor in tensors.items():
if tensor.dtype == torch.long or tensor.dtype == torch.int or tensor.dtype == torch.bool:
continue
# NOTE: now all reduce mean this across tp ranks
# tp_group = parallel_context.tp_pg
# for metric_name, metric_value in stats[key].items():
# stats[key][metric_name] = torch.tensor(metric_value, device=tensor.device, dtype=tensor.dtype)
# dist.all_reduce(stats[key][metric_name], op=ReduceOp.MAX, group=tp_group)

stats[key] = {}
for metric in metrics:
stats[key][metric] = NAME_TO_FUNC[metric](tensor)
# tp_rank = dist.get_rank(group=tp_group)
# stats[key][f"data:tp_{tp_rank}"] = tensor.detach().cpu().tolist()

# NOTE: now all reduce mean this across tp ranks
tp_group = parallel_context.tp_pg
for metric_name, metric_value in stats[key].items():
stats[key][metric_name] = torch.tensor(metric_value, device=tensor.device, dtype=tensor.dtype)
dist.all_reduce(stats[key][metric_name], op=ReduceOp.MAX, group=tp_group)
return stats[list(stats.keys())[0]] if len(stats) == 1 else stats

return stats[list(stats.keys())[0]] if len(stats) == 1 else stats

def track_weight_and_grad_stats(
name: str, module: nn.Module, parallel_context: ParallelContext, save_path: Optional[Path] = None
):
logs: Dict[str, Dict[str, float]] = {}

if name not in logs:
Expand All @@ -55,9 +74,10 @@ def _save_output_stats(module: nn.Module, input: torch.Tensor, output: torch.Ten
for param_name in param_names:
if hasattr(module, param_name):
param = getattr(module, param_name)
stats = compute_stats(param.data)
stats = compute_stats(name, param.data)
if stats is not None:
logs[name][param_name] = stats
save_tensor(f"{name}.{param_name}", param.data, path=f"{save_path}/weights/")

inputs = input if isinstance(input, tuple) else (input,)
outputs = output if isinstance(output, tuple) else (output,)
Expand All @@ -67,76 +87,95 @@ def _save_output_stats(module: nn.Module, input: torch.Tensor, output: torch.Ten
if inp.dtype == torch.long:
# NOTE: this is input ids in transformers
continue
stats = compute_stats(inp)
stats = compute_stats(name, inp)
if stats is not None:
logs[name][f"input:{i}"] = stats
elif len(inputs) == 1:
stats = compute_stats(inputs[0])
stats = compute_stats(name, inputs[0])
if stats is not None:
logs[name]["input"] = stats

if len(outputs) > 1:
for i, out in enumerate(outputs):
stats = compute_stats(out)
stats = compute_stats(name, out)
if name is None or name == "":
assert 1 == 1

if stats is not None:
logs[name][f"output:{i}"] = stats
save_tensor(name, out, path=f"{save_path}/acts/")
elif len(outputs) == 1:
stats = compute_stats(outputs[0])
# if name is None or name == "":
# assert 1 == 1

stats = compute_stats(name, outputs[0])
if stats is not None:
logs[name]["output"] = stats
try:
save_tensor(name, outputs[0], path=f"{save_path}/acts/")
except:
assert 1 == 1

def _save_grad_stats(module: nn.Linear, grad_input, grad_output: torch.Tensor):
if isinstance(grad_output, tuple):
for i, grad in enumerate(grad_output):
if grad is None:
continue

stats = compute_stats(grad)
stats = compute_stats(name, grad)
if stats is not None:
logs[name][f"grad_output:{i}"] = stats
else:
stats = compute_stats(grad_output)
stats = compute_stats(name, grad_output)
if stats is not None:
logs[name]["grad_output"] = stats

if isinstance(grad_input, tuple):
for i, grad in enumerate(grad_input):
if grad is not None:
stats = compute_stats(grad)
stats = compute_stats(name, grad)
if stats is not None:
logs[name][f"grad_input:{i}"] = stats
else:
if grad_input is not None:
stats = compute_stats(grad_input)
stats = compute_stats(name, grad_input)
if stats is not None:
logs[name]["grad_input"] = stats

handles = []
handles.append(module.register_forward_hook(_save_output_stats))
handles.append(module.register_backward_hook(_save_grad_stats))
# handles.append(module.register_backward_hook(_save_grad_stats))
return logs, handles


def monitor_model(
model: NanotronModel,
save_weights: bool = False,
save_grads: bool = False,
save_acts: bool = False,
parallel_context: Optional[ParallelContext] = None,
) -> Tuple[Dict[str, Union[torch.Tensor, float]], List]:
def monitor_nanotron_model(run_name: str, model: NanotronModel, parallel_context: Optional[ParallelContext] = None):
assert parallel_context is not None
assert isinstance(constants.GLOBAL_STEP, int)

save_path = f"{MONITOR_STATE_PATH}/{run_name}/{constants.GLOBAL_STEP}"

def get_leaf_modules(module: nn.Module) -> List[Tuple[str, nn.Module]]:
"""
Return all the leaf modules (modules without any child modules) in a PyTorch module.
"""
leaf_modules = []
for n, m in module.named_modules():
if not list(m.children()):
leaf_modules.append((n, m))
return leaf_modules

logs = {}
handles = []
# leaf_modules = get_leaf_modules(model)
leaf_modules = [(name, module) for name, module in model.named_modules()]

for name, module in leaf_modules:
module_logs, module_handles = track_weight_and_grad_stats(
name=name,
module=module,
save_weights=save_weights,
save_grads=save_grads,
save_acts=save_acts,
parallel_context=parallel_context,
# save_tensor=True,
save_path=save_path,
)
logs.update(module_logs)
handles.extend(module_handles)
Expand All @@ -145,7 +184,7 @@ def monitor_model(


def convert_logs_to_flat_logs(
logs: Dict[str, Dict[str, Dict[str, Union[torch.Tensor, float]]]]
logs: Dict[str, Dict[str, Dict[str, Union[torch.Tensor, float]]]],
) -> Dict[str, Union[torch.Tensor, float]]:
flat_logs = {}
for module_name, components in logs.items():
Expand Down
Loading

0 comments on commit 74bbff2

Please sign in to comment.