Skip to content

Commit

Permalink
refactor: calling BatchProgress only Progress from now on
Browse files Browse the repository at this point in the history
  • Loading branch information
le1nux committed Sep 11, 2024
1 parent bcc20f3 commit 872d4a0
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 14 deletions.
18 changes: 9 additions & 9 deletions src/modalities/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from modalities.batch import DatasetBatch, EvaluationResultBatch, InferenceResultBatch
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.messages import BatchProgressUpdate, ExperimentStatus, MessageTypes
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
from modalities.models.model import model_predict_batch
from modalities.running_env.fsdp.reducer import Reducer
Expand All @@ -19,16 +19,16 @@ class Evaluator:

def __init__(
self,
batch_progress_publisher: MessagePublisher[BatchProgressUpdate],
progress_publisher: MessagePublisher[ProgressUpdate],
evaluation_result_publisher: MessagePublisher[EvaluationResultBatch],
) -> None:
"""Initializes the Evaluator class.
Args:
batch_progress_publisher (MessagePublisher[BatchProgressUpdate]): Publisher for batch progress updates
progress_publisher (MessagePublisher[ProgressUpdate]): Publisher for progress updates
evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): Publisher for evaluation results
"""
self.batch_progress_publisher = batch_progress_publisher
self.progress_publisher = progress_publisher
self.evaluation_result_publisher = evaluation_result_publisher

def evaluate_batch(
Expand Down Expand Up @@ -79,7 +79,7 @@ def evaluate(
cumulated_loss = torch.zeros(3).to(device)

Evaluator._publish_progress(
batch_progress_publisher=self.batch_progress_publisher,
progress_publisher=self.progress_publisher,
num_eval_steps_done=0, # Reset progress bar
dataloader_tag=data_loader.dataloader_tag,
)
Expand All @@ -98,7 +98,7 @@ def evaluate(
thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor)

Evaluator._publish_progress(
batch_progress_publisher=self.batch_progress_publisher,
progress_publisher=self.progress_publisher,
num_eval_steps_done=batch_id + 1,
dataloader_tag=data_loader.dataloader_tag,
)
Expand Down Expand Up @@ -138,16 +138,16 @@ def evaluate(

@staticmethod
def _publish_progress(
batch_progress_publisher: MessagePublisher[BatchProgressUpdate],
progress_publisher: MessagePublisher[ProgressUpdate],
num_eval_steps_done: int,
dataloader_tag: str,
):
payload = BatchProgressUpdate(
payload = ProgressUpdate(
num_steps_done=num_eval_steps_done,
experiment_status=ExperimentStatus.EVALUATION,
dataloader_tag=dataloader_tag,
)
batch_progress_publisher.publish_message(payload=payload, message_type=MessageTypes.BATCH_PROGRESS_UPDATE)
progress_publisher.publish_message(payload=payload, message_type=MessageTypes.BATCH_PROGRESS_UPDATE)

@staticmethod
def _publish_evaluation_result(
Expand Down
4 changes: 2 additions & 2 deletions src/modalities/logging_broker/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class ExperimentStatus(Enum):


@dataclass
class BatchProgressUpdate:
"""Object holding the state of the current batch computation progress."""
class ProgressUpdate:
"""Object holding the state of the current batch / step computation progress."""

num_steps_done: int
# Note: in case of ExperimentState.TRAIN, dataset_batch_id=global_train_batch_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from modalities.config.config import WandbMode
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.subscriber_impl.batch_progress_subscriber import (
from modalities.logging_broker.subscriber_impl.progress_subscriber import (
DummyProgressSubscriber,
RichProgressSubscriber,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def progress_publisher_mock():
def trainer(progress_publisher_mock, gradient_clipper_mock):
return Trainer(
global_rank=int(os.getenv("RANK")),
batch_progress_publisher=progress_publisher_mock,
progress_publisher=progress_publisher_mock,
evaluation_result_publisher=progress_publisher_mock,
gradient_acc_steps=1,
gradient_clipper=gradient_clipper_mock,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_evaluate_cpu(
llm_data_loader_mock.batch_size = batch_size

evaluator = Evaluator(
batch_progress_publisher=progress_publisher_mock,
progress_publisher=progress_publisher_mock,
evaluation_result_publisher=progress_publisher_mock,
)

Expand Down

0 comments on commit 872d4a0

Please sign in to comment.