1
0

adb: take lock in more places

for example, adb.get_utxos() could previously return an inconsistent result
This commit is contained in:
SomberNight
2025-05-21 17:53:25 +00:00
parent 3b37a920d6
commit 4543192e1a

View File

@@ -105,6 +105,7 @@ class AddressSynchronizer(Logger, EventListener):
def diagnostic_name(self):
return self.name or ""
@with_lock
def load_and_cleanup(self):
self.load_local_history()
self.check_history()
@@ -146,6 +147,7 @@ class AddressSynchronizer(Logger, EventListener):
"""Return number of transactions where address is involved."""
return len(self._history_local.get(addr, ()))
@with_lock
def get_txin_address(self, txin: TxInput) -> Optional[str]:
if txin.address:
return txin.address
@@ -160,6 +162,7 @@ class AddressSynchronizer(Logger, EventListener):
return tx.outputs()[prevout_n].address
return None
@with_lock
def get_txin_value(self, txin: TxInput, *, address: str = None) -> Optional[int]:
if txin.value_sats() is not None:
return txin.value_sats()
@@ -179,6 +182,7 @@ class AddressSynchronizer(Logger, EventListener):
return tx.outputs()[prevout_n].value
return None
@with_lock
def load_unverified_transactions(self):
# review transactions that are in the history
for addr in self.db.get_history():
@@ -198,8 +202,9 @@ class AddressSynchronizer(Logger, EventListener):
@event_listener
def on_event_blockchain_updated(self, *args):
self._get_balance_cache = {} # invalidate cache
self.db.put('stored_height', self.get_local_height())
with self.lock:
self._get_balance_cache = {} # invalidate cache
self.db.put('stored_height', self.get_local_height())
async def stop(self):
if self.network:
@@ -252,6 +257,7 @@ class AddressSynchronizer(Logger, EventListener):
conflicting_txns -= {tx_hash}
return conflicting_txns
@with_lock
def get_transaction(self, txid: str) -> Optional[Transaction]:
tx = self.db.get_transaction(txid)
if tx:
@@ -422,6 +428,7 @@ class AddressSynchronizer(Logger, EventListener):
children |= self.get_depending_transactions(other_hash)
return children
@with_lock
def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None:
txid = tx.txid()
assert txid is not None
@@ -430,18 +437,18 @@ class AddressSynchronizer(Logger, EventListener):
self.add_unverified_or_unconfirmed_tx(txid, tx_height)
self.add_transaction(tx, allow_unrelated=True)
@with_lock
def receive_history_callback(self, addr: str, hist, tx_fees: Dict[str, int]):
with self.lock:
old_hist = self.get_address_history(addr)
for tx_hash, height in old_hist.items():
if (tx_hash, height) not in hist:
# make tx local
self.unverified_tx.pop(tx_hash, None)
self.unconfirmed_tx.pop(tx_hash, None)
self.db.remove_verified_tx(tx_hash)
if self.verifier:
self.verifier.remove_spv_proof_for_tx(tx_hash)
self.db.set_addr_history(addr, hist)
old_hist = self.get_address_history(addr)
for tx_hash, height in old_hist.items():
if (tx_hash, height) not in hist:
# make tx local
self.unverified_tx.pop(tx_hash, None)
self.unconfirmed_tx.pop(tx_hash, None)
self.db.remove_verified_tx(tx_hash)
if self.verifier:
self.verifier.remove_spv_proof_for_tx(tx_hash)
self.db.set_addr_history(addr, hist)
for tx_hash, tx_height in hist:
# add it in case it was previously unconfirmed
@@ -460,6 +467,7 @@ class AddressSynchronizer(Logger, EventListener):
for tx_hash, fee_sat in tx_fees.items():
self.db.add_tx_fee_from_server(tx_hash, fee_sat)
@with_lock
@profiler
def load_local_history(self):
self._history_local = {} # type: Dict[str, Set[str]] # address -> set(txid)
@@ -467,6 +475,7 @@ class AddressSynchronizer(Logger, EventListener):
for txid in itertools.chain(self.db.list_txi(), self.db.list_txo()):
self._add_tx_to_local_history(txid)
@with_lock
@profiler
def check_history(self):
hist_addrs_mine = list(filter(lambda k: self.is_mine(k), self.db.get_history()))
@@ -482,6 +491,7 @@ class AddressSynchronizer(Logger, EventListener):
if tx is not None:
self.add_transaction(tx, allow_unrelated=True)
@with_lock
def remove_local_transactions_we_dont_have(self):
for txid in itertools.chain(self.db.list_txi(), self.db.list_txo()):
tx_height = self.get_tx_height(txid).height
@@ -749,6 +759,7 @@ class AddressSynchronizer(Logger, EventListener):
delta += v
return delta
@with_lock
def get_tx_fee(self, txid: str) -> Optional[int]:
"""Returns tx_fee or None. Use server fee only if tx is unconfirmed and not mine.
@@ -784,16 +795,15 @@ class AddressSynchronizer(Logger, EventListener):
return None
# compute fee if possible
v_in = v_out = 0
with self.lock:
for txin in tx.inputs():
addr = self.get_txin_address(txin)
value = self.get_txin_value(txin, address=addr)
if value is None:
v_in = None
elif v_in is not None:
v_in += value
for txout in tx.outputs():
v_out += txout.value
for txin in tx.inputs():
addr = self.get_txin_address(txin)
value = self.get_txin_value(txin, address=addr)
if value is None:
v_in = None
elif v_in is not None:
v_in += value
for txout in tx.outputs():
v_out += txout.value
if v_in is not None:
fee = v_in - v_out
else:
@@ -918,6 +928,7 @@ class AddressSynchronizer(Logger, EventListener):
self._get_balance_cache[cache_key] = result
return result
@with_lock
@with_local_height_cached
def get_utxos(
self,
@@ -974,6 +985,7 @@ class AddressSynchronizer(Logger, EventListener):
coins = self.get_addr_utxo(address)
return not bool(coins)
@with_lock
@with_local_height_cached
def address_is_old(self, address: str, *, req_conf: int = 3) -> bool:
"""Returns whether address has any history that is deeply confirmed.
@@ -993,7 +1005,8 @@ class AddressSynchronizer(Logger, EventListener):
max_conf = max(max_conf, tx_age)
return max_conf >= req_conf
def get_spender(self, outpoint: str) -> str:
@with_lock
def get_spender(self, outpoint: str) -> Optional[str]:
"""
returns txid spending outpoint.
subscribes to addresses as a side effect.
@@ -1005,7 +1018,7 @@ class AddressSynchronizer(Logger, EventListener):
if tx_mined_status.height in [TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE]:
spender_txid = None
if not spender_txid:
return
return None
spender_tx = self.get_transaction(spender_txid)
for i, o in enumerate(spender_tx.outputs()):
if o.address is None: