diff --git a/ideal_words/ideal_words.py b/ideal_words/ideal_words.py index f6e9a5d..580dc0d 100644 --- a/ideal_words/ideal_words.py +++ b/ideal_words/ideal_words.py @@ -363,7 +363,7 @@ def _accuracy(self, approx: str) -> float: # find nearest neighbor for each approximated embedding and compare against expected nearest neighbors approx_matches = dists.argmin(dim=0) - expected_matches = torch.arange(len(approx_matches)) + expected_matches = torch.arange(len(approx_matches), device=self.device) return (approx_matches == expected_matches).float().mean().cpu().item()