Skip to content

Commit

Permalink
SNOW-118945: Fix ArgumentError where insert with autoincrement failed…
Browse files Browse the repository at this point in the history
… due to incompatible column type affinity (#297)

* bug fix

* fix old test

* add autoincrement test

* fix tests

* review feedbacks

* consistent naming
  • Loading branch information
sfc-gh-aling authored Jun 3, 2022
1 parent e2cb627 commit 3f7ed9e
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 10 deletions.
5 changes: 5 additions & 0 deletions DESCRIPTION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ https://github.com/snowflakedb/snowflake-sqlalchemy
Release Notes
-------------------------------------------------------------------------------

- v1.3.5(Unreleased)

- Fixed a bug where insert with autoincrement failed due to incompatible column type affinity #124
- Fixed a bug when creating a column with sequence, default value was set incorrectly

- v1.3.4(April 27,2022)

- Fixed a bug where identifier max length was set to the wrong value and added relevant schema introspection
Expand Down
7 changes: 5 additions & 2 deletions base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from sqlalchemy import util as sa_util
from sqlalchemy.engine import default
from sqlalchemy.schema import Table
from sqlalchemy.schema import Sequence, Table
from sqlalchemy.sql import compiler, expression
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.util.compat import string_types
Expand Down Expand Up @@ -362,7 +362,10 @@ def get_column_specification(self, column, **kwargs):
if column.table is not None \
and column is column.table._autoincrement_column and \
column.server_default is None:
colspec.append('AUTOINCREMENT')
if isinstance(column.default, Sequence):
colspec.append(f"DEFAULT {column.default.name}.nextval")
else:
colspec.append('AUTOINCREMENT')

return ' '.join(colspec)

Expand Down
7 changes: 7 additions & 0 deletions custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#

import sqlalchemy.types as sqltypes
import sqlalchemy.util as util

TEXT = sqltypes.VARCHAR
CHARACTER = sqltypes.CHAR
Expand Down Expand Up @@ -51,3 +52,9 @@ class TIMESTAMP_NTZ(SnowflakeType):

class GEOGRAPHY(SnowflakeType):
__visit_name__ = 'GEOGRAPHY'


class _CUSTOM_DECIMAL(SnowflakeType, sqltypes.DECIMAL):
@util.memoized_property
def _type_affinity(self):
return sqltypes.INTEGER if self.scale == 0 else sqltypes.DECIMAL
6 changes: 3 additions & 3 deletions snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
SnowflakeIdentifierPreparer,
SnowflakeTypeCompiler,
)
from .custom_types import ARRAY, GEOGRAPHY, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, TIMESTAMP_TZ, VARIANT
from .custom_types import _CUSTOM_DECIMAL, ARRAY, GEOGRAPHY, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, TIMESTAMP_TZ, VARIANT

colspecs = {}

Expand All @@ -65,7 +65,7 @@
'FLOAT': FLOAT,
'INT': INTEGER,
'INTEGER': INTEGER,
'NUMBER': DECIMAL,
'NUMBER': _CUSTOM_DECIMAL,
# 'OBJECT': ?
'REAL': REAL,
'BYTEINT': SMALLINT,
Expand Down Expand Up @@ -609,7 +609,7 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):
ret = cursor.fetchone()
if ret:
return ret[n2i['text']]
except Exxception:
except Exception:
pass
return None

Expand Down
10 changes: 5 additions & 5 deletions test/test_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_table_name_with_reserved_words(engine_testaccount, db_parameters):
metadata = MetaData()
test_table_name = 'insert'
insert_table = Table(test_table_name, metadata,
Column('id', Integer, Sequence(test_table_name + '_id_seq'),
Column('id', Integer, Sequence(f"{test_table_name}_id_seq"),
primary_key=True),
Column('name', String),
Column('fullname', String),
Expand All @@ -22,10 +22,10 @@ def test_table_name_with_reserved_words(engine_testaccount, db_parameters):
inspector = inspect(engine_testaccount)
columns_in_insert = inspector.get_columns(test_table_name)
assert len(columns_in_insert) == 3
assert columns_in_insert[0]['autoincrement'], 'autoincrement'
assert columns_in_insert[0]['default'] is None, 'default'
assert columns_in_insert[0]['name'] == 'id', 'name'
assert columns_in_insert[0]['primary_key'], 'primary key'
assert columns_in_insert[0]['autoincrement'] is False
assert f"{test_table_name}_id_seq.nextval" in columns_in_insert[0]['default'].lower()
assert columns_in_insert[0]['name'] == 'id'
assert columns_in_insert[0]['primary_key']
assert not columns_in_insert[0]['nullable']

columns_in_insert = inspector.get_columns(test_table_name, schema=db_parameters['schema'])
Expand Down
79 changes: 79 additions & 0 deletions test/test_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved.
#

from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table, select


def test_table_with_sequence(engine_testaccount, db_parameters):
# https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124
test_table_name = 'sequence'
test_sequence_name = f'{test_table_name}_id_seq'
sequence_table = Table(test_table_name, MetaData(),
Column('id', Integer, Sequence(test_sequence_name), primary_key=True),
Column('data', String(39))
)
sequence_table.create(engine_testaccount)
seq = Sequence(test_sequence_name)
try:
engine_testaccount.execute(sequence_table.insert(), [{'data': 'test_insert_1'}])

select_stmt = select([sequence_table]).order_by('id')
result = engine_testaccount.execute(select_stmt).fetchall()
assert result == [(1, 'test_insert_1')]

autoload_sequence_table = Table(test_table_name, MetaData(), autoload=True, autoload_with=engine_testaccount)

engine_testaccount.execute(autoload_sequence_table.insert(),
[{'data': 'multi_insert_1'}, {'data': 'multi_insert_2'}])

engine_testaccount.execute(autoload_sequence_table.insert(), [{'data': 'test_insert_2'}])

nextid = engine_testaccount.execute(seq)
engine_testaccount.execute(autoload_sequence_table.insert(), [{'id': nextid, 'data': 'test_insert_seq'}])
result = engine_testaccount.execute(select_stmt).fetchall()
assert result == [
(1, 'test_insert_1'),
(2, 'multi_insert_1'),
(3, 'multi_insert_2'),
(4, 'test_insert_2'),
(5, 'test_insert_seq')
]
finally:
sequence_table.drop(engine_testaccount)
seq.drop(engine_testaccount)


def test_table_with_autoincrement(engine_testaccount, db_parameters):
# https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124
test_table_name = 'sequence'
autoincrement_table = Table(test_table_name, MetaData(),
Column('id', Integer, autoincrement=True, primary_key=True),
Column('data', String(39))
)
autoincrement_table.create(engine_testaccount)
try:
engine_testaccount.execute(autoincrement_table.insert(), [{'data': 'test_insert_1'}])

select_stmt = select([autoincrement_table]).order_by('id')
result = engine_testaccount.execute(select_stmt).fetchall()
assert result == [(1, 'test_insert_1')]

autoload_sequence_table = Table(test_table_name, MetaData(), autoload=True, autoload_with=engine_testaccount)

engine_testaccount.execute(autoload_sequence_table.insert(),
[{'data': 'multi_insert_1'}, {'data': 'multi_insert_2'}])

engine_testaccount.execute(autoload_sequence_table.insert(), [{'data': 'test_insert_2'}])

result = engine_testaccount.execute(select_stmt).fetchall()
assert result == [
(1, 'test_insert_1'),
(2, 'multi_insert_1'),
(3, 'multi_insert_2'),
(4, 'test_insert_2'),
]
finally:
autoincrement_table.drop(engine_testaccount)

0 comments on commit 3f7ed9e

Please sign in to comment.