From 5d67936e793e4fec88a03f263cce8ae6f909421c Mon Sep 17 00:00:00 2001 From: Cryp Toon Date: Mon, 5 Feb 2024 17:43:51 +0100 Subject: [PATCH] Fix issues with autoincrement key_id and postgresql --- bitcoinlib/db.py | 5 ++--- bitcoinlib/db_cache.py | 5 ++--- bitcoinlib/wallets.py | 20 ++++++++++++++++---- tests/test_tools.py | 18 +++++++++--------- tests/test_wallets.py | 10 ++++------ 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/bitcoinlib/db.py b/bitcoinlib/db.py index f5d966d7..f8e9b2c2 100644 --- a/bitcoinlib/db.py +++ b/bitcoinlib/db.py @@ -23,7 +23,7 @@ ForeignKey, DateTime, LargeBinary, TypeDecorator) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.compiler import compiles -from sqlalchemy.orm import sessionmaker, relationship, close_all_sessions +from sqlalchemy.orm import sessionmaker, relationship, session from urllib.parse import urlparse from bitcoinlib.main import * from bitcoinlib.encoding import aes_encrypt, aes_decrypt @@ -97,7 +97,6 @@ def drop_db(self, yes_i_am_sure=False): if yes_i_am_sure: self.session.commit() self.session.close_all() - close_all_sessions() Base.metadata.drop_all(self.engine) @staticmethod @@ -286,7 +285,7 @@ class DbKey(Base): "depth=1 are the masterkeys children.") change = Column(Integer, doc="Change or normal address: Normal=0, Change=1") address_index = Column(BigInteger, doc="Index of address in HD key structure address level") - public = Column(LargeBinary(33), index=True, doc="Bytes representation of public key") + public = Column(LargeBinary(65), index=True, doc="Bytes representation of public key") private = Column(EncryptedBinary(48), doc="Bytes representation of private key") wif = Column(EncryptedString(128), index=True, doc="Public or private WIF (Wallet Import Format) representation") compressed = Column(Boolean, default=True, doc="Is key compressed or not. Default is True") diff --git a/bitcoinlib/db_cache.py b/bitcoinlib/db_cache.py index 36df4cdb..18b99fd1 100644 --- a/bitcoinlib/db_cache.py +++ b/bitcoinlib/db_cache.py @@ -21,7 +21,7 @@ from sqlalchemy import create_engine from sqlalchemy import Column, Integer, BigInteger, String, Boolean, ForeignKey, DateTime, Enum, LargeBinary from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, relationship, close_all_sessions +from sqlalchemy.orm import sessionmaker, relationship, session # try: # import mysql.connector # from parameterized import parameterized_class @@ -90,8 +90,7 @@ def __init__(self, db_uri=None): def drop_db(self): self.session.commit() - # self.session.close_all() - close_all_sessions() + self.session.close_all() Base.metadata.drop_all(self.engine) diff --git a/bitcoinlib/wallets.py b/bitcoinlib/wallets.py index e41f7f50..b80fe327 100644 --- a/bitcoinlib/wallets.py +++ b/bitcoinlib/wallets.py @@ -212,6 +212,7 @@ def wallet_empty(wallet, db_uri=None, db_password=None): else: w = session.query(DbWallet).filter_by(name=wallet) if not w or not w.first(): + session.close() raise WalletError("Wallet '%s' not found" % wallet) wallet_id = w.first().id @@ -375,13 +376,20 @@ def from_key(name, wallet_id, session, key, account_id=0, network=None, change=0 encoding = get_encoding_from_witness(witness_type) script_type = script_type_default(witness_type, multisig) + if not new_key_id: + key_id_max = session.query(func.max(DbKey.id)).scalar() + new_key_id = key_id_max + 1 if key_id_max else None + commit = True + else: + commit = False + if not key_is_address: if key_type != 'single' and k.depth != len(path.split('/'))-1: if path == 'm' and k.depth > 1: path = "M" address = k.address(encoding=encoding, script_type=script_type) - if not new_key_id: + if commit: keyexists = session.query(DbKey).\ filter(DbKey.wallet_id == wallet_id, DbKey.wif == k.wif(witness_type=witness_type, multisig=multisig, is_private=True)).first() @@ -422,10 +430,10 @@ def from_key(name, wallet_id, session, key, account_id=0, network=None, change=0 key_type=key_type, network_name=network, encoding=encoding, cosigner_id=cosigner_id, witness_type=witness_type) - if not new_key_id: + if commit: session.merge(DbNetwork(name=network)) session.add(nk) - if new_key_id is None: + if commit: session.commit() return WalletKey(nk.id, session, k) @@ -488,6 +496,9 @@ def __init__(self, key_id, session, hdkey_object=None): else: raise WalletError("Key with id %s not found" % key_id) + def __del__(self): + self._session.close() + def __repr__(self): return "" % (self.key_id, self.name, self.wif, self.path) @@ -1699,7 +1710,8 @@ def _new_key_multisig(self, public_keys, name, account_id, change, cosigner_id, if not name: name = "Multisig Key " + '/'.join(public_key_ids) - multisig_key = DbKey( + new_key_id = (self._session.query(func.max(DbKey.id)).scalar() or 0) + 1 + multisig_key = DbKey(id=new_key_id, name=name[:80], wallet_id=self.wallet_id, purpose=self.purpose, account_id=account_id, depth=depth, change=change, address_index=address_index, parent_id=0, is_private=False, path=path, public=address.hash_bytes, wif='multisig-%s' % address, address=address.address, cosigner_id=cosigner_id, diff --git a/tests/test_tools.py b/tests/test_tools.py index 9f50b0f7..e3db1c84 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -13,24 +13,21 @@ try: import mysql.connector - import psycopg2 - from psycopg2 import sql - from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT + import psycopg + from psycopg import sql except ImportError: pass # Only necessary when mysql or postgres is used -# from bitcoinlib.main import UNITTESTS_FULL_DATABASE_TEST -from bitcoinlib.db import BCL_DATABASE_DIR +from bitcoinlib.db import BCL_DATABASE_DIR, session from bitcoinlib.encoding import normalize_string -SQLITE_DATABASE_FILE = os.path.join(str(BCL_DATABASE_DIR), 'bitcoinlib.unittest.sqlite') DATABASE_NAME = 'bitcoinlib_unittest' def database_init(dbname=DATABASE_NAME): + session.close_all_sessions() if os.getenv('UNITTEST_DATABASE') == 'postgresql': - con = psycopg2.connect(user='postgres', host='localhost', password='postgres') - con.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + con = psycopg.connect(user='postgres', host='localhost', password='postgres', autocommit=True) cur = con.cursor() # try: cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format( @@ -48,7 +45,10 @@ def database_init(dbname=DATABASE_NAME): con.close() return 'postgresql://postgres:postgres@localhost:5432/' + dbname elif os.getenv('UNITTEST_DATABASE') == 'mysql': - con = mysql.connector.connect(user='user', host='localhost', password='password') + try: + con = mysql.connector.connect(user='root', host='localhost') + except mysql.connector.errors.ProgrammingError: + con = mysql.connector.connect(user='user', host='localhost', password='password') cur = con.cursor() cur.execute("DROP DATABASE IF EXISTS {}".format(dbname)) cur.execute("CREATE DATABASE {}".format(dbname)) diff --git a/tests/test_wallets.py b/tests/test_wallets.py index 76611072..a9890897 100644 --- a/tests/test_wallets.py +++ b/tests/test_wallets.py @@ -23,9 +23,8 @@ try: import mysql.connector - import psycopg2 - from psycopg2 import sql - from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT + import psycopg + from psycopg import sql except ImportError as e: print("Could not import all modules. Error: %s" % e) # from psycopg2cffi import compat # Use for PyPy support @@ -63,10 +62,9 @@ def database_init(dbname=DATABASE_NAME): - close_all_sessions() + session.close_all_sessions() if os.getenv('UNITTEST_DATABASE') == 'postgresql': - con = psycopg2.connect(user='postgres', host='localhost', password='postgres') - con.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + con = psycopg.connect(user='postgres', host='localhost', password='postgres', autocommit=True) cur = con.cursor() try: # cur.execute(sql.SQL("ALTER DATABASE {} allow_connections = off").format(sql.Identifier(dbname)))