1
0

address_synchronizer: apply @with_lock where applicable

This commit is contained in:
наб
2025-06-15 19:14:08 +02:00
parent 4887fb3e7f
commit fdaafd5abf

View File

@@ -128,6 +128,7 @@ class AddressSynchronizer(Logger, EventListener):
def get_addresses(self):
return sorted(self.db.get_history())
@with_lock
def get_address_history(self, addr: str) -> Dict[str, int]:
"""Returns the history for the address, as a txid->height dict.
In addition to what we have from the server, this includes local and future txns.
@@ -136,11 +137,10 @@ class AddressSynchronizer(Logger, EventListener):
so that only includes txns the server sees.
"""
h = {}
with self.lock:
related_txns = self._history_local.get(addr, set())
for tx_hash in related_txns:
tx_height = self.get_tx_height(tx_hash).height
h[tx_hash] = tx_height
related_txns = self._history_local.get(addr, set())
for tx_hash in related_txns:
tx_height = self.get_tx_height(tx_hash).height
h[tx_hash] = tx_height
return h
def get_address_history_len(self, addr: str) -> int:
@@ -201,10 +201,10 @@ class AddressSynchronizer(Logger, EventListener):
self.register_callbacks()
@event_listener
@with_lock
def on_event_blockchain_updated(self, *args):
with self.lock:
self._get_balance_cache = {} # invalidate cache
self.db.put('stored_height', self.get_local_height())
self._get_balance_cache = {} # invalidate cache
self.db.put('stored_height', self.get_local_height())
async def stop(self):
if self.network:
@@ -227,6 +227,7 @@ class AddressSynchronizer(Logger, EventListener):
self.synchronizer.add(address)
self.up_to_date_changed()
@with_lock
def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool = False) -> Set[str]:
"""Returns a set of transaction hashes from the wallet history that are
directly conflicting with tx, i.e. they have common outpoints being
@@ -236,27 +237,26 @@ class AddressSynchronizer(Logger, EventListener):
conflict (if already in wallet history)
"""
conflicting_txns = set()
with self.lock:
for txin in tx.inputs():
if txin.is_coinbase_input():
continue
prevout_hash = txin.prevout.txid.hex()
prevout_n = txin.prevout.out_idx
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
if spending_tx_hash is None:
continue
# this outpoint has already been spent, by spending_tx
# annoying assert that has revealed several bugs over time:
assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db"
conflicting_txns |= {spending_tx_hash}
if tx_hash := tx.txid():
if tx_hash in conflicting_txns:
# this tx is already in history, so it conflicts with itself
if len(conflicting_txns) > 1:
raise Exception('Found conflicting transactions already in wallet history.')
if not include_self:
conflicting_txns -= {tx_hash}
return conflicting_txns
for txin in tx.inputs():
if txin.is_coinbase_input():
continue
prevout_hash = txin.prevout.txid.hex()
prevout_n = txin.prevout.out_idx
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
if spending_tx_hash is None:
continue
# this outpoint has already been spent, by spending_tx
# annoying assert that has revealed several bugs over time:
assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db"
conflicting_txns |= {spending_tx_hash}
if tx_hash := tx.txid():
if tx_hash in conflicting_txns:
# this tx is already in history, so it conflicts with itself
if len(conflicting_txns) > 1:
raise Exception('Found conflicting transactions already in wallet history.')
if not include_self:
conflicting_txns -= {tx_hash}
return conflicting_txns
@with_lock
def get_transaction(self, txid: str) -> Optional[Transaction]:
@@ -368,15 +368,15 @@ class AddressSynchronizer(Logger, EventListener):
util.trigger_callback('adb_added_tx', self, tx_hash, tx)
return True
@with_lock
def remove_transaction(self, tx_hash: str) -> None:
"""Removes a transaction AND all its dependents/children
from the wallet history.
"""
with self.lock:
to_remove = {tx_hash}
to_remove |= self.get_depending_transactions(tx_hash)
for txid in to_remove:
self._remove_transaction(txid)
to_remove = {tx_hash}
to_remove |= self.get_depending_transactions(tx_hash)
for txid in to_remove:
self._remove_transaction(txid)
def _remove_transaction(self, tx_hash: str) -> None:
"""Removes a single transaction from the wallet history, and attempts
@@ -419,15 +419,15 @@ class AddressSynchronizer(Logger, EventListener):
self.db.remove_prevout_by_scripthash(scripthash, prevout=prevout, value=txo.value)
util.trigger_callback('adb_removed_tx', self, tx_hash, tx)
@with_lock
def get_depending_transactions(self, tx_hash: str) -> Set[str]:
"""Returns all (grand-)children of tx_hash in this wallet."""
with self.lock:
children = set()
for n in self.db.get_spent_outpoints(tx_hash):
other_hash = self.db.get_spent_outpoint(tx_hash, n)
children.add(other_hash)
children |= self.get_depending_transactions(other_hash)
return children
children = set()
for n in self.db.get_spent_outpoints(tx_hash):
other_hash = self.db.get_spent_outpoint(tx_hash, n)
children.add(other_hash)
children |= self.get_depending_transactions(other_hash)
return children
@with_lock
def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None:
@@ -499,19 +499,19 @@ class AddressSynchronizer(Logger, EventListener):
if tx_height == TX_HEIGHT_LOCAL and not self.db.get_transaction(txid):
self.remove_transaction(txid)
@with_lock
def clear_history(self):
with self.lock:
self.db.clear_history()
self._history_local.clear()
self._get_balance_cache.clear() # invalidate cache
self.db.clear_history()
self._history_local.clear()
self._get_balance_cache.clear() # invalidate cache
@with_lock
def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]:
"""Returns a key to be used for sorting txs."""
with self.lock:
tx_mined_info = self.get_tx_height(tx_hash)
height = self.tx_height_to_sort_height(tx_mined_info.height)
txpos = tx_mined_info.txpos or -1
return height, txpos
tx_mined_info = self.get_tx_height(tx_hash)
height = self.tx_height_to_sort_height(tx_mined_info.height)
txpos = tx_mined_info.txpos or -1
return height, txpos
@classmethod
def tx_height_to_sort_height(cls, height: int = None):
@@ -578,26 +578,26 @@ class AddressSynchronizer(Logger, EventListener):
raise Exception("wallet.get_history() failed balance sanity-check")
return h2
@with_lock
def _add_tx_to_local_history(self, txid):
with self.lock:
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
cur_hist = self._history_local.get(addr, set())
cur_hist.add(txid)
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
cur_hist = self._history_local.get(addr, set())
cur_hist.add(txid)
self._history_local[addr] = cur_hist
self._mark_address_history_changed(addr)
@with_lock
def _remove_tx_from_local_history(self, txid):
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
cur_hist = self._history_local.get(addr, set())
try:
cur_hist.remove(txid)
except KeyError:
pass
else:
self._history_local[addr] = cur_hist
self._mark_address_history_changed(addr)
def _remove_tx_from_local_history(self, txid):
with self.lock:
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
cur_hist = self._history_local.get(addr, set())
try:
cur_hist.remove(txid)
except KeyError:
pass
else:
self._history_local[addr] = cur_hist
self._mark_address_history_changed(addr)
def _mark_address_history_changed(self, addr: str) -> None:
def set_and_clear():
event = self._address_history_changed_events[addr]
@@ -617,27 +617,27 @@ class AddressSynchronizer(Logger, EventListener):
assert self.is_mine(addr), "address needs to be is_mine to be watched"
await self._address_history_changed_events[addr].wait()
@with_lock
def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None:
assert tx_height >= TX_HEIGHT_UNCONF_PARENT, f"got {tx_height=} for {tx_hash=}" # forbid local/future txs here
with self.lock:
if self.db.is_in_verified_tx(tx_hash):
if tx_height <= 0:
# tx was previously SPV-verified but now in mempool (probably reorg)
self.db.remove_verified_tx(tx_hash)
self.unconfirmed_tx[tx_hash] = tx_height
if self.verifier:
self.verifier.remove_spv_proof_for_tx(tx_hash)
if self.db.is_in_verified_tx(tx_hash):
if tx_height <= 0:
# tx was previously SPV-verified but now in mempool (probably reorg)
self.db.remove_verified_tx(tx_hash)
self.unconfirmed_tx[tx_hash] = tx_height
if self.verifier:
self.verifier.remove_spv_proof_for_tx(tx_hash)
else:
if tx_height > 0:
self.unverified_tx[tx_hash] = tx_height
else:
if tx_height > 0:
self.unverified_tx[tx_hash] = tx_height
else:
self.unconfirmed_tx[tx_hash] = tx_height
self.unconfirmed_tx[tx_hash] = tx_height
@with_lock
def remove_unverified_tx(self, tx_hash: str, tx_height: int) -> None:
with self.lock:
new_height = self.unverified_tx.get(tx_hash)
if new_height == tx_height:
self.unverified_tx.pop(tx_hash, None)
new_height = self.unverified_tx.get(tx_hash)
if new_height == tx_height:
self.unverified_tx.pop(tx_hash, None)
def add_verified_tx(self, tx_hash: str, info: TxMinedInfo):
# Remove from the unverified map and add to the verified map
@@ -646,10 +646,10 @@ class AddressSynchronizer(Logger, EventListener):
self.db.add_verified_tx(tx_hash, info)
util.trigger_callback('adb_added_verified_tx', self, tx_hash)
@with_lock
def get_unverified_txs(self) -> Dict[str, int]:
'''Returns a map from tx hash to transaction height'''
with self.lock:
return dict(self.unverified_tx) # copy
return dict(self.unverified_tx) # copy
def undo_verifications(self, blockchain: Blockchain, above_height: int) -> Set[str]:
'''Used by the verifier when a reorg has happened'''
@@ -830,20 +830,20 @@ class AddressSynchronizer(Logger, EventListener):
self.db.add_num_inputs_to_tx(txid, len(tx.inputs()))
return fee
@with_lock
def get_addr_io(self, address: str):
with self.lock:
h = self.get_address_history(address).items()
received = {}
sent = {}
for tx_hash, height in h:
tx_mined_info = self.get_tx_height(tx_hash)
txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1
d = self.db.get_txo_addr(tx_hash, address)
for n, (v, is_cb) in d.items():
received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb)
l = self.db.get_txi_addr(tx_hash, address)
for txi, v in l:
sent[txi] = tx_hash, height, txpos
h = self.get_address_history(address).items()
received = {}
sent = {}
for tx_hash, height in h:
tx_mined_info = self.get_tx_height(tx_hash)
txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1
d = self.db.get_txo_addr(tx_hash, address)
for n, (v, is_cb) in d.items():
received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb)
l = self.db.get_txi_addr(tx_hash, address)
for txi, v in l:
sent[txi] = tx_hash, height, txpos
return received, sent
def get_addr_outputs(self, address: str) -> Dict[TxOutpoint, PartialTxInput]: