Skip to content

Commit

Permalink
fixing camelyon16
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jun 4, 2024
1 parent 5bcfb74 commit b7e5421
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def evaluate_func(m, test_dls, metric):
def compute_avg_loss_on_client(self, model, dataset):
average_loss = 0.0
count_batch = 0
for X, y in dl(dataset, self.batch_size_test, shuffle=False):
for X, y in dl(dataset, self.batch_size_test, shuffle=False, collate_fn=self.collate_fn):
if torch.cuda.is_available():
X = X.cuda()
y = y.cuda()
Expand Down Expand Up @@ -161,6 +161,7 @@ def robust_metric(y_true, y_pred):
self.pooled_test_dataset,
self.batch_size_test,
shuffle=False,
collate_fn=self.collate_fn,
)
],
robust_metric,
Expand Down

0 comments on commit b7e5421

Please sign in to comment.