diff --git a/lib/db_utils.py b/lib/db_utils.py index e71f619..ea88368 100644 --- a/lib/db_utils.py +++ b/lib/db_utils.py @@ -209,14 +209,14 @@ def get_channel_messages(cursor, channel_name): cursor.execute(query, (channel_name,)) messages = cursor.fetchall() return messages - except sqlite3.IntegrityError as e: - raise sqlite3.IntegrityError(f"Integrity error retriving from db: {e}") - except sqlite3.OperationalError as e: - raise sqlite3.OperationalError(f"Operational error retrieving from db: {e}") - except sqlite3.ProgrammingError as e: - raise sqlite3.ProgrammingError(f"Programming error occurred: {e}") - except sqlite3.DatabaseError as e: - raise sqlite3.DatabaseError(f"Database error occurred: {e}") + except sqlite3.IntegrityError: + raise + except sqlite3.OperationalError: + raise + except sqlite3.ProgrammingError: + raise + except sqlite3.DatabaseError: + raise def exists_translation_for_message(cursor, message_id, translation_parameters_id): @@ -244,14 +244,14 @@ def exists_translation_for_message(cursor, message_id, translation_parameters_id try: cursor.execute(query, (message_id, translation_parameters_id,)) return bool(cursor.fetchone()[0] > 0) - except sqlite3.IntegrityError as e: - raise sqlite3.IntegrityError(f"Integrity error occurred: {e}") - except sqlite3.OperationalError as e: - raise sqlite3.OperationalError(f"Operational error occurred: {e}") - except sqlite3.ProgrammingError as e: - raise sqlite3.ProgrammingError(f"Programming error occurred: {e}") - except sqlite3.DatabaseError as e: - raise sqlite3.DatabaseError(f"Database error occurred: {e}") + except sqlite3.IntegrityError: + raise + except sqlite3.OperationalError: + raise + except sqlite3.ProgrammingError: + raise + except sqlite3.DatabaseError: + raise def upsert_message_translation(cursor, message_id, translation_parameters_id, translation_text): @@ -292,9 +292,9 @@ def upsert_message_translation(cursor, message_id, translation_parameters_id, tr cursor.execute(query, params) return cursor.lastrowid - except sqlite3.IntegrityError as e: - raise sqlite3.IntegrityError(f"Integrity error occurred while upserting message translation: {e}") - except sqlite3.OperationalError as e: - raise sqlite3.OperationalError(f"Operational error occurred while upserting message translation: {e}") - except sqlite3.DatabaseError as e: - raise sqlite3.DatabaseError(f"Database error occurred while upserting message translation: {e}") + except sqlite3.IntegrityError: + raise + except sqlite3.OperationalError: + raise + except sqlite3.DatabaseError: + raise diff --git a/tests/test_db_utils.py b/tests/test_db_utils.py index bbd495a..b3babc8 100644 --- a/tests/test_db_utils.py +++ b/tests/test_db_utils.py @@ -3,10 +3,12 @@ import pytest import sys import sqlite3 +import tempfile from os import path from datetime import datetime -import tempfile from os import remove +from unittest.mock import patch +from unittest.mock import MagicMock sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) from lib.db_utils import get_db_connection from lib.db_utils import check_channel_exists @@ -64,7 +66,7 @@ def setup_database(): connection.close() -def test_channel_exists(setup_database): +def test_check_channel_exists(setup_database): """ Test that check_channel_exists returns the correct channel_id when the channel exists. @@ -76,7 +78,7 @@ def test_channel_exists(setup_database): assert actual_channel_id == expected_channel_id[0] -def test_channel_does_not_exist(setup_database): +def test_check_channel_does_not_exist(setup_database): """ Test that check_channel_exists returns None when the channel does not exist. @@ -85,6 +87,28 @@ def test_channel_does_not_exist(setup_database): assert check_channel_exists(cursor, 'nonexistent_channel') is None +@pytest.mark.parametrize("exception", [ + sqlite3.OperationalError, + sqlite3.IntegrityError, + sqlite3.ProgrammingError, + sqlite3.DatabaseError +]) +def test_check_channel_exceptions(exception): + """ + Test that check_channel_exists correctly raises sqlite3 exceptions. + """ + cursor = MagicMock() + channel_name = 'test_channel' + + # Mock to return a valid channel ID + with patch('lib.db_utils.check_channel_exists', return_value=1): + # Then mock cursor.execute to raise the specified exception + cursor.execute.side_effect = exception + + with pytest.raises(exception): + has_channel_messages(cursor, channel_name) + + def test_has_channel_messages_true(setup_database): """Test that has_channel_messages returns True when messages exist.""" cursor = setup_database @@ -105,6 +129,28 @@ def test_has_channel_messages_nonexistent_channel(setup_database): assert has_channel_messages(cursor, 'nonexistent_channel') is False +@pytest.mark.parametrize("exception", [ + sqlite3.OperationalError, + sqlite3.IntegrityError, + sqlite3.ProgrammingError, + sqlite3.DatabaseError +]) +def test_has_channel_messages_exceptions(exception): + """ + Test that has_channel_messages correctly raises sqlite3 exceptions. + """ + cursor = MagicMock() + channel_name = 'existing_channel' + + # Mock check_channel_exists to return a valid channel ID + with patch('lib.db_utils.has_channel_messages', return_value=1): + # Then mock cursor.execute to raise the specified exception + cursor.execute.side_effect = exception + + with pytest.raises(exception): + has_channel_messages(cursor, channel_name) + + def test_schema_validity(): # Connect to an in-memory SQLite database conn = sqlite3.connect(':memory:') @@ -140,6 +186,20 @@ def test_check_table_exists_false(setup_database): assert check_table_exists(cursor, "non_existent_table") is False, "Should return False for non-existent table" +@pytest.mark.parametrize("exception", [sqlite3.OperationalError]) +def test_check_table_exists_exceptions(exception): + """ + Test that check_table_exists correctly raises sqlite3 exceptions. + """ + cursor = MagicMock() + table_name = 'channels' + + cursor.execute.side_effect = exception + + with pytest.raises(exception): + check_table_exists(cursor, table_name) + + def test_read_sql_from_file(): # Create a temporary file with known SQL content expected_sql_content = "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT);" @@ -183,6 +243,28 @@ def test_create_tables_from_schema(setup_database): connection.close() +@pytest.mark.parametrize("exception", [sqlite3.OperationalError]) +def test_create_tables_from_schema_exceptions(exception): + """ + Test that create_tables_from_schema correctly handles exceptions. + """ + connection = MagicMock() + cursor = MagicMock() + schema_file_path = 'path/to/schema.sql' + + # Mock read_sql_from_file to return a specific SQL command + with patch('lib.db_utils.read_sql_from_file', return_value="CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY);"): + # Set the side_effect of cursor.executescript to raise the specified exception + cursor.executescript.side_effect = exception + + # Ensure connection.rollback is called and sqlite3.OperationalError is raised + with pytest.raises(sqlite3.OperationalError): + create_tables_from_schema(connection, cursor, schema_file_path) + + # Verify that rollback was called in the event of an exception + connection.rollback.assert_called_once() + + @pytest.fixture def db_cursor(): # Create an in-memory SQLite database and cursor @@ -232,6 +314,26 @@ def test_insert_translation_parameters(db_cursor): assert result[5] == translation_config +@pytest.mark.parametrize("exception,expected_exception", [ + (sqlite3.IntegrityError, sqlite3.IntegrityError), + (sqlite3.OperationalError, sqlite3.OperationalError), +]) +def test_insert_translation_parameters_exceptions(exception, expected_exception): + """ + Test that insert_translation_parameters correctly handles sqlite3 exceptions. + """ + cursor = MagicMock() + args = ('tool_name', 'commit_hash', 'model_name', 'config_sha256', 'config_data') + + # Configure the cursor.execute to raise the specified exception + cursor.execute.side_effect = exception("Simulated database error") + + with pytest.raises(expected_exception) as exc_info: + insert_translation_parameters(cursor, *args) + + assert "inserting into database" in str(exc_info.value) + + def test_get_channel_messages(setup_database): """Test that get_channel_messages returns the correct messages for a channel.""" cursor = setup_database @@ -246,6 +348,26 @@ def test_get_channel_messages_no_channel(setup_database): assert messages == [], "Messages were retrieved for a non-existent channel." +@pytest.mark.parametrize("exception", [ + sqlite3.IntegrityError, + sqlite3.OperationalError, + sqlite3.ProgrammingError, + sqlite3.DatabaseError, +]) +def test_get_channel_messages_exceptions(exception): + """ + Test that get_channel_messages correctly handles various sqlite3 exceptions. + """ + cursor = MagicMock() + channel_name = 'test_channel' + + # Configure the cursor.execute to raise the specified exception + cursor.execute.side_effect = exception("Simulated database error") + + with pytest.raises(exception): + get_channel_messages(cursor, channel_name) + + def test_exists_translation_for_message_true(setup_database): """Test that exists_translation_for_message returns True when a translation exists.""" cursor = setup_database @@ -262,6 +384,26 @@ def test_exists_translation_for_message_false(setup_database): assert exists_translation_for_message(cursor, message_id, translation_parameters_id) is False, "Translation should not exist but was reported as found." +@pytest.mark.parametrize("exception", [ + sqlite3.IntegrityError, + sqlite3.OperationalError, + sqlite3.ProgrammingError, + sqlite3.DatabaseError, +]) +def test_exists_translation_for_message_exceptions(exception): + """ + Test that exists_translation_for_message correctly re-raises sqlite3 exceptions. + """ + cursor = MagicMock() + message_id = 1 + translation_parameters_id = 1 + + # Configure the cursor.execute to raise the specified exception + cursor.execute.side_effect = exception("Simulated database error") + + with pytest.raises(exception): + exists_translation_for_message(cursor, message_id, translation_parameters_id) + def test_insert_new_translation(setup_database): """Test inserting a new translation.""" @@ -291,3 +433,24 @@ def test_update_existing_translation(setup_database): cursor.execute("SELECT translation_text FROM message_translation WHERE message_id = ? AND translation_parameters_id = ?", (message_id, translation_parameters_id)) translation_text = cursor.fetchone()[0] assert translation_text == updated_translation_text, "Translation text should have been updated." + + +@pytest.mark.parametrize("exception", [ + sqlite3.IntegrityError, + sqlite3.OperationalError, + sqlite3.DatabaseError, +]) +def test_upsert_message_translation_exceptions(exception): + """ + Test that upsert_message_translation correctly re-raises sqlite3 exceptions. + """ + cursor = MagicMock() + message_id = 1 + translation_parameters_id = 2 + translation_text = "Sample translation text" + + # Configure the cursor.execute to raise the specified exception + cursor.execute.side_effect = exception("Simulated database error") + + with pytest.raises(exception): + upsert_message_translation(cursor, message_id, translation_parameters_id, translation_text) diff --git a/tests/test_hermeneisGPT.py b/tests/test_hermeneisGPT.py index 94a39d9..e2a84d2 100644 --- a/tests/test_hermeneisGPT.py +++ b/tests/test_hermeneisGPT.py @@ -1,12 +1,14 @@ # pylint: disable=missing-docstring +import argparse import sys import pytest import logging from os import path from unittest.mock import patch - sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) from hermeneisGPT import load_and_parse_config +from hermeneisGPT import main + def test_load_and_parse_config_success(tmp_path): directory = tmp_path / "sub" @@ -40,6 +42,7 @@ def test_load_and_parse_config_success(tmp_path): 'log': 'output.log', } + # Example of a failure to load due to file not found or other IO issues def test_load_and_parse_config_failure(tmp_path): non_existent_file_path = tmp_path / "does_not_exist.yaml" @@ -52,3 +55,44 @@ def test_load_and_parse_config_failure(tmp_path): # Check if 'error' was called at least once mock_error.assert_called_once() + +def test_argument_parsing(): + test_args = [ + "hermeneisGPT.py", + "--verbose", + "--debug", + "--yaml_config", "path/to/config.yml", + "--env", "path/to/.env", + "--mode", "auto-sqlite", + "--channel_name", "example_channel", + "--max_limit", "5", + "--sqlite_db", "path/to/database.db", + "--sqlite_schema", "path/to/schema.sql", + "--sqlite_chn_table", "custom_channels", + "--sqlite_chn_field", "custom_channel_name", + "--sqlite_msg_table", "custom_messages", + "--sqlite_msg_field", "custom_message_text" + ] + + with patch('sys.argv', test_args): + with patch('argparse.ArgumentParser.parse_args') as mock_parse: + # Set the return value of parse_args to simulate parsed arguments + mock_parse.return_value = argparse.Namespace( + verbose=True, + debug=True, + yaml_config="path/to/config.yml", + env="path/to/.env", + mode="auto-sqlite", + channel_name="example_channel", + max_limit="5", + sqlite_db="path/to/database.db", + sqlite_schema="path/to/schema.sql", + sqlite_chn_table="custom_channels", + sqlite_chn_field="custom_channel_name", + sqlite_msg_table="custom_messages", + sqlite_msg_field="custom_messages_text", + ) + main() + + # Verify parse_args was called indicating arguments were parsed + mock_parse.assert_called()