Skip to content

Commit

Permalink
add cycle_length (#825)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Nov 25, 2024
1 parent 574f933 commit a0fc09d
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 43 deletions.
60 changes: 36 additions & 24 deletions docs/Configuration-Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,21 +295,21 @@ All optimizers in Levanter are based on the [levanter.optim.OptimizerConfig][] d
which are common to all optimizers (and most have to do with learning rate scheduling):


| Parameter | Description | Default |
|-----------------|-----------------------------------------------------------------------|----------|
| `weight_decay` | The weight decay. | `0.0` |
| `learning_rate` | The learning rate. | `1e-4` |
| `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` |
| `min_lr_ratio` | The minimum learning rate ratio. | `0.1` |
| `warmup` | Warmup fraction or number of steps | `0.01` |
| `decay` | Decay fraction or number of steps | `None` |
| `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` |
| `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` |

By default, Levanter uses a cosine learning rate schedule with a warmup. The learning rate is decayed to
| Parameter | Description | Default |
|-----------------|-------------------------------------------------------------------------------|----------|
| `weight_decay` | The weight decay. | `0.0` |
| `learning_rate` | The learning rate. | `1e-4` |
| `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` |
| `min_lr_ratio` | The minimum learning rate ratio. | `0.1` |
| `warmup` | Warmup fraction or number of steps | `0.01` |
| `decay` | Decay fraction or number of steps | `None` |
| `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` |
| `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` |
| `cycle_length` | How long the cycles should be (as an int, fraction), or list of cycle lengths | `None` |

By default, Levanter uses a cosine learning rate decay with warmup. The learning rate is decayed to
`min_lr_ratio * learning_rate` over the course of the training run. This is a fairly standard default for LLM training.


#### Learning Rate Schedules

The `lr_schedule` parameter specifies the learning rate schedule. The following schedules are supported:
Expand All @@ -328,8 +328,11 @@ By default, there is only one cycle, and Levanter's LR schedule looks like this:
[warmup] -> [stable] -> [decay]
```

But you can specify more with the `cycles` parameter. If you specify an int for `cycles`, the
learning rate will cycle through the schedule `cycles` times. Levanter's LR schedule looks like this:
But you can specify more with either the `cycles` or `cycle_length` parameters.
If you want to use a learning rate schedule with cycles, you can specify the number of cycles with the `cycles`
or `cycle_length` parameters. The LR will be decayed to `min_lr_ratio * learning_rate` at the end of each cycle.
With cycles, Levanter's LR schedule looks like this:


```
[warmup] -> [stable] -> [decay] -> {[rewarmup] -> [stable] -> [decay]} x (cycles - 1)
Expand All @@ -348,27 +351,37 @@ Here's what the phases mean:
* `decay`: The decay period. The LR will decay to `min_lr_ratio * learning_rate` over this period.
* `rewarmup`: The re-warmup period. If using cycles, the LR will be re-warmed from the final value of the previous cycle back to the peak value of the next cycle.

Also note that if *rewarmup* is 0, there will be no rewarmup period, meaning the LR will jump
back to the max LR. This is the default, and works surprisingly well. In addition, the stable
and decay phase of the first cycle will generally be different from the stable and decay phase of the other cycles,
since rewarmup and warmup are typically different.

`stable` cannot be specified directly. It is the period between `warmup` and `decay` in the first cycle, and the period
between `rewarmup` and `decay` in subsequent cycles. By default, there is no `stable` period.

All of these parameters can be specified in terms of a fraction of the total number of steps of a cycle or as an absolute number of
steps.

If you want to use a learning rate schedule with cycles, you can specify the number of cycles with the `cycles`
parameter. The LR will be decayed to `min_lr_ratio * learning_rate` at the end of each cycle.
Here are what the `cycles` and `cycle_length` parameters mean:

* `cycle_length`: If you specify an int or float for `cycle_length`, the learning rate will cycle through the
schedule with the specified length. This is equivalent to specifying `cycles` as `num_train_steps / cycle_length`.
If `cycle_length` is a float < 1.0, it is interpreted as a fraction of the total number of steps.
If you specify a list of ints, the learning rate will cycle through the schedule with the specified cycle lengths.
* `cycles`: If you specify an int for `cycles`, the learning rate will cycle through the schedule `cycles` times.
If you specify a list of ints, the learning rate will cycle through the schedule with the specified steps as the minima
of the cycles.

It is an error to specify both `cycles` and `cycle_length`.

You can also specify `cycles` as a list, e.g. `[10000, 25000, 50000]`. In this case,
`cycles` is interpreted as the minima for the cycles, with the first and final steps being cycle minima as well.
`cycles` as an int is equivalent to list `cycles` with the low points evenly spaced at
`[num_train_steps / (c + 1)]`.

Also note that if *rewarmup* is 0, there will be no rewarmup period, meaning the LR will jump
back to the max LR. This is the default. In addition, the stable
and decay phase of the first cycle will generally be different from the stable and decay phase of the other cycles,
since rewarmup and warmup are typically different.

See [our paper on WSD-S](https://arxiv.org/pdf/2410.05192) for more information on cyclic LR schedules for training LLMs
with short or no rewarmup.


### AdamConfig

Additionally, [levanter.optim.AdamConfig][] has the following fields:
Expand All @@ -381,7 +394,6 @@ Additionally, [levanter.optim.AdamConfig][] has the following fields:
| `max_grad_norm` | The maximum gradient norm (for clipping). | `1.0` |



## LM Model Config

[levanter.models.lm_model.LmConfig][] is a Draccus "choice class" that acts as a base class for all autoregressive
Expand Down
66 changes: 48 additions & 18 deletions src/levanter/optim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import draccus
import equinox as eqx
import jax
import numpy as np
import optax
from jax import numpy as jnp

Expand All @@ -24,16 +25,18 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):

min_lr_ratio: float = 0.1
"""The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]"""
warmup: float = 0.01
warmup: int | float = 0.01
"""fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup"""
decay: Optional[float] = None
decay: int | float | None = None
"""fraction of training steps to use as decay, or steps to use. None means full decay"""
rewarmup: float = 0.0
rewarmup: int | float = 0.0
"If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup."
cooldown: Optional[float] = None
"""Deprecated, as its semantics are confusing."""
cycles: int | None | list[int] = None
""" Number of cycles to use. If None or 1, use a single cycle. Overriden by haps."""
cycle_length: int | float | None | list[int] = None
""" Length of cycle. If <= 1, it is treated as a fraction of the total number of steps. None is equivalent to 1.0."""
cycles: int | list[int] | None = None
"""Number of cycles or a list of cycle endpoints. Can use at most one of cycle_length, cycles, or haps."""

lr_schedule: str = "cosine" # constant, cosine, linear
haps: Optional[list[int]] = None
Expand Down Expand Up @@ -145,16 +148,13 @@ def mask_fn(model):
def lr_scheduler(self, num_train_steps):
if self.cooldown is not None:
warnings.warn("cooldown is deprecated. Just use the normal schedule.", DeprecationWarning)
cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps)
cooldown_steps = _convert_frac_or_steps(self.cooldown, num_train_steps)
else:
cooldown_steps = 0

total_main_steps = num_train_steps - cooldown_steps
cooldown_points = self._get_cycle_minima(total_main_steps)

cooldown_points.insert(0, 0)
cooldown_points.append(num_train_steps)

min_lr = self.learning_rate * self.min_lr_ratio

schedules = []
Expand All @@ -165,17 +165,17 @@ def lr_scheduler(self, num_train_steps):
for cycle, (start, end) in enumerate(zip(cooldown_points[:-1], cooldown_points[1:])):
cycle_steps = end - start
if cycle == 0: # warmup
warmup_steps = _convert_ratio_or_steps(self.warmup, cycle_steps)
warmup_steps = _convert_frac_or_steps(self.warmup, cycle_steps)
else:
warmup_steps = _convert_ratio_or_steps(self.rewarmup, cycle_steps)
warmup_steps = _convert_frac_or_steps(self.rewarmup, cycle_steps)

if warmup_steps != 0:
warmup = optax.linear_schedule(previous_end, self.learning_rate, warmup_steps)
schedules.append(warmup)
boundaries.append(start + warmup_steps)

lr_decay_steps = (
_convert_ratio_or_steps(self.decay, cycle_steps)
_convert_frac_or_steps(self.decay, cycle_steps)
if self.decay is not None
else cycle_steps - warmup_steps
)
Expand Down Expand Up @@ -218,7 +218,31 @@ def lr_scheduler(self, num_train_steps):
return schedule

def _get_cycle_minima(self, total_main_steps):
if self.haps is not None:
if self.cycle_length is not None:
if self.cycles is not None:
raise ValueError("Can't use both cycle_length and cycles.")
if self.haps is not None:
warnings.warn("haps is deprecated. Use cycles instead.", DeprecationWarning)
raise ValueError("Can't use both cycle_length and haps.")

if isinstance(self.cycle_length, int | float):
cycle_length = _convert_frac_or_steps(self.cycle_length, total_main_steps)
cooldown_points = [i * cycle_length for i in range(1, total_main_steps // cycle_length)]
if total_main_steps % cycle_length != 0:
warnings.warn(
"Cycle length does not divide total number of steps. The last cycle will be shorter."
)

elif isinstance(self.cycle_length, list):
lengths = np.array(self.cycle_length)
steps = np.cumsum(lengths)
if steps[-1] > total_main_steps:
raise ValueError(f"Cycle lengths exceed total number of steps: {steps[-1]} > {total_main_steps}")
cooldown_points = steps.tolist()
else:
raise ValueError("Invalid cycle_length. Must be a fraction, number of steps, or a list of steps.")

elif self.haps is not None:
warnings.warn("haps is deprecated. Use cycles instead.", DeprecationWarning)
cooldown_points = list(self.haps)
elif isinstance(self.cycles, int):
Expand All @@ -228,6 +252,9 @@ def _get_cycle_minima(self, total_main_steps):
cooldown_points = list(self.cycles)
else:
cooldown_points = []

cooldown_points.insert(0, 0)
cooldown_points.append(total_main_steps)
return cooldown_points


Expand All @@ -247,11 +274,14 @@ def schedule(count):
return schedule


def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int):
if ratio_or_steps < 1.0:
return int(ratio_or_steps * num_train_steps)
else:
return int(ratio_or_steps)
def _convert_frac_or_steps(frac_or_steps: float | int, num_train_steps: int):
# if it's greater than 1, it must be a whole number of steps
if frac_or_steps < 0.0 or (frac_or_steps > 1.0 and frac_or_steps % 1 != 0):
raise ValueError(f"Invalid fraction {frac_or_steps}. Must be between 0 and 1. You can also use (whole) steps.")
if frac_or_steps <= 1.0:
return int(frac_or_steps * num_train_steps)

return int(frac_or_steps)


@dataclass
Expand Down
63 changes: 62 additions & 1 deletion tests/test_optimizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_wsds_schedule():
assert np.isclose(sched_fn(659), 1e-3)
assert sched_fn(661) < 1e-3

# Thrid cycle
# Third cycle
assert np.isclose(sched_fn(701), 1e-3)
assert np.isclose(sched_fn(969), 1e-3)
assert sched_fn(971) < 1e-3
Expand Down Expand Up @@ -182,3 +182,64 @@ def test_rewarmup_schedule():
# Final decay phase
assert sched_fn(999 - 1) > sched_fn(999)
assert np.isclose(sched_fn(999), 0.2e-2, atol=1e-4) # End of second decay


def test_linear_schedule_with_cycle_length():
optimizer = AdamConfig(
learning_rate=5e-4,
weight_decay=0.0,
warmup=50,
min_lr_ratio=0.2,
lr_schedule="linear",
cycle_length=500,
)

sched_fn = optimizer.lr_scheduler(1000)

# Warmup phase
assert np.isclose(sched_fn(0), 0.0)
assert np.isclose(sched_fn(50), 5e-4)

num_main_steps = 1000

# First cycle decay
assert np.isclose(sched_fn(499), 0.2 * 5e-4, atol=1e-5)

# Second cycle starts
assert np.isclose(sched_fn(500), 5e-4)

# midway through second cycle
midpoint = 500 - 1 + num_main_steps // 4
assert np.isclose(sched_fn(midpoint), (5e-4 + 0.2 * 5e-4) / 2, atol=1e-5)

# Final value
assert np.isclose(sched_fn(999), 0.2 * 5e-4, atol=1e-5)


def test_wsds_schedule_with_cycle_points():
optimizer = AdamConfig(
learning_rate=1e-3,
weight_decay=0.0,
warmup=0.0,
decay=0.1,
min_lr_ratio=0.1,
lr_schedule="cosine",
cycle_length=[300, 400],
)

sched_fn = optimizer.lr_scheduler(1000)

# First cycle
assert np.isclose(sched_fn(0), 1e-3)
assert np.isclose(sched_fn(269), 1e-3)
assert sched_fn(271) < 1e-3

# Second cycle
assert np.isclose(sched_fn(300), 1e-3)
assert np.isclose(sched_fn(659), 1e-3)
assert sched_fn(661) < 1e-3

# Third cycle
assert np.isclose(sched_fn(701), 1e-3)
assert np.isclose(sched_fn(969), 1e-3)
assert sched_fn(971) < 1e-3

0 comments on commit a0fc09d

Please sign in to comment.