Skip to content

Commit

Permalink
chore: Fix linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Jan 15, 2025
1 parent 836371f commit 942a321
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions test/embeddings/test_spacy.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@


import pytest
from chromadbx.embeddings.spacy import SpacyEmbeddingFunction
spacy = pytest.importorskip("spacy", reason="spacy not installed")
import subprocess

def download_model(model_name: str):
spacy = pytest.importorskip("spacy", reason="spacy not installed")


def download_model(model_name: str) -> None:
s = subprocess.run(["python", "-m", "spacy", "download", model_name])
assert s.returncode == 0

@pytest.fixture(scope="module")
def model():
download_model("en_core_web_lg")
return "en_core_web_lg"


def test_spacy():
def test_spacy() -> None:
download_model("en_core_web_lg")
ef = SpacyEmbeddingFunction()
texts = ["hello world", "goodbye world"] * 1000
embeddings = ef(texts)
Expand All @@ -24,7 +20,8 @@ def test_spacy():
assert len(embeddings[0]) == 300
assert len(embeddings[1]) == 300

def test_spacy_with_model():

def test_spacy_with_model() -> None:
download_model("en_core_web_sm")
ef = SpacyEmbeddingFunction(model_name="en_core_web_sm")
embeddings = ef(["hello world", "goodbye world"])
Expand All @@ -34,7 +31,7 @@ def test_spacy_with_model():
assert len(embeddings[1]) == 96


def test_spacy_with_invalid_model():
def test_spacy_with_invalid_model() -> None:
with pytest.raises(ValueError) as e:
ef = SpacyEmbeddingFunction(model_name="invalid_model")
SpacyEmbeddingFunction(model_name="invalid_model")
assert "spacy model 'invalid_model' are not downloaded yet" in str(e.value)

0 comments on commit 942a321

Please sign in to comment.