From 4543192e1ad4a2f93dc61c4c4e06392dcb135da7 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Wed, 21 May 2025 17:53:25 +0000 Subject: [PATCH] adb: take lock in more places for example, adb.get_utxos() could previously return an inconsistent result --- electrum/address_synchronizer.py | 63 +++++++++++++++++++------------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index 0bcf7ee48..d60f219e9 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -105,6 +105,7 @@ class AddressSynchronizer(Logger, EventListener): def diagnostic_name(self): return self.name or "" + @with_lock def load_and_cleanup(self): self.load_local_history() self.check_history() @@ -146,6 +147,7 @@ class AddressSynchronizer(Logger, EventListener): """Return number of transactions where address is involved.""" return len(self._history_local.get(addr, ())) + @with_lock def get_txin_address(self, txin: TxInput) -> Optional[str]: if txin.address: return txin.address @@ -160,6 +162,7 @@ class AddressSynchronizer(Logger, EventListener): return tx.outputs()[prevout_n].address return None + @with_lock def get_txin_value(self, txin: TxInput, *, address: str = None) -> Optional[int]: if txin.value_sats() is not None: return txin.value_sats() @@ -179,6 +182,7 @@ class AddressSynchronizer(Logger, EventListener): return tx.outputs()[prevout_n].value return None + @with_lock def load_unverified_transactions(self): # review transactions that are in the history for addr in self.db.get_history(): @@ -198,8 +202,9 @@ class AddressSynchronizer(Logger, EventListener): @event_listener def on_event_blockchain_updated(self, *args): - self._get_balance_cache = {} # invalidate cache - self.db.put('stored_height', self.get_local_height()) + with self.lock: + self._get_balance_cache = {} # invalidate cache + self.db.put('stored_height', self.get_local_height()) async def stop(self): if self.network: @@ -252,6 +257,7 @@ class AddressSynchronizer(Logger, EventListener): conflicting_txns -= {tx_hash} return conflicting_txns + @with_lock def get_transaction(self, txid: str) -> Optional[Transaction]: tx = self.db.get_transaction(txid) if tx: @@ -422,6 +428,7 @@ class AddressSynchronizer(Logger, EventListener): children |= self.get_depending_transactions(other_hash) return children + @with_lock def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None: txid = tx.txid() assert txid is not None @@ -430,18 +437,18 @@ class AddressSynchronizer(Logger, EventListener): self.add_unverified_or_unconfirmed_tx(txid, tx_height) self.add_transaction(tx, allow_unrelated=True) + @with_lock def receive_history_callback(self, addr: str, hist, tx_fees: Dict[str, int]): - with self.lock: - old_hist = self.get_address_history(addr) - for tx_hash, height in old_hist.items(): - if (tx_hash, height) not in hist: - # make tx local - self.unverified_tx.pop(tx_hash, None) - self.unconfirmed_tx.pop(tx_hash, None) - self.db.remove_verified_tx(tx_hash) - if self.verifier: - self.verifier.remove_spv_proof_for_tx(tx_hash) - self.db.set_addr_history(addr, hist) + old_hist = self.get_address_history(addr) + for tx_hash, height in old_hist.items(): + if (tx_hash, height) not in hist: + # make tx local + self.unverified_tx.pop(tx_hash, None) + self.unconfirmed_tx.pop(tx_hash, None) + self.db.remove_verified_tx(tx_hash) + if self.verifier: + self.verifier.remove_spv_proof_for_tx(tx_hash) + self.db.set_addr_history(addr, hist) for tx_hash, tx_height in hist: # add it in case it was previously unconfirmed @@ -460,6 +467,7 @@ class AddressSynchronizer(Logger, EventListener): for tx_hash, fee_sat in tx_fees.items(): self.db.add_tx_fee_from_server(tx_hash, fee_sat) + @with_lock @profiler def load_local_history(self): self._history_local = {} # type: Dict[str, Set[str]] # address -> set(txid) @@ -467,6 +475,7 @@ class AddressSynchronizer(Logger, EventListener): for txid in itertools.chain(self.db.list_txi(), self.db.list_txo()): self._add_tx_to_local_history(txid) + @with_lock @profiler def check_history(self): hist_addrs_mine = list(filter(lambda k: self.is_mine(k), self.db.get_history())) @@ -482,6 +491,7 @@ class AddressSynchronizer(Logger, EventListener): if tx is not None: self.add_transaction(tx, allow_unrelated=True) + @with_lock def remove_local_transactions_we_dont_have(self): for txid in itertools.chain(self.db.list_txi(), self.db.list_txo()): tx_height = self.get_tx_height(txid).height @@ -749,6 +759,7 @@ class AddressSynchronizer(Logger, EventListener): delta += v return delta + @with_lock def get_tx_fee(self, txid: str) -> Optional[int]: """Returns tx_fee or None. Use server fee only if tx is unconfirmed and not mine. @@ -784,16 +795,15 @@ class AddressSynchronizer(Logger, EventListener): return None # compute fee if possible v_in = v_out = 0 - with self.lock: - for txin in tx.inputs(): - addr = self.get_txin_address(txin) - value = self.get_txin_value(txin, address=addr) - if value is None: - v_in = None - elif v_in is not None: - v_in += value - for txout in tx.outputs(): - v_out += txout.value + for txin in tx.inputs(): + addr = self.get_txin_address(txin) + value = self.get_txin_value(txin, address=addr) + if value is None: + v_in = None + elif v_in is not None: + v_in += value + for txout in tx.outputs(): + v_out += txout.value if v_in is not None: fee = v_in - v_out else: @@ -918,6 +928,7 @@ class AddressSynchronizer(Logger, EventListener): self._get_balance_cache[cache_key] = result return result + @with_lock @with_local_height_cached def get_utxos( self, @@ -974,6 +985,7 @@ class AddressSynchronizer(Logger, EventListener): coins = self.get_addr_utxo(address) return not bool(coins) + @with_lock @with_local_height_cached def address_is_old(self, address: str, *, req_conf: int = 3) -> bool: """Returns whether address has any history that is deeply confirmed. @@ -993,7 +1005,8 @@ class AddressSynchronizer(Logger, EventListener): max_conf = max(max_conf, tx_age) return max_conf >= req_conf - def get_spender(self, outpoint: str) -> str: + @with_lock + def get_spender(self, outpoint: str) -> Optional[str]: """ returns txid spending outpoint. subscribes to addresses as a side effect. @@ -1005,7 +1018,7 @@ class AddressSynchronizer(Logger, EventListener): if tx_mined_status.height in [TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE]: spender_txid = None if not spender_txid: - return + return None spender_tx = self.get_transaction(spender_txid) for i, o in enumerate(spender_tx.outputs()): if o.address is None: