Skip to content

Commit

Permalink
Set default value for dataloader_idx
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Nov 8, 2023
1 parent 1ef896b commit 7a22542
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def generic_step(
return losses["loss"]

def validation_step(
self, data: Mapping[str, Any], _, dataloader_idx: int
self, data: Mapping[str, Any], _, dataloader_idx: int = 0
) -> torch.Tensor:
batch = {frozenset(data.keys()): data}
for domain in data.keys():
Expand All @@ -197,7 +197,7 @@ def validation_step(
return self.generic_step(batch, mode="val/ood")

def test_step(
self, data: Mapping[str, Any], _, dataloader_idx: int
self, data: Mapping[str, Any], _, dataloader_idx: int = 0
) -> torch.Tensor:
batch = {frozenset(data.keys()): data}
for domain in data.keys():
Expand Down

0 comments on commit 7a22542

Please sign in to comment.