Skip to content

Commit

Permalink
Improved tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Feb 9, 2025
1 parent 8a7040d commit 00cd08e
Showing 1 changed file with 71 additions and 68 deletions.
139 changes: 71 additions & 68 deletions tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ class Item(Base):


def create_items():
session = Session(engine)
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])))
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])))
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])))
session.commit()
with Session(engine) as session:
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])))
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])))
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])))
session.commit()


class TestSqlalchemy:
Expand Down Expand Up @@ -129,11 +129,11 @@ def test_orm(self):
item2 = Item(embedding=[4, 5, 6])
item3 = Item()

session = Session(engine)
session.add(item)
session.add(item2)
session.add(item3)
session.commit()
with Session(engine) as session:
session.add(item)
session.add(item2)
session.add(item3)
session.commit()

stmt = select(Item)
with Session(engine) as session:
Expand All @@ -148,11 +148,11 @@ def test_orm(self):
assert items[2].embedding is None

def test_vector(self):
session = Session(engine)
session.add(Item(id=1, embedding=[1, 2, 3]))
session.commit()
item = session.get(Item, 1)
assert item.embedding.tolist() == [1, 2, 3]
with Session(engine) as session:
session.add(Item(id=1, embedding=[1, 2, 3]))
session.commit()
item = session.get(Item, 1)
assert item.embedding.tolist() == [1, 2, 3]

def test_vector_l2_distance(self):
create_items()
Expand Down Expand Up @@ -203,11 +203,11 @@ def test_vector_l1_distance_orm(self):
assert [v.id for v in items] == [1, 3, 2]

def test_halfvec(self):
session = Session(engine)
session.add(Item(id=1, half_embedding=[1, 2, 3]))
session.commit()
item = session.get(Item, 1)
assert item.half_embedding.to_list() == [1, 2, 3]
with Session(engine) as session:
session.add(Item(id=1, half_embedding=[1, 2, 3]))
session.commit()
item = session.get(Item, 1)
assert item.half_embedding.to_list() == [1, 2, 3]

def test_halfvec_l2_distance(self):
create_items()
Expand Down Expand Up @@ -258,11 +258,11 @@ def test_halfvec_l1_distance_orm(self):
assert [v.id for v in items] == [1, 3, 2]

def test_bit(self):
session = Session(engine)
session.add(Item(id=1, binary_embedding='101'))
session.commit()
item = session.get(Item, 1)
assert item.binary_embedding == '101'
with Session(engine) as session:
session.add(Item(id=1, binary_embedding='101'))
session.commit()
item = session.get(Item, 1)
assert item.binary_embedding == '101'

def test_bit_hamming_distance(self):
create_items()
Expand All @@ -289,11 +289,11 @@ def test_bit_jaccard_distance_orm(self):
assert [v.id for v in items] == [2, 3, 1]

def test_sparsevec(self):
session = Session(engine)
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
session.commit()
item = session.get(Item, 1)
assert item.sparse_embedding.to_list() == [1, 2, 3]
with Session(engine) as session:
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
session.commit()
item = session.get(Item, 1)
assert item.sparse_embedding.to_list() == [1, 2, 3]

def test_sparsevec_l2_distance(self):
create_items()
Expand Down Expand Up @@ -405,24 +405,24 @@ def test_sum_orm(self):

def test_bad_dimensions(self):
item = Item(embedding=[1, 2])
session = Session(engine)
session.add(item)
with pytest.raises(StatementError, match='expected 3 dimensions, not 2'):
session.commit()
with Session(engine) as session:
session.add(item)
with pytest.raises(StatementError, match='expected 3 dimensions, not 2'):
session.commit()

def test_bad_ndim(self):
item = Item(embedding=np.array([[1, 2, 3]]))
session = Session(engine)
session.add(item)
with pytest.raises(StatementError, match='expected ndim to be 1'):
session.commit()
with Session(engine) as session:
session.add(item)
with pytest.raises(StatementError, match='expected ndim to be 1'):
session.commit()

def test_bad_dtype(self):
item = Item(embedding=np.array(['one', 'two', 'three']))
session = Session(engine)
session.add(item)
with pytest.raises(StatementError, match='could not convert string to float'):
session.commit()
with Session(engine) as session:
session.add(item)
with pytest.raises(StatementError, match='could not convert string to float'):
session.commit()

def test_inspect(self):
columns = inspect(engine).get_columns('sqlalchemy_orm_item')
Expand All @@ -433,44 +433,48 @@ def test_literal_binds(self):
assert "embedding <-> '[1.0,2.0,3.0]'" in str(sql)

def test_insert(self):
session.execute(insert(Item).values(embedding=np.array([1, 2, 3])))
with Session(engine) as session:
session.execute(insert(Item).values(embedding=np.array([1, 2, 3])))

def test_insert_bulk(self):
session.execute(insert(Item), [{'embedding': np.array([1, 2, 3])}])
with Session(engine) as session:
session.execute(insert(Item), [{'embedding': np.array([1, 2, 3])}])

# register_vector in psycopg2 tests change this behavior
# def test_insert_text(self):
# session.execute(text('INSERT INTO sqlalchemy_orm_item (embedding) VALUES (:embedding)'), {'embedding': np.array([1, 2, 3])})
# with Session(engine) as session:
# session.execute(text('INSERT INTO sqlalchemy_orm_item (embedding) VALUES (:embedding)'), {'embedding': np.array([1, 2, 3])})

def test_automap(self):
metadata = MetaData()
metadata.reflect(engine, only=['sqlalchemy_orm_item'])
AutoBase = automap_base(metadata=metadata)
AutoBase.prepare()
AutoItem = AutoBase.classes.sqlalchemy_orm_item
session.execute(insert(AutoItem), [{'embedding': np.array([1, 2, 3])}])
item = session.query(AutoItem).first()
assert item.embedding.tolist() == [1, 2, 3]
with Session(engine) as session:
session.execute(insert(AutoItem), [{'embedding': np.array([1, 2, 3])}])
item = session.query(AutoItem).first()
assert item.embedding.tolist() == [1, 2, 3]

def test_vector_array(self):
session = Session(array_engine)
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
session.commit()
with Session(array_engine) as session:
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
session.commit()

# this fails if the driver does not cast arrays
item = session.get(Item, 1)
assert item.embeddings[0].tolist() == [1, 2, 3]
assert item.embeddings[1].tolist() == [4, 5, 6]
# this fails if the driver does not cast arrays
item = session.get(Item, 1)
assert item.embeddings[0].tolist() == [1, 2, 3]
assert item.embeddings[1].tolist() == [4, 5, 6]

def test_halfvec_array(self):
session = Session(array_engine)
session.add(Item(id=1, half_embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
session.commit()
with Session(array_engine) as session:
session.add(Item(id=1, half_embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
session.commit()

# this fails if the driver does not cast arrays
item = session.get(Item, 1)
assert item.half_embeddings[0].to_list() == [1, 2, 3]
assert item.half_embeddings[1].to_list() == [4, 5, 6]
# this fails if the driver does not cast arrays
item = session.get(Item, 1)
assert item.half_embeddings[0].to_list() == [1, 2, 3]
assert item.half_embeddings[1].to_list() == [4, 5, 6]

def test_half_precision(self):
create_items()
Expand All @@ -479,13 +483,12 @@ def test_half_precision(self):
assert [v.id for v in items] == [1, 3, 2]

def test_binary_quantize(self):
session = Session(engine)
session.add(Item(id=1, embedding=[-1, -2, -3]))
session.add(Item(id=2, embedding=[1, -2, 3]))
session.add(Item(id=3, embedding=[1, 2, 3]))
session.commit()

with Session(engine) as session:
session.add(Item(id=1, embedding=[-1, -2, -3]))
session.add(Item(id=2, embedding=[1, -2, 3]))
session.add(Item(id=3, embedding=[1, 2, 3]))
session.commit()

distance = func.cast(func.binary_quantize(Item.embedding), BIT(3)).hamming_distance(func.binary_quantize(func.cast([3, -1, 2], VECTOR(3))))
items = session.query(Item).order_by(distance).all()
assert [v.id for v in items] == [2, 3, 1]
Expand Down

0 comments on commit 00cd08e

Please sign in to comment.