diff --git a/.vscode/launch.json b/.vscode/launch.json index 972b57a..14b75df 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -208,6 +208,18 @@ "console": "integratedTerminal", "cwd": "${workspaceFolder}" }, + { + "name": "test_multiple_databases)", + "type": "python", + "request": "launch", + "module": "pytest", + "args": [ + "tests/test_multiple_databases.py", + "-sx" + ], + "console": "integratedTerminal", + "cwd": "${workspaceFolder}" + }, { "name": "test migrations)", "type": "python", diff --git a/Makefile b/Makefile index ea73400..bb86342 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,7 @@ test: pytest tests/test_query_caching.py -s -x; pytest tests/test_querying.py -s -x; pytest tests/test_query_no_caching.py -s -x; + pytest tests/test_multiple_databases.py -s -x; test-migrations: python tests/migrations/test_migrations.py diff --git a/pydbantic/core.py b/pydbantic/core.py index d4ae348..aaf54f2 100644 --- a/pydbantic/core.py +++ b/pydbantic/core.py @@ -28,6 +28,7 @@ Boolean, Date, DateTime, + Float, Integer, LargeBinary, Numeric, @@ -79,16 +80,17 @@ def get_model_getter(model, primary_key, primary_key_value): class BaseMeta: - translations: dict = { - str: sqlalchemy.String, - int: sqlalchemy.Integer, - float: sqlalchemy.Float, - bool: sqlalchemy.Boolean, - dict: sqlalchemy.LargeBinary, - list: sqlalchemy.LargeBinary, - tuple: sqlalchemy.LargeBinary, - } - tables: dict = {} + def __init__(self): + self.translations: dict = { + str: sqlalchemy.String, + int: sqlalchemy.Integer, + float: sqlalchemy.Float, + bool: sqlalchemy.Boolean, + dict: sqlalchemy.LargeBinary, + list: sqlalchemy.LargeBinary, + tuple: sqlalchemy.LargeBinary, + } + self.tables: dict = {} def Relationship( @@ -116,6 +118,7 @@ def Relationship( Boolean, JSON, VARCHAR, + Float, ] @@ -402,8 +405,6 @@ def not_matches(self, choices: List[Any]) -> DataBaseModelCondition: class DataBaseModel(BaseModel): - __metadata__: BaseMeta = BaseMeta() - class Config: arbitrary_types_allowed = True @@ -497,6 +498,7 @@ def deserialize(cls: Type[T], data, expected_type=None): @classmethod def setup(cls: Type[T], database) -> None: + cls.__metadata__ = database.__metadata__ cls.__tablename__ = getattr(cls, "__tablename__", cls.__name__) cls.update_backward_refs() diff --git a/pydbantic/database.py b/pydbantic/database.py index 59e83c6..a4e3f2e 100644 --- a/pydbantic/database.py +++ b/pydbantic/database.py @@ -14,7 +14,7 @@ from sqlalchemy import create_engine from pydbantic.cache import Redis -from pydbantic.core import DatabaseInit, DataBaseModel, TableMeta +from pydbantic.core import BaseMeta, DatabaseInit, DataBaseModel, TableMeta from pydbantic.translations import DEFAULT_TRANSLATIONS @@ -49,6 +49,7 @@ def __init__( else {}, ) self.use_alembic = use_alembic + self.__metadata__: BaseMeta = BaseMeta() self.DEFAULT_TRANSLATIONS = DEFAULT_TRANSLATIONS diff --git a/tests/test_multiple_databases.py b/tests/test_multiple_databases.py new file mode 100644 index 0000000..11d15b9 --- /dev/null +++ b/tests/test_multiple_databases.py @@ -0,0 +1,26 @@ +import pytest + +from pydbantic import Database, DataBaseModel + + +@pytest.mark.asyncio +async def test_multiple_database(): + class DB1Model1(DataBaseModel): + __tablename__ = "model1" + data: str + data2: str + + class DB2Model1(DataBaseModel): + __tablename__ = "model1" + data: str + data2: str + + db1 = await Database.create("sqlite:///db1", tables=[DB1Model1]) + db2 = await Database.create("sqlite:///db2", tables=[DB2Model1]) + + assert not await DB1Model1.all() + assert not await DB2Model1.all() + + await DB1Model1.create(data="1", data2="2") + assert await DB1Model1.all() + assert not await DB2Model1.all()