Skip to content

Commit

Permalink
Improve test_log_to_wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Apr 1, 2024
1 parent 92c141b commit 83ccb39
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions tests/train/test_wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,11 @@ def test_init_wandb(mock_wandb_init: MagicMock, mock_giga_config):

@patch("wandb.log")
def test_log_to_wandb(mock_wandb_log: MagicMock):
model = MagicMock() # type: ignore
model.__class__ = transformers.LlamaForCausalLM
optimizer = MagicMock()
optimizer.__class__ = torch.optim.AdamW
model = MagicMock(spec=transformers.LlamaForCausalLM)
optimizer = MagicMock(spec=torch.optim.AdamW)
log_to_wandb(
mts=ModelTrainingState(
model=model, # type: ignore
model=model,
optimizer=optimizer,
step=5,
epoch=1,
Expand All @@ -88,3 +86,16 @@ def test_log_to_wandb(mock_wandb_log: MagicMock):
losses={"train": 0.5, "val": 0.4},
tokens_so_far=4242,
)
assert mock_wandb_log.call_count == 1
mock_wandb_log.assert_called_with(
{
"epoch": 1,
"epoch_iter": 5,
"global_iter": 55,
"tokens": 4242,
"loss/train": 0.5,
"loss/val": 0.4,
"lr": 0.007,
},
step=55,
)

0 comments on commit 83ccb39

Please sign in to comment.