diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index 8ec7f3290..3d6cc367a 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -22,6 +22,7 @@ # SOFTWARE. import asyncio +import copy import threading import itertools from collections import defaultdict @@ -99,9 +100,15 @@ class AddressSynchronizer(Logger, EventListener): self.threadlocal_cache = threading.local() self._get_balance_cache = {} + self._get_utxos_cache = {} self.load_and_cleanup() + @with_lock + def invalidate_cache(self): + self._get_balance_cache.clear() + self._get_utxos_cache.clear() + def diagnostic_name(self): return self.name or "" @@ -128,6 +135,7 @@ class AddressSynchronizer(Logger, EventListener): def get_addresses(self): return sorted(self.db.get_history()) + @with_lock def get_address_history(self, addr: str) -> Dict[str, int]: """Returns the history for the address, as a txid->height dict. In addition to what we have from the server, this includes local and future txns. @@ -136,11 +144,10 @@ class AddressSynchronizer(Logger, EventListener): so that only includes txns the server sees. """ h = {} - with self.lock: - related_txns = self._history_local.get(addr, set()) - for tx_hash in related_txns: - tx_height = self.get_tx_height(tx_hash).height - h[tx_hash] = tx_height + related_txns = self._history_local.get(addr, set()) + for tx_hash in related_txns: + tx_height = self.get_tx_height(tx_hash).height + h[tx_hash] = tx_height return h def get_address_history_len(self, addr: str) -> int: @@ -201,10 +208,10 @@ class AddressSynchronizer(Logger, EventListener): self.register_callbacks() @event_listener + @with_lock def on_event_blockchain_updated(self, *args): - with self.lock: - self._get_balance_cache = {} # invalidate cache - self.db.put('stored_height', self.get_local_height()) + self.invalidate_cache() + self.db.put('stored_height', self.get_local_height()) async def stop(self): if self.network: @@ -227,6 +234,7 @@ class AddressSynchronizer(Logger, EventListener): self.synchronizer.add(address) self.up_to_date_changed() + @with_lock def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool = False) -> Set[str]: """Returns a set of transaction hashes from the wallet history that are directly conflicting with tx, i.e. they have common outpoints being @@ -236,27 +244,26 @@ class AddressSynchronizer(Logger, EventListener): conflict (if already in wallet history) """ conflicting_txns = set() - with self.lock: - for txin in tx.inputs(): - if txin.is_coinbase_input(): - continue - prevout_hash = txin.prevout.txid.hex() - prevout_n = txin.prevout.out_idx - spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n) - if spending_tx_hash is None: - continue - # this outpoint has already been spent, by spending_tx - # annoying assert that has revealed several bugs over time: - assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db" - conflicting_txns |= {spending_tx_hash} - if tx_hash := tx.txid(): - if tx_hash in conflicting_txns: - # this tx is already in history, so it conflicts with itself - if len(conflicting_txns) > 1: - raise Exception('Found conflicting transactions already in wallet history.') - if not include_self: - conflicting_txns -= {tx_hash} - return conflicting_txns + for txin in tx.inputs(): + if txin.is_coinbase_input(): + continue + prevout_hash = txin.prevout.txid.hex() + prevout_n = txin.prevout.out_idx + spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n) + if spending_tx_hash is None: + continue + # this outpoint has already been spent, by spending_tx + # annoying assert that has revealed several bugs over time: + assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db" + conflicting_txns |= {spending_tx_hash} + if tx_hash := tx.txid(): + if tx_hash in conflicting_txns: + # this tx is already in history, so it conflicts with itself + if len(conflicting_txns) > 1: + raise Exception('Found conflicting transactions already in wallet history.') + if not include_self: + conflicting_txns -= {tx_hash} + return conflicting_txns @with_lock def get_transaction(self, txid: str) -> Optional[Transaction]: @@ -335,7 +342,7 @@ class AddressSynchronizer(Logger, EventListener): pass else: self.db.add_txi_addr(tx_hash, addr, ser, v) - self._get_balance_cache.clear() # invalidate cache + self.invalidate_cache() for txi in tx.inputs(): if txi.is_coinbase_input(): continue @@ -353,7 +360,7 @@ class AddressSynchronizer(Logger, EventListener): addr = txo.address if addr and self.is_mine(addr): self.db.add_txo_addr(tx_hash, addr, n, v, is_coinbase) - self._get_balance_cache.clear() # invalidate cache + self.invalidate_cache() # give v to txi that spends me next_tx = self.db.get_spent_outpoint(tx_hash, n) if next_tx is not None: @@ -368,15 +375,15 @@ class AddressSynchronizer(Logger, EventListener): util.trigger_callback('adb_added_tx', self, tx_hash, tx) return True + @with_lock def remove_transaction(self, tx_hash: str) -> None: """Removes a transaction AND all its dependents/children from the wallet history. """ - with self.lock: - to_remove = {tx_hash} - to_remove |= self.get_depending_transactions(tx_hash) - for txid in to_remove: - self._remove_transaction(txid) + to_remove = {tx_hash} + to_remove |= self.get_depending_transactions(tx_hash) + for txid in to_remove: + self._remove_transaction(txid) def _remove_transaction(self, tx_hash: str) -> None: """Removes a single transaction from the wallet history, and attempts @@ -405,7 +412,7 @@ class AddressSynchronizer(Logger, EventListener): remove_from_spent_outpoints() self._remove_tx_from_local_history(tx_hash) for addr in itertools.chain(self.db.get_txi_addresses(tx_hash), self.db.get_txo_addresses(tx_hash)): - self._get_balance_cache.clear() # invalidate cache + self.invalidate_cache() self.db.remove_txi(tx_hash) self.db.remove_txo(tx_hash) self.db.remove_tx_fee(tx_hash) @@ -419,15 +426,15 @@ class AddressSynchronizer(Logger, EventListener): self.db.remove_prevout_by_scripthash(scripthash, prevout=prevout, value=txo.value) util.trigger_callback('adb_removed_tx', self, tx_hash, tx) + @with_lock def get_depending_transactions(self, tx_hash: str) -> Set[str]: """Returns all (grand-)children of tx_hash in this wallet.""" - with self.lock: - children = set() - for n in self.db.get_spent_outpoints(tx_hash): - other_hash = self.db.get_spent_outpoint(tx_hash, n) - children.add(other_hash) - children |= self.get_depending_transactions(other_hash) - return children + children = set() + for n in self.db.get_spent_outpoints(tx_hash): + other_hash = self.db.get_spent_outpoint(tx_hash, n) + children.add(other_hash) + children |= self.get_depending_transactions(other_hash) + return children @with_lock def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None: @@ -499,19 +506,19 @@ class AddressSynchronizer(Logger, EventListener): if tx_height == TX_HEIGHT_LOCAL and not self.db.get_transaction(txid): self.remove_transaction(txid) + @with_lock def clear_history(self): - with self.lock: - self.db.clear_history() - self._history_local.clear() - self._get_balance_cache.clear() # invalidate cache + self.db.clear_history() + self._history_local.clear() + self.invalidate_cache() + @with_lock def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]: """Returns a key to be used for sorting txs.""" - with self.lock: - tx_mined_info = self.get_tx_height(tx_hash) - height = self.tx_height_to_sort_height(tx_mined_info.height) - txpos = tx_mined_info.txpos or -1 - return height, txpos + tx_mined_info = self.get_tx_height(tx_hash) + height = self.tx_height_to_sort_height(tx_mined_info.height) + txpos = tx_mined_info.txpos or -1 + return height, txpos @classmethod def tx_height_to_sort_height(cls, height: int = None): @@ -578,26 +585,26 @@ class AddressSynchronizer(Logger, EventListener): raise Exception("wallet.get_history() failed balance sanity-check") return h2 + @with_lock def _add_tx_to_local_history(self, txid): - with self.lock: - for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)): - cur_hist = self._history_local.get(addr, set()) - cur_hist.add(txid) + for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)): + cur_hist = self._history_local.get(addr, set()) + cur_hist.add(txid) + self._history_local[addr] = cur_hist + self._mark_address_history_changed(addr) + + @with_lock + def _remove_tx_from_local_history(self, txid): + for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)): + cur_hist = self._history_local.get(addr, set()) + try: + cur_hist.remove(txid) + except KeyError: + pass + else: self._history_local[addr] = cur_hist self._mark_address_history_changed(addr) - def _remove_tx_from_local_history(self, txid): - with self.lock: - for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)): - cur_hist = self._history_local.get(addr, set()) - try: - cur_hist.remove(txid) - except KeyError: - pass - else: - self._history_local[addr] = cur_hist - self._mark_address_history_changed(addr) - def _mark_address_history_changed(self, addr: str) -> None: def set_and_clear(): event = self._address_history_changed_events[addr] @@ -617,27 +624,27 @@ class AddressSynchronizer(Logger, EventListener): assert self.is_mine(addr), "address needs to be is_mine to be watched" await self._address_history_changed_events[addr].wait() + @with_lock def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None: assert tx_height >= TX_HEIGHT_UNCONF_PARENT, f"got {tx_height=} for {tx_hash=}" # forbid local/future txs here - with self.lock: - if self.db.is_in_verified_tx(tx_hash): - if tx_height <= 0: - # tx was previously SPV-verified but now in mempool (probably reorg) - self.db.remove_verified_tx(tx_hash) - self.unconfirmed_tx[tx_hash] = tx_height - if self.verifier: - self.verifier.remove_spv_proof_for_tx(tx_hash) + if self.db.is_in_verified_tx(tx_hash): + if tx_height <= 0: + # tx was previously SPV-verified but now in mempool (probably reorg) + self.db.remove_verified_tx(tx_hash) + self.unconfirmed_tx[tx_hash] = tx_height + if self.verifier: + self.verifier.remove_spv_proof_for_tx(tx_hash) + else: + if tx_height > 0: + self.unverified_tx[tx_hash] = tx_height else: - if tx_height > 0: - self.unverified_tx[tx_hash] = tx_height - else: - self.unconfirmed_tx[tx_hash] = tx_height + self.unconfirmed_tx[tx_hash] = tx_height + @with_lock def remove_unverified_tx(self, tx_hash: str, tx_height: int) -> None: - with self.lock: - new_height = self.unverified_tx.get(tx_hash) - if new_height == tx_height: - self.unverified_tx.pop(tx_hash, None) + new_height = self.unverified_tx.get(tx_hash) + if new_height == tx_height: + self.unverified_tx.pop(tx_hash, None) def add_verified_tx(self, tx_hash: str, info: TxMinedInfo): # Remove from the unverified map and add to the verified map @@ -646,10 +653,10 @@ class AddressSynchronizer(Logger, EventListener): self.db.add_verified_tx(tx_hash, info) util.trigger_callback('adb_added_verified_tx', self, tx_hash) + @with_lock def get_unverified_txs(self) -> Dict[str, int]: '''Returns a map from tx hash to transaction height''' - with self.lock: - return dict(self.unverified_tx) # copy + return dict(self.unverified_tx) # copy def undo_verifications(self, blockchain: Blockchain, above_height: int) -> Set[str]: '''Used by the verifier when a reorg has happened''' @@ -830,20 +837,20 @@ class AddressSynchronizer(Logger, EventListener): self.db.add_num_inputs_to_tx(txid, len(tx.inputs())) return fee + @with_lock def get_addr_io(self, address: str): - with self.lock: - h = self.get_address_history(address).items() - received = {} - sent = {} - for tx_hash, height in h: - tx_mined_info = self.get_tx_height(tx_hash) - txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1 - d = self.db.get_txo_addr(tx_hash, address) - for n, (v, is_cb) in d.items(): - received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb) - l = self.db.get_txi_addr(tx_hash, address) - for txi, v in l: - sent[txi] = tx_hash, height, txpos + h = self.get_address_history(address).items() + received = {} + sent = {} + for tx_hash, height in h: + tx_mined_info = self.get_tx_height(tx_hash) + txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1 + d = self.db.get_txo_addr(tx_hash, address) + for n, (v, is_cb) in d.items(): + received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb) + l = self.db.get_txi_addr(tx_hash, address) + for txi, v in l: + sent[txi] = tx_hash, height, txpos return received, sent def get_addr_outputs(self, address: str) -> Dict[TxOutpoint, PartialTxInput]: @@ -970,6 +977,13 @@ class AddressSynchronizer(Logger, EventListener): if excluded_addresses: domain = set(domain) - set(excluded_addresses) mempool_height = block_height + 1 # height of next block + cache_key = sha256( + ','.join(sorted(domain)) + + f";{mature_only};{confirmed_funding_only};{confirmed_spending_only};{nonlocal_only};{block_height}" + ) + cached = self._get_utxos_cache.get(cache_key) + if cached is not None: + return copy.deepcopy(cached) for addr in domain: txos = self.get_addr_outputs(addr) for txo in txos.values(): @@ -987,6 +1001,7 @@ class AddressSynchronizer(Logger, EventListener): continue coins.append(txo) continue + self._get_utxos_cache[cache_key] = copy.deepcopy(coins) return coins def is_used(self, address: str) -> bool: