From 28adcf55b519dd4aa195969cfabcfb0c09009649 Mon Sep 17 00:00:00 2001 From: Andrey Laguta Date: Thu, 12 Nov 2020 16:06:50 +0300 Subject: [PATCH 1/7] fix sqlalchemy defaults --- databases/backends/postgres.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 223ec290..ef00a79f 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -4,6 +4,7 @@ import asyncpg from sqlalchemy.dialects.postgresql import pypostgresql +from sqlalchemy.dialects.postgresql.base import PGCompiler from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement @@ -19,6 +20,22 @@ _result_processors = {} # type: dict +class APGCompiler_psycopg2(PGCompiler): + def construct_params(self, params=None, _group_number=None, _check=True): + pd = super().construct_params(params, _group_number, _check) + + for column in self.prefetch: + pd[column.key] = self._exec_default(column.default) + + return pd + + def _exec_default(self, default: typing.Any) -> typing.Any: + if default.is_callable: + return default.arg(self.dialect) + else: + return default.arg + + class PostgresBackend(DatabaseBackend): def __init__( self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any @@ -31,6 +48,7 @@ def __init__( def _get_dialect(self) -> Dialect: dialect = pypostgresql.dialect(paramstyle="pyformat") + dialect.statement_compiler = APGCompiler_psycopg2 dialect.implicit_returning = True dialect.supports_native_enum = True dialect.supports_smallserial = True # 9.2+ From 71f4616160d5628e339d2dc5757cd42c576c88d5 Mon Sep 17 00:00:00 2001 From: Andrey Laguta Date: Fri, 13 Nov 2020 00:01:35 +0300 Subject: [PATCH 2/7] add tests for column defaults when inserting --- databases/core.py | 5 ++++ tests/test_databases.py | 54 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/databases/core.py b/databases/core.py index 2bab6735..cec261e7 100644 --- a/databases/core.py +++ b/databases/core.py @@ -301,6 +301,11 @@ def _build_query( elif values: return query.values(**values) + # for case when `table.insert()` called without `.values()` it has to be + # called to produce `insert_prefetch` for compiled query + if query.__visit_name__ == "insert": + return query.values() + return query diff --git a/tests/test_databases.py b/tests/test_databases.py index c7317688..163a9deb 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -70,6 +70,15 @@ def process_result_value(self, value, dialect): sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)), ) +# Used to test column default values +default_values = sqlalchemy.Table( + "default_values", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("with_default", sqlalchemy.Integer, default=42), + sqlalchemy.Column("without_default", sqlalchemy.Integer), +) + @pytest.fixture(autouse=True, scope="module") def create_test_database(): @@ -651,6 +660,50 @@ async def test_json_field(database_url): assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1} +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_insert_with_default_values(database_url): + """ + Test insert with column default values + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + # execute() + query = default_values.insert() + values = {"without_default": 1} + inserted_id = await database.execute(query, values) + + # fetch_one() + query = default_values.select().where(default_values.c.id == inserted_id) + result = await database.fetch_one(query=query) + assert result["with_default"] == 42 + assert result["without_default"] == values["without_default"] + + # test without passing values and without calling `values()` + # execute() + query = default_values.insert() + inserted_id = await database.execute(query) + + # fetch_one() + query = default_values.select().where(default_values.c.id == inserted_id) + result = await database.fetch_one(query=query) + assert result["with_default"] == 42 + assert result["without_default"] is None + + # test pass other than default value + # execute() + query = default_values.insert() + values = {"with_default": 5} + inserted_id = await database.execute(query, values) + + # fetch_one() + query = default_values.select().where(default_values.c.id == inserted_id) + result = await database.fetch_one(query=query) + assert result["with_default"] == values["with_default"] + assert result["without_default"] is None + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_custom_field(database_url): @@ -915,7 +968,6 @@ def test_global_connection_is_initialized_lazily(database_url): @async_adapter async def run_database_queries(): async with database: - async def db_lookup(): await database.fetch_one("SELECT pg_sleep(1)") From 9e34f535e421ccf496d35b2750206798c205ea31 Mon Sep 17 00:00:00 2001 From: Andrey Laguta Date: Fri, 13 Nov 2020 13:14:04 +0300 Subject: [PATCH 3/7] cover default callable case --- tests/test_databases.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index 163a9deb..6b60283c 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -76,6 +76,7 @@ def process_result_value(self, value, dialect): metadata, sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), sqlalchemy.Column("with_default", sqlalchemy.Integer, default=42), + sqlalchemy.Column("with_callable_default", sqlalchemy.String(length=100), default=lambda: "default_value"), sqlalchemy.Column("without_default", sqlalchemy.Integer), ) @@ -701,7 +702,16 @@ async def test_insert_with_default_values(database_url): query = default_values.select().where(default_values.c.id == inserted_id) result = await database.fetch_one(query=query) assert result["with_default"] == values["with_default"] - assert result["without_default"] is None + + # test callable default + # execute() + query = default_values.insert() + inserted_id = await database.execute(query) + + # fetch_one() + query = default_values.select().where(default_values.c.id == inserted_id) + result = await database.fetch_one(query=query) + assert result["with_callable_default"] == "default_value" @pytest.mark.parametrize("database_url", DATABASE_URLS) From bd5af23ad81b74f4b8e9c261f8504a42a268d153 Mon Sep 17 00:00:00 2001 From: Andrey Laguta Date: Fri, 13 Nov 2020 13:22:37 +0300 Subject: [PATCH 4/7] fix default column value test for mysql and sqlite --- tests/test_databases.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index 6b60283c..9806c8e9 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -673,10 +673,10 @@ async def test_insert_with_default_values(database_url): # execute() query = default_values.insert() values = {"without_default": 1} - inserted_id = await database.execute(query, values) + await database.execute(query, values) # fetch_one() - query = default_values.select().where(default_values.c.id == inserted_id) + query = default_values.select().where(default_values.c.id == 1) result = await database.fetch_one(query=query) assert result["with_default"] == 42 assert result["without_default"] == values["without_default"] @@ -684,10 +684,10 @@ async def test_insert_with_default_values(database_url): # test without passing values and without calling `values()` # execute() query = default_values.insert() - inserted_id = await database.execute(query) + await database.execute(query) # fetch_one() - query = default_values.select().where(default_values.c.id == inserted_id) + query = default_values.select().where(default_values.c.id == 2) result = await database.fetch_one(query=query) assert result["with_default"] == 42 assert result["without_default"] is None @@ -696,20 +696,20 @@ async def test_insert_with_default_values(database_url): # execute() query = default_values.insert() values = {"with_default": 5} - inserted_id = await database.execute(query, values) + await database.execute(query, values) # fetch_one() - query = default_values.select().where(default_values.c.id == inserted_id) + query = default_values.select().where(default_values.c.id == 3) result = await database.fetch_one(query=query) assert result["with_default"] == values["with_default"] # test callable default # execute() query = default_values.insert() - inserted_id = await database.execute(query) + await database.execute(query) # fetch_one() - query = default_values.select().where(default_values.c.id == inserted_id) + query = default_values.select().where(default_values.c.id == 4) result = await database.fetch_one(query=query) assert result["with_callable_default"] == "default_value" From 62c3ab8d325eb8b2be29af2244b34fe3491681f0 Mon Sep 17 00:00:00 2001 From: Andrey Laguta Date: Fri, 13 Nov 2020 14:12:03 +0300 Subject: [PATCH 5/7] refactor default params mixin, add support for aiomysql and aiosqlite --- databases/backends/common.py | 22 ++++++++++++++++++++++ databases/backends/mysql.py | 17 ++++++++++++++--- databases/backends/postgres.py | 18 +++--------------- databases/backends/sqlite.py | 19 +++++++++++++++---- 4 files changed, 54 insertions(+), 22 deletions(-) create mode 100644 databases/backends/common.py diff --git a/databases/backends/common.py b/databases/backends/common.py new file mode 100644 index 00000000..e8206670 --- /dev/null +++ b/databases/backends/common.py @@ -0,0 +1,22 @@ +import typing + + +class ConstructDefaultParamsMixin: + """ + A mixin to support column default values for insert queries for asyncpg, + aiomysql and aiosqlite + """ + + def construct_params(self, params=None, _group_number=None, _check=True): + pd = super().construct_params(params, _group_number, _check) + + for column in self.prefetch: + pd[column.key] = self._exec_default(column.default) + + return pd + + def _exec_default(self, default: typing.Any) -> typing.Any: + if default.is_callable: + return default.arg(self.dialect) + else: + return default.arg diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index b6476add..fc10bfcb 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -9,24 +9,35 @@ from sqlalchemy.engine.result import ResultMetaData, RowProxy from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from sqlalchemy.types import TypeEngine +from databases.backends.common import ConstructDefaultParamsMixin from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") +class MySQLCompiler(ConstructDefaultParamsMixin, pymysql.dialect.statement_compiler): + pass + + class MySQLBackend(DatabaseBackend): def __init__( self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any ) -> None: self._database_url = DatabaseURL(database_url) self._options = options - self._dialect = pymysql.dialect(paramstyle="pyformat") - self._dialect.supports_native_decimal = True + self._dialect = self._get_dialect() self._pool = None + def _get_dialect(self) -> Dialect: + dialect = pymysql.dialect(paramstyle="pyformat") + + dialect.statement_compiler = MySQLCompiler + dialect.supports_native_decimal = True + + return dialect + def _get_connection_kwargs(self) -> dict: url_options = self._database_url.options diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index ef00a79f..e967f126 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -4,13 +4,13 @@ import asyncpg from sqlalchemy.dialects.postgresql import pypostgresql -from sqlalchemy.dialects.postgresql.base import PGCompiler from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement from sqlalchemy.sql.schema import Column from sqlalchemy.types import TypeEngine +from databases.backends.common import ConstructDefaultParamsMixin from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend @@ -20,20 +20,8 @@ _result_processors = {} # type: dict -class APGCompiler_psycopg2(PGCompiler): - def construct_params(self, params=None, _group_number=None, _check=True): - pd = super().construct_params(params, _group_number, _check) - - for column in self.prefetch: - pd[column.key] = self._exec_default(column.default) - - return pd - - def _exec_default(self, default: typing.Any) -> typing.Any: - if default.is_callable: - return default.arg(self.dialect) - else: - return default.arg +class APGCompiler_psycopg2(ConstructDefaultParamsMixin, pypostgresql.dialect.statement_compiler): + pass class PostgresBackend(DatabaseBackend): diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 28ceb6fb..0339efce 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -8,25 +8,36 @@ from sqlalchemy.engine.result import ResultMetaData, RowProxy from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from sqlalchemy.types import TypeEngine +from databases.backends.common import ConstructDefaultParamsMixin from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") +class SQLiteCompiler(ConstructDefaultParamsMixin, pysqlite.dialect.statement_compiler): + pass + + class SQLiteBackend(DatabaseBackend): def __init__( self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any ) -> None: self._database_url = DatabaseURL(database_url) self._options = options - self._dialect = pysqlite.dialect(paramstyle="qmark") - # aiosqlite does not support decimals - self._dialect.supports_native_decimal = False + self._dialect = self._get_dialect() self._pool = SQLitePool(self._database_url, **self._options) + def _get_dialect(self) -> Dialect: + dialect = pysqlite.dialect(paramstyle="qmark") + + # aiosqlite does not support decimals + dialect.supports_native_decimal = False + dialect.statement_compiler = SQLiteCompiler + + return dialect + async def connect(self) -> None: pass # assert self._pool is None, "DatabaseBackend is already running" From 0272630a6b341b8a1476c284f0217cc7d0b1e7bf Mon Sep 17 00:00:00 2001 From: Andrey Laguta Date: Fri, 13 Nov 2020 15:11:47 +0300 Subject: [PATCH 6/7] fix tests running with asyncpg and aiopg in one run (when sharing same db) --- tests/test_databases.py | 59 +++++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index 9806c8e9..25cc5fc4 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -663,54 +663,79 @@ async def test_json_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_insert_with_default_values(database_url): +async def test_insert_with_scalar_default(database_url): """ - Test insert with column default values + Test insert with scalar column default value """ async with Database(database_url) as database: async with database.transaction(force_rollback=True): - # execute() query = default_values.insert() values = {"without_default": 1} await database.execute(query, values) - # fetch_one() - query = default_values.select().where(default_values.c.id == 1) + query = default_values.select().order_by(default_values.c.id.desc()) result = await database.fetch_one(query=query) + assert result["with_default"] == 42 assert result["without_default"] == values["without_default"] - # test without passing values and without calling `values()` - # execute() + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_insert_default_values_with_no_values_called(database_url): + """ + Test insert default values without calling ``values()`` on insert and + without passing ``values`` to ``execute()``. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): query = default_values.insert() await database.execute(query) - # fetch_one() - query = default_values.select().where(default_values.c.id == 2) + query = default_values.select().order_by(default_values.c.id.desc()) result = await database.fetch_one(query=query) + assert result["with_default"] == 42 assert result["without_default"] is None - # test pass other than default value - # execute() + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_insert_default_values_with_overriden_default(database_url): + """ + Test if we provide value for a column having default value, the first one + should be set, not default one. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): query = default_values.insert() values = {"with_default": 5} await database.execute(query, values) - # fetch_one() - query = default_values.select().where(default_values.c.id == 3) + query = default_values.select().order_by(default_values.c.id.desc()) result = await database.fetch_one(query=query) + assert result["with_default"] == values["with_default"] - # test callable default - # execute() + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_insert_callable_default(database_url): + """ + Test insert with column having callable default. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): query = default_values.insert() await database.execute(query) - # fetch_one() - query = default_values.select().where(default_values.c.id == 4) + query = default_values.select().order_by(default_values.c.id.desc()) result = await database.fetch_one(query=query) + assert result["with_callable_default"] == "default_value" From c9fe0546c3fddd343742b2fd9fbce453ca5b1c7b Mon Sep 17 00:00:00 2001 From: Andrey Laguta Date: Fri, 13 Nov 2020 15:25:36 +0300 Subject: [PATCH 7/7] fix linting --- databases/backends/common.py | 17 ++++++++++++++--- databases/backends/postgres.py | 4 +++- tests/test_databases.py | 8 +++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/databases/backends/common.py b/databases/backends/common.py index e8206670..ca4c3fc4 100644 --- a/databases/backends/common.py +++ b/databases/backends/common.py @@ -1,5 +1,8 @@ import typing +from sqlalchemy import ColumnDefault +from sqlalchemy.engine.default import DefaultDialect + class ConstructDefaultParamsMixin: """ @@ -7,15 +10,23 @@ class ConstructDefaultParamsMixin: aiomysql and aiosqlite """ - def construct_params(self, params=None, _group_number=None, _check=True): - pd = super().construct_params(params, _group_number, _check) + prefetch: typing.List + dialect: DefaultDialect + + def construct_params( + self, + params: typing.Optional[typing.Mapping] = None, + _group_number: typing.Any = None, + _check: bool = True, + ) -> typing.Dict: + pd = super().construct_params(params, _group_number, _check) # type: ignore for column in self.prefetch: pd[column.key] = self._exec_default(column.default) return pd - def _exec_default(self, default: typing.Any) -> typing.Any: + def _exec_default(self, default: ColumnDefault) -> typing.Any: if default.is_callable: return default.arg(self.dialect) else: diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index e967f126..749f3c58 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -20,7 +20,9 @@ _result_processors = {} # type: dict -class APGCompiler_psycopg2(ConstructDefaultParamsMixin, pypostgresql.dialect.statement_compiler): +class APGCompiler_psycopg2( + ConstructDefaultParamsMixin, pypostgresql.dialect.statement_compiler +): pass diff --git a/tests/test_databases.py b/tests/test_databases.py index 25cc5fc4..5f078a2f 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -76,7 +76,11 @@ def process_result_value(self, value, dialect): metadata, sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), sqlalchemy.Column("with_default", sqlalchemy.Integer, default=42), - sqlalchemy.Column("with_callable_default", sqlalchemy.String(length=100), default=lambda: "default_value"), + sqlalchemy.Column( + "with_callable_default", + sqlalchemy.String(length=100), + default=lambda: "default_value", + ), sqlalchemy.Column("without_default", sqlalchemy.Integer), ) @@ -1003,7 +1007,9 @@ def test_global_connection_is_initialized_lazily(database_url): @async_adapter async def run_database_queries(): async with database: + async def db_lookup(): + await database.fetch_one("SELECT pg_sleep(1)") await asyncio.gather(db_lookup(), db_lookup())