diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 909d52cf..8e481478 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -11,6 +11,7 @@ Source code is also available at: - (Unreleased) + - Fixed quoting of `_` as column name - Add support for dynamic tables and required options - Add support for hybrid tables - Fixed SAWarning when registering functions with existing name in default namespace diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 023f7afb..9c29695a 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -5,6 +5,7 @@ import itertools import operator import re +import string from typing import List from sqlalchemy import exc as sa_exc @@ -112,7 +113,8 @@ AUTOCOMMIT_REGEXP = re.compile( r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE ) - +# used for quoting identifiers ie. table names, column names, etc. +ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"_", "$"})) """ Overwrite methods to handle Snowflake BCR change: @@ -437,6 +439,7 @@ def _join_left_to_right( class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS def __init__(self, dialect, **kw): quote = '"' diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 40207b41..55451c2f 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Integer, String, and_, func, select +from sqlalchemy import Integer, String, and_, func, insert, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table from sqlalchemy.testing.assertions import AssertsCompiledSQL @@ -33,6 +33,21 @@ def test_now_func(self): dialect="snowflake", ) + def test_underscore_as_valid_identifier(self): + _table = table( + "table_1745924", + column("ca", Integer), + column("cb", String), + column("_", String), + ) + + stmt = insert(_table).values(ca=1, cb="test", _="test_") + self.assert_compile( + stmt, + 'INSERT INTO table_1745924 (ca, cb, "_") VALUES (%(ca)s, %(cb)s, %(_)s)', + dialect="snowflake", + ) + def test_multi_table_delete(self): statement = table1.delete().where(table1.c.id == table2.c.id) self.assert_compile( diff --git a/tests/test_quote.py b/tests/test_quote.py index ca6f36dd..0dd69059 100644 --- a/tests/test_quote.py +++ b/tests/test_quote.py @@ -38,3 +38,26 @@ def test_table_name_with_reserved_words(engine_testaccount, db_parameters): finally: insert_table.drop(engine_testaccount) return insert_table + + +def test_table_column_as_underscore(engine_testaccount): + metadata = MetaData() + test_table_name = "table_1745924" + insert_table = Table( + test_table_name, + metadata, + Column("ca", Integer), + Column("cb", String), + Column("_", String), + ) + metadata.create_all(engine_testaccount) + try: + inspector = inspect(engine_testaccount) + columns_in_insert = inspector.get_columns(test_table_name) + assert len(columns_in_insert) == 3 + assert columns_in_insert[0]["name"] == "ca" + assert columns_in_insert[1]["name"] == "cb" + assert columns_in_insert[2]["name"] == "_" + finally: + insert_table.drop(engine_testaccount) + return insert_table diff --git a/tests/test_quote_identifiers.py b/tests/test_quote_identifiers.py new file mode 100644 index 00000000..58575fad --- /dev/null +++ b/tests/test_quote_identifiers.py @@ -0,0 +1,43 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table, insert, select + + +@pytest.mark.parametrize( + "identifier", + ( + pytest.param("_", id="underscore"), + pytest.param(".", id="dot"), + ), +) +def test_insert_with_identifier_as_column_name(identifier: str, engine_testaccount): + expected_identifier = f"test: {identifier}" + metadata = MetaData() + table = Table( + "table_1745924", + metadata, + Column("ca", Integer), + Column("cb", String), + Column(identifier, String), + ) + + try: + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as connection: + connection.execute( + insert(table).values( + { + "ca": 1, + "cb": "test", + identifier: f"test: {identifier}", + } + ) + ) + result = connection.execute(select(table)).fetchall() + assert result == [(1, "test", expected_identifier)] + finally: + metadata.drop_all(engine_testaccount)