diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index f20488ee2..0b00c0800 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -302,7 +302,7 @@ which are common to all optimizers (and most have to do with learning rate sched | `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` | | `min_lr_ratio` | The minimum learning rate ratio. | `0.1` | | `warmup` | Warmup fraction or number of steps | `0.01` | -| `stable` | Stable fraction or number of steps | `0.0` | +| `decay` | Decay fraction or number of steps | `None` | | `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` | | `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` | diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index d814a6b64..7b684efeb 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -26,8 +26,8 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): """The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]""" warmup: float = 0.01 """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" - stable: float = 0.00 - """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" + decay: Optional[float] = None + """fraction of training steps to use as decay, or steps to use. None means full decay""" rewarmup: float = 0.0 "If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup." cooldown: Optional[float] = None @@ -174,8 +174,12 @@ def lr_scheduler(self, num_train_steps): schedules.append(warmup) boundaries.append(start + warmup_steps) - stable_steps = _convert_ratio_or_steps(self.stable, cycle_steps) - lr_decay_steps = cycle_steps - stable_steps - warmup_steps + lr_decay_steps = ( + _convert_ratio_or_steps(self.decay, cycle_steps) + if self.decay is not None + else cycle_steps - warmup_steps + ) + stable_steps = cycle_steps - warmup_steps - lr_decay_steps if stable_steps != 0: stable = optax.constant_schedule(self.learning_rate) diff --git a/tests/test_optimizer_config.py b/tests/test_optimizer_config.py index 9c5b91d7c..70737df7c 100644 --- a/tests/test_optimizer_config.py +++ b/tests/test_optimizer_config.py @@ -8,11 +8,10 @@ def test_no_stable_weirdness(): learning_rate=2e-6, # 2x10^-6 weight_decay=0.0, warmup=0.03, - stable=0.0, min_lr_ratio=0.0, lr_schedule="linear", max_grad_norm=None, - haps=None, + cycles=None, weight_decay_modules=None, default_weight_decay_mask=None, ) @@ -33,10 +32,8 @@ def test_constant_schedule(): learning_rate=1e-3, weight_decay=0.0, warmup=0.0, - stable=0.0, min_lr_ratio=1.0, # No decay lr_schedule="constant", - haps=None, cycles=None, ) @@ -52,10 +49,8 @@ def test_warmup_and_cosine_decay(): learning_rate=1e-2, weight_decay=0.0, warmup=0.1, # 10% of steps - stable=0.0, min_lr_ratio=0.1, lr_schedule="cosine", - haps=None, cycles=None, ) @@ -75,7 +70,6 @@ def test_linear_schedule_with_cycles(): learning_rate=5e-4, weight_decay=0.0, warmup=50, - stable=0.0, min_lr_ratio=0.2, lr_schedule="linear", cycles=2, @@ -105,30 +99,33 @@ def test_linear_schedule_with_cycles(): assert np.isclose(sched_fn(999), 0.2 * 5e-4, atol=1e-5) -def test_haps_schedule(): +def test_wsds_schedule(): optimizer = AdamConfig( learning_rate=1e-3, weight_decay=0.0, warmup=0.0, - stable=0.0, + decay=0.1, min_lr_ratio=0.1, lr_schedule="cosine", - haps=[300, 700], + cycles=[300, 700], ) sched_fn = optimizer.lr_scheduler(1000) - # Before first haps + # First cycle assert np.isclose(sched_fn(0), 1e-3) + assert np.isclose(sched_fn(269), 1e-3) + assert sched_fn(271) < 1e-3 - # First haps + # Second cycle assert np.isclose(sched_fn(300), 1e-3) + assert np.isclose(sched_fn(659), 1e-3) + assert sched_fn(661) < 1e-3 - # After first haps - assert sched_fn(301) < 1e-3 - - # Before second haps - assert sched_fn(699) < sched_fn(301) + # Thrid cycle + assert np.isclose(sched_fn(701), 1e-3) + assert np.isclose(sched_fn(969), 1e-3) + assert sched_fn(971) < 1e-3 def test_inv_sqrt_decay_schedule(): @@ -136,10 +133,9 @@ def test_inv_sqrt_decay_schedule(): learning_rate=1e-3, weight_decay=0.0, warmup=0.1, - stable=0.0, min_lr_ratio=0.1, lr_schedule="inv_sqrt", - haps=None, + cycles=None, ) sched_fn = optimizer.lr_scheduler(100_000) @@ -157,7 +153,6 @@ def test_rewarmup_schedule(): learning_rate=1e-2, weight_decay=0.0, warmup=0.2, # 20% of cycle - stable=0.0, min_lr_ratio=0.2, lr_schedule="linear", cycles=2,