Skip to content

Commit

Permalink
fix accuracy not creating label tensor on appropriate device
Browse files Browse the repository at this point in the history
  • Loading branch information
icetube23 committed Jul 9, 2024
1 parent b060c83 commit 908f421
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ideal_words/ideal_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def _accuracy(self, approx: str) -> float:

# find nearest neighbor for each approximated embedding and compare against expected nearest neighbors
approx_matches = dists.argmin(dim=0)
expected_matches = torch.arange(len(approx_matches))
expected_matches = torch.arange(len(approx_matches), device=self.device)

return (approx_matches == expected_matches).float().mean().cpu().item()

Expand Down

0 comments on commit 908f421

Please sign in to comment.