1
0

address_synchronizer: apply @with_lock where applicable

This commit is contained in:
наб
2025-06-15 19:14:08 +02:00
parent 4887fb3e7f
commit fdaafd5abf

View File

@@ -128,6 +128,7 @@ class AddressSynchronizer(Logger, EventListener):
def get_addresses(self): def get_addresses(self):
return sorted(self.db.get_history()) return sorted(self.db.get_history())
@with_lock
def get_address_history(self, addr: str) -> Dict[str, int]: def get_address_history(self, addr: str) -> Dict[str, int]:
"""Returns the history for the address, as a txid->height dict. """Returns the history for the address, as a txid->height dict.
In addition to what we have from the server, this includes local and future txns. In addition to what we have from the server, this includes local and future txns.
@@ -136,11 +137,10 @@ class AddressSynchronizer(Logger, EventListener):
so that only includes txns the server sees. so that only includes txns the server sees.
""" """
h = {} h = {}
with self.lock: related_txns = self._history_local.get(addr, set())
related_txns = self._history_local.get(addr, set()) for tx_hash in related_txns:
for tx_hash in related_txns: tx_height = self.get_tx_height(tx_hash).height
tx_height = self.get_tx_height(tx_hash).height h[tx_hash] = tx_height
h[tx_hash] = tx_height
return h return h
def get_address_history_len(self, addr: str) -> int: def get_address_history_len(self, addr: str) -> int:
@@ -201,10 +201,10 @@ class AddressSynchronizer(Logger, EventListener):
self.register_callbacks() self.register_callbacks()
@event_listener @event_listener
@with_lock
def on_event_blockchain_updated(self, *args): def on_event_blockchain_updated(self, *args):
with self.lock: self._get_balance_cache = {} # invalidate cache
self._get_balance_cache = {} # invalidate cache self.db.put('stored_height', self.get_local_height())
self.db.put('stored_height', self.get_local_height())
async def stop(self): async def stop(self):
if self.network: if self.network:
@@ -227,6 +227,7 @@ class AddressSynchronizer(Logger, EventListener):
self.synchronizer.add(address) self.synchronizer.add(address)
self.up_to_date_changed() self.up_to_date_changed()
@with_lock
def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool = False) -> Set[str]: def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool = False) -> Set[str]:
"""Returns a set of transaction hashes from the wallet history that are """Returns a set of transaction hashes from the wallet history that are
directly conflicting with tx, i.e. they have common outpoints being directly conflicting with tx, i.e. they have common outpoints being
@@ -236,27 +237,26 @@ class AddressSynchronizer(Logger, EventListener):
conflict (if already in wallet history) conflict (if already in wallet history)
""" """
conflicting_txns = set() conflicting_txns = set()
with self.lock: for txin in tx.inputs():
for txin in tx.inputs(): if txin.is_coinbase_input():
if txin.is_coinbase_input(): continue
continue prevout_hash = txin.prevout.txid.hex()
prevout_hash = txin.prevout.txid.hex() prevout_n = txin.prevout.out_idx
prevout_n = txin.prevout.out_idx spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n) if spending_tx_hash is None:
if spending_tx_hash is None: continue
continue # this outpoint has already been spent, by spending_tx
# this outpoint has already been spent, by spending_tx # annoying assert that has revealed several bugs over time:
# annoying assert that has revealed several bugs over time: assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db"
assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db" conflicting_txns |= {spending_tx_hash}
conflicting_txns |= {spending_tx_hash} if tx_hash := tx.txid():
if tx_hash := tx.txid(): if tx_hash in conflicting_txns:
if tx_hash in conflicting_txns: # this tx is already in history, so it conflicts with itself
# this tx is already in history, so it conflicts with itself if len(conflicting_txns) > 1:
if len(conflicting_txns) > 1: raise Exception('Found conflicting transactions already in wallet history.')
raise Exception('Found conflicting transactions already in wallet history.') if not include_self:
if not include_self: conflicting_txns -= {tx_hash}
conflicting_txns -= {tx_hash} return conflicting_txns
return conflicting_txns
@with_lock @with_lock
def get_transaction(self, txid: str) -> Optional[Transaction]: def get_transaction(self, txid: str) -> Optional[Transaction]:
@@ -368,15 +368,15 @@ class AddressSynchronizer(Logger, EventListener):
util.trigger_callback('adb_added_tx', self, tx_hash, tx) util.trigger_callback('adb_added_tx', self, tx_hash, tx)
return True return True
@with_lock
def remove_transaction(self, tx_hash: str) -> None: def remove_transaction(self, tx_hash: str) -> None:
"""Removes a transaction AND all its dependents/children """Removes a transaction AND all its dependents/children
from the wallet history. from the wallet history.
""" """
with self.lock: to_remove = {tx_hash}
to_remove = {tx_hash} to_remove |= self.get_depending_transactions(tx_hash)
to_remove |= self.get_depending_transactions(tx_hash) for txid in to_remove:
for txid in to_remove: self._remove_transaction(txid)
self._remove_transaction(txid)
def _remove_transaction(self, tx_hash: str) -> None: def _remove_transaction(self, tx_hash: str) -> None:
"""Removes a single transaction from the wallet history, and attempts """Removes a single transaction from the wallet history, and attempts
@@ -419,15 +419,15 @@ class AddressSynchronizer(Logger, EventListener):
self.db.remove_prevout_by_scripthash(scripthash, prevout=prevout, value=txo.value) self.db.remove_prevout_by_scripthash(scripthash, prevout=prevout, value=txo.value)
util.trigger_callback('adb_removed_tx', self, tx_hash, tx) util.trigger_callback('adb_removed_tx', self, tx_hash, tx)
@with_lock
def get_depending_transactions(self, tx_hash: str) -> Set[str]: def get_depending_transactions(self, tx_hash: str) -> Set[str]:
"""Returns all (grand-)children of tx_hash in this wallet.""" """Returns all (grand-)children of tx_hash in this wallet."""
with self.lock: children = set()
children = set() for n in self.db.get_spent_outpoints(tx_hash):
for n in self.db.get_spent_outpoints(tx_hash): other_hash = self.db.get_spent_outpoint(tx_hash, n)
other_hash = self.db.get_spent_outpoint(tx_hash, n) children.add(other_hash)
children.add(other_hash) children |= self.get_depending_transactions(other_hash)
children |= self.get_depending_transactions(other_hash) return children
return children
@with_lock @with_lock
def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None: def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None:
@@ -499,19 +499,19 @@ class AddressSynchronizer(Logger, EventListener):
if tx_height == TX_HEIGHT_LOCAL and not self.db.get_transaction(txid): if tx_height == TX_HEIGHT_LOCAL and not self.db.get_transaction(txid):
self.remove_transaction(txid) self.remove_transaction(txid)
@with_lock
def clear_history(self): def clear_history(self):
with self.lock: self.db.clear_history()
self.db.clear_history() self._history_local.clear()
self._history_local.clear() self._get_balance_cache.clear() # invalidate cache
self._get_balance_cache.clear() # invalidate cache
@with_lock
def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]: def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]:
"""Returns a key to be used for sorting txs.""" """Returns a key to be used for sorting txs."""
with self.lock: tx_mined_info = self.get_tx_height(tx_hash)
tx_mined_info = self.get_tx_height(tx_hash) height = self.tx_height_to_sort_height(tx_mined_info.height)
height = self.tx_height_to_sort_height(tx_mined_info.height) txpos = tx_mined_info.txpos or -1
txpos = tx_mined_info.txpos or -1 return height, txpos
return height, txpos
@classmethod @classmethod
def tx_height_to_sort_height(cls, height: int = None): def tx_height_to_sort_height(cls, height: int = None):
@@ -578,26 +578,26 @@ class AddressSynchronizer(Logger, EventListener):
raise Exception("wallet.get_history() failed balance sanity-check") raise Exception("wallet.get_history() failed balance sanity-check")
return h2 return h2
@with_lock
def _add_tx_to_local_history(self, txid): def _add_tx_to_local_history(self, txid):
with self.lock: for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)): cur_hist = self._history_local.get(addr, set())
cur_hist = self._history_local.get(addr, set()) cur_hist.add(txid)
cur_hist.add(txid) self._history_local[addr] = cur_hist
self._mark_address_history_changed(addr)
@with_lock
def _remove_tx_from_local_history(self, txid):
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
cur_hist = self._history_local.get(addr, set())
try:
cur_hist.remove(txid)
except KeyError:
pass
else:
self._history_local[addr] = cur_hist self._history_local[addr] = cur_hist
self._mark_address_history_changed(addr) self._mark_address_history_changed(addr)
def _remove_tx_from_local_history(self, txid):
with self.lock:
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
cur_hist = self._history_local.get(addr, set())
try:
cur_hist.remove(txid)
except KeyError:
pass
else:
self._history_local[addr] = cur_hist
self._mark_address_history_changed(addr)
def _mark_address_history_changed(self, addr: str) -> None: def _mark_address_history_changed(self, addr: str) -> None:
def set_and_clear(): def set_and_clear():
event = self._address_history_changed_events[addr] event = self._address_history_changed_events[addr]
@@ -617,27 +617,27 @@ class AddressSynchronizer(Logger, EventListener):
assert self.is_mine(addr), "address needs to be is_mine to be watched" assert self.is_mine(addr), "address needs to be is_mine to be watched"
await self._address_history_changed_events[addr].wait() await self._address_history_changed_events[addr].wait()
@with_lock
def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None: def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None:
assert tx_height >= TX_HEIGHT_UNCONF_PARENT, f"got {tx_height=} for {tx_hash=}" # forbid local/future txs here assert tx_height >= TX_HEIGHT_UNCONF_PARENT, f"got {tx_height=} for {tx_hash=}" # forbid local/future txs here
with self.lock: if self.db.is_in_verified_tx(tx_hash):
if self.db.is_in_verified_tx(tx_hash): if tx_height <= 0:
if tx_height <= 0: # tx was previously SPV-verified but now in mempool (probably reorg)
# tx was previously SPV-verified but now in mempool (probably reorg) self.db.remove_verified_tx(tx_hash)
self.db.remove_verified_tx(tx_hash) self.unconfirmed_tx[tx_hash] = tx_height
self.unconfirmed_tx[tx_hash] = tx_height if self.verifier:
if self.verifier: self.verifier.remove_spv_proof_for_tx(tx_hash)
self.verifier.remove_spv_proof_for_tx(tx_hash) else:
if tx_height > 0:
self.unverified_tx[tx_hash] = tx_height
else: else:
if tx_height > 0: self.unconfirmed_tx[tx_hash] = tx_height
self.unverified_tx[tx_hash] = tx_height
else:
self.unconfirmed_tx[tx_hash] = tx_height
@with_lock
def remove_unverified_tx(self, tx_hash: str, tx_height: int) -> None: def remove_unverified_tx(self, tx_hash: str, tx_height: int) -> None:
with self.lock: new_height = self.unverified_tx.get(tx_hash)
new_height = self.unverified_tx.get(tx_hash) if new_height == tx_height:
if new_height == tx_height: self.unverified_tx.pop(tx_hash, None)
self.unverified_tx.pop(tx_hash, None)
def add_verified_tx(self, tx_hash: str, info: TxMinedInfo): def add_verified_tx(self, tx_hash: str, info: TxMinedInfo):
# Remove from the unverified map and add to the verified map # Remove from the unverified map and add to the verified map
@@ -646,10 +646,10 @@ class AddressSynchronizer(Logger, EventListener):
self.db.add_verified_tx(tx_hash, info) self.db.add_verified_tx(tx_hash, info)
util.trigger_callback('adb_added_verified_tx', self, tx_hash) util.trigger_callback('adb_added_verified_tx', self, tx_hash)
@with_lock
def get_unverified_txs(self) -> Dict[str, int]: def get_unverified_txs(self) -> Dict[str, int]:
'''Returns a map from tx hash to transaction height''' '''Returns a map from tx hash to transaction height'''
with self.lock: return dict(self.unverified_tx) # copy
return dict(self.unverified_tx) # copy
def undo_verifications(self, blockchain: Blockchain, above_height: int) -> Set[str]: def undo_verifications(self, blockchain: Blockchain, above_height: int) -> Set[str]:
'''Used by the verifier when a reorg has happened''' '''Used by the verifier when a reorg has happened'''
@@ -830,20 +830,20 @@ class AddressSynchronizer(Logger, EventListener):
self.db.add_num_inputs_to_tx(txid, len(tx.inputs())) self.db.add_num_inputs_to_tx(txid, len(tx.inputs()))
return fee return fee
@with_lock
def get_addr_io(self, address: str): def get_addr_io(self, address: str):
with self.lock: h = self.get_address_history(address).items()
h = self.get_address_history(address).items() received = {}
received = {} sent = {}
sent = {} for tx_hash, height in h:
for tx_hash, height in h: tx_mined_info = self.get_tx_height(tx_hash)
tx_mined_info = self.get_tx_height(tx_hash) txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1
txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1 d = self.db.get_txo_addr(tx_hash, address)
d = self.db.get_txo_addr(tx_hash, address) for n, (v, is_cb) in d.items():
for n, (v, is_cb) in d.items(): received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb)
received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb) l = self.db.get_txi_addr(tx_hash, address)
l = self.db.get_txi_addr(tx_hash, address) for txi, v in l:
for txi, v in l: sent[txi] = tx_hash, height, txpos
sent[txi] = tx_hash, height, txpos
return received, sent return received, sent
def get_addr_outputs(self, address: str) -> Dict[TxOutpoint, PartialTxInput]: def get_addr_outputs(self, address: str) -> Dict[TxOutpoint, PartialTxInput]: