Skip to content

Commit

Permalink
Merge pull request #20 from cyrannano/sqlserver-integration
Browse files Browse the repository at this point in the history
SQL Server integration
  • Loading branch information
Sammindinventory authored Oct 18, 2024
2 parents b95ca73 + 701c046 commit 3d0ff02
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 0 deletions.
3 changes: 3 additions & 0 deletions mindsql/_utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@
PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty."
POSTGRESQL_SHOW_CREATE_TABLE_QUERY = """SELECT 'CREATE TABLE "' || table_name || '" (' || array_to_string(array_agg(column_name || ' ' || data_type), ', ') || ');' AS create_statement FROM information_schema.columns WHERE table_name = '{table}' GROUP BY table_name;"""
ANTHROPIC_VALUE_ERROR = "Anthropic API key is required"
SQLSERVER_SHOW_DATABASE_QUERY= "SELECT name FROM sys.databases;"
SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT CONCAT(TABLE_SCHEMA,'.',TABLE_NAME) FROM [{db}].INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE'"
SQLSERVER_SHOW_CREATE_TABLE_QUERY = "DECLARE @TableName NVARCHAR(MAX) = '{table}'; DECLARE @SchemaName NVARCHAR(MAX) = '{schema}'; DECLARE @SQL NVARCHAR(MAX); SELECT @SQL = 'CREATE TABLE ' + @SchemaName + '.' + t.name + ' (' + CHAR(13) + ( SELECT ' ' + c.name + ' ' + UPPER(tp.name) + CASE WHEN tp.name IN ('char', 'varchar', 'nchar', 'nvarchar') THEN '(' + CASE WHEN c.max_length = -1 THEN 'MAX' ELSE CAST(c.max_length AS VARCHAR(10)) END + ')' WHEN tp.name IN ('decimal', 'numeric') THEN '(' + CAST(c.precision AS VARCHAR(10)) + ',' + CAST(c.scale AS VARCHAR(10)) + ')' ELSE '' END + ',' + CHAR(13) FROM sys.columns c JOIN sys.types tp ON c.user_type_id = tp.user_type_id WHERE c.object_id = t.object_id ORDER BY c.column_id FOR XML PATH(''), TYPE ).value('.', 'NVARCHAR(MAX)') + CHAR(13) + ')' FROM sys.tables t JOIN sys.schemas s ON t.schema_id = s.schema_id WHERE t.name = @TableName AND s.name = @SchemaName; SELECT @SQL AS SQLQuery;"
OLLAMA_CONFIG_REQUIRED = "{type} configuration is required."
1 change: 1 addition & 0 deletions mindsql/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .mysql import MySql
from .postgres import Postgres
from .sqlite import Sqlite
from .sqlserver import SQLServer
147 changes: 147 additions & 0 deletions mindsql/databases/sqlserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from typing import List, Optional
from urllib.parse import urlparse

import pandas as pd
import pyodbc

from . import IDatabase
from .._utils import logger
from .._utils.constants import ERROR_WHILE_RUNNING_QUERY, ERROR_CONNECTING_TO_DB_CONSTANT, INVALID_DB_CONNECTION_OBJECT, \
CONNECTION_ESTABLISH_ERROR_CONSTANT, SQLSERVER_SHOW_DATABASE_QUERY, SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY, \
SQLSERVER_SHOW_CREATE_TABLE_QUERY

log = logger.init_loggers("SQL Server")


class SQLServer(IDatabase):
@staticmethod
def create_connection(url: str, **kwargs) -> any:
"""
Connects to a SQL Server database using the provided URL.
Parameters:
- url (str): The connection string to the SQL Server database in the format:
'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password'
- **kwargs: Additional keyword arguments for the connection
Returns:
- connection: A connection to the SQL Server database
"""

try:
connection = pyodbc.connect(url, **kwargs)
return connection
except pyodbc.Error as e:
log.error(ERROR_CONNECTING_TO_DB_CONSTANT.format("SQL Server", e))

def execute_sql(self, connection, sql:str) -> Optional[pd.DataFrame]:
"""
A function that runs an SQL query using the provided connection and returns the results as a pandas DataFrame.
Parameters:
connection: The connection object for the database.
sql (str): The SQL query to be executed
Returns:
pd.DataFrame: A DataFrame containing the results of the SQL query.
"""
try:
self.validate_connection(connection)
cursor = connection.cursor()
cursor.execute(sql)
columns = [column[0] for column in cursor.description]
data = cursor.fetchall()
data = [list(row) for row in data]
cursor.close()
return pd.DataFrame(data, columns=columns)
except pyodbc.Error as e:
log.error(ERROR_WHILE_RUNNING_QUERY.format(e))
return None

def get_databases(self, connection) -> List[str]:
"""
Get a list of databases from the given connection and SQL query.
Parameters:
connection: The connection object for the database.
Returns:
List[str]: A list of unique database names.
"""
try:
self.validate_connection(connection)
cursor = connection.cursor()
cursor.execute(SQLSERVER_SHOW_DATABASE_QUERY)
databases = [row[0] for row in cursor.fetchall()]
cursor.close()
return databases
except pyodbc.Error as e:
log.error(ERROR_WHILE_RUNNING_QUERY.format(e))
return []

def get_table_names(self, connection, database: str) -> pd.DataFrame:
"""
Retrieves the tables along with their schema (schema.table_name) from the information schema for the specified
database.
Parameters:
connection: The database connection object.
database (str): The name of the database.
Returns:
DataFrame: A pandas DataFrame containing the table names from the information schema.
"""
self.validate_connection(connection)
query = SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY.format(db=database)
return self.execute_sql(connection, query)




def get_all_ddls(self, connection: any, database: str) -> pd.DataFrame:
"""
A method to get the DDLs for all the tables in the database.
Parameters:
connection (any): The connection object.
database (str): The name of the database.
Returns:
DataFrame: A pandas DataFrame containing the DDLs for all the tables in the database.
"""
df_tables = self.get_table_names(connection, database)
ddl_df = pd.DataFrame(columns=['Table', 'DDL'])
for index, row in df_tables.iterrows():
ddl = self.get_ddl(connection, row.iloc[0])
ddl_df = ddl_df._append({'Table': row.iloc[0], 'DDL': ddl}, ignore_index=True)

return ddl_df



def validate_connection(self, connection: any) -> None:
"""
A function that validates if the provided connection is a SQL Server connection.
Parameters:
connection: The connection object for accessing the database.
Raises:
ValueError: If the provided connection is not a SQL Server connection.
Returns:
None
"""
if connection is None:
raise ValueError(CONNECTION_ESTABLISH_ERROR_CONSTANT)
if not isinstance(connection, pyodbc.Connection):
raise ValueError(INVALID_DB_CONNECTION_OBJECT.format("SQL Server"))

def get_ddl(self, connection: any, table_name: str, **kwargs) -> str:
schema_name, table_name = table_name.split('.')
query = SQLSERVER_SHOW_CREATE_TABLE_QUERY.format(table=table_name, schema=schema_name)
df_ddl = self.execute_sql(connection, query)
return df_ddl['SQLQuery'][0]

def get_dialect(self) -> str:
return 'tsql'
163 changes: 163 additions & 0 deletions tests/sqlserver_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import unittest
from unittest.mock import patch, MagicMock
import pyodbc
import pandas as pd
from mindsql.databases.sqlserver import SQLServer, ERROR_WHILE_RUNNING_QUERY, ERROR_CONNECTING_TO_DB_CONSTANT, \
INVALID_DB_CONNECTION_OBJECT, CONNECTION_ESTABLISH_ERROR_CONSTANT
from mindsql.databases.sqlserver import log as logger


class TestSQLServer(unittest.TestCase):

@patch('mindsql.databases.sqlserver.pyodbc.connect')
def test_create_connection_success(self, mock_connect):
mock_connect.return_value = MagicMock(spec=pyodbc.Connection)
connection = SQLServer.create_connection(
'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password')
self.assertIsInstance(connection, pyodbc.Connection)

@patch('mindsql.databases.sqlserver.pyodbc.connect')
def test_create_connection_failure(self, mock_connect):
mock_connect.side_effect = pyodbc.Error('Connection failed')
with self.assertLogs(logger, level='ERROR') as cm:
connection = SQLServer.create_connection(
'DRIVER={ODBC Driver 17 for SQL Server};SERVER=server_name;DATABASE=database_name;UID=user;PWD=password')
self.assertIsNone(connection)
self.assertTrue(any(
ERROR_CONNECTING_TO_DB_CONSTANT.format("SQL Server", 'Connection failed') in message for message in
cm.output))

@patch('mindsql.databases.sqlserver.pyodbc.connect')
def test_execute_sql_success(self, mock_connect):
# Mock the connection and cursor
mock_connection = MagicMock(spec=pyodbc.Connection)
mock_cursor = MagicMock()

mock_connect.return_value = mock_connection
mock_connection.cursor.return_value = mock_cursor

# Mock cursor behavior
mock_cursor.execute.return_value = None
mock_cursor.description = [('column1',), ('column2',)]
mock_cursor.fetchall.return_value = [(1, 'a'), (2, 'b')]

connection = SQLServer.create_connection('fake_connection_string')
sql = "SELECT * FROM table"
sql_server = SQLServer()
result = sql_server.execute_sql(connection, sql)
expected_df = pd.DataFrame(data=[(1, 'a'), (2, 'b')], columns=['column1', 'column2'])
pd.testing.assert_frame_equal(result, expected_df)

@patch('mindsql.databases.sqlserver.pyodbc.connect')
def test_execute_sql_failure(self, mock_connect):
# Mock the connection and cursor
mock_connection = MagicMock(spec=pyodbc.Connection)
mock_cursor = MagicMock()

mock_connect.return_value = mock_connection
mock_connection.cursor.return_value = mock_cursor
mock_cursor.execute.side_effect = pyodbc.Error('Query failed')

connection = SQLServer.create_connection('fake_connection_string')
sql = "SELECT * FROM table"
sql_server = SQLServer()

with self.assertLogs(logger, level='ERROR') as cm:
result = sql_server.execute_sql(connection, sql)
self.assertIsNone(result)
self.assertTrue(any(ERROR_WHILE_RUNNING_QUERY.format('Query failed') in message for message in cm.output))

@patch('mindsql.databases.sqlserver.pyodbc.connect')
def test_get_databases_success(self, mock_connect):
# Mock the connection and cursor
mock_connection = MagicMock(spec=pyodbc.Connection)
mock_cursor = MagicMock()

mock_connect.return_value = mock_connection
mock_connection.cursor.return_value = mock_cursor

# Mock cursor behavior
mock_cursor.execute.return_value = None
mock_cursor.fetchall.return_value = [('database1',), ('database2',)]

connection = SQLServer.create_connection('fake_connection_string')
sql_server = SQLServer()
result = sql_server.get_databases(connection)
self.assertEqual(result, ['database1', 'database2'])

@patch('mindsql.databases.sqlserver.pyodbc.connect')
def test_get_databases_failure(self, mock_connect):
# Mock the connection and cursor
mock_connection = MagicMock(spec=pyodbc.Connection)
mock_cursor = MagicMock()

mock_connect.return_value = mock_connection
mock_connection.cursor.return_value = mock_cursor
mock_cursor.execute.side_effect = pyodbc.Error('Query failed')

connection = SQLServer.create_connection('fake_connection_string')
sql_server = SQLServer()

with self.assertLogs(logger, level='ERROR') as cm:
result = sql_server.get_databases(connection)
self.assertEqual(result, [])
self.assertTrue(any(ERROR_WHILE_RUNNING_QUERY.format('Query failed') in message for message in cm.output))

@patch('mindsql.databases.sqlserver.SQLServer.execute_sql')
def test_get_table_names_success(self, mock_execute_sql):
mock_execute_sql.return_value = pd.DataFrame(data=[('schema1.table1',), ('schema2.table2',)],
columns=['table_name'])

connection = MagicMock(spec=pyodbc.Connection)
sql_server = SQLServer()
result = sql_server.get_table_names(connection, 'database_name')
expected_df = pd.DataFrame(data=[('schema1.table1',), ('schema2.table2',)], columns=['table_name'])
pd.testing.assert_frame_equal(result, expected_df)

@patch('mindsql.databases.sqlserver.SQLServer.execute_sql')
def test_get_all_ddls_success(self, mock_execute_sql):
mock_execute_sql.side_effect = [
pd.DataFrame(data=[('schema1.table1',)], columns=['table_name']),
pd.DataFrame(data=['CREATE TABLE schema1.table1 (...);'], columns=['SQLQuery'])
]

connection = MagicMock(spec=pyodbc.Connection)
sql_server = SQLServer()
result = sql_server.get_all_ddls(connection, 'database_name')

expected_df = pd.DataFrame(data=[{'Table': 'schema1.table1', 'DDL': 'CREATE TABLE schema1.table1 (...);'}])
pd.testing.assert_frame_equal(result, expected_df)

def test_validate_connection_success(self):
connection = MagicMock(spec=pyodbc.Connection)
sql_server = SQLServer()
# Should not raise any exception
sql_server.validate_connection(connection)

def test_validate_connection_failure(self):
sql_server = SQLServer()

with self.assertRaises(ValueError) as cm:
sql_server.validate_connection(None)
self.assertEqual(str(cm.exception), CONNECTION_ESTABLISH_ERROR_CONSTANT)

with self.assertRaises(ValueError) as cm:
sql_server.validate_connection("InvalidConnectionObject")
self.assertEqual(str(cm.exception), INVALID_DB_CONNECTION_OBJECT.format("SQL Server"))

@patch('mindsql.databases.sqlserver.SQLServer.execute_sql')
def test_get_ddl_success(self, mock_execute_sql):
mock_execute_sql.return_value = pd.DataFrame(data=['CREATE TABLE schema1.table1 (...);'], columns=['SQLQuery'])

connection = MagicMock(spec=pyodbc.Connection)
sql_server = SQLServer()
result = sql_server.get_ddl(connection, 'schema1.table1')
self.assertEqual(result, 'CREATE TABLE schema1.table1 (...);')

def test_get_dialect(self):
sql_server = SQLServer()
self.assertEqual(sql_server.get_dialect(), 'tsql')


if __name__ == '__main__':
unittest.main()

0 comments on commit 3d0ff02

Please sign in to comment.