diff --git a/CHANGELOG.md b/CHANGELOG.md index 43ec3f6..c3d2537 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 2.1.10 + +* [FEAT] add support for AUTOINCREMENT + # 2.1.9 * [FIX] pin MySQL Connector/Python to 8.3.0 diff --git a/mysql_to_sqlite3/__init__.py b/mysql_to_sqlite3/__init__.py index c7a4032..5197a6c 100644 --- a/mysql_to_sqlite3/__init__.py +++ b/mysql_to_sqlite3/__init__.py @@ -1,4 +1,5 @@ """Utility to transfer data from MySQL to SQLite 3.""" -__version__ = "2.1.9" + +__version__ = "2.1.10" from .transporter import MySQLtoSQLite diff --git a/mysql_to_sqlite3/cli.py b/mysql_to_sqlite3/cli.py index f74f73c..3cc85eb 100644 --- a/mysql_to_sqlite3/cli.py +++ b/mysql_to_sqlite3/cli.py @@ -1,4 +1,5 @@ """The command line interface of MySQLtoSQLite.""" + import os import sys import typing as t diff --git a/mysql_to_sqlite3/sqlite_utils.py b/mysql_to_sqlite3/sqlite_utils.py index e2fb7e3..3c2473c 100644 --- a/mysql_to_sqlite3/sqlite_utils.py +++ b/mysql_to_sqlite3/sqlite_utils.py @@ -52,3 +52,20 @@ def convert_date(value: t.Any) -> date: return date.fromisoformat(value.decode()) except ValueError as err: raise ValueError(f"DATE field contains {err}") # pylint: disable=W0707 + + +Integer_Types: t.Set[str] = { + "INTEGER", + "INTEGER UNSIGNED", + "INT", + "INT UNSIGNED", + "BIGINT", + "BIGINT UNSIGNED", + "MEDIUMINT", + "MEDIUMINT UNSIGNED", + "SMALLINT", + "SMALLINT UNSIGNED", + "TINYINT", + "TINYINT UNSIGNED", + "NUMERIC", +} diff --git a/mysql_to_sqlite3/transporter.py b/mysql_to_sqlite3/transporter.py index 0be7548..171b17c 100644 --- a/mysql_to_sqlite3/transporter.py +++ b/mysql_to_sqlite3/transporter.py @@ -21,6 +21,7 @@ from mysql_to_sqlite3.mysql_utils import CHARSET_INTRODUCERS from mysql_to_sqlite3.sqlite_utils import ( CollatingSequences, + Integer_Types, adapt_decimal, adapt_timedelta, convert_date, @@ -384,24 +385,42 @@ def _build_create_table_sql(self, table_name: str) -> str: column_type=row["Type"], # type: ignore[arg-type] sqlite_json1_extension_enabled=self._sqlite_json1_extension_enabled, ) - sql += '\n\t"{name}" {type} {notnull} {default} {collation},'.format( - name=row["Field"].decode() if isinstance(row["Field"], bytes) else row["Field"], - type=column_type, - notnull="NULL" if row["Null"] == "YES" else "NOT NULL", - default=self._translate_default_from_mysql_to_sqlite(row["Default"], column_type, row["Extra"]), - collation=self._data_type_collation_sequence(self._collation, column_type), - ) + if row["Key"] == "PRI" and row["Extra"] == "auto_increment": + if column_type in Integer_Types: + sql += '\n\t"{name}" INTEGER PRIMARY KEY AUTOINCREMENT,'.format( + name=row["Field"].decode() if isinstance(row["Field"], bytes) else row["Field"], + ) + else: + self._logger.warning( + 'Primary key "%s" in table "%s" is not an INTEGER type! Skipping.', + row["Field"], + table_name, + ) + else: + sql += '\n\t"{name}" {type} {notnull} {default} {collation},'.format( + name=row["Field"].decode() if isinstance(row["Field"], bytes) else row["Field"], + type=column_type, + notnull="NULL" if row["Null"] == "YES" else "NOT NULL", + default=self._translate_default_from_mysql_to_sqlite(row["Default"], column_type, row["Extra"]), + collation=self._data_type_collation_sequence(self._collation, column_type), + ) self._mysql_cur_dict.execute( """ - SELECT INDEX_NAME AS `name`, - IF (NON_UNIQUE = 0 AND INDEX_NAME = 'PRIMARY', 1, 0) AS `primary`, - IF (NON_UNIQUE = 0 AND INDEX_NAME <> 'PRIMARY', 1, 0) AS `unique`, - GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX) AS `columns` - FROM information_schema.STATISTICS - WHERE TABLE_SCHEMA = %s - AND TABLE_NAME = %s - GROUP BY INDEX_NAME, NON_UNIQUE + SELECT s.INDEX_NAME AS `name`, + IF (NON_UNIQUE = 0 AND s.INDEX_NAME = 'PRIMARY', 1, 0) AS `primary`, + IF (NON_UNIQUE = 0 AND s.INDEX_NAME <> 'PRIMARY', 1, 0) AS `unique`, + IF (c.EXTRA = 'auto_increment', 1, 0) AS `auto_increment`, + GROUP_CONCAT(s.COLUMN_NAME ORDER BY SEQ_IN_INDEX) AS `columns`, + GROUP_CONCAT(c.COLUMN_TYPE ORDER BY SEQ_IN_INDEX) AS `types` + FROM information_schema.STATISTICS AS s + JOIN information_schema.COLUMNS AS c + ON s.TABLE_SCHEMA = c.TABLE_SCHEMA + AND s.TABLE_NAME = c.TABLE_NAME + AND s.COLUMN_NAME = c.COLUMN_NAME + WHERE s.TABLE_SCHEMA = %s + AND s.TABLE_NAME = %s + GROUP BY s.INDEX_NAME, s.NON_UNIQUE, c.EXTRA """, (self._mysql_database, table_name), ) @@ -437,17 +456,33 @@ def _build_create_table_sql(self, table_name: str) -> str: elif isinstance(index["columns"], str): columns = index["columns"] + types: str = "" + if isinstance(index["types"], bytes): + types = index["types"].decode() + elif isinstance(index["types"], str): + types = index["types"] + if len(columns) > 0: if index["primary"] in {1, "1"}: - primary += "\n\tPRIMARY KEY ({})".format( - ", ".join(f'"{column}"' for column in columns.split(",")) - ) + if (index["auto_increment"] not in {1, "1"}) or any( + self._translate_type_from_mysql_to_sqlite( + column_type=_type, + sqlite_json1_extension_enabled=self._sqlite_json1_extension_enabled, + ) + not in Integer_Types + for _type in types.split(",") + ): + primary += "\n\tPRIMARY KEY ({})".format( + ", ".join(f'"{column}"' for column in columns.split(",")) + ) else: indices += """CREATE {unique} INDEX IF NOT EXISTS "{name}" ON "{table}" ({columns});""".format( unique="UNIQUE" if index["unique"] in {1, "1"} else "", - name=f"{table_name}_{index_name}" - if (table_collisions > 0 or self._prefix_indices) - else index_name, + name=( + f"{table_name}_{index_name}" + if (table_collisions > 0 or self._prefix_indices) + else index_name + ), table=table_name, columns=", ".join(f'"{column}"' for column in columns.split(",")), ) @@ -481,9 +516,11 @@ def _build_create_table_sql(self, table_name: str) -> str: c.UPDATE_RULE, c.DELETE_RULE """.format( - JOIN="JOIN" - if (server_version is not None and server_version[0] == 8 and server_version[2] > 19) - else "LEFT JOIN" + JOIN=( + "JOIN" + if (server_version is not None and server_version[0] == 8 and server_version[2] > 19) + else "LEFT JOIN" + ) ), (self._mysql_database, table_name, "FOREIGN KEY"), ) diff --git a/mysql_to_sqlite3/types.py b/mysql_to_sqlite3/types.py index 52cf1f8..b2aebc2 100644 --- a/mysql_to_sqlite3/types.py +++ b/mysql_to_sqlite3/types.py @@ -1,4 +1,5 @@ """Types for mysql-to-sqlite3.""" + import os import typing as t from logging import Logger diff --git a/tests/func/mysql_to_sqlite3_test.py b/tests/func/mysql_to_sqlite3_test.py index 81e7821..3bf2527 100644 --- a/tests/func/mysql_to_sqlite3_test.py +++ b/tests/func/mysql_to_sqlite3_test.py @@ -433,14 +433,14 @@ def test_transfer_transfers_all_tables_from_mysql_to_sqlite( mysql_inspect: Inspector = inspect(mysql_engine) mysql_tables: t.List[str] = mysql_inspect.get_table_names() - mysql_connector_connection: t.Union[ - PooledMySQLConnection, MySQLConnection, CMySQLConnection - ] = mysql.connector.connect( - user=mysql_credentials.user, - password=mysql_credentials.password, - host=mysql_credentials.host, - port=mysql_credentials.port, - database=mysql_credentials.database, + mysql_connector_connection: t.Union[PooledMySQLConnection, MySQLConnection, CMySQLConnection] = ( + mysql.connector.connect( + user=mysql_credentials.user, + password=mysql_credentials.password, + host=mysql_credentials.host, + port=mysql_credentials.port, + database=mysql_credentials.database, + ) ) server_version: t.Tuple[int, ...] = mysql_connector_connection.get_server_version() @@ -490,9 +490,7 @@ def test_transfer_transfers_all_tables_from_mysql_to_sqlite( AND i.CONSTRAINT_TYPE = :constraint_type """.format( # MySQL 8.0.19 still works with "LEFT JOIN" everything above requires "JOIN" - JOIN="JOIN" - if (server_version[0] == 8 and server_version[2] > 19) - else "LEFT JOIN" + JOIN="JOIN" if (server_version[0] == 8 and server_version[2] > 19) else "LEFT JOIN" ) ).bindparams( table_schema=mysql_credentials.database, @@ -1183,14 +1181,14 @@ def test_transfer_limited_rows_from_mysql_to_sqlite( mysql_inspect: Inspector = inspect(mysql_engine) mysql_tables: t.List[str] = mysql_inspect.get_table_names() - mysql_connector_connection: t.Union[ - PooledMySQLConnection, MySQLConnection, CMySQLConnection - ] = mysql.connector.connect( - user=mysql_credentials.user, - password=mysql_credentials.password, - host=mysql_credentials.host, - port=mysql_credentials.port, - database=mysql_credentials.database, + mysql_connector_connection: t.Union[PooledMySQLConnection, MySQLConnection, CMySQLConnection] = ( + mysql.connector.connect( + user=mysql_credentials.user, + password=mysql_credentials.password, + host=mysql_credentials.host, + port=mysql_credentials.port, + database=mysql_credentials.database, + ) ) server_version: t.Tuple[int, ...] = mysql_connector_connection.get_server_version() @@ -1240,9 +1238,7 @@ def test_transfer_limited_rows_from_mysql_to_sqlite( AND i.CONSTRAINT_TYPE = :constraint_type """.format( # MySQL 8.0.19 still works with "LEFT JOIN" everything above requires "JOIN" - JOIN="JOIN" - if (server_version[0] == 8 and server_version[2] > 19) - else "LEFT JOIN" + JOIN="JOIN" if (server_version[0] == 8 and server_version[2] > 19) else "LEFT JOIN" ) ).bindparams( table_schema=mysql_credentials.database, diff --git a/tox.ini b/tox.ini index 1ea4e4a..9d02090 100644 --- a/tox.ini +++ b/tox.ini @@ -80,7 +80,7 @@ deps = mypy>=1.3.0 -rrequirements_dev.txt commands = - mypy mysql_to_sqlite3 --enable-incomplete-feature=Unpack + mypy mysql_to_sqlite3 [testenv:linters] basepython = python3