From 6e214f61c80900d3b3728a007b51535b3b72faaf Mon Sep 17 00:00:00 2001
From: Lennart Jongeneel <lennart@1200wd.com>
Date: Thu, 8 Feb 2024 22:34:14 +0100
Subject: [PATCH] Add encrypted field test for mysql and postgresql

---
 tests/test_security.py | 42 ++++++++++++++++++++++++++++++++++--------
 1 file changed, 34 insertions(+), 8 deletions(-)

diff --git a/tests/test_security.py b/tests/test_security.py
index 6cb23f20..ea242ce3 100644
--- a/tests/test_security.py
+++ b/tests/test_security.py
@@ -25,16 +25,39 @@
 from bitcoinlib.wallets import Wallet
 from bitcoinlib.config.config import DATABASE_ENCRYPTION_ENABLED
 
-DATABASEFILE_UNITTESTS_ENCRYPTED = os.path.join(str(BCL_DATABASE_DIR), 'bitcoinlib.unittest_security.sqlite')
-# DATABASEFILE_UNITTESTS_ENCRYPTED = 'postgresql://postgres:postgres@localhost:5432/bitcoinlib_security'
 
+try:
+    import mysql.connector
+    import psycopg
+    from psycopg import sql
+except ImportError:
+    pass  # Only necessary when mysql or postgres is used
 
-class TestSecurity(TestCase):
 
-    @classmethod
-    def setUpClass(cls):
-        if os.path.isfile(DATABASEFILE_UNITTESTS_ENCRYPTED):
-            os.remove(DATABASEFILE_UNITTESTS_ENCRYPTED)
+if os.getenv('UNITTEST_DATABASE') == 'postgresql':
+    con = psycopg.connect(user='postgres', host='localhost', password='postgres', autocommit=True)
+    cur = con.cursor()
+    cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier('bitcoinlib_security')))
+    cur.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier('bitcoinlib_security')))
+    cur.close()
+    con.close()
+    DATABASEFILE_UNITTESTS_ENCRYPTED = 'postgresql://postgres:postgres@localhost:5432/bitcoinlib_security'
+elif os.getenv('UNITTEST_DATABASE') == 'mysql':
+    con = mysql.connector.connect(user='user', host='localhost', password='password')
+    cur = con.cursor()
+    cur.execute("DROP DATABASE IF EXISTS {}".format('bitcoinlib_security'))
+    cur.execute("CREATE DATABASE {}".format('bitcoinlib_security'))
+    con.commit()
+    cur.close()
+    con.close()
+    DATABASEFILE_UNITTESTS_ENCRYPTED = 'mysql://user:password@localhost:3306/bitcoinlib_security'
+else:
+    DATABASEFILE_UNITTESTS_ENCRYPTED = os.path.join(str(BCL_DATABASE_DIR), 'bitcoinlib.unittest_security.sqlite')
+    if os.path.isfile(DATABASEFILE_UNITTESTS_ENCRYPTED):
+        os.remove(DATABASEFILE_UNITTESTS_ENCRYPTED)
+
+
+class TestSecurity(TestCase):
 
     def test_security_wallet_field_encryption(self):
         pk = 'xprv9s21ZrQH143K2HrtPWvqgD8mUhMrrfE1ZME43baM8ti3hWgJwWX1wjHc25y2x11seT5G3KeHFY28MyTRxceeW22kMDAWsMDn7' \
@@ -54,7 +77,10 @@ def test_security_wallet_field_encryption(self):
         wallet.new_key()
         self.assertEqual(wallet.main_key.wif, pk)
 
-        db_query = text('SELECT wif, private FROM keys WHERE id=%d' % wallet._dbwallet.main_key_id)
+        if os.getenv('UNITTEST_DATABASE') == 'mysql':
+            db_query = text("SELECT wif, private FROM `keys` WHERE id=%d" % wallet._dbwallet.main_key_id)
+        else:
+            db_query = text("SELECT wif, private FROM keys WHERE id=%d" % wallet._dbwallet.main_key_id)
         encrypted_main_key_wif = wallet._session.execute(db_query).fetchone()[0]
         encrypted_main_key_private = wallet._session.execute(db_query).fetchone()[1]
         self.assertIn(type(encrypted_main_key_wif), (bytes, memoryview), "Encryption of database private key failed!")