From 3bdf5ba489ad20659c83e694d6354d2027627b4e Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Sun, 26 May 2024 13:14:10 -0700 Subject: [PATCH] clean up --- src/levanter/trainer.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 62ee4f9eb..328f20f6c 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -9,7 +9,21 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Protocol, Sequence, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, +) import equinox as eqx import jax @@ -536,9 +550,9 @@ class TrainerConfig: tensor_parallel_axes: Optional[List[str]] = None # Axes, if any, to use for tensor parallelism # TODO: in theory we can support tuples of physical axis names, but I don't think anyone actually uses that. - axis_resources: ResourceMapping = field(default_factory=dict) + axis_resources: Mapping[str, Any] = field(default_factory=dict) """mapping from logical axis to physical axis. batch_axis, fsdp_axis, and tensor_parallel_axes are preferred""" - parameter_axis_resources: ResourceMapping = field(default_factory=dict) # overrides axis_mapping for parameter + parameter_axis_resources: Mapping[str, Any] = field(default_factory=dict) # overrides axis_mapping for parameter """logical->physical mapping for parameter/optimizer sharding. fsdp_axis and tensor_parallel_axes are preferred""" """Interchip Interconnect (ICI) & Data Center Networking (DCN) shardings"""