From a2e5dc71e92bce89e72825ff527a34b11f4a15c2 Mon Sep 17 00:00:00 2001 From: Szymon Cyranik Date: Thu, 22 Aug 2024 14:38:29 +0200 Subject: [PATCH 1/4] feat(sqlserver): add SQLServer class with connection handling and query execution --- mindsql/databases/sqlserver.py | 147 +++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 mindsql/databases/sqlserver.py diff --git a/mindsql/databases/sqlserver.py b/mindsql/databases/sqlserver.py new file mode 100644 index 0000000..97de59a --- /dev/null +++ b/mindsql/databases/sqlserver.py @@ -0,0 +1,147 @@ +from typing import List, Optional +from urllib.parse import urlparse + +import pandas as pd +import pyodbc + +from . import IDatabase +from .._utils import logger +from .._utils.constants import ERROR_WHILE_RUNNING_QUERY, ERROR_CONNECTING_TO_DB_CONSTANT, INVALID_DB_CONNECTION_OBJECT, \ + CONNECTION_ESTABLISH_ERROR_CONSTANT, SQLSERVER_SHOW_DATABASE_QUERY, SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY, \ + SQLSERVER_SHOW_CREATE_TABLE_QUERY + +log = logger.init_loggers("SQL Server") + + +class SQLServer(IDatabase): + @staticmethod + def create_connection(url: str, **kwargs) -> any: + """ + Connects to a SQL Server database using the provided URL. + + Parameters: + - url (str): The connection string to the SQL Server database in the format: + 'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password' + - **kwargs: Additional keyword arguments for the connection + + Returns: + - connection: A connection to the SQL Server database + """ + + try: + connection = pyodbc.connect(url, **kwargs) + return connection + except pyodbc.Error as e: + log.error(ERROR_CONNECTING_TO_DB_CONSTANT.format("SQL Server", e)) + + def execute_sql(self, connection, sql:str) -> Optional[pd.DataFrame]: + """ + A function that runs an SQL query using the provided connection and returns the results as a pandas DataFrame. + + Parameters: + connection: The connection object for the database. + sql (str): The SQL query to be executed + + Returns: + pd.DataFrame: A DataFrame containing the results of the SQL query. + """ + try: + self.validate_connection(connection) + cursor = connection.cursor() + cursor.execute(sql) + columns = [column[0] for column in cursor.description] + data = cursor.fetchall() + data = [list(row) for row in data] + cursor.close() + return pd.DataFrame(data, columns=columns) + except pyodbc.Error as e: + log.error(ERROR_WHILE_RUNNING_QUERY.format(e)) + return None + + def get_databases(self, connection) -> List[str]: + """ + Get a list of databases from the given connection and SQL query. + + Parameters: + connection: The connection object for the database. + + Returns: + List[str]: A list of unique database names. + """ + try: + self.validate_connection(connection) + cursor = connection.cursor() + cursor.execute(SQLSERVER_SHOW_DATABASE_QUERY) + databases = [row[0] for row in cursor.fetchall()] + cursor.close() + return databases + except pyodbc.Error as e: + log.error(ERROR_WHILE_RUNNING_QUERY.format(e)) + return [] + + def get_table_names(self, connection, database: str) -> pd.DataFrame: + """ + Retrieves the tables along with their schema (schema.table_name) from the information schema for the specified + database. + + Parameters: + connection: The database connection object. + database (str): The name of the database. + + Returns: + DataFrame: A pandas DataFrame containing the table names from the information schema. + """ + self.validate_connection(connection) + query = SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY.format(db=database) + return self.execute_sql(connection, query) + + + + + def get_all_ddls(self, connection: any, database: str) -> pd.DataFrame: + """ + A method to get the DDLs for all the tables in the database. + + Parameters: + connection (any): The connection object. + database (str): The name of the database. + + Returns: + DataFrame: A pandas DataFrame containing the DDLs for all the tables in the database. + """ + df_tables = self.get_table_names(connection, database) + ddl_df = pd.DataFrame(columns=['Table', 'DDL']) + for index, row in df_tables.iterrows(): + ddl = self.get_ddl(connection, row.iloc[0]) + ddl_df = ddl_df._append({'Table': row.iloc[0], 'DDL': ddl}, ignore_index=True) + + return ddl_df + + + + def validate_connection(self, connection: any) -> None: + """ + A function that validates if the provided connection is a SQL Server connection. + + Parameters: + connection: The connection object for accessing the database. + + Raises: + ValueError: If the provided connection is not a SQL Server connection. + + Returns: + None + """ + if connection is None: + raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT) + if not isinstance(connection, pyodbc.Connection): + raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("SQL Server")) + + def get_ddl(self, connection: any, table_name: str, **kwargs) -> str: + schema_name, table_name = table_name.split('.') + query = SQLSERVER_SHOW_CREATE_TABLE_QUERY.format(table=table_name, schema=schema_name) + df_ddl = self.execute_sql(connection, query) + return df_ddl['SQLQuery'][0] + + def get_dialect(self) -> str: + return 'tsql' From 0babeeec24d27f7412f4f3c4428504f41cfff436 Mon Sep 17 00:00:00 2001 From: Szymon Cyranik Date: Thu, 22 Aug 2024 14:39:24 +0200 Subject: [PATCH 2/4] test(sqlserver): add unit tests for SQLServer class --- tests/sqlserver_test.py | 163 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 tests/sqlserver_test.py diff --git a/tests/sqlserver_test.py b/tests/sqlserver_test.py new file mode 100644 index 0000000..fd68159 --- /dev/null +++ b/tests/sqlserver_test.py @@ -0,0 +1,163 @@ +import unittest +from unittest.mock import patch, MagicMock +import pyodbc +import pandas as pd +from mindsql.databases.sqlserver import SQLServer, ERROR_WHILE_RUNNING_QUERY, ERROR_CONNECTING_TO_DB_CONSTANT, \ + INVALID_DB_CONNECTION_OBJECT, CONNECTION_ESTABLISH_ERROR_CONSTANT +from mindsql.databases.sqlserver import log as logger + + +class TestSQLServer(unittest.TestCase): + + @patch('mindsql.databases.sqlserver.pyodbc.connect') + def test_create_connection_success(self, mock_connect): + mock_connect.return_value = MagicMock(spec=pyodbc.Connection) + connection = SQLServer.create_connection( + 'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password') + self.assertIsInstance(connection, pyodbc.Connection) + + @patch('mindsql.databases.sqlserver.pyodbc.connect') + def test_create_connection_failure(self, mock_connect): + mock_connect.side_effect = pyodbc.Error('Connection failed') + with self.assertLogs(logger, level='ERROR') as cm: + connection = SQLServer.create_connection( + 'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password') + self.assertIsNone(connection) + self.assertTrue(any( + ERROR_CONNECTING_TO_DB_CONSTANT.format("SQL Server", 'Connection failed') in message for message in + cm.output)) + + @patch('mindsql.databases.sqlserver.pyodbc.connect') + def test_execute_sql_success(self, mock_connect): + # Mock the connection and cursor + mock_connection = MagicMock(spec=pyodbc.Connection) + mock_cursor = MagicMock() + + mock_connect.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + + # Mock cursor behavior + mock_cursor.execute.return_value = None + mock_cursor.description = [('column1',), ('column2',)] + mock_cursor.fetchall.return_value = [(1, 'a'), (2, 'b')] + + connection = SQLServer.create_connection('fake_connection_string') + sql = "SELECT * FROM table" + sql_server = SQLServer() + result = sql_server.execute_sql(connection, sql) + expected_df = pd.DataFrame(data=[(1, 'a'), (2, 'b')], columns=['column1', 'column2']) + pd.testing.assert_frame_equal(result, expected_df) + + @patch('mindsql.databases.sqlserver.pyodbc.connect') + def test_execute_sql_failure(self, mock_connect): + # Mock the connection and cursor + mock_connection = MagicMock(spec=pyodbc.Connection) + mock_cursor = MagicMock() + + mock_connect.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + mock_cursor.execute.side_effect = pyodbc.Error('Query failed') + + connection = SQLServer.create_connection('fake_connection_string') + sql = "SELECT * FROM table" + sql_server = SQLServer() + + with self.assertLogs(logger, level='ERROR') as cm: + result = sql_server.execute_sql(connection, sql) + self.assertIsNone(result) + self.assertTrue(any(ERROR_WHILE_RUNNING_QUERY.format('Query failed') in message for message in cm.output)) + + @patch('mindsql.databases.sqlserver.pyodbc.connect') + def test_get_databases_success(self, mock_connect): + # Mock the connection and cursor + mock_connection = MagicMock(spec=pyodbc.Connection) + mock_cursor = MagicMock() + + mock_connect.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + + # Mock cursor behavior + mock_cursor.execute.return_value = None + mock_cursor.fetchall.return_value = [('database1',), ('database2',)] + + connection = SQLServer.create_connection('fake_connection_string') + sql_server = SQLServer() + result = sql_server.get_databases(connection) + self.assertEqual(result, ['database1', 'database2']) + + @patch('mindsql.databases.sqlserver.pyodbc.connect') + def test_get_databases_failure(self, mock_connect): + # Mock the connection and cursor + mock_connection = MagicMock(spec=pyodbc.Connection) + mock_cursor = MagicMock() + + mock_connect.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + mock_cursor.execute.side_effect = pyodbc.Error('Query failed') + + connection = SQLServer.create_connection('fake_connection_string') + sql_server = SQLServer() + + with self.assertLogs(logger, level='ERROR') as cm: + result = sql_server.get_databases(connection) + self.assertEqual(result, []) + self.assertTrue(any(ERROR_WHILE_RUNNING_QUERY.format('Query failed') in message for message in cm.output)) + + @patch('mindsql.databases.sqlserver.SQLServer.execute_sql') + def test_get_table_names_success(self, mock_execute_sql): + mock_execute_sql.return_value = pd.DataFrame(data=[('schema1.table1',), ('schema2.table2',)], + columns=['table_name']) + + connection = MagicMock(spec=pyodbc.Connection) + sql_server = SQLServer() + result = sql_server.get_table_names(connection, 'database_name') + expected_df = pd.DataFrame(data=[('schema1.table1',), ('schema2.table2',)], columns=['table_name']) + pd.testing.assert_frame_equal(result, expected_df) + + @patch('mindsql.databases.sqlserver.SQLServer.execute_sql') + def test_get_all_ddls_success(self, mock_execute_sql): + mock_execute_sql.side_effect = [ + pd.DataFrame(data=[('schema1.table1',)], columns=['table_name']), + pd.DataFrame(data=['CREATE TABLE schema1.table1 (...);'], columns=['SQLQuery']) + ] + + connection = MagicMock(spec=pyodbc.Connection) + sql_server = SQLServer() + result = sql_server.get_all_ddls(connection, 'database_name') + + expected_df = pd.DataFrame(data=[{'Table': 'schema1.table1', 'DDL': 'CREATE TABLE schema1.table1 (...);'}]) + pd.testing.assert_frame_equal(result, expected_df) + + def test_validate_connection_success(self): + connection = MagicMock(spec=pyodbc.Connection) + sql_server = SQLServer() + # Should not raise any exception + sql_server.validate_connection(connection) + + def test_validate_connection_failure(self): + sql_server = SQLServer() + + with self.assertRaises(ValueError) as cm: + sql_server.validate_connection(None) + self.assertEqual(str(cm.exception), CONNECTION_ESTABLISH_ERROR_CONSTANT) + + with self.assertRaises(ValueError) as cm: + sql_server.validate_connection("InvalidConnectionObject") + self.assertEqual(str(cm.exception), INVALID_DB_CONNECTION_OBJECT.format("SQL Server")) + + @patch('mindsql.databases.sqlserver.SQLServer.execute_sql') + def test_get_ddl_success(self, mock_execute_sql): + mock_execute_sql.return_value = pd.DataFrame(data=['CREATE TABLE schema1.table1 (...);'], columns=['SQLQuery']) + + connection = MagicMock(spec=pyodbc.Connection) + sql_server = SQLServer() + result = sql_server.get_ddl(connection, 'schema1.table1') + self.assertEqual(result, 'CREATE TABLE schema1.table1 (...);') + + def test_get_dialect(self): + sql_server = SQLServer() + self.assertEqual(sql_server.get_dialect(), 'tsql') + + +if __name__ == '__main__': + unittest.main() From e6164ce82feb0715a01434a55644efb5da60762c Mon Sep 17 00:00:00 2001 From: Szymon Cyranik Date: Thu, 22 Aug 2024 14:54:15 +0200 Subject: [PATCH 3/4] feat(sqlserver): add constants for SQL Server integration --- mindsql/_utils/constants.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mindsql/_utils/constants.py b/mindsql/_utils/constants.py index 0d59f58..58bc3fd 100644 --- a/mindsql/_utils/constants.py +++ b/mindsql/_utils/constants.py @@ -32,4 +32,7 @@ OPENAI_VALUE_ERROR = "OpenAI API key is required" PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty." POSTGRESQL_SHOW_CREATE_TABLE_QUERY = """SELECT 'CREATE TABLE "' || table_name || '" (' || array_to_string(array_agg(column_name || ' ' || data_type), ', ') || ');' AS create_statement FROM information_schema.columns WHERE table_name = '{table}' GROUP BY table_name;""" -ANTHROPIC_VALUE_ERROR = "Anthropic API key is required" \ No newline at end of file +ANTHROPIC_VALUE_ERROR = "Anthropic API key is required" +SQLSERVER_SHOW_DATABASE_QUERY= "SELECT name FROM sys.databases;" +SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT CONCAT(TABLE_SCHEMA,'.',TABLE_NAME) FROM [{db}].INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE'" +SQLSERVER_SHOW_CREATE_TABLE_QUERY = "DECLARE @TableName NVARCHAR(MAX) = '{table}'; DECLARE @SchemaName NVARCHAR(MAX) = '{schema}'; DECLARE @SQL NVARCHAR(MAX); SELECT @SQL = 'CREATE TABLE ' + @SchemaName + '.' + t.name + ' (' + CHAR(13) + ( SELECT ' ' + c.name + ' ' + UPPER(tp.name) + CASE WHEN tp.name IN ('char', 'varchar', 'nchar', 'nvarchar') THEN '(' + CASE WHEN c.max_length = -1 THEN 'MAX' ELSE CAST(c.max_length AS VARCHAR(10)) END + ')' WHEN tp.name IN ('decimal', 'numeric') THEN '(' + CAST(c.precision AS VARCHAR(10)) + ',' + CAST(c.scale AS VARCHAR(10)) + ')' ELSE '' END + ',' + CHAR(13) FROM sys.columns c JOIN sys.types tp ON c.user_type_id = tp.user_type_id WHERE c.object_id = t.object_id ORDER BY c.column_id FOR XML PATH(''), TYPE ).value('.', 'NVARCHAR(MAX)') + CHAR(13) + ')' FROM sys.tables t JOIN sys.schemas s ON t.schema_id = s.schema_id WHERE t.name = @TableName AND s.name = @SchemaName; SELECT @SQL AS SQLQuery;" From 845f26dc45cad5e198f0758776a195ee91961cfc Mon Sep 17 00:00:00 2001 From: Szymon Cyranik Date: Thu, 22 Aug 2024 15:24:21 +0200 Subject: [PATCH 4/4] feat(sqlserver): add import to __init__ --- mindsql/databases/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindsql/databases/__init__.py b/mindsql/databases/__init__.py index 109c0ef..0034303 100644 --- a/mindsql/databases/__init__.py +++ b/mindsql/databases/__init__.py @@ -2,3 +2,4 @@ from .mysql import MySql from .postgres import Postgres from .sqlite import Sqlite +from .sqlserver import SQLServer