Skip to content

Commit

Permalink
feat: add potential embedding key
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
JoanFM committed Jul 21, 2023
1 parent 1cc99e7 commit 2af1eda
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ docs = ann.get_docs(limit=10, offset=0, order_by='x', ascending=True)
After you have indexed the `docs`, you can update the docs in the index by calling `ann.update()`:

```python
updated_docs = [{'id': '0', 'embedding': [], 'price': 6}]
updated_docs = [{'id': '0', 'embedding': np.random.random([128]).astype(np.float32), 'price': 6}]

ann.update(updated_docs)
```
Expand Down
13 changes: 7 additions & 6 deletions annlite/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self,
n_dim: int,
metric: Union[str, Metric] = 'cosine',
embedding_field: str = 'embedding',
n_cells: int = 1,
n_subvectors: Optional[int] = None,
n_clusters: Optional[int] = 256,
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
self.n_probe = max(n_probe, n_cells)
self.n_cells = n_cells
self.size_limit = 2048
self._embedding_field = embedding_field

if isinstance(metric, str):
metric = Metric.from_string(metric)
Expand Down Expand Up @@ -172,7 +174,7 @@ def __init__(
total_size = 0
# TODO: add a progress bar
for docs in self.documents_generator(0, batch_size=1024):
x = np.array([doc['embedding'] for doc in docs])
x = np.array([doc[self._embedding_field] for doc in docs])
total_size += x.shape[0]
self.partial_train(x, auto_save=True, force_train=True)
if total_size >= MAX_TRAINING_DATA_SIZE:
Expand Down Expand Up @@ -280,8 +282,7 @@ def index(self, docs: 'List', **kwargs):
if not self.is_trained:
raise RuntimeError(f'The indexer is not trained, cannot add new documents')

# TODO: Obtain the embeddings from the dict or change index signature
x = np.array([doc['embedding'] for doc in docs])
x = np.array([doc[self._embedding_field] for doc in docs])
n_data, _ = self._sanity_check(x)

assigned_cells = (
Expand Down Expand Up @@ -312,7 +313,7 @@ def update(
raise RuntimeError(f'The indexer is not trained, cannot add new documents')

# TODO: Obtain the embeddings from the dict or change index signature
x = np.array([doc['embedding'] for doc in docs])
x = np.array([doc[self._embedding_field] for doc in docs])
n_data, _ = self._sanity_check(x)

assigned_cells = (
Expand Down Expand Up @@ -347,7 +348,7 @@ def search(
if not self.is_trained:
raise RuntimeError(f'The indexer is not trained, cannot add new documents')

query_np = np.array([doc['embedding'] for doc in docs])
query_np = np.array([doc[self._embedding_field] for doc in docs])

_, match_docs = self.search_by_vectors(
query_np, filter=filter, limit=limit, include_metadata=include_metadata
Expand Down Expand Up @@ -778,7 +779,7 @@ def _rebuild_index_from_local(self):
f'Rebuild the index of cell-{cell_id} ({cell_size} docs)...'
)
for docs in self.documents_generator(cell_id, batch_size=10240):
x = np.array([doc['embedding'] for doc in docs])
x = np.array([doc[self._embedding_field] for doc in docs])

assigned_cells = np.ones(len(docs), dtype=np.int64) * cell_id
super().insert(x, assigned_cells, docs, only_index=True)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def test_local_backup_restore(tmpdir):
index = AnnLite(n_dim=D, data_path=tmpdir / 'workspace' / '0')
index.index(docs)

tmpname = uuid.uuid4().hex
index.backup()
index.close()

Expand All @@ -278,3 +277,24 @@ def test_local_backup_restore(tmpdir):
status = index.stat
assert int(status['total_docs']) == N
assert int(status['index_size']) == N


def test_index_search_different_field(tmpdir):
X = np.random.random((N, D)).astype(
np.float32
) # 10,000 128-dim vectors to be indexed

index = AnnLite(
n_dim=D, data_path=str(tmpdir), embedding_field='encoding', metric='euclidean'
)
docs = [dict(id=f'{i}', encoding=X[i]) for i in range(N)]
index.index(docs)
query = [dict(encoding=X[i]) for i in range(5)]

matches = index.search(query)

for i in range(len(matches[0]) - 1):
assert (
matches[0][i]['scores']['euclidean']
<= matches[0][i + 1]['scores']['euclidean']
)

0 comments on commit 2af1eda

Please sign in to comment.