From 04aa5bca2ee60c73de91507e5eb7472a6cf6d7a6 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Tue, 12 Nov 2024 22:39:46 -0800 Subject: [PATCH] Added test for arrays with SQLAlchemy async - #101 --- tests/test_sqlalchemy.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index f8e4bb1..77c03fc 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -483,3 +483,25 @@ async def test_async(self): assert avg.first() == '[2.5,3.5,4.5]' await engine.dispose() + + @pytest.mark.asyncio + @pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+') + async def test_async_vector_array(self): + engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test') + async_session = async_sessionmaker(engine, expire_on_commit=False) + + @event.listens_for(engine.sync_engine, "connect") + def connect(dbapi_connection, connection_record): + from pgvector.psycopg import register_vector_async + dbapi_connection.run_async(register_vector_async) + + async with async_session() as session: + async with session.begin(): + session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])) + + # this fails if the driver does not cast arrays + item = await session.get(Item, 1) + assert item.embeddings[0].tolist() == [1, 2, 3] + assert item.embeddings[1].tolist() == [4, 5, 6] + + await engine.dispose()