Skip to content

Commit

Permalink
Test more engine configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Feb 9, 2025
1 parent a1d8997 commit c792451
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,39 @@

psycopg2_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
pg8000_engine = create_engine(f'postgresql+pg8000://{os.environ["USER"]}@localhost/pgvector_python_test')
psycopg2_array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
psycopg2_type_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')


@event.listens_for(psycopg2_array_engine, "connect")
@event.listens_for(psycopg2_type_engine, "connect")
def psycopg2_connect(dbapi_connection, connection_record):
from pgvector.psycopg2 import register_vector
register_vector(dbapi_connection, globally=False, arrays=True)


engines = [psycopg2_engine, pg8000_engine]
array_engines = [psycopg2_array_engine]
engines = [psycopg2_engine, pg8000_engine, psycopg2_type_engine]
array_engines = [psycopg2_type_engine]
async_engines = []

if sqlalchemy_version > 1:
psycopg_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
engines.append(psycopg_engine)

psycopg_type_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')

@event.listens_for(psycopg_type_engine, "connect")
def psycopg_connect(dbapi_connection, connection_record):
from pgvector.psycopg import register_vector
register_vector(dbapi_connection)

engines.append(psycopg_type_engine)
array_engines.append(psycopg_type_engine)

psycopg_async_engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
async_engines.append(psycopg_async_engine)

asyncpg_engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
async_engines.append(asyncpg_engine)

psycopg_array_engine = create_engine('postgresql+psycopg://localhost/pgvector_python_test')
array_engines.append(psycopg_array_engine)

@event.listens_for(psycopg_array_engine, "connect")
def psycopg_connect(dbapi_connection, connection_record):
from pgvector.psycopg import register_vector
register_vector(dbapi_connection)

setup_engine = engines[0]
with Session(setup_engine) as session:
session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
Expand Down Expand Up @@ -169,9 +171,10 @@ def test_orm(self, engine):
stmt = select(Item)
with Session(engine) as session:
items = [v[0] for v in session.execute(stmt).all()]
assert items[0].id in [1, 4, 7]
assert items[1].id in [2, 5, 8]
assert items[2].id in [3, 6, 9]
# TODO improve
assert items[0].id % 3 == 1
assert items[1].id % 3 == 2
assert items[2].id % 3 == 0
assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3]))
assert items[0].embedding.dtype == np.float32
assert np.array_equal(items[1].embedding, np.array([4, 5, 6]))
Expand Down

0 comments on commit c792451

Please sign in to comment.