Skip to content

Commit

Permalink
Merge pull request #186 from Modalities/fix_num_params
Browse files Browse the repository at this point in the history
Fix: Computation of Total Number of Parameters
  • Loading branch information
mali-git authored Jul 15, 2024
2 parents b380a66 + 2715fbb commit 8e79acc
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from modalities.registry.registry import Registry
from modalities.running_env.cuda_env import CudaEnv
from modalities.trainer import Trainer
from modalities.util import compute_number_of_trainable_parameters
from modalities.util import get_total_number_of_trainable_parameters


@click.group()
Expand Down Expand Up @@ -255,7 +255,7 @@ def run(self, components: TrainingComponentsInstantiationModel):
num_ranks=components.settings.cuda_env.world_size,
)
wrapped_model = components.wrapped_model
num_params = compute_number_of_trainable_parameters(wrapped_model)
num_params = get_total_number_of_trainable_parameters(wrapped_model)
components.evaluation_subscriber.consume_dict({"No. Parameters": num_params})
logging.info(f"Training model with {num_params} parameters.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF
from modalities.config.config import PrecisionEnum
from modalities.util import compute_number_of_trainable_parameters
from modalities.util import get_local_number_of_trainable_parameters


class TorchCheckpointLoading(CheckpointLoadingIF):
Expand Down Expand Up @@ -46,7 +46,7 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module:
# set the model to the correct device and precision
# model = model.to(self.precision.value)
print(
f"Model loaded with {compute_number_of_trainable_parameters(model)} trainable parameters from {file_path}"
f"Model loaded with {get_local_number_of_trainable_parameters(model)} trainable parameters from {file_path}"
)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import Any, Dict

import rich
import wandb
import yaml
from rich.console import Group
from rich.panel import Panel

import wandb
from modalities.batch import EvaluationResultBatch
from modalities.config.config import WandbMode
from modalities.logging_broker.messages import Message
Expand Down
7 changes: 4 additions & 3 deletions src/modalities/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.running_env.env_utils import MixedPrecisionSettings
from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory
from modalities.util import compute_number_of_trainable_parameters
from modalities.util import get_local_number_of_trainable_parameters


class ModelFactory:
Expand All @@ -34,7 +34,8 @@ def get_fsdp_wrapped_model(
sharding_strategy: ShardingStrategy,
) -> FSDP:
print(
f"Unsharded number of parameters on rank {dist.get_rank()}: {compute_number_of_trainable_parameters(model)}"
f"Unsharded number of parameters on rank {dist.get_rank()}: "
f"{get_local_number_of_trainable_parameters(model)}"
)
# Here, FSDPTransformerAutoWrapPolicyFactory is hardcoded and should be passed in instead!
# we also might want to have different auto wrap policies later...
Expand All @@ -52,7 +53,7 @@ def get_fsdp_wrapped_model(
)
print(
f"Sharded number of parameters on rank {dist.get_rank()}:"
f"{compute_number_of_trainable_parameters(fsdp_model)}"
f"{get_local_number_of_trainable_parameters(fsdp_model)}"
)

return fsdp_model
Expand Down
4 changes: 2 additions & 2 deletions src/modalities/optimizers/optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF
from modalities.exceptions import OptimizerError
from modalities.models.model import NNModel
from modalities.util import compute_number_of_trainable_parameters
from modalities.util import get_local_number_of_trainable_parameters

OptimizerGroups = List[Dict[str, List[nn.Parameter] | float]]

Expand Down Expand Up @@ -166,7 +166,7 @@ def _assert_completeness_of_optimizer_groups(model: FSDP, optimizer_groups: Opti
checks that the number of parameters in the optimizer groups
sum up to the total number of model parameters as expected
"""
num_params_check = compute_number_of_trainable_parameters(model)
num_params_check = get_local_number_of_trainable_parameters(model)
num_params = sum(p.numel() for optimizer_group in optimizer_groups for p in optimizer_group["params"])
if num_params != num_params_check:
raise OptimizerError(
Expand Down
12 changes: 11 additions & 1 deletion src/modalities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
import torch.distributed as dist
from pydantic import ValidationError
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.types import Number

from modalities.exceptions import TimeRecorderStateError
from modalities.running_env.fsdp.reducer import Reducer
Expand Down Expand Up @@ -56,10 +58,18 @@ def format_metrics_to_gb(item):
return metric_num


def compute_number_of_trainable_parameters(model: torch.nn.Module):
def get_local_number_of_trainable_parameters(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_total_number_of_trainable_parameters(model: FSDP) -> Number:
num_params = get_local_number_of_trainable_parameters(model)
num_params_tensor = torch.tensor(num_params).cuda()
dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM)
total_num_params = num_params_tensor.item()
return total_num_params


class TimeRecorderStates(Enum):
RUNNING = "RUNNING"
STOPPED = "STOPPED"
Expand Down
46 changes: 46 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch

import modalities
import modalities.util
from modalities.batch import DatasetBatch
from modalities.util import get_local_number_of_trainable_parameters, get_total_number_of_trainable_parameters


def configure_dataloader_mock(
Expand All @@ -23,3 +26,46 @@ def configure_dataloader_mock(
llm_data_loader_mock.__len__ = lambda _: num_batches

return llm_data_loader_mock, batches


def test_get_local_number_of_trainable_parameters():
# Create a simple model with trainable parameters
model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2))

# Calculate the expected number of trainable parameters
expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

# Call the function and check the result
assert get_local_number_of_trainable_parameters(model) == expected_params


def test_get_total_number_of_trainable_parameters():
# Create a simple model with trainable parameters
model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2))

# Calculate the expected number of trainable parameters
expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

# Create a mock FSDP model
class MockFSDP:
def __init__(self, model):
self.model = model

fsdp_model = MockFSDP(model)

# Mock the dist.all_reduce function
def mock_all_reduce(tensor, op):
tensor.item = lambda: tensor
return tensor

def mock_cuda(tensor):
return tensor

def mock_get_local_number_of_trainable_parameters(model: MockFSDP):
return get_local_number_of_trainable_parameters(model.model)

modalities.util.get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters
torch.distributed.all_reduce = mock_all_reduce
torch.Tensor.cuda = mock_cuda

assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params
11 changes: 2 additions & 9 deletions tests/test_yaml_configs/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,7 @@ batch_progress_subscriber:
instance_key: eval_dataloaders
pass_type: BY_REFERENCE


evaluation_subscriber:
component_key: results_subscriber
variant_key: wandb
config:
local_rank: ${settings.cuda_env.local_rank}
project: modalities_lorem_ipsum
mode: ONLINE
directory: "."
experiment_id: ${settings.experiment_id}
config_file_path: ${settings.config_file_path}
variant_key: dummy
config: {}

0 comments on commit 8e79acc

Please sign in to comment.