From d76117b6758776129dd4130570ecc1fbaaec8ab5 Mon Sep 17 00:00:00 2001 From: Lennart Jongeneel Date: Sun, 18 Feb 2024 18:04:19 +0100 Subject: [PATCH] Move wallet db session to separate method --- bitcoinlib/wallets.py | 174 ++++++++++++++++++++++------------------- tests/test_security.py | 4 +- tests/test_wallets.py | 21 ++--- 3 files changed, 105 insertions(+), 94 deletions(-) diff --git a/bitcoinlib/wallets.py b/bitcoinlib/wallets.py index b80fe327..ed9c6f58 100644 --- a/bitcoinlib/wallets.py +++ b/bitcoinlib/wallets.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # BitcoinLib - Python Cryptocurrency Library # WALLETS - HD wallet Class for Key and Transaction management -# © 2016 - 2023 May - 1200 Web Development +# © 2016 - 2024 February - 1200 Web Development # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -311,7 +311,7 @@ def from_key(name, wallet_id, session, key, account_id=0, network=None, change=0 >>> w = wallet_create_or_open('hdwalletkey_test') >>> wif = 'xprv9s21ZrQH143K2mcs9jcK4EjALbu2z1N9qsMTUG1frmnXM3NNCSGR57yLhwTccfNCwdSQEDftgjCGm96P29wGGcbBsPqZH85iqpoHA7LrqVy' - >>> wk = WalletKey.from_key('import_key', w.wallet_id, w._session, wif) + >>> wk = WalletKey.from_key('import_key', w.wallet_id, w.session, wif) >>> wk.address '1MwVEhGq6gg1eeSrEdZom5bHyPqXtJSnPg' >>> wk # doctest:+ELLIPSIS @@ -439,9 +439,9 @@ def from_key(name, wallet_id, session, key, account_id=0, network=None, change=0 def _commit(self): try: - self._session.commit() + self.session.commit() except Exception: - self._session.rollback() + self.session.rollback() raise def __init__(self, key_id, session, hdkey_object=None): @@ -457,7 +457,7 @@ def __init__(self, key_id, session, hdkey_object=None): """ - self._session = session + self.session = session wk = session.query(DbKey).filter_by(id=key_id).first() if wk: self._dbkey = wk @@ -497,7 +497,7 @@ def __init__(self, key_id, session, hdkey_object=None): raise WalletError("Key with id %s not found" % key_id) def __del__(self): - self._session.close() + self.session.close() def __repr__(self): return "" % (self.key_id, self.name, self.wif, self.path) @@ -692,7 +692,7 @@ def from_txid(cls, hdwallet, txid): :return WalletClass: """ - sess = hdwallet._session + sess = hdwallet.session # If txid is unknown add it to database, else update db_tx_query = sess.query(DbTransaction). \ filter(DbTransaction.wallet_id == hdwallet.wallet_id, DbTransaction.txid == to_bytes(txid)) @@ -838,7 +838,7 @@ def send(self, offline=False): # Update db: Update spent UTXO's, add transaction to database for inp in self.inputs: txid = inp.prev_txid - utxos = self.hdwallet._session.query(DbTransactionOutput).join(DbTransaction).\ + utxos = self.hdwallet.session.query(DbTransactionOutput).join(DbTransaction).\ filter(DbTransaction.txid == txid, DbTransactionOutput.output_n == inp.output_n_int, DbTransactionOutput.spent.is_(False)).all() @@ -857,7 +857,7 @@ def store(self): :return int: Transaction index number """ - sess = self.hdwallet._session + sess = self.hdwallet.session # If txid is unknown add it to database, else update db_tx_query = sess.query(DbTransaction). \ filter(DbTransaction.wallet_id == self.hdwallet.wallet_id, DbTransaction.txid == bytes.fromhex(self.txid)) @@ -1013,7 +1013,7 @@ def delete(self): :return int: Number of deleted transactions """ - session = self.hdwallet._session + session = self.hdwallet.session txid = bytes.fromhex(self.txid) tx_query = session.query(DbTransaction).filter_by(txid=txid) tx = tx_query.scalar() @@ -1113,9 +1113,9 @@ def _create(cls, name, key, owner, network, account_id, purpose, scheme, parent_ def _commit(self): try: - self._session.commit() + self.session.commit() except Exception: - self._session.rollback() + self.session.rollback() raise @classmethod @@ -1331,10 +1331,10 @@ def create(cls, name, keys=None, owner='', network=None, account_id=0, purpose=0 db_uri=db_uri, db_cache_uri=db_cache_uri, db_password=db_password) hdpm.cosigner.append(w) wlt_cos_id += 1 - # hdpm._dbwallet = hdpm._session.query(DbWallet).filter(DbWallet.id == hdpm.wallet_id) + # hdpm._dbwallet = hdpm.session.query(DbWallet).filter(DbWallet.id == hdpm.wallet_id) # hdpm._dbwallet.update({DbWallet.cosigner_id: hdpm.cosigner_id}) # hdpm._dbwallet.update({DbWallet.key_path: hdpm.key_path}) - # hdpm._session.commit() + # hdpm.session.commit() return hdpm @@ -1357,18 +1357,20 @@ def __init__(self, wallet, db_uri=None, db_cache_uri=None, session=None, main_ke :type main_key_object: HDKey """ + self._session = None if session: self._session = session - else: - dbinit = Db(db_uri=db_uri, password=db_password) - self._session = dbinit.session - self._engine = dbinit.engine + # else: + # dbinit = Db(db_uri=db_uri, password=db_password) + # self.session = dbinit.session + # self._engine = dbinit.engine + self._db_password = db_password self.db_uri = db_uri self.db_cache_uri = db_cache_uri if isinstance(wallet, int) or wallet.isdigit(): - db_wlt = self._session.query(DbWallet).filter_by(id=wallet).scalar() + db_wlt = self.session.query(DbWallet).filter_by(id=wallet).scalar() else: - db_wlt = self._session.query(DbWallet).filter_by(name=wallet).scalar() + db_wlt = self.session.query(DbWallet).filter_by(name=wallet).scalar() if db_wlt: self._dbwallet = db_wlt self.wallet_id = db_wlt.id @@ -1383,12 +1385,12 @@ def __init__(self, wallet, db_uri=None, db_cache_uri=None, session=None, main_ke self.main_key = None self._default_account_id = db_wlt.default_account_id self.multisig_n_required = db_wlt.multisig_n_required - co_sign_wallets = self._session.query(DbWallet).\ + co_sign_wallets = self.session.query(DbWallet).\ filter(DbWallet.parent_id == self.wallet_id).order_by(DbWallet.name).all() self.cosigner = [Wallet(w.id, db_uri=db_uri, db_cache_uri=db_cache_uri) for w in co_sign_wallets] self.sort_keys = db_wlt.sort_keys if db_wlt.main_key_id: - self.main_key = WalletKey(self.main_key_id, session=self._session, hdkey_object=main_key_object) + self.main_key = WalletKey(self.main_key_id, session=self.session, hdkey_object=main_key_object) if self._default_account_id is None: self._default_account_id = 0 if self.main_key: @@ -1420,14 +1422,14 @@ def __init__(self, wallet, db_uri=None, db_cache_uri=None, session=None, main_ke def __exit__(self, exception_type, exception_value, traceback): try: - self._session.close() + self.session.close() self._engine.dispose() except Exception: pass def __del__(self): try: - self._session.close() + self.session.close() self._engine.dispose() except Exception: pass @@ -1465,7 +1467,7 @@ def _get_account_defaults(self, network=None, account_id=None, key_id=None): network = self.network.name if account_id is None and network == self.network.name: account_id = self.default_account_id - qr = self._session.query(DbKey).\ + qr = self.session.query(DbKey).\ filter_by(wallet_id=self.wallet_id, purpose=self.purpose, depth=self.depth_public_master, network_name=network) if account_id is not None: @@ -1488,7 +1490,7 @@ def default_account_id(self): @default_account_id.setter def default_account_id(self, value): self._default_account_id = value - self._dbwallet = self._session.query(DbWallet).filter(DbWallet.id == self.wallet_id). \ + self._dbwallet = self.session.query(DbWallet).filter(DbWallet.id == self.wallet_id). \ update({DbWallet.default_account_id: value}) self._commit() @@ -1514,7 +1516,7 @@ def owner(self, value): """ self._owner = value - self._dbwallet = self._session.query(DbWallet).filter(DbWallet.id == self.wallet_id).\ + self._dbwallet = self.session.query(DbWallet).filter(DbWallet.id == self.wallet_id).\ update({DbWallet.owner: value}) self._commit() @@ -1542,14 +1544,22 @@ def name(self, value): if wallet_exists(value, db_uri=self.db_uri): raise WalletError("Wallet with name '%s' already exists" % value) self._name = value - self._session.query(DbWallet).filter(DbWallet.id == self.wallet_id).update({DbWallet.name: value}) + self.session.query(DbWallet).filter(DbWallet.id == self.wallet_id).update({DbWallet.name: value}) self._commit() + @property + def session(self): + if not self._session: + dbinit = Db(db_uri=self.db_uri, password=self._db_password) + self._session = dbinit.session + self._engine = dbinit.engine + return self._session + def default_network_set(self, network): if not isinstance(network, Network): network = Network(network) self.network = network - self._session.query(DbWallet).filter(DbWallet.id == self.wallet_id).\ + self.session.query(DbWallet).filter(DbWallet.id == self.wallet_id).\ update({DbWallet.network_name: network.name}) self._commit() @@ -1592,11 +1602,11 @@ def import_master_key(self, hdkey, name='Masterkey (imported)'): # self.key_path = ks[0]['key_path'] self.key_path, _, _ = get_key_structure_data(self.witness_type, self.multisig) self.main_key = WalletKey.from_key( - key=hdkey, name=name, session=self._session, wallet_id=self.wallet_id, network=network, + key=hdkey, name=name, session=self.session, wallet_id=self.wallet_id, network=network, account_id=account_id, purpose=self.purpose, key_type='bip32', witness_type=self.witness_type) self.main_key_id = self.main_key.key_id self._key_objects.update({self.main_key_id: self.main_key}) - self._session.query(DbWallet).filter(DbWallet.id == self.wallet_id).\ + self.session.query(DbWallet).filter(DbWallet.id == self.wallet_id).\ update({DbWallet.main_key_id: self.main_key_id}) for key in self.keys(is_private=False): @@ -1662,7 +1672,7 @@ def import_key(self, key, account_id=0, name='', network=None, purpose=84, key_t if key_type == 'single': # Create path for unrelated import keys hdkey.depth = self.key_depth - last_import_key = self._session.query(DbKey).filter(DbKey.path.like("import_key_%")).\ + last_import_key = self.session.query(DbKey).filter(DbKey.path.like("import_key_%")).\ order_by(DbKey.path.desc()).first() if last_import_key: ik_path = "import_key_" + str(int(last_import_key.path[-5:]) + 1).zfill(5) @@ -1673,7 +1683,7 @@ def import_key(self, key, account_id=0, name='', network=None, purpose=84, key_t mk = WalletKey.from_key( key=hdkey, name=name, wallet_id=self.wallet_id, network=network, key_type=key_type, - account_id=account_id, purpose=purpose, session=self._session, path=ik_path, + account_id=account_id, purpose=purpose, session=self.session, path=ik_path, witness_type=self.witness_type) self._key_objects.update({mk.key_id: mk}) if mk.key_id == self.main_key.key_id: @@ -1701,7 +1711,7 @@ def _new_key_multisig(self, public_keys, name, account_id, change, cosigner_id, if witness_type == 'p2sh-segwit': script_type = 'p2sh_p2wsh' address = Address(redeemscript, script_type=script_type, network=network, witness_type=witness_type) - already_found_key = self._session.query(DbKey).filter_by(wallet_id=self.wallet_id, + already_found_key = self.session.query(DbKey).filter_by(wallet_id=self.wallet_id, address=address.address).first() if already_found_key: return self.key(already_found_key.id) @@ -1710,16 +1720,16 @@ def _new_key_multisig(self, public_keys, name, account_id, change, cosigner_id, if not name: name = "Multisig Key " + '/'.join(public_key_ids) - new_key_id = (self._session.query(func.max(DbKey.id)).scalar() or 0) + 1 + new_key_id = (self.session.query(func.max(DbKey.id)).scalar() or 0) + 1 multisig_key = DbKey(id=new_key_id, name=name[:80], wallet_id=self.wallet_id, purpose=self.purpose, account_id=account_id, depth=depth, change=change, address_index=address_index, parent_id=0, is_private=False, path=path, public=address.hash_bytes, wif='multisig-%s' % address, address=address.address, cosigner_id=cosigner_id, key_type='multisig', witness_type=witness_type, network_name=network) - self._session.add(multisig_key) + self.session.add(multisig_key) self._commit() for child_id in public_key_ids: - self._session.add(DbKeyMultisigChildren(key_order=public_key_ids.index(child_id), parent_id=multisig_key.id, + self.session.add(DbKeyMultisigChildren(key_order=public_key_ids.index(child_id), parent_id=multisig_key.id, child_id=int(child_id))) self._commit() return self.key(multisig_key.id) @@ -1797,7 +1807,7 @@ def new_keys(self, name='', account_id=None, change=0, cosigner_id=None, witness self.cosigner[cosigner_id].key_path == ['m'])): req_path = [] else: - prevkey = self._session.query(DbKey).\ + prevkey = self.session.query(DbKey).\ filter_by(wallet_id=self.wallet_id, purpose=purpose, network_name=network, account_id=account_id, witness_type=witness_type, change=change, cosigner_id=cosigner_id, depth=self.key_depth).\ order_by(DbKey.address_index.desc()).first() @@ -1891,7 +1901,7 @@ def scan(self, scan_gap_limit=5, account_id=None, change=None, rescan_used=False self.transactions_update_confirmations() # Check unconfirmed transactions - db_txs = self._session.query(DbTransaction). \ + db_txs = self.session.query(DbTransaction). \ filter(DbTransaction.wallet_id == self.wallet_id, DbTransaction.network_name == network, DbTransaction.confirmations == 0).all() for db_tx in db_txs: @@ -1939,14 +1949,14 @@ def _get_key(self, account_id=None, witness_type=None, network=None, cosigner_id (cosigner_id, len(self.cosigner))) witness_type = witness_type if witness_type else self.witness_type - last_used_qr = self._session.query(DbKey.id).\ + last_used_qr = self.session.query(DbKey.id).\ filter_by(wallet_id=self.wallet_id, account_id=account_id, network_name=network, cosigner_id=cosigner_id, used=True, change=change, depth=self.key_depth, witness_type=witness_type).\ order_by(DbKey.id.desc()).first() last_used_key_id = 0 if last_used_qr: last_used_key_id = last_used_qr.id - dbkey = (self._session.query(DbKey.id). + dbkey = (self.session.query(DbKey.id). filter_by(wallet_id=self.wallet_id, account_id=account_id, network_name=network, cosigner_id=cosigner_id, used=False, change=change, depth=self.key_depth, witness_type=witness_type). filter(DbKey.id > last_used_key_id). @@ -2099,7 +2109,7 @@ def new_account(self, name='', account_id=None, witness_type=None, network=None) # Determine account_id and name if account_id is None: account_id = 0 - qr = self._session.query(DbKey). \ + qr = self.session.query(DbKey). \ filter_by(wallet_id=self.wallet_id, witness_type=witness_type, network_name=network). \ order_by(DbKey.account_id.desc()).first() if qr: @@ -2277,7 +2287,7 @@ def keys_for_path(self, path, level_offset=None, name=None, account_id=None, cos wpath = ["M"] + fullpath[self.main_key.depth + 1:] dbkey = None while wpath and not dbkey: - qr = self._session.query(DbKey).filter_by(path=normalize_path('/'.join(wpath)), wallet_id=self.wallet_id) + qr = self.session.query(DbKey).filter_by(path=normalize_path('/'.join(wpath)), wallet_id=self.wallet_id) if recreate: qr = qr.filter_by(is_private=True) dbkey = qr.first() @@ -2325,7 +2335,7 @@ def keys_for_path(self, path, level_offset=None, name=None, account_id=None, cos nkey = WalletKey.from_key(key=ck, name=key_name, wallet_id=self.wallet_id, account_id=account_id, change=change, purpose=purpose, path=newpath, parent_id=parent_id, encoding=encoding, witness_type=witness_type, - cosigner_id=cosigner_id, network=network, session=self._session) + cosigner_id=cosigner_id, network=network, session=self.session) self._key_objects.update({nkey.key_id: nkey}) parent_id = nkey.key_id if nkey: @@ -2333,7 +2343,7 @@ def keys_for_path(self, path, level_offset=None, name=None, account_id=None, cos if len(new_keys) < number_of_keys: topkey = self._key_objects[new_keys[0].parent_id] parent_key = topkey.key() - new_key_id = self._session.query(DbKey.id).order_by(DbKey.id.desc()).first()[0] + 1 + new_key_id = self.session.query(DbKey.id).order_by(DbKey.id.desc()).first()[0] + 1 keys_to_add = [str(k_id) for k_id in range(int(fullpath[-1]) + len(new_keys), int(fullpath[-1]) + number_of_keys)] @@ -2346,8 +2356,8 @@ def keys_for_path(self, path, level_offset=None, name=None, account_id=None, cos key=ck, name=key_name, wallet_id=self.wallet_id, account_id=account_id, change=change, purpose=purpose, path=newpath, parent_id=parent_id, encoding=encoding, witness_type=witness_type, new_key_id=new_key_id, - cosigner_id=cosigner_id, network=network, session=self._session)) - self._session.commit() + cosigner_id=cosigner_id, network=network, session=self.session)) + self.session.commit() return new_keys @@ -2393,7 +2403,7 @@ def keys(self, account_id=None, name=None, key_id=None, change=None, depth=None, :return list of DbKey: List of Keys """ - qr = self._session.query(DbKey).filter_by(wallet_id=self.wallet_id).order_by(DbKey.id) + qr = self.session.query(DbKey).filter_by(wallet_id=self.wallet_id).order_by(DbKey.id) if network is not None: qr = qr.filter(DbKey.network_name == network) if witness_type is not None: @@ -2612,7 +2622,7 @@ def key(self, term): """ dbkey = None - qr = self._session.query(DbKey).filter_by(wallet_id=self.wallet_id) + qr = self.session.query(DbKey).filter_by(wallet_id=self.wallet_id) if isinstance(term, numbers.Number): dbkey = qr.filter_by(id=term).scalar() if not dbkey: @@ -2625,7 +2635,7 @@ def key(self, term): if dbkey.id in self._key_objects.keys(): return self._key_objects[dbkey.id] else: - hdwltkey = WalletKey(key_id=dbkey.id, session=self._session) + hdwltkey = WalletKey(key_id=dbkey.id, session=self.session) self._key_objects.update({dbkey.id: hdwltkey}) return hdwltkey else: @@ -2648,7 +2658,7 @@ def account(self, account_id): if "account'" not in self.key_path: raise WalletError("Accounts are not supported for this wallet. Account not found in key path %s" % self.key_path) - qr = self._session.query(DbKey).\ + qr = self.session.query(DbKey).\ filter_by(wallet_id=self.wallet_id, purpose=self.purpose, network_name=self.network.name, account_id=account_id, depth=3).scalar() if not qr: @@ -2690,7 +2700,7 @@ def witness_types(self, account_id=None, network=None): """ # network, account_id, _ = self._get_account_defaults(network, account_id) - qr = self._session.query(DbKey.witness_type).filter_by(wallet_id=self.wallet_id) + qr = self.session.query(DbKey.witness_type).filter_by(wallet_id=self.wallet_id) if network is not None: qr = qr.filter(DbKey.network_name == network) if account_id is not None: @@ -2711,7 +2721,7 @@ def networks(self, as_dict=False): nw_list = [self.network] if self.multisig and self.cosigner: - keys_qr = self._session.query(DbKey.network_name).\ + keys_qr = self.session.query(DbKey.network_name).\ filter_by(wallet_id=self.wallet_id, depth=self.key_depth).\ group_by(DbKey.network_name).all() nw_list += [Network(nw[0]) for nw in keys_qr] @@ -2824,7 +2834,7 @@ def _balance_update(self, account_id=None, network=None, key_id=None, min_confir :return: Updated balance """ - qr = self._session.query(DbTransactionOutput, func.sum(DbTransactionOutput.value), DbTransaction.network_name, + qr = self.session.query(DbTransactionOutput, func.sum(DbTransactionOutput.value), DbTransaction.network_name, DbTransaction.account_id).\ join(DbTransaction). \ filter(DbTransactionOutput.spent.is_(False), @@ -2896,7 +2906,7 @@ def _balance_update(self, account_id=None, network=None, key_id=None, min_confir for kb in key_balance_list: if kb['id'] in self._key_objects: self._key_objects[kb['id']]._balance = kb['balance'] - self._session.bulk_update_mappings(DbKey, key_balance_list) + self.session.bulk_update_mappings(DbKey, key_balance_list) self._commit() _logger.info("Got balance for %d key(s)" % len(key_balance_list)) return self._balances @@ -2948,7 +2958,7 @@ def utxos_update(self, account_id=None, used=None, networks=None, key_id=None, d single_key = None if key_id: - single_key = self._session.query(DbKey).filter_by(id=key_id).scalar() + single_key = self.session.query(DbKey).filter_by(id=key_id).scalar() networks = [single_key.network_name] account_id = single_key.account_id rescan_all = False @@ -2963,14 +2973,14 @@ def utxos_update(self, account_id=None, used=None, networks=None, key_id=None, d for network in networks: # Remove current UTXO's if rescan_all: - cur_utxos = self._session.query(DbTransactionOutput). \ + cur_utxos = self.session.query(DbTransactionOutput). \ join(DbTransaction). \ filter(DbTransactionOutput.spent.is_(False), DbTransaction.account_id == account_id, DbTransaction.wallet_id == self.wallet_id, DbTransaction.network_name == network).all() for u in cur_utxos: - self._session.query(DbTransactionOutput).filter_by( + self.session.query(DbTransactionOutput).filter_by( transaction_id=u.transaction_id, output_n=u.output_n).update({DbTransactionOutput.spent: True}) self._commit() @@ -3008,7 +3018,7 @@ def utxos_update(self, account_id=None, used=None, networks=None, key_id=None, d for utxo in utxos: key = single_key if not single_key: - key = self._session.query(DbKey).\ + key = self.session.query(DbKey).\ filter_by(wallet_id=self.wallet_id, address=utxo['address']).scalar() if not key: raise WalletError("Key with address %s not found in this wallet" % utxo['address']) @@ -3018,14 +3028,14 @@ def utxos_update(self, account_id=None, used=None, networks=None, key_id=None, d status = 'confirmed' # Update confirmations in db if utxo was already imported - transaction_in_db = self._session.query(DbTransaction).\ + transaction_in_db = self.session.query(DbTransaction).\ filter_by(wallet_id=self.wallet_id, txid=bytes.fromhex(utxo['txid']), network_name=network) - utxo_in_db = self._session.query(DbTransactionOutput).join(DbTransaction).\ + utxo_in_db = self.session.query(DbTransactionOutput).join(DbTransaction).\ filter(DbTransaction.wallet_id == self.wallet_id, DbTransaction.txid == bytes.fromhex(utxo['txid']), DbTransactionOutput.output_n == utxo['output_n']) - spent_in_db = self._session.query(DbTransactionInput).join(DbTransaction).\ + spent_in_db = self.session.query(DbTransactionInput).join(DbTransaction).\ filter(DbTransaction.wallet_id == self.wallet_id, DbTransactionInput.prev_txid == bytes.fromhex(utxo['txid']), DbTransactionInput.output_n == utxo['output_n']) @@ -3049,7 +3059,7 @@ def utxos_update(self, account_id=None, used=None, networks=None, key_id=None, d wallet_id=self.wallet_id, txid=bytes.fromhex(utxo['txid']), status=status, is_complete=False, block_height=block_height, account_id=account_id, confirmations=utxo['confirmations'], network_name=network) - self._session.add(new_tx) + self.session.add(new_tx) # TODO: Get unique id before inserting to increase performance for large utxo-sets self._commit() tid = new_tx.id @@ -3063,7 +3073,7 @@ def utxos_update(self, account_id=None, used=None, networks=None, key_id=None, d script=bytes.fromhex(utxo['script']), script_type=script_type, spent=bool(spent_in_db.count())) - self._session.add(new_utxo) + self.session.add(new_utxo) count_utxos += 1 self._commit() @@ -3100,7 +3110,7 @@ def utxos(self, account_id=None, network=None, min_confirms=0, key_id=None): first_key_id = key_id[0] network, account_id, acckey = self._get_account_defaults(network, account_id, first_key_id) - qr = self._session.query(DbTransactionOutput, DbKey.address, DbTransaction.confirmations, DbTransaction.txid, + qr = self.session.query(DbTransactionOutput, DbKey.address, DbTransaction.confirmations, DbTransaction.txid, DbKey.network_name).\ join(DbTransaction).join(DbKey). \ filter(DbTransactionOutput.spent.is_(False), @@ -3172,7 +3182,7 @@ def utxo_last(self, address): :return str: """ - to = self._session.query( + to = self.session.query( DbTransaction.txid, DbTransaction.confirmations). \ join(DbTransactionOutput).join(DbKey). \ filter(DbKey.address == address, DbTransaction.wallet_id == self.wallet_id, @@ -3189,11 +3199,11 @@ def transactions_update_confirmations(self): network = self.network.name srv = Service(network=network, providers=self.providers, cache_uri=self.db_cache_uri) blockcount = srv.blockcount() - db_txs = self._session.query(DbTransaction). \ + db_txs = self.session.query(DbTransaction). \ filter(DbTransaction.wallet_id == self.wallet_id, DbTransaction.network_name == network, DbTransaction.block_height > 0).all() for db_tx in db_txs: - self._session.query(DbTransaction).filter_by(id=db_tx.id). \ + self.session.query(DbTransaction).filter_by(id=db_tx.id). \ update({DbTransaction.status: 'confirmed', DbTransaction.confirmations: (blockcount - DbTransaction.block_height) + 1}) self._commit() @@ -3227,7 +3237,7 @@ def transactions_update_by_txids(self, txids): utxo_set.update(utxos) for utxo in list(utxo_set): - tos = self._session.query(DbTransactionOutput).join(DbTransaction). \ + tos = self.session.query(DbTransactionOutput).join(DbTransaction). \ filter(DbTransaction.txid == bytes.fromhex(utxo[0]), DbTransactionOutput.output_n == utxo[1], DbTransactionOutput.spent.is_(False)).all() for u in tos: @@ -3270,11 +3280,11 @@ def transactions_update(self, account_id=None, used=None, network=None, key_id=N srv = Service(network=network, providers=self.providers, cache_uri=self.db_cache_uri) blockcount = srv.blockcount() - db_txs = self._session.query(DbTransaction).\ + db_txs = self.session.query(DbTransaction).\ filter(DbTransaction.wallet_id == self.wallet_id, DbTransaction.network_name == network, DbTransaction.block_height > 0).all() for db_tx in db_txs: - self._session.query(DbTransaction).filter_by(id=db_tx.id).\ + self.session.query(DbTransaction).filter_by(id=db_tx.id).\ update({DbTransaction.status: 'confirmed', DbTransaction.confirmations: (blockcount - DbTransaction.block_height) + 1}) self._commit() @@ -3290,7 +3300,7 @@ def transactions_update(self, account_id=None, used=None, network=None, key_id=N if txs and txs[-1].date and txs[-1].date < last_updated: last_updated = txs[-1].date if txs and txs[-1].confirmations: - dbkey = self._session.query(DbKey).filter(DbKey.address == address, DbKey.wallet_id == self.wallet_id) + dbkey = self.session.query(DbKey).filter(DbKey.address == address, DbKey.wallet_id == self.wallet_id) if not dbkey.update({DbKey.latest_txid: bytes.fromhex(txs[-1].txid)}): raise WalletError("Failed to update latest transaction id for key with address %s" % address) self._commit() @@ -3305,7 +3315,7 @@ def transactions_update(self, account_id=None, used=None, network=None, key_id=N utxos = [(ti.prev_txid.hex(), ti.output_n_int) for ti in wt.inputs] utxo_set.update(utxos) for utxo in list(utxo_set): - tos = self._session.query(DbTransactionOutput).join(DbTransaction).\ + tos = self.session.query(DbTransactionOutput).join(DbTransaction).\ filter(DbTransaction.txid == bytes.fromhex(utxo[0]), DbTransactionOutput.output_n == utxo[1], DbTransactionOutput.spent.is_(False), DbTransaction.wallet_id == self.wallet_id).all() for u in tos: @@ -3326,7 +3336,7 @@ def transaction_last(self, address): :return str: """ - txid = self._session.query(DbKey.latest_txid).\ + txid = self.session.query(DbKey.latest_txid).\ filter(DbKey.address == address, DbKey.wallet_id == self.wallet_id).scalar() return '' if not txid else txid.hex() @@ -3357,7 +3367,7 @@ def transactions(self, account_id=None, network=None, include_new=False, key_id= network, account_id, acckey = self._get_account_defaults(network, account_id, key_id) # Transaction inputs - qr = self._session.query(DbTransactionInput, DbTransactionInput.address, DbTransaction.confirmations, + qr = self.session.query(DbTransactionInput, DbTransactionInput.address, DbTransaction.confirmations, DbTransaction.txid, DbTransaction.network_name, DbTransaction.status). \ join(DbTransaction).join(DbKey). \ filter(DbTransaction.account_id == account_id, @@ -3370,7 +3380,7 @@ def transactions(self, account_id=None, network=None, include_new=False, key_id= qr = qr.filter(or_(DbTransaction.status == 'confirmed', DbTransaction.status == 'unconfirmed')) txs = qr.all() # Transaction outputs - qr = self._session.query(DbTransactionOutput, DbTransactionOutput.address, DbTransaction.confirmations, + qr = self.session.query(DbTransactionOutput, DbTransactionOutput.address, DbTransaction.confirmations, DbTransaction.txid, DbTransaction.network_name, DbTransaction.status). \ join(DbTransaction).join(DbKey). \ filter(DbTransaction.account_id == account_id, @@ -3430,7 +3440,7 @@ def transactions_full(self, network=None, include_new=False, limit=0, offset=0): :return list of WalletTransaction: """ network, _, _ = self._get_account_defaults(network) - qr = self._session.query(DbTransaction.txid, DbTransaction.network_name, DbTransaction.status). \ + qr = self.session.query(DbTransaction.txid, DbTransaction.network_name, DbTransaction.status). \ filter(DbTransaction.wallet_id == self.wallet_id, DbTransaction.network_name == network) if not include_new: @@ -3510,7 +3520,7 @@ def transaction_spent(self, txid, output_n): txid = to_bytes(txid) if isinstance(output_n, bytes): output_n = int.from_bytes(output_n, 'big') - qr = self._session.query(DbTransactionInput, DbTransaction.confirmations, + qr = self.session.query(DbTransactionInput, DbTransaction.confirmations, DbTransaction.txid, DbTransaction.status). \ join(DbTransaction). \ filter(DbTransaction.wallet_id == self.wallet_id, @@ -3519,7 +3529,7 @@ def transaction_spent(self, txid, output_n): return qr.transaction.txid.hex() def _objects_by_key_id(self, key_id): - key = self._session.query(DbKey).filter_by(id=key_id).scalar() + key = self.session.query(DbKey).filter_by(id=key_id).scalar() if not key: raise WalletError("Key '%s' not found in this wallet" % key_id) if key.key_type == 'multisig': @@ -3571,7 +3581,7 @@ def select_inputs(self, amount, variance=None, input_key_id=None, account_id=Non if variance is None: variance = dust_amount - utxo_query = self._session.query(DbTransactionOutput).join(DbTransaction).join(DbKey). \ + utxo_query = self.session.query(DbTransactionOutput).join(DbTransaction).join(DbKey). \ filter(DbTransaction.wallet_id == self.wallet_id, DbTransaction.account_id == account_id, DbTransaction.network_name == network, DbKey.public != b'', DbTransactionOutput.spent.is_(False), DbTransaction.confirmations >= min_confirms) @@ -3796,7 +3806,7 @@ def transaction_create(self, output_arr, input_arr=None, input_key_id=None, acco if not (key_id and value and unlocking_script_type): if not isinstance(output_n, TYPE_INT): output_n = int.from_bytes(output_n, 'big') - inp_utxo = self._session.query(DbTransactionOutput).join(DbTransaction). \ + inp_utxo = self.session.query(DbTransactionOutput).join(DbTransaction). \ filter(DbTransaction.wallet_id == self.wallet_id, DbTransaction.txid == to_bytes(prev_txid), DbTransactionOutput.output_n == output_n).first() @@ -3809,7 +3819,7 @@ def transaction_create(self, output_arr, input_arr=None, input_key_id=None, acco else: _logger.info("UTXO %s not found in this wallet. Please update UTXO's if this is not an " "offline wallet" % to_hexstring(prev_txid)) - key_id = self._session.query(DbKey.id).\ + key_id = self.session.query(DbKey.id).\ filter(DbKey.wallet_id == self.wallet_id, DbKey.address == address).scalar() if not key_id: raise WalletError("UTXO %s and key with address %s not found in this wallet" % ( diff --git a/tests/test_security.py b/tests/test_security.py index 48e400c7..4235971c 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -81,8 +81,8 @@ def test_security_wallet_field_encryption(self): db_query = text("SELECT wif, private FROM `keys` WHERE id=%d" % wallet._dbwallet.main_key_id) else: db_query = text("SELECT wif, private FROM keys WHERE id=%d" % wallet._dbwallet.main_key_id) - encrypted_main_key_wif = wallet._session.execute(db_query).fetchone()[0] - encrypted_main_key_private = wallet._session.execute(db_query).fetchone()[1] + encrypted_main_key_wif = wallet.session.execute(db_query).fetchone()[0] + encrypted_main_key_private = wallet.session.execute(db_query).fetchone()[1] self.assertIn(type(encrypted_main_key_wif), (bytes, memoryview), "Encryption of database private key failed!") self.assertEqual(encrypted_main_key_wif.hex(), pk_wif_enc_hex) self.assertEqual(encrypted_main_key_private.hex(), pk_enc_hex) diff --git a/tests/test_wallets.py b/tests/test_wallets.py index 3e8f55dd..cf40abaf 100644 --- a/tests/test_wallets.py +++ b/tests/test_wallets.py @@ -44,6 +44,7 @@ print("DATABASE USED: %s" % os.getenv('UNITTEST_DATABASE')) + def database_init(dbname=DATABASE_NAME): session.close_all_sessions() if os.getenv('UNITTEST_DATABASE') == 'postgresql': @@ -639,24 +640,24 @@ def test_wallet_key_create_from_key(self): k1 = HDKey(network='testnet') k2 = HDKey(network='testnet') w1 = Wallet.create('network_mixup_test_wallet', network='litecoin', db_uri=self.database_uri) - wk1 = WalletKey.from_key('key1', w1.wallet_id, w1._session, key=k1.address_obj) + wk1 = WalletKey.from_key('key1', w1.wallet_id, w1.session, key=k1.address_obj) self.assertEqual(wk1.network.name, 'testnet') self.assertRaisesRegex(WalletError, "Specified network and key network should be the same", - WalletKey.from_key, 'key2', w1.wallet_id, w1._session, key=k2.address_obj, + WalletKey.from_key, 'key2', w1.wallet_id, w1.session, key=k2.address_obj, network='bitcoin') w2 = Wallet.create('network_mixup_test_wallet2', network='litecoin', db_uri=self.database_uri) - wk2 = WalletKey.from_key('key1', w2.wallet_id, w2._session, key=k1) + wk2 = WalletKey.from_key('key1', w2.wallet_id, w2.session, key=k1) self.assertEqual(wk2.network.name, 'testnet') self.assertRaisesRegex(WalletError, "Specified network and key network should be the same", - WalletKey.from_key, 'key2', w2.wallet_id, w2._session, key=k2, + WalletKey.from_key, 'key2', w2.wallet_id, w2.session, key=k2, network='bitcoin') - wk3 = WalletKey.from_key('key3', w2.wallet_id, w2._session, key=k1) + wk3 = WalletKey.from_key('key3', w2.wallet_id, w2.session, key=k1) self.assertEqual(wk3.name, 'key1') - wk4 = WalletKey.from_key('key4', w2.wallet_id, w2._session, key=k1.address_obj) + wk4 = WalletKey.from_key('key4', w2.wallet_id, w2.session, key=k1.address_obj) self.assertEqual(wk4.name, 'key1') k = HDKey().public_master() w = Wallet.create('pmtest', network='litecoin', db_uri=self.database_uri) - wk1 = WalletKey.from_key('key', w.wallet_id, w._session, key=k) + wk1 = WalletKey.from_key('key', w.wallet_id, w.session, key=k) self.assertEqual(wk1.path, 'M') # Test __repr__ method self.assertIn("