Skip to content

Commit

Permalink
Added test for arrays with SQLAlchemy async - #101
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Nov 13, 2024
1 parent d23844e commit 04aa5bc
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,25 @@ async def test_async(self):
assert avg.first() == '[2.5,3.5,4.5]'

await engine.dispose()

@pytest.mark.asyncio
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
async def test_async_vector_array(self):
engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')
async_session = async_sessionmaker(engine, expire_on_commit=False)

@event.listens_for(engine.sync_engine, "connect")
def connect(dbapi_connection, connection_record):
from pgvector.psycopg import register_vector_async
dbapi_connection.run_async(register_vector_async)

async with async_session() as session:
async with session.begin():
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))

# this fails if the driver does not cast arrays
item = await session.get(Item, 1)
assert item.embeddings[0].tolist() == [1, 2, 3]
assert item.embeddings[1].tolist() == [4, 5, 6]

await engine.dispose()

0 comments on commit 04aa5bc

Please sign in to comment.