diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index c7c1a5285..08394d22a 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -766,6 +766,10 @@ def _validate_and_set_defaults(self): if self.per_device_eval_parallelism == -1: self.per_device_eval_parallelism = self.per_device_parallelism + if self.replica_dcn_axis_size == -1: + self.replica_dcn_axis_size = self.num_slices + logger.info(f"Setting replica_dcn_axis_size to {self.replica_dcn_axis_size}") + class AllConfig(Protocol): trainer: TrainerConfig