Skip to content

Commit

Permalink
Merge pull request #6 from stratosphereips/add_tests_db_utils
Browse files Browse the repository at this point in the history
Increase test coverage
  • Loading branch information
verovaleros authored Feb 23, 2024
2 parents 5495887 + 307d14e commit 37ab932
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 26 deletions.
44 changes: 22 additions & 22 deletions lib/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ def get_channel_messages(cursor, channel_name):
cursor.execute(query, (channel_name,))
messages = cursor.fetchall()
return messages
except sqlite3.IntegrityError as e:
raise sqlite3.IntegrityError(f"Integrity error retriving from db: {e}")
except sqlite3.OperationalError as e:
raise sqlite3.OperationalError(f"Operational error retrieving from db: {e}")
except sqlite3.ProgrammingError as e:
raise sqlite3.ProgrammingError(f"Programming error occurred: {e}")
except sqlite3.DatabaseError as e:
raise sqlite3.DatabaseError(f"Database error occurred: {e}")
except sqlite3.IntegrityError:
raise
except sqlite3.OperationalError:
raise
except sqlite3.ProgrammingError:
raise
except sqlite3.DatabaseError:
raise


def exists_translation_for_message(cursor, message_id, translation_parameters_id):
Expand Down Expand Up @@ -244,14 +244,14 @@ def exists_translation_for_message(cursor, message_id, translation_parameters_id
try:
cursor.execute(query, (message_id, translation_parameters_id,))
return bool(cursor.fetchone()[0] > 0)
except sqlite3.IntegrityError as e:
raise sqlite3.IntegrityError(f"Integrity error occurred: {e}")
except sqlite3.OperationalError as e:
raise sqlite3.OperationalError(f"Operational error occurred: {e}")
except sqlite3.ProgrammingError as e:
raise sqlite3.ProgrammingError(f"Programming error occurred: {e}")
except sqlite3.DatabaseError as e:
raise sqlite3.DatabaseError(f"Database error occurred: {e}")
except sqlite3.IntegrityError:
raise
except sqlite3.OperationalError:
raise
except sqlite3.ProgrammingError:
raise
except sqlite3.DatabaseError:
raise


def upsert_message_translation(cursor, message_id, translation_parameters_id, translation_text):
Expand Down Expand Up @@ -292,9 +292,9 @@ def upsert_message_translation(cursor, message_id, translation_parameters_id, tr
cursor.execute(query, params)

return cursor.lastrowid
except sqlite3.IntegrityError as e:
raise sqlite3.IntegrityError(f"Integrity error occurred while upserting message translation: {e}")
except sqlite3.OperationalError as e:
raise sqlite3.OperationalError(f"Operational error occurred while upserting message translation: {e}")
except sqlite3.DatabaseError as e:
raise sqlite3.DatabaseError(f"Database error occurred while upserting message translation: {e}")
except sqlite3.IntegrityError:
raise
except sqlite3.OperationalError:
raise
except sqlite3.DatabaseError:
raise
169 changes: 166 additions & 3 deletions tests/test_db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import pytest
import sys
import sqlite3
import tempfile
from os import path
from datetime import datetime
import tempfile
from os import remove
from unittest.mock import patch
from unittest.mock import MagicMock
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
from lib.db_utils import get_db_connection
from lib.db_utils import check_channel_exists
Expand Down Expand Up @@ -64,7 +66,7 @@ def setup_database():
connection.close()


def test_channel_exists(setup_database):
def test_check_channel_exists(setup_database):
"""
Test that check_channel_exists returns the correct
channel_id when the channel exists.
Expand All @@ -76,7 +78,7 @@ def test_channel_exists(setup_database):
assert actual_channel_id == expected_channel_id[0]


def test_channel_does_not_exist(setup_database):
def test_check_channel_does_not_exist(setup_database):
"""
Test that check_channel_exists returns None when the
channel does not exist.
Expand All @@ -85,6 +87,28 @@ def test_channel_does_not_exist(setup_database):
assert check_channel_exists(cursor, 'nonexistent_channel') is None


@pytest.mark.parametrize("exception", [
sqlite3.OperationalError,
sqlite3.IntegrityError,
sqlite3.ProgrammingError,
sqlite3.DatabaseError
])
def test_check_channel_exceptions(exception):
"""
Test that check_channel_exists correctly raises sqlite3 exceptions.
"""
cursor = MagicMock()
channel_name = 'test_channel'

# Mock to return a valid channel ID
with patch('lib.db_utils.check_channel_exists', return_value=1):
# Then mock cursor.execute to raise the specified exception
cursor.execute.side_effect = exception

with pytest.raises(exception):
has_channel_messages(cursor, channel_name)


def test_has_channel_messages_true(setup_database):
"""Test that has_channel_messages returns True when messages exist."""
cursor = setup_database
Expand All @@ -105,6 +129,28 @@ def test_has_channel_messages_nonexistent_channel(setup_database):
assert has_channel_messages(cursor, 'nonexistent_channel') is False


@pytest.mark.parametrize("exception", [
sqlite3.OperationalError,
sqlite3.IntegrityError,
sqlite3.ProgrammingError,
sqlite3.DatabaseError
])
def test_has_channel_messages_exceptions(exception):
"""
Test that has_channel_messages correctly raises sqlite3 exceptions.
"""
cursor = MagicMock()
channel_name = 'existing_channel'

# Mock check_channel_exists to return a valid channel ID
with patch('lib.db_utils.has_channel_messages', return_value=1):
# Then mock cursor.execute to raise the specified exception
cursor.execute.side_effect = exception

with pytest.raises(exception):
has_channel_messages(cursor, channel_name)


def test_schema_validity():
# Connect to an in-memory SQLite database
conn = sqlite3.connect(':memory:')
Expand Down Expand Up @@ -140,6 +186,20 @@ def test_check_table_exists_false(setup_database):
assert check_table_exists(cursor, "non_existent_table") is False, "Should return False for non-existent table"


@pytest.mark.parametrize("exception", [sqlite3.OperationalError])
def test_check_table_exists_exceptions(exception):
"""
Test that check_table_exists correctly raises sqlite3 exceptions.
"""
cursor = MagicMock()
table_name = 'channels'

cursor.execute.side_effect = exception

with pytest.raises(exception):
check_table_exists(cursor, table_name)


def test_read_sql_from_file():
# Create a temporary file with known SQL content
expected_sql_content = "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT);"
Expand Down Expand Up @@ -183,6 +243,28 @@ def test_create_tables_from_schema(setup_database):
connection.close()


@pytest.mark.parametrize("exception", [sqlite3.OperationalError])
def test_create_tables_from_schema_exceptions(exception):
"""
Test that create_tables_from_schema correctly handles exceptions.
"""
connection = MagicMock()
cursor = MagicMock()
schema_file_path = 'path/to/schema.sql'

# Mock read_sql_from_file to return a specific SQL command
with patch('lib.db_utils.read_sql_from_file', return_value="CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY);"):
# Set the side_effect of cursor.executescript to raise the specified exception
cursor.executescript.side_effect = exception

# Ensure connection.rollback is called and sqlite3.OperationalError is raised
with pytest.raises(sqlite3.OperationalError):
create_tables_from_schema(connection, cursor, schema_file_path)

# Verify that rollback was called in the event of an exception
connection.rollback.assert_called_once()


@pytest.fixture
def db_cursor():
# Create an in-memory SQLite database and cursor
Expand Down Expand Up @@ -232,6 +314,26 @@ def test_insert_translation_parameters(db_cursor):
assert result[5] == translation_config


@pytest.mark.parametrize("exception,expected_exception", [
(sqlite3.IntegrityError, sqlite3.IntegrityError),
(sqlite3.OperationalError, sqlite3.OperationalError),
])
def test_insert_translation_parameters_exceptions(exception, expected_exception):
"""
Test that insert_translation_parameters correctly handles sqlite3 exceptions.
"""
cursor = MagicMock()
args = ('tool_name', 'commit_hash', 'model_name', 'config_sha256', 'config_data')

# Configure the cursor.execute to raise the specified exception
cursor.execute.side_effect = exception("Simulated database error")

with pytest.raises(expected_exception) as exc_info:
insert_translation_parameters(cursor, *args)

assert "inserting into database" in str(exc_info.value)


def test_get_channel_messages(setup_database):
"""Test that get_channel_messages returns the correct messages for a channel."""
cursor = setup_database
Expand All @@ -246,6 +348,26 @@ def test_get_channel_messages_no_channel(setup_database):
assert messages == [], "Messages were retrieved for a non-existent channel."


@pytest.mark.parametrize("exception", [
sqlite3.IntegrityError,
sqlite3.OperationalError,
sqlite3.ProgrammingError,
sqlite3.DatabaseError,
])
def test_get_channel_messages_exceptions(exception):
"""
Test that get_channel_messages correctly handles various sqlite3 exceptions.
"""
cursor = MagicMock()
channel_name = 'test_channel'

# Configure the cursor.execute to raise the specified exception
cursor.execute.side_effect = exception("Simulated database error")

with pytest.raises(exception):
get_channel_messages(cursor, channel_name)


def test_exists_translation_for_message_true(setup_database):
"""Test that exists_translation_for_message returns True when a translation exists."""
cursor = setup_database
Expand All @@ -262,6 +384,26 @@ def test_exists_translation_for_message_false(setup_database):
assert exists_translation_for_message(cursor, message_id, translation_parameters_id) is False, "Translation should not exist but was reported as found."


@pytest.mark.parametrize("exception", [
sqlite3.IntegrityError,
sqlite3.OperationalError,
sqlite3.ProgrammingError,
sqlite3.DatabaseError,
])
def test_exists_translation_for_message_exceptions(exception):
"""
Test that exists_translation_for_message correctly re-raises sqlite3 exceptions.
"""
cursor = MagicMock()
message_id = 1
translation_parameters_id = 1

# Configure the cursor.execute to raise the specified exception
cursor.execute.side_effect = exception("Simulated database error")

with pytest.raises(exception):
exists_translation_for_message(cursor, message_id, translation_parameters_id)


def test_insert_new_translation(setup_database):
"""Test inserting a new translation."""
Expand Down Expand Up @@ -291,3 +433,24 @@ def test_update_existing_translation(setup_database):
cursor.execute("SELECT translation_text FROM message_translation WHERE message_id = ? AND translation_parameters_id = ?", (message_id, translation_parameters_id))
translation_text = cursor.fetchone()[0]
assert translation_text == updated_translation_text, "Translation text should have been updated."


@pytest.mark.parametrize("exception", [
sqlite3.IntegrityError,
sqlite3.OperationalError,
sqlite3.DatabaseError,
])
def test_upsert_message_translation_exceptions(exception):
"""
Test that upsert_message_translation correctly re-raises sqlite3 exceptions.
"""
cursor = MagicMock()
message_id = 1
translation_parameters_id = 2
translation_text = "Sample translation text"

# Configure the cursor.execute to raise the specified exception
cursor.execute.side_effect = exception("Simulated database error")

with pytest.raises(exception):
upsert_message_translation(cursor, message_id, translation_parameters_id, translation_text)
46 changes: 45 additions & 1 deletion tests/test_hermeneisGPT.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# pylint: disable=missing-docstring
import argparse
import sys
import pytest
import logging
from os import path
from unittest.mock import patch

sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
from hermeneisGPT import load_and_parse_config
from hermeneisGPT import main


def test_load_and_parse_config_success(tmp_path):
directory = tmp_path / "sub"
Expand Down Expand Up @@ -40,6 +42,7 @@ def test_load_and_parse_config_success(tmp_path):
'log': 'output.log',
}


# Example of a failure to load due to file not found or other IO issues
def test_load_and_parse_config_failure(tmp_path):
non_existent_file_path = tmp_path / "does_not_exist.yaml"
Expand All @@ -52,3 +55,44 @@ def test_load_and_parse_config_failure(tmp_path):
# Check if 'error' was called at least once
mock_error.assert_called_once()


def test_argument_parsing():
test_args = [
"hermeneisGPT.py",
"--verbose",
"--debug",
"--yaml_config", "path/to/config.yml",
"--env", "path/to/.env",
"--mode", "auto-sqlite",
"--channel_name", "example_channel",
"--max_limit", "5",
"--sqlite_db", "path/to/database.db",
"--sqlite_schema", "path/to/schema.sql",
"--sqlite_chn_table", "custom_channels",
"--sqlite_chn_field", "custom_channel_name",
"--sqlite_msg_table", "custom_messages",
"--sqlite_msg_field", "custom_message_text"
]

with patch('sys.argv', test_args):
with patch('argparse.ArgumentParser.parse_args') as mock_parse:
# Set the return value of parse_args to simulate parsed arguments
mock_parse.return_value = argparse.Namespace(
verbose=True,
debug=True,
yaml_config="path/to/config.yml",
env="path/to/.env",
mode="auto-sqlite",
channel_name="example_channel",
max_limit="5",
sqlite_db="path/to/database.db",
sqlite_schema="path/to/schema.sql",
sqlite_chn_table="custom_channels",
sqlite_chn_field="custom_channel_name",
sqlite_msg_table="custom_messages",
sqlite_msg_field="custom_messages_text",
)
main()

# Verify parse_args was called indicating arguments were parsed
mock_parse.assert_called()

0 comments on commit 37ab932

Please sign in to comment.