Skip to content

Commit

Permalink
Added additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Aug 22, 2020
1 parent daf6f5b commit 7c7819a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
build/
dist/
htmlcov/
*egg-info/
__pycache__/
.coverage
Expand Down
30 changes: 30 additions & 0 deletions test/python/testann.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,45 @@ def testSave(self):
model.save(index)
model.load(index)

def testSearch(self):
"""
Tests ANN search
"""

# Generate ANN index
model = self.backend("annoy")

# Generate query vector
query = np.random.rand(300).astype(np.float32)
self.normalize(query)

# Ensure top result has similarity > 0
self.assertGreater(model.search(query, 1)[0][1], 0)

def backend(self, name, length=100):
"""
Tests a backend
"""

# Generate test data
data = np.random.rand(length, 300).astype(np.float32)
self.normalize(data)

model = ANN.create({"backend": name, "dimensions": data.shape[1]})
model.index(data)

return model

def normalize(self, embeddings):
"""
Normalizes embeddings using L2 normalization. Operation applied directly on array.
Args:
embeddings: input embeddings matrix
"""

# Calculation is different for matrices vs vectors
if len(embeddings.shape) > 1:
embeddings /= np.linalg.norm(embeddings, axis=1)[:, np.newaxis]
else:
embeddings /= np.linalg.norm(embeddings)
38 changes: 38 additions & 0 deletions test/python/testembeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from txtai.embeddings import Embeddings
from txtai.vectors import WordVectors

class TestEmbeddings(unittest.TestCase):
"""
Expand Down Expand Up @@ -72,3 +73,40 @@ def testSimilarity(self):
uid = np.argmax(self.embeddings.similarity("feel good story", self.data))

self.assertEqual(self.data[uid], self.data[4])

def testWords(self):
"""
Test embeddings backed by word vectors
"""

# Initialize model path
path = os.path.join(tempfile.gettempdir(), "model")
os.makedirs(path, exist_ok=True)

# Build tokens file
with tempfile.NamedTemporaryFile(mode="w", delete=False) as output:
tokens = output.name
for x in self.data:
output.write(x + "\n")

# Word vectors path
vectors = os.path.join(path, "test-300d")

# Build word vectors, if they don't already exist
WordVectors.build(tokens, 300, 1, vectors)

# Create dataset
data = [(x, row, None) for x, row in enumerate(self.data)]

# Create embeddings model, backed by word vectors
embeddings = Embeddings({"path": vectors + ".magnitude",
"scoring": "bm25",
"pca": 3,
"quantize": True})

# Call scoring and index methods
embeddings.score(data)
embeddings.index(data)

# Test search
self.assertIsNotNone(embeddings.search("win", 1))

0 comments on commit 7c7819a

Please sign in to comment.