Skip to content

Commit

Permalink
fix: fixed failing warmstart gpu test
Browse files Browse the repository at this point in the history
  • Loading branch information
le1nux committed Aug 4, 2024
1 parent b9bc212 commit 120596d
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/end2end_tests/test_fsdp_warmstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import tempfile
from pathlib import Path
from typing import List
from typing import Any, Dict, List

import pytest
import torch
Expand Down Expand Up @@ -37,6 +37,9 @@ def consume_message(self, message: Message[EvaluationResultBatch]):
"""Consumes a message from a message broker."""
self.message_list.append(message)

def consume_dict(self, mesasge_dict: Dict[str, Any]):
pass


class SaveAllResultSubscriberConfig(BaseModel):
pass
Expand Down Expand Up @@ -126,7 +129,7 @@ def test_warm_start(self):
# we collect the loss values from rank 0 and store them in the temporary experiment folder
if dist.get_rank() == 0:
messages_0: List[Message[EvaluationResultBatch]] = components_0.evaluation_subscriber.message_list
loss_scores_0 = TestWarmstart.get_loss_scores(messages_0, "CLMCrossEntropyLoss average")
loss_scores_0 = TestWarmstart.get_loss_scores(messages_0, "train loss avg")
with open(loss_values_experiment_0_path, "w") as f:
json.dump(loss_scores_0, f)

Expand All @@ -146,7 +149,7 @@ def test_warm_start(self):
# and store them in the temporary experiment folder
if dist.get_rank() == 0:
messages_1: List[Message[EvaluationResultBatch]] = components_1.evaluation_subscriber.message_list
loss_scores_1 = TestWarmstart.get_loss_scores(messages_1, "CLMCrossEntropyLoss average")
loss_scores_1 = TestWarmstart.get_loss_scores(messages_1, "train loss avg")
with open(loss_values_experiment_1_path, "w") as f:
json.dump(loss_scores_1, f)

Expand Down

0 comments on commit 120596d

Please sign in to comment.