diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index 6d61159bd..c6b3bd783 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -24,7 +24,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): min_lr_ratio: float = 0.1 warmup_ratio: Optional[float] = None # Deprecated. fraction of training steps to use as warmup - """The lr scheduler operates on 4 stages: [warmup] - [stable] - [decay] - [cooldown]""" + """The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]""" warmup: float = 0.01 """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" stable: float = 0.00 @@ -32,6 +32,8 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): cooldown: float = 0.0 """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" lr_schedule: str = "cosine" # constant, cosine, linear + haps: Optional[list[int]] = None + """list of integers indicating pit stop steps. See paper https://openreview.net/pdf?id=RSsavSvAvN""" weight_decay_modules: Optional[list[str] | str] = None """A regex or a list of strings to identify where to mask weight. For nano-GPT, this field can be set as `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`""" @@ -138,22 +140,13 @@ def mask_fn(model): def lr_scheduler(self, num_train_steps): warmup_steps = self._convert_warmup(num_train_steps) - stable_steps = _convert_ratio_or_steps(self.stable, num_train_steps) cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) - lr_decay_steps = num_train_steps - warmup_steps - stable_steps - cooldown_steps - min_lr = self.learning_rate * self.min_lr_ratio + if self.haps is None: + self.haps = [] + self.haps.insert(0, warmup_steps) + self.haps.append(num_train_steps - cooldown_steps) - match self.lr_schedule: - case "constant": - schedule = optax.constant_schedule(self.learning_rate) - case "cosine": - schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio) - case "linear": - schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps) - case "inv_sqrt": - schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000) - case _: - raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}") + min_lr = self.learning_rate * self.min_lr_ratio schedules = [] boundaries = [] @@ -163,18 +156,37 @@ def lr_scheduler(self, num_train_steps): schedules.append(warmup) boundaries.append(warmup_steps) - if stable_steps != 0: - stable = optax.constant_schedule(self.learning_rate) - schedules.append(stable) - boundaries.append(warmup_steps + stable_steps) - - schedules.append(schedule) + for start, end in zip(self.haps[:-1], self.haps[1:]): + cycle_steps = end - start + stable_steps = _convert_ratio_or_steps(self.stable, cycle_steps) + lr_decay_steps = cycle_steps - stable_steps + + if stable_steps != 0: + stable = optax.constant_schedule(self.learning_rate) + schedules.append(stable) + boundaries.append(start + stable_steps) + + match self.lr_schedule: + case "constant": + schedule = optax.constant_schedule(self.learning_rate) + case "cosine": + schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio) + case "linear": + schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps) + case "inv_sqrt": + schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000) + case "inv": + schedule = _inv_decay_schedule(self.learning_rate, min_lr, lr_decay_steps) + case _: + raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}") + + schedules.append(schedule) + boundaries.append(end) if cooldown_steps != 0: final_main_lr = schedule(lr_decay_steps) cooldown = optax.linear_schedule(final_main_lr, min_lr, cooldown_steps) schedules.append(cooldown) - boundaries.append(num_train_steps - cooldown_steps) if len(schedules) > 1: schedule = optax.join_schedules(schedules, boundaries) @@ -197,6 +209,14 @@ def schedule(count): return schedule +def _inv_decay_schedule(lr: float, min_lr: float, decay_steps: int): + def schedule(count): + decay = jnp.minimum(1.0, 1.0 / ((lr / min_lr - 1) * jnp.maximum(count, 1) / decay_steps + 1)) + return jnp.maximum(lr * decay, min_lr) + + return schedule + + def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): if ratio_or_steps < 1.0: return int(ratio_or_steps * num_train_steps)