Skip to content

Commit

Permalink
refactor: add missing type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
mali-git committed Jul 15, 2024
1 parent 6319ff0 commit 2715fbb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/modalities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.distributed as dist
from pydantic import ValidationError
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.types import Number

from modalities.exceptions import TimeRecorderStateError
from modalities.running_env.fsdp.reducer import Reducer
Expand Down Expand Up @@ -57,11 +58,11 @@ def format_metrics_to_gb(item):
return metric_num


def get_local_number_of_trainable_parameters(model: torch.nn.Module):
def get_local_number_of_trainable_parameters(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_total_number_of_trainable_parameters(model: FSDP):
def get_total_number_of_trainable_parameters(model: FSDP) -> Number:
num_params = get_local_number_of_trainable_parameters(model)
num_params_tensor = torch.tensor(num_params).cuda()
dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM)
Expand Down

0 comments on commit 2715fbb

Please sign in to comment.