Skip to content

Commit

Permalink
Add merge statement (#32)
Browse files Browse the repository at this point in the history
* Add merge statement

 * ideas taken from snowflake-sqlalchemy

* Fix Merge statement so a table, select or subquery may be used as source

* Fix get_table_names, get_view_names for server update

* Quote target table

* Quote field name in update
  • Loading branch information
rad-pat authored Apr 27, 2024
1 parent e0197e5 commit 0c8d54e
Show file tree
Hide file tree
Showing 4 changed files with 535 additions and 3 deletions.
29 changes: 29 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,32 @@ Compatibility

- If databend version >= v0.9.0 or later, you need to use databend-sqlalchemy version >= v0.1.0.
- The databend-sqlalchemy use [databend-py](https://github.com/datafuselabs/databend-py) as internal driver when version < v0.4.0, but when version >= v0.4.0 it use [databend driver python binding](https://github.com/datafuselabs/bendsql/blob/main/bindings/python/README.md) as internal driver. The only difference between the two is that the connection parameters provided in the DSN are different. When using the corresponding version, you should refer to the connection parameters provided by the corresponding Driver.


Merge Command Support
---------------------

Databend SQLAlchemy supports upserts via its `Merge` custom expression.
See [Merge](https://docs.databend.com/sql/sql-commands/dml/dml-merge) for full documentation.

The Merge command can be used as below::

from sqlalchemy.orm import sessionmaker
from sqlalchemy import MetaData, create_engine
from databend_sqlalchemy.databend_dialect import Merge

engine = create_engine(db.url, echo=False)
session = sessionmaker(bind=engine)()
connection = engine.connect()

meta = MetaData()
meta.reflect(bind=session.bind)
t1 = meta.tables['t1']
t2 = meta.tables['t2']

merge = Merge(target=t1, source=t2, on=t1.c.t1key == t2.c.t2key)
merge.when_matched_then_delete().where(t2.c.marked == 1)
merge.when_matched_then_update().where(t2.c.isnewstatus == 1).values(val = t2.c.newval, status=t2.c.newstatus)
merge.when_matched_then_update().values(val=t2.c.newval)
merge.when_not_matched_then_insert().values(val=t2.c.newval, status=t2.c.newstatus)
connection.execute(merge)
99 changes: 96 additions & 3 deletions databend_sqlalchemy/databend_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# licensed under the same Apache 2.0 License
import decimal
import re
import operator
import datetime
import sqlalchemy.types as sqltypes
from typing import Any, Dict, Optional, Union
from sqlalchemy import util as sa_util
from sqlalchemy.engine import reflection
from sqlalchemy.sql import compiler, text, bindparam
from sqlalchemy.sql import compiler, text, bindparam, select, TableClause, Select, Subquery
from sqlalchemy.dialects.postgresql.base import PGCompiler, PGIdentifierPreparer
from sqlalchemy.types import (
BIGINT,
Expand All @@ -27,6 +28,7 @@
)
from sqlalchemy.engine import ExecutionContext, default
from sqlalchemy.exc import DBAPIError, NoSuchTableError
from .dml import Merge


# Type decorators
Expand Down Expand Up @@ -231,6 +233,85 @@ def visit_not_like_op_binary(self, binary, operator, **kw):
)


def visit_merge(self, merge, **kw):
clauses = "\n ".join(
clause._compiler_dispatch(self, **kw)
for clause in merge.clauses
)
source_kw = {'asfrom': True}
if isinstance(merge.source, TableClause):
source = select(merge.source).subquery().alias(merge.source.name)._compiler_dispatch(self, **source_kw)
elif isinstance(merge.source, Select):
source = merge.source.subquery().alias(merge.source.get_final_froms()[0].name)._compiler_dispatch(self, **source_kw)
elif isinstance(merge.source, Subquery):
source = merge.source._compiler_dispatch(self, **source_kw)

target_table = self.preparer.format_table(merge.target)
return (
f"MERGE INTO {target_table}\n"
f" USING {source}\n"
f" ON {merge.on}\n"
f"{clauses if clauses else ''}"
)

def visit_when_merge_matched_update(self, merge_matched_update, **kw):
case_predicate = (
f" AND {str(merge_matched_update.predicate._compiler_dispatch(self, **kw))}"
if merge_matched_update.predicate is not None
else ""
)
update_str = (
f"WHEN MATCHED{case_predicate} THEN\n"
f"\tUPDATE"
)
if not merge_matched_update.set:
return f"{update_str} *"

set_list = list(merge_matched_update.set.items())
if kw.get("deterministic", False):
set_list.sort(key=operator.itemgetter(0))
set_values = (
", ".join(
[
f"{self.preparer.quote_identifier(set_item[0])} = {set_item[1]._compiler_dispatch(self, **kw)}"
for set_item in set_list
]
)
)
return f"{update_str} SET {str(set_values)}"

def visit_when_merge_matched_delete(self, merge_matched_delete, **kw):
case_predicate = (
f" AND {str(merge_matched_delete.predicate._compiler_dispatch(self, **kw))}"
if merge_matched_delete.predicate is not None
else ""
)
return f"WHEN MATCHED{case_predicate} THEN DELETE"

def visit_when_merge_unmatched(self, merge_unmatched, **kw):
case_predicate = (
f" AND {str(merge_unmatched.predicate._compiler_dispatch(self, **kw))}"
if merge_unmatched.predicate is not None
else ""
)
insert_str = (
f"WHEN NOT MATCHED{case_predicate} THEN\n"
f"\tINSERT"
)
if not merge_unmatched.set:
return f"{insert_str} *"

set_cols, sets_vals = zip(*merge_unmatched.set.items())
set_cols, sets_vals = list(set_cols), list(sets_vals)
if kw.get("deterministic", False):
set_cols, sets_vals = zip(
*sorted(merge_unmatched.set.items(), key=operator.itemgetter(0))
)
return "{} ({}) VALUES ({})".format(
insert_str,
", ".join(set_cols),
", ".join(map(lambda e: e._compiler_dispatch(self, **kw), sets_vals)),
)
class DatabendExecutionContext(default.DefaultExecutionContext):
@sa_util.memoized_property
def should_autocommit(self):
Expand Down Expand Up @@ -489,12 +570,17 @@ def get_indexes(self, connection, table_name, schema=None, **kw):

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
query = text("""
table_name_query = """
select table_name
from information_schema.tables
where table_schema = :schema_name
"""
if self.server_version_info <= (1, 2, 410):
table_name_query += """
and engine NOT LIKE '%VIEW%'
"""
query = text(
table_name_query
).bindparams(
bindparam("schema_name", type_=sqltypes.Unicode)
)
Expand All @@ -506,13 +592,20 @@ def get_table_names(self, connection, schema=None, **kw):

@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
query = text(
view_name_query = """
select table_name
from information_schema.views
where table_schema = :schema_name
"""
if self.server_version_info <= (1, 2, 410):
view_name_query = """
select table_name
from information_schema.tables
where table_schema = :schema_name
and engine LIKE '%VIEW%'
"""
query = text(
view_name_query
).bindparams(
bindparam("schema_name", type_=sqltypes.Unicode)
)
Expand Down
100 changes: 100 additions & 0 deletions databend_sqlalchemy/dml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python
#
# Note: parts of the file come from https://github.com/snowflakedb/snowflake-sqlalchemy
# licensed under the same Apache 2.0 License

from sqlalchemy.sql.selectable import Select, Subquery, TableClause
from sqlalchemy.sql.dml import UpdateBase
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.expression import select


class _OnMergeBaseClause(ClauseElement):
# __visit_name__ = "on_merge_base_clause"

def __init__(self):
self.set = {}
self.predicate = None

def __repr__(self):
return f" AND {str(self.predicate)}" if self.predicate is not None else ""

def values(self, **kwargs):
self.set = kwargs
return self

def where(self, expr):
self.predicate = expr
return self


class WhenMergeMatchedUpdateClause(_OnMergeBaseClause):
__visit_name__ = "when_merge_matched_update"

def __repr__(self):
case_predicate = super()
update_str = f"WHEN MATCHED{case_predicate} THEN UPDATE"
if not self.set:
return f"{update_str} *"

set_values = ", ".join([f"{set_item[0]} = {set_item[1]}" for set_item in self.set.items()])
return f"{update_str} SET {str(set_values)}"


class WhenMergeMatchedDeleteClause(_OnMergeBaseClause):
__visit_name__ = "when_merge_matched_delete"

def __repr__(self):
case_predicate = super()
return f"WHEN MATCHED{case_predicate} THEN DELETE"


class WhenMergeUnMatchedClause(_OnMergeBaseClause):
__visit_name__ = "when_merge_unmatched"

def __repr__(self):
case_predicate = super()
insert_str = f"WHEN NOT MATCHED{case_predicate} THEN INSERT"
if not self.set:
return f"{insert_str} *"

sets, sets_tos = zip(*self.set.items())
return "{} ({}) VALUES ({})".format(
insert_str,
", ".join(sets),
", ".join(map(str, sets_tos)),
)


class Merge(UpdateBase):
__visit_name__ = "merge"
_bind = None

def __init__(self, target, source, on):
if not isinstance(source, (TableClause, Select, Subquery)):
raise Exception(f'Invalid type for merge source: {source}')
self.target = target
self.source = source
self.on = on
self.clauses = []

def __repr__(self):
clauses = " ".join([repr(clause) for clause in self.clauses])
return f"MERGE INTO {self.target} USING ({select(self.source)}) AS {self.source.name} ON {self.on}" + (
f" {clauses}" if clauses else ""
)

def when_matched_then_update(self):
clause = WhenMergeMatchedUpdateClause()
self.clauses.append(clause)
return clause

def when_matched_then_delete(self):
clause = WhenMergeMatchedDeleteClause()
self.clauses.append(clause)
return clause

def when_not_matched_then_insert(self):
clause = WhenMergeUnMatchedClause()
self.clauses.append(clause)
return clause
Loading

0 comments on commit 0c8d54e

Please sign in to comment.