Skip to content

Commit

Permalink
Pin vectors to the CPU after deserialization (#157)
Browse files Browse the repository at this point in the history
* Pin vectors to the CPU after deserialization

* Restore CPU ops after regression test

* Skip test if GPU support is not present

* Use `use_ops` context manager in test

* Typo
  • Loading branch information
shadeMe authored Apr 17, 2023
1 parent eb53bf4 commit ab97146
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
19 changes: 17 additions & 2 deletions sense2vec/sense2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from spacy.vectors import Vectors
from spacy.strings import StringStore
from spacy.util import SimpleFrozenDict
from thinc.api import NumpyOps
import numpy
import srsly

Expand Down Expand Up @@ -247,7 +248,11 @@ def get_other_senses(
result = []
key = key if isinstance(key, str) else self.strings[key]
word, orig_sense = self.split_key(key)
versions = set([word, word.lower(), word.upper(), word.title()]) if ignore_case else [word]
versions = (
set([word, word.lower(), word.upper(), word.title()])
if ignore_case
else [word]
)
for text in versions:
for sense in self.senses:
new_key = self.make_key(text, sense)
Expand All @@ -270,7 +275,11 @@ def get_best_sense(
sense_options = senses or self.senses
if not sense_options:
return None
versions = set([word, word.lower(), word.upper(), word.title()]) if ignore_case else [word]
versions = (
set([word, word.lower(), word.upper(), word.title()])
if ignore_case
else [word]
)
freqs = []
for text in versions:
for sense in sense_options:
Expand Down Expand Up @@ -304,6 +313,9 @@ def from_bytes(self, bytes_data: bytes, exclude: Sequence[str] = tuple()):
"""
data = srsly.msgpack_loads(bytes_data)
self.vectors = Vectors().from_bytes(data["vectors"])
# Pin vectors to the CPU so that we don't end up comparing
# numpy and cupy arrays.
self.vectors.to_ops(NumpyOps())
self.freqs = dict(data.get("freqs", []))
self.cfg.update(data.get("cfg", {}))
if "strings" not in exclude and "strings" in data:
Expand Down Expand Up @@ -340,6 +352,9 @@ def from_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
freqs_path = path / "freqs.json"
cache_path = path / "cache"
self.vectors = Vectors().from_disk(path)
# Pin vectors to the CPU so that we don't end up comparing
# numpy and cupy arrays.
self.vectors.to_ops(NumpyOps())
self.cfg.update(srsly.read_json(path / "cfg"))
if freqs_path.exists():
self.freqs = dict(srsly.read_json(freqs_path))
Expand Down
13 changes: 13 additions & 0 deletions sense2vec/tests/test_issue155.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pathlib import Path
import pytest
from sense2vec.sense2vec import Sense2Vec
from thinc.api import use_ops
from thinc.util import has_cupy_gpu


@pytest.mark.skipif(not has_cupy_gpu, reason="requires Cupy/GPU")
def test_issue155():
data_path = Path(__file__).parent / "data"
with use_ops("cupy"):
s2v = Sense2Vec().from_disk(data_path)
s2v.most_similar("beekeepers|NOUN")

0 comments on commit ab97146

Please sign in to comment.