-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from cyrannano/sqlserver-integration
SQL Server integration
- Loading branch information
Showing
4 changed files
with
314 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from typing import List, Optional | ||
from urllib.parse import urlparse | ||
|
||
import pandas as pd | ||
import pyodbc | ||
|
||
from . import IDatabase | ||
from .._utils import logger | ||
from .._utils.constants import ERROR_WHILE_RUNNING_QUERY, ERROR_CONNECTING_TO_DB_CONSTANT, INVALID_DB_CONNECTION_OBJECT, \ | ||
CONNECTION_ESTABLISH_ERROR_CONSTANT, SQLSERVER_SHOW_DATABASE_QUERY, SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY, \ | ||
SQLSERVER_SHOW_CREATE_TABLE_QUERY | ||
|
||
log = logger.init_loggers("SQL Server") | ||
|
||
|
||
class SQLServer(IDatabase): | ||
@staticmethod | ||
def create_connection(url: str, **kwargs) -> any: | ||
""" | ||
Connects to a SQL Server database using the provided URL. | ||
Parameters: | ||
- url (str): The connection string to the SQL Server database in the format: | ||
'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password' | ||
- **kwargs: Additional keyword arguments for the connection | ||
Returns: | ||
- connection: A connection to the SQL Server database | ||
""" | ||
|
||
try: | ||
connection = pyodbc.connect(url, **kwargs) | ||
return connection | ||
except pyodbc.Error as e: | ||
log.error(ERROR_CONNECTING_TO_DB_CONSTANT.format("SQL Server", e)) | ||
|
||
def execute_sql(self, connection, sql:str) -> Optional[pd.DataFrame]: | ||
""" | ||
A function that runs an SQL query using the provided connection and returns the results as a pandas DataFrame. | ||
Parameters: | ||
connection: The connection object for the database. | ||
sql (str): The SQL query to be executed | ||
Returns: | ||
pd.DataFrame: A DataFrame containing the results of the SQL query. | ||
""" | ||
try: | ||
self.validate_connection(connection) | ||
cursor = connection.cursor() | ||
cursor.execute(sql) | ||
columns = [column[0] for column in cursor.description] | ||
data = cursor.fetchall() | ||
data = [list(row) for row in data] | ||
cursor.close() | ||
return pd.DataFrame(data, columns=columns) | ||
except pyodbc.Error as e: | ||
log.error(ERROR_WHILE_RUNNING_QUERY.format(e)) | ||
return None | ||
|
||
def get_databases(self, connection) -> List[str]: | ||
""" | ||
Get a list of databases from the given connection and SQL query. | ||
Parameters: | ||
connection: The connection object for the database. | ||
Returns: | ||
List[str]: A list of unique database names. | ||
""" | ||
try: | ||
self.validate_connection(connection) | ||
cursor = connection.cursor() | ||
cursor.execute(SQLSERVER_SHOW_DATABASE_QUERY) | ||
databases = [row[0] for row in cursor.fetchall()] | ||
cursor.close() | ||
return databases | ||
except pyodbc.Error as e: | ||
log.error(ERROR_WHILE_RUNNING_QUERY.format(e)) | ||
return [] | ||
|
||
def get_table_names(self, connection, database: str) -> pd.DataFrame: | ||
""" | ||
Retrieves the tables along with their schema (schema.table_name) from the information schema for the specified | ||
database. | ||
Parameters: | ||
connection: The database connection object. | ||
database (str): The name of the database. | ||
Returns: | ||
DataFrame: A pandas DataFrame containing the table names from the information schema. | ||
""" | ||
self.validate_connection(connection) | ||
query = SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY.format(db=database) | ||
return self.execute_sql(connection, query) | ||
|
||
|
||
|
||
|
||
def get_all_ddls(self, connection: any, database: str) -> pd.DataFrame: | ||
""" | ||
A method to get the DDLs for all the tables in the database. | ||
Parameters: | ||
connection (any): The connection object. | ||
database (str): The name of the database. | ||
Returns: | ||
DataFrame: A pandas DataFrame containing the DDLs for all the tables in the database. | ||
""" | ||
df_tables = self.get_table_names(connection, database) | ||
ddl_df = pd.DataFrame(columns=['Table', 'DDL']) | ||
for index, row in df_tables.iterrows(): | ||
ddl = self.get_ddl(connection, row.iloc[0]) | ||
ddl_df = ddl_df._append({'Table': row.iloc[0], 'DDL': ddl}, ignore_index=True) | ||
|
||
return ddl_df | ||
|
||
|
||
|
||
def validate_connection(self, connection: any) -> None: | ||
""" | ||
A function that validates if the provided connection is a SQL Server connection. | ||
Parameters: | ||
connection: The connection object for accessing the database. | ||
Raises: | ||
ValueError: If the provided connection is not a SQL Server connection. | ||
Returns: | ||
None | ||
""" | ||
if connection is None: | ||
raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT) | ||
if not isinstance(connection, pyodbc.Connection): | ||
raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("SQL Server")) | ||
|
||
def get_ddl(self, connection: any, table_name: str, **kwargs) -> str: | ||
schema_name, table_name = table_name.split('.') | ||
query = SQLSERVER_SHOW_CREATE_TABLE_QUERY.format(table=table_name, schema=schema_name) | ||
df_ddl = self.execute_sql(connection, query) | ||
return df_ddl['SQLQuery'][0] | ||
|
||
def get_dialect(self) -> str: | ||
return 'tsql' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
import pyodbc | ||
import pandas as pd | ||
from mindsql.databases.sqlserver import SQLServer, ERROR_WHILE_RUNNING_QUERY, ERROR_CONNECTING_TO_DB_CONSTANT, \ | ||
INVALID_DB_CONNECTION_OBJECT, CONNECTION_ESTABLISH_ERROR_CONSTANT | ||
from mindsql.databases.sqlserver import log as logger | ||
|
||
|
||
class TestSQLServer(unittest.TestCase): | ||
|
||
@patch('mindsql.databases.sqlserver.pyodbc.connect') | ||
def test_create_connection_success(self, mock_connect): | ||
mock_connect.return_value = MagicMock(spec=pyodbc.Connection) | ||
connection = SQLServer.create_connection( | ||
'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password') | ||
self.assertIsInstance(connection, pyodbc.Connection) | ||
|
||
@patch('mindsql.databases.sqlserver.pyodbc.connect') | ||
def test_create_connection_failure(self, mock_connect): | ||
mock_connect.side_effect = pyodbc.Error('Connection failed') | ||
with self.assertLogs(logger, level='ERROR') as cm: | ||
connection = SQLServer.create_connection( | ||
'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password') | ||
self.assertIsNone(connection) | ||
self.assertTrue(any( | ||
ERROR_CONNECTING_TO_DB_CONSTANT.format("SQL Server", 'Connection failed') in message for message in | ||
cm.output)) | ||
|
||
@patch('mindsql.databases.sqlserver.pyodbc.connect') | ||
def test_execute_sql_success(self, mock_connect): | ||
# Mock the connection and cursor | ||
mock_connection = MagicMock(spec=pyodbc.Connection) | ||
mock_cursor = MagicMock() | ||
|
||
mock_connect.return_value = mock_connection | ||
mock_connection.cursor.return_value = mock_cursor | ||
|
||
# Mock cursor behavior | ||
mock_cursor.execute.return_value = None | ||
mock_cursor.description = [('column1',), ('column2',)] | ||
mock_cursor.fetchall.return_value = [(1, 'a'), (2, 'b')] | ||
|
||
connection = SQLServer.create_connection('fake_connection_string') | ||
sql = "SELECT * FROM table" | ||
sql_server = SQLServer() | ||
result = sql_server.execute_sql(connection, sql) | ||
expected_df = pd.DataFrame(data=[(1, 'a'), (2, 'b')], columns=['column1', 'column2']) | ||
pd.testing.assert_frame_equal(result, expected_df) | ||
|
||
@patch('mindsql.databases.sqlserver.pyodbc.connect') | ||
def test_execute_sql_failure(self, mock_connect): | ||
# Mock the connection and cursor | ||
mock_connection = MagicMock(spec=pyodbc.Connection) | ||
mock_cursor = MagicMock() | ||
|
||
mock_connect.return_value = mock_connection | ||
mock_connection.cursor.return_value = mock_cursor | ||
mock_cursor.execute.side_effect = pyodbc.Error('Query failed') | ||
|
||
connection = SQLServer.create_connection('fake_connection_string') | ||
sql = "SELECT * FROM table" | ||
sql_server = SQLServer() | ||
|
||
with self.assertLogs(logger, level='ERROR') as cm: | ||
result = sql_server.execute_sql(connection, sql) | ||
self.assertIsNone(result) | ||
self.assertTrue(any(ERROR_WHILE_RUNNING_QUERY.format('Query failed') in message for message in cm.output)) | ||
|
||
@patch('mindsql.databases.sqlserver.pyodbc.connect') | ||
def test_get_databases_success(self, mock_connect): | ||
# Mock the connection and cursor | ||
mock_connection = MagicMock(spec=pyodbc.Connection) | ||
mock_cursor = MagicMock() | ||
|
||
mock_connect.return_value = mock_connection | ||
mock_connection.cursor.return_value = mock_cursor | ||
|
||
# Mock cursor behavior | ||
mock_cursor.execute.return_value = None | ||
mock_cursor.fetchall.return_value = [('database1',), ('database2',)] | ||
|
||
connection = SQLServer.create_connection('fake_connection_string') | ||
sql_server = SQLServer() | ||
result = sql_server.get_databases(connection) | ||
self.assertEqual(result, ['database1', 'database2']) | ||
|
||
@patch('mindsql.databases.sqlserver.pyodbc.connect') | ||
def test_get_databases_failure(self, mock_connect): | ||
# Mock the connection and cursor | ||
mock_connection = MagicMock(spec=pyodbc.Connection) | ||
mock_cursor = MagicMock() | ||
|
||
mock_connect.return_value = mock_connection | ||
mock_connection.cursor.return_value = mock_cursor | ||
mock_cursor.execute.side_effect = pyodbc.Error('Query failed') | ||
|
||
connection = SQLServer.create_connection('fake_connection_string') | ||
sql_server = SQLServer() | ||
|
||
with self.assertLogs(logger, level='ERROR') as cm: | ||
result = sql_server.get_databases(connection) | ||
self.assertEqual(result, []) | ||
self.assertTrue(any(ERROR_WHILE_RUNNING_QUERY.format('Query failed') in message for message in cm.output)) | ||
|
||
@patch('mindsql.databases.sqlserver.SQLServer.execute_sql') | ||
def test_get_table_names_success(self, mock_execute_sql): | ||
mock_execute_sql.return_value = pd.DataFrame(data=[('schema1.table1',), ('schema2.table2',)], | ||
columns=['table_name']) | ||
|
||
connection = MagicMock(spec=pyodbc.Connection) | ||
sql_server = SQLServer() | ||
result = sql_server.get_table_names(connection, 'database_name') | ||
expected_df = pd.DataFrame(data=[('schema1.table1',), ('schema2.table2',)], columns=['table_name']) | ||
pd.testing.assert_frame_equal(result, expected_df) | ||
|
||
@patch('mindsql.databases.sqlserver.SQLServer.execute_sql') | ||
def test_get_all_ddls_success(self, mock_execute_sql): | ||
mock_execute_sql.side_effect = [ | ||
pd.DataFrame(data=[('schema1.table1',)], columns=['table_name']), | ||
pd.DataFrame(data=['CREATE TABLE schema1.table1 (...);'], columns=['SQLQuery']) | ||
] | ||
|
||
connection = MagicMock(spec=pyodbc.Connection) | ||
sql_server = SQLServer() | ||
result = sql_server.get_all_ddls(connection, 'database_name') | ||
|
||
expected_df = pd.DataFrame(data=[{'Table': 'schema1.table1', 'DDL': 'CREATE TABLE schema1.table1 (...);'}]) | ||
pd.testing.assert_frame_equal(result, expected_df) | ||
|
||
def test_validate_connection_success(self): | ||
connection = MagicMock(spec=pyodbc.Connection) | ||
sql_server = SQLServer() | ||
# Should not raise any exception | ||
sql_server.validate_connection(connection) | ||
|
||
def test_validate_connection_failure(self): | ||
sql_server = SQLServer() | ||
|
||
with self.assertRaises(ValueError) as cm: | ||
sql_server.validate_connection(None) | ||
self.assertEqual(str(cm.exception), CONNECTION_ESTABLISH_ERROR_CONSTANT) | ||
|
||
with self.assertRaises(ValueError) as cm: | ||
sql_server.validate_connection("InvalidConnectionObject") | ||
self.assertEqual(str(cm.exception), INVALID_DB_CONNECTION_OBJECT.format("SQL Server")) | ||
|
||
@patch('mindsql.databases.sqlserver.SQLServer.execute_sql') | ||
def test_get_ddl_success(self, mock_execute_sql): | ||
mock_execute_sql.return_value = pd.DataFrame(data=['CREATE TABLE schema1.table1 (...);'], columns=['SQLQuery']) | ||
|
||
connection = MagicMock(spec=pyodbc.Connection) | ||
sql_server = SQLServer() | ||
result = sql_server.get_ddl(connection, 'schema1.table1') | ||
self.assertEqual(result, 'CREATE TABLE schema1.table1 (...);') | ||
|
||
def test_get_dialect(self): | ||
sql_server = SQLServer() | ||
self.assertEqual(sql_server.get_dialect(), 'tsql') | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |