From 872d4a07181e51e6c21f916168e57e74f568c948 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 11 Sep 2024 15:12:24 +0200 Subject: [PATCH] refactor: calling BatchProgress only Progress from now on --- src/modalities/evaluator.py | 18 +++++++++--------- src/modalities/logging_broker/messages.py | 4 ++-- ...ss_subscriber.py => progress_subscriber.py} | 0 .../subscriber_impl/subscriber_factory.py | 2 +- tests/conftest.py | 2 +- tests/test_evaluator.py | 2 +- 6 files changed, 14 insertions(+), 14 deletions(-) rename src/modalities/logging_broker/subscriber_impl/{batch_progress_subscriber.py => progress_subscriber.py} (100%) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index e47aa124..c564e14f 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -6,7 +6,7 @@ from modalities.batch import DatasetBatch, EvaluationResultBatch, InferenceResultBatch from modalities.dataloader.dataloader import LLMDataLoader -from modalities.logging_broker.messages import BatchProgressUpdate, ExperimentStatus, MessageTypes +from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate from modalities.logging_broker.publisher import MessagePublisher from modalities.models.model import model_predict_batch from modalities.running_env.fsdp.reducer import Reducer @@ -19,16 +19,16 @@ class Evaluator: def __init__( self, - batch_progress_publisher: MessagePublisher[BatchProgressUpdate], + progress_publisher: MessagePublisher[ProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], ) -> None: """Initializes the Evaluator class. Args: - batch_progress_publisher (MessagePublisher[BatchProgressUpdate]): Publisher for batch progress updates + progress_publisher (MessagePublisher[ProgressUpdate]): Publisher for progress updates evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): Publisher for evaluation results """ - self.batch_progress_publisher = batch_progress_publisher + self.progress_publisher = progress_publisher self.evaluation_result_publisher = evaluation_result_publisher def evaluate_batch( @@ -79,7 +79,7 @@ def evaluate( cumulated_loss = torch.zeros(3).to(device) Evaluator._publish_progress( - batch_progress_publisher=self.batch_progress_publisher, + progress_publisher=self.progress_publisher, num_eval_steps_done=0, # Reset progress bar dataloader_tag=data_loader.dataloader_tag, ) @@ -98,7 +98,7 @@ def evaluate( thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) Evaluator._publish_progress( - batch_progress_publisher=self.batch_progress_publisher, + progress_publisher=self.progress_publisher, num_eval_steps_done=batch_id + 1, dataloader_tag=data_loader.dataloader_tag, ) @@ -138,16 +138,16 @@ def evaluate( @staticmethod def _publish_progress( - batch_progress_publisher: MessagePublisher[BatchProgressUpdate], + progress_publisher: MessagePublisher[ProgressUpdate], num_eval_steps_done: int, dataloader_tag: str, ): - payload = BatchProgressUpdate( + payload = ProgressUpdate( num_steps_done=num_eval_steps_done, experiment_status=ExperimentStatus.EVALUATION, dataloader_tag=dataloader_tag, ) - batch_progress_publisher.publish_message(payload=payload, message_type=MessageTypes.BATCH_PROGRESS_UPDATE) + progress_publisher.publish_message(payload=payload, message_type=MessageTypes.BATCH_PROGRESS_UPDATE) @staticmethod def _publish_evaluation_result( diff --git a/src/modalities/logging_broker/messages.py b/src/modalities/logging_broker/messages.py index dea10659..2c0a66c3 100644 --- a/src/modalities/logging_broker/messages.py +++ b/src/modalities/logging_broker/messages.py @@ -29,8 +29,8 @@ class ExperimentStatus(Enum): @dataclass -class BatchProgressUpdate: - """Object holding the state of the current batch computation progress.""" +class ProgressUpdate: + """Object holding the state of the current batch / step computation progress.""" num_steps_done: int # Note: in case of ExperimentState.TRAIN, dataset_batch_id=global_train_batch_id diff --git a/src/modalities/logging_broker/subscriber_impl/batch_progress_subscriber.py b/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py similarity index 100% rename from src/modalities/logging_broker/subscriber_impl/batch_progress_subscriber.py rename to src/modalities/logging_broker/subscriber_impl/progress_subscriber.py diff --git a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py index 98cf54a0..c55f4ce0 100644 --- a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py +++ b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py @@ -4,7 +4,7 @@ from modalities.config.config import WandbMode from modalities.dataloader.dataloader import LLMDataLoader -from modalities.logging_broker.subscriber_impl.batch_progress_subscriber import ( +from modalities.logging_broker.subscriber_impl.progress_subscriber import ( DummyProgressSubscriber, RichProgressSubscriber, ) diff --git a/tests/conftest.py b/tests/conftest.py index b2e60f0b..9208c9ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -182,7 +182,7 @@ def progress_publisher_mock(): def trainer(progress_publisher_mock, gradient_clipper_mock): return Trainer( global_rank=int(os.getenv("RANK")), - batch_progress_publisher=progress_publisher_mock, + progress_publisher=progress_publisher_mock, evaluation_result_publisher=progress_publisher_mock, gradient_acc_steps=1, gradient_clipper=gradient_clipper_mock, diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 260bb7d4..271ddd5d 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -25,7 +25,7 @@ def test_evaluate_cpu( llm_data_loader_mock.batch_size = batch_size evaluator = Evaluator( - batch_progress_publisher=progress_publisher_mock, + progress_publisher=progress_publisher_mock, evaluation_result_publisher=progress_publisher_mock, )