Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 26, 2024
1 parent 68ed94e commit 3bdf5ba
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,21 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Protocol, Sequence, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
List,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
)

import equinox as eqx
import jax
Expand Down Expand Up @@ -536,9 +550,9 @@ class TrainerConfig:
tensor_parallel_axes: Optional[List[str]] = None # Axes, if any, to use for tensor parallelism

# TODO: in theory we can support tuples of physical axis names, but I don't think anyone actually uses that.
axis_resources: ResourceMapping = field(default_factory=dict)
axis_resources: Mapping[str, Any] = field(default_factory=dict)
"""mapping from logical axis to physical axis. batch_axis, fsdp_axis, and tensor_parallel_axes are preferred"""
parameter_axis_resources: ResourceMapping = field(default_factory=dict) # overrides axis_mapping for parameter
parameter_axis_resources: Mapping[str, Any] = field(default_factory=dict) # overrides axis_mapping for parameter
"""logical->physical mapping for parameter/optimizer sharding. fsdp_axis and tensor_parallel_axes are preferred"""

"""Interchip Interconnect (ICI) & Data Center Networking (DCN) shardings"""
Expand Down

0 comments on commit 3bdf5ba

Please sign in to comment.