diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 4be8de84..fd139fef 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -287,8 +287,8 @@ def run(self, components: TrainingComponentsInstantiationModel): os.makedirs(experiment_path, exist_ok=True) shutil.copy(self.config_path, experiment_path / self.config_path.name) - evaluation_result_publisher, process_publisher = self.get_logging_publishers( - progress_subscriber=components.batch_progress_subscriber, + evaluation_result_publisher, progress_publisher = self.get_logging_publishers( + progress_subscriber=components.progress_subscriber, results_subscriber=components.evaluation_subscriber, global_rank=components.settings.cuda_env.global_rank, local_rank=components.settings.cuda_env.local_rank, @@ -303,7 +303,7 @@ def run(self, components: TrainingComponentsInstantiationModel): ) trainer = Trainer( global_rank=components.settings.cuda_env.global_rank, - progress_publisher=process_publisher, + progress_publisher=progress_publisher, num_target_steps=components.settings.training_target.num_target_steps, num_target_tokens=components.settings.training_target.num_target_tokens, num_seen_train_steps=components.settings.training_progress.num_seen_steps, @@ -316,7 +316,7 @@ def run(self, components: TrainingComponentsInstantiationModel): # Evaluator evaluator = Evaluator( - progress_publisher=process_publisher, + progress_publisher=progress_publisher, evaluation_result_publisher=evaluation_result_publisher, ) diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 4c133c72..690559b3 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -172,7 +172,7 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel train_dataset: PydanticDatasetIFType train_dataloader: PydanticLLMDataLoaderIFType eval_dataloaders: List[PydanticLLMDataLoaderIFType] - batch_progress_subscriber: PydanticMessageSubscriberIFType + progress_subscriber: PydanticMessageSubscriberIFType evaluation_subscriber: PydanticMessageSubscriberIFType checkpoint_saving: PydanticCheckpointSavingIFType gradient_clipper: PydanticGradientClipperIFType diff --git a/tests/end2end_tests/gpt2_train_num_steps_8.yaml b/tests/end2end_tests/gpt2_train_num_steps_8.yaml index 3f8ecf88..0ca58132 100644 --- a/tests/end2end_tests/gpt2_train_num_steps_8.yaml +++ b/tests/end2end_tests/gpt2_train_num_steps_8.yaml @@ -238,7 +238,7 @@ gradient_clipper: norm_type: P2_NORM max_norm: 1.0 -batch_progress_subscriber: +progress_subscriber: component_key: progress_subscriber variant_key: rich config: diff --git a/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml b/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml index ffc095e7..03ba9c16 100644 --- a/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml +++ b/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml @@ -278,7 +278,7 @@ gradient_clipper: norm_type: P2_NORM max_norm: 1.0 -batch_progress_subscriber: +progress_subscriber: component_key: progress_subscriber variant_key: rich config: diff --git a/tests/test_yaml_configs/config_lorem_ipsum.yaml b/tests/test_yaml_configs/config_lorem_ipsum.yaml index 8d4396f5..96f2f7db 100644 --- a/tests/test_yaml_configs/config_lorem_ipsum.yaml +++ b/tests/test_yaml_configs/config_lorem_ipsum.yaml @@ -302,7 +302,7 @@ gradient_clipper: pass_type: BY_REFERENCE norm_type: P2_NORM -batch_progress_subscriber: +progress_subscriber: component_key: progress_subscriber variant_key: rich config: diff --git a/tutorials/getting_started/example_config.yaml b/tutorials/getting_started/example_config.yaml index 1d06f856..ee68d9ec 100644 --- a/tutorials/getting_started/example_config.yaml +++ b/tutorials/getting_started/example_config.yaml @@ -275,7 +275,7 @@ gradient_clipper: norm_type: P2_NORM max_norm: 1.0 -batch_progress_subscriber: +progress_subscriber: component_key: progress_subscriber variant_key: rich config: diff --git a/tutorials/library_usage/config_lorem_ipsum.yaml b/tutorials/library_usage/config_lorem_ipsum.yaml index ed1c35ab..bd7cd59c 100644 --- a/tutorials/library_usage/config_lorem_ipsum.yaml +++ b/tutorials/library_usage/config_lorem_ipsum.yaml @@ -316,7 +316,7 @@ gradient_clipper: norm_type: P2_NORM max_norm: 1.0 -batch_progress_subscriber: +progress_subscriber: component_key: progress_subscriber variant_key: rich config: diff --git a/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml b/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml index a271bcc5..d4c7ec2f 100644 --- a/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml +++ b/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml @@ -238,7 +238,7 @@ gradient_clipper: norm_type: P2_NORM max_norm: 1.0 -batch_progress_subscriber: +progress_subscriber: component_key: progress_subscriber variant_key: dummy config: {} diff --git a/tutorials/warmstart/configs/pre_training_config.yaml b/tutorials/warmstart/configs/pre_training_config.yaml index c23f0fad..9bf1de87 100644 --- a/tutorials/warmstart/configs/pre_training_config.yaml +++ b/tutorials/warmstart/configs/pre_training_config.yaml @@ -238,7 +238,7 @@ gradient_clipper: norm_type: P2_NORM max_norm: 1.0 -batch_progress_subscriber: +progress_subscriber: component_key: progress_subscriber variant_key: rich config: diff --git a/tutorials/warmstart/configs/warmstart_config.yaml b/tutorials/warmstart/configs/warmstart_config.yaml index f6454f8e..1858d9a1 100644 --- a/tutorials/warmstart/configs/warmstart_config.yaml +++ b/tutorials/warmstart/configs/warmstart_config.yaml @@ -276,7 +276,7 @@ gradient_clipper: norm_type: P2_NORM max_norm: 1.0 -batch_progress_subscriber: +progress_subscriber: component_key: progress_subscriber variant_key: rich config: