diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 9e31bd39..3206bce0 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -168,7 +168,8 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - do_compile = True + # do_compile = True + do_compile = False # self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) self.split_silu_mul = GLUActivation(config.hidden_act) if do_compile: diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 5a6b797e..816f044d 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -63,7 +63,7 @@ def save( sanity_checks: bool = True, ) -> None: assert isinstance(training_metadata, TrainingMetadata) - assert isinstance(valid_metadata, TrainingMetadata) + assert (valid_metadata is None) or isinstance(valid_metadata, TrainingMetadata) try: if should_save_config: diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 35204899..6c55d5e6 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -64,7 +64,7 @@ class CheckpointMetadata: tp: int dp: int train_meta: TrainingMetadata - valid_meta: TrainingMetadata + valid_meta: Optional[TrainingMetadata] @dataclasses.dataclass @@ -130,7 +130,7 @@ def save_meta(parallel_context: ParallelContext, training_metadata: TrainingMetadata, valid_metadata: TrainingMetadata): assert isinstance(training_metadata, TrainingMetadata) - assert isinstance(valid_metadata, TrainingMetadata) + assert (valid_metadata is None) or isinstance(valid_metadata, TrainingMetadata) if dist.get_rank(parallel_context.world_pg) != 0: return diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index da8d1bd4..65569e7c 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -229,7 +229,7 @@ def __init__( parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path ) assert isinstance(checkpoint_metadata.train_meta, TrainingMetadata) - assert isinstance(checkpoint_metadata.valid_meta, TrainingMetadata) + assert (checkpoint_metadata.valid_meta is None) or isinstance(checkpoint_metadata.valid_meta, TrainingMetadata) log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0) self.metadata: TrainingMetadata = checkpoint_metadata.train_meta @@ -688,7 +688,6 @@ def train_step_logs( if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" - lr = self.lr_scheduler.get_last_lr()[0] log_entries = [