diff --git a/tests/end2end_tests/test_fsdp_warmstart.py b/tests/end2end_tests/test_fsdp_warmstart.py index 044f18af..f258f9d6 100644 --- a/tests/end2end_tests/test_fsdp_warmstart.py +++ b/tests/end2end_tests/test_fsdp_warmstart.py @@ -2,7 +2,7 @@ import os import tempfile from pathlib import Path -from typing import List +from typing import Any, Dict, List import pytest import torch @@ -37,6 +37,9 @@ def consume_message(self, message: Message[EvaluationResultBatch]): """Consumes a message from a message broker.""" self.message_list.append(message) + def consume_dict(self, mesasge_dict: Dict[str, Any]): + pass + class SaveAllResultSubscriberConfig(BaseModel): pass @@ -126,7 +129,7 @@ def test_warm_start(self): # we collect the loss values from rank 0 and store them in the temporary experiment folder if dist.get_rank() == 0: messages_0: List[Message[EvaluationResultBatch]] = components_0.evaluation_subscriber.message_list - loss_scores_0 = TestWarmstart.get_loss_scores(messages_0, "CLMCrossEntropyLoss average") + loss_scores_0 = TestWarmstart.get_loss_scores(messages_0, "train loss avg") with open(loss_values_experiment_0_path, "w") as f: json.dump(loss_scores_0, f) @@ -146,7 +149,7 @@ def test_warm_start(self): # and store them in the temporary experiment folder if dist.get_rank() == 0: messages_1: List[Message[EvaluationResultBatch]] = components_1.evaluation_subscriber.message_list - loss_scores_1 = TestWarmstart.get_loss_scores(messages_1, "CLMCrossEntropyLoss average") + loss_scores_1 = TestWarmstart.get_loss_scores(messages_1, "train loss avg") with open(loss_values_experiment_1_path, "w") as f: json.dump(loss_scores_1, f)