address_synchronizer: apply @with_lock where applicable
This commit is contained in:
@@ -128,6 +128,7 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
def get_addresses(self):
|
def get_addresses(self):
|
||||||
return sorted(self.db.get_history())
|
return sorted(self.db.get_history())
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def get_address_history(self, addr: str) -> Dict[str, int]:
|
def get_address_history(self, addr: str) -> Dict[str, int]:
|
||||||
"""Returns the history for the address, as a txid->height dict.
|
"""Returns the history for the address, as a txid->height dict.
|
||||||
In addition to what we have from the server, this includes local and future txns.
|
In addition to what we have from the server, this includes local and future txns.
|
||||||
@@ -136,11 +137,10 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
so that only includes txns the server sees.
|
so that only includes txns the server sees.
|
||||||
"""
|
"""
|
||||||
h = {}
|
h = {}
|
||||||
with self.lock:
|
related_txns = self._history_local.get(addr, set())
|
||||||
related_txns = self._history_local.get(addr, set())
|
for tx_hash in related_txns:
|
||||||
for tx_hash in related_txns:
|
tx_height = self.get_tx_height(tx_hash).height
|
||||||
tx_height = self.get_tx_height(tx_hash).height
|
h[tx_hash] = tx_height
|
||||||
h[tx_hash] = tx_height
|
|
||||||
return h
|
return h
|
||||||
|
|
||||||
def get_address_history_len(self, addr: str) -> int:
|
def get_address_history_len(self, addr: str) -> int:
|
||||||
@@ -201,10 +201,10 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
self.register_callbacks()
|
self.register_callbacks()
|
||||||
|
|
||||||
@event_listener
|
@event_listener
|
||||||
|
@with_lock
|
||||||
def on_event_blockchain_updated(self, *args):
|
def on_event_blockchain_updated(self, *args):
|
||||||
with self.lock:
|
self._get_balance_cache = {} # invalidate cache
|
||||||
self._get_balance_cache = {} # invalidate cache
|
self.db.put('stored_height', self.get_local_height())
|
||||||
self.db.put('stored_height', self.get_local_height())
|
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
if self.network:
|
if self.network:
|
||||||
@@ -227,6 +227,7 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
self.synchronizer.add(address)
|
self.synchronizer.add(address)
|
||||||
self.up_to_date_changed()
|
self.up_to_date_changed()
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool = False) -> Set[str]:
|
def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool = False) -> Set[str]:
|
||||||
"""Returns a set of transaction hashes from the wallet history that are
|
"""Returns a set of transaction hashes from the wallet history that are
|
||||||
directly conflicting with tx, i.e. they have common outpoints being
|
directly conflicting with tx, i.e. they have common outpoints being
|
||||||
@@ -236,27 +237,26 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
conflict (if already in wallet history)
|
conflict (if already in wallet history)
|
||||||
"""
|
"""
|
||||||
conflicting_txns = set()
|
conflicting_txns = set()
|
||||||
with self.lock:
|
for txin in tx.inputs():
|
||||||
for txin in tx.inputs():
|
if txin.is_coinbase_input():
|
||||||
if txin.is_coinbase_input():
|
continue
|
||||||
continue
|
prevout_hash = txin.prevout.txid.hex()
|
||||||
prevout_hash = txin.prevout.txid.hex()
|
prevout_n = txin.prevout.out_idx
|
||||||
prevout_n = txin.prevout.out_idx
|
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
|
||||||
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
|
if spending_tx_hash is None:
|
||||||
if spending_tx_hash is None:
|
continue
|
||||||
continue
|
# this outpoint has already been spent, by spending_tx
|
||||||
# this outpoint has already been spent, by spending_tx
|
# annoying assert that has revealed several bugs over time:
|
||||||
# annoying assert that has revealed several bugs over time:
|
assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db"
|
||||||
assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db"
|
conflicting_txns |= {spending_tx_hash}
|
||||||
conflicting_txns |= {spending_tx_hash}
|
if tx_hash := tx.txid():
|
||||||
if tx_hash := tx.txid():
|
if tx_hash in conflicting_txns:
|
||||||
if tx_hash in conflicting_txns:
|
# this tx is already in history, so it conflicts with itself
|
||||||
# this tx is already in history, so it conflicts with itself
|
if len(conflicting_txns) > 1:
|
||||||
if len(conflicting_txns) > 1:
|
raise Exception('Found conflicting transactions already in wallet history.')
|
||||||
raise Exception('Found conflicting transactions already in wallet history.')
|
if not include_self:
|
||||||
if not include_self:
|
conflicting_txns -= {tx_hash}
|
||||||
conflicting_txns -= {tx_hash}
|
return conflicting_txns
|
||||||
return conflicting_txns
|
|
||||||
|
|
||||||
@with_lock
|
@with_lock
|
||||||
def get_transaction(self, txid: str) -> Optional[Transaction]:
|
def get_transaction(self, txid: str) -> Optional[Transaction]:
|
||||||
@@ -368,15 +368,15 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
util.trigger_callback('adb_added_tx', self, tx_hash, tx)
|
util.trigger_callback('adb_added_tx', self, tx_hash, tx)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def remove_transaction(self, tx_hash: str) -> None:
|
def remove_transaction(self, tx_hash: str) -> None:
|
||||||
"""Removes a transaction AND all its dependents/children
|
"""Removes a transaction AND all its dependents/children
|
||||||
from the wallet history.
|
from the wallet history.
|
||||||
"""
|
"""
|
||||||
with self.lock:
|
to_remove = {tx_hash}
|
||||||
to_remove = {tx_hash}
|
to_remove |= self.get_depending_transactions(tx_hash)
|
||||||
to_remove |= self.get_depending_transactions(tx_hash)
|
for txid in to_remove:
|
||||||
for txid in to_remove:
|
self._remove_transaction(txid)
|
||||||
self._remove_transaction(txid)
|
|
||||||
|
|
||||||
def _remove_transaction(self, tx_hash: str) -> None:
|
def _remove_transaction(self, tx_hash: str) -> None:
|
||||||
"""Removes a single transaction from the wallet history, and attempts
|
"""Removes a single transaction from the wallet history, and attempts
|
||||||
@@ -419,15 +419,15 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
self.db.remove_prevout_by_scripthash(scripthash, prevout=prevout, value=txo.value)
|
self.db.remove_prevout_by_scripthash(scripthash, prevout=prevout, value=txo.value)
|
||||||
util.trigger_callback('adb_removed_tx', self, tx_hash, tx)
|
util.trigger_callback('adb_removed_tx', self, tx_hash, tx)
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def get_depending_transactions(self, tx_hash: str) -> Set[str]:
|
def get_depending_transactions(self, tx_hash: str) -> Set[str]:
|
||||||
"""Returns all (grand-)children of tx_hash in this wallet."""
|
"""Returns all (grand-)children of tx_hash in this wallet."""
|
||||||
with self.lock:
|
children = set()
|
||||||
children = set()
|
for n in self.db.get_spent_outpoints(tx_hash):
|
||||||
for n in self.db.get_spent_outpoints(tx_hash):
|
other_hash = self.db.get_spent_outpoint(tx_hash, n)
|
||||||
other_hash = self.db.get_spent_outpoint(tx_hash, n)
|
children.add(other_hash)
|
||||||
children.add(other_hash)
|
children |= self.get_depending_transactions(other_hash)
|
||||||
children |= self.get_depending_transactions(other_hash)
|
return children
|
||||||
return children
|
|
||||||
|
|
||||||
@with_lock
|
@with_lock
|
||||||
def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None:
|
def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None:
|
||||||
@@ -499,19 +499,19 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
if tx_height == TX_HEIGHT_LOCAL and not self.db.get_transaction(txid):
|
if tx_height == TX_HEIGHT_LOCAL and not self.db.get_transaction(txid):
|
||||||
self.remove_transaction(txid)
|
self.remove_transaction(txid)
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def clear_history(self):
|
def clear_history(self):
|
||||||
with self.lock:
|
self.db.clear_history()
|
||||||
self.db.clear_history()
|
self._history_local.clear()
|
||||||
self._history_local.clear()
|
self._get_balance_cache.clear() # invalidate cache
|
||||||
self._get_balance_cache.clear() # invalidate cache
|
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]:
|
def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]:
|
||||||
"""Returns a key to be used for sorting txs."""
|
"""Returns a key to be used for sorting txs."""
|
||||||
with self.lock:
|
tx_mined_info = self.get_tx_height(tx_hash)
|
||||||
tx_mined_info = self.get_tx_height(tx_hash)
|
height = self.tx_height_to_sort_height(tx_mined_info.height)
|
||||||
height = self.tx_height_to_sort_height(tx_mined_info.height)
|
txpos = tx_mined_info.txpos or -1
|
||||||
txpos = tx_mined_info.txpos or -1
|
return height, txpos
|
||||||
return height, txpos
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tx_height_to_sort_height(cls, height: int = None):
|
def tx_height_to_sort_height(cls, height: int = None):
|
||||||
@@ -578,26 +578,26 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
raise Exception("wallet.get_history() failed balance sanity-check")
|
raise Exception("wallet.get_history() failed balance sanity-check")
|
||||||
return h2
|
return h2
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def _add_tx_to_local_history(self, txid):
|
def _add_tx_to_local_history(self, txid):
|
||||||
with self.lock:
|
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
|
||||||
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
|
cur_hist = self._history_local.get(addr, set())
|
||||||
cur_hist = self._history_local.get(addr, set())
|
cur_hist.add(txid)
|
||||||
cur_hist.add(txid)
|
self._history_local[addr] = cur_hist
|
||||||
|
self._mark_address_history_changed(addr)
|
||||||
|
|
||||||
|
@with_lock
|
||||||
|
def _remove_tx_from_local_history(self, txid):
|
||||||
|
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
|
||||||
|
cur_hist = self._history_local.get(addr, set())
|
||||||
|
try:
|
||||||
|
cur_hist.remove(txid)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
self._history_local[addr] = cur_hist
|
self._history_local[addr] = cur_hist
|
||||||
self._mark_address_history_changed(addr)
|
self._mark_address_history_changed(addr)
|
||||||
|
|
||||||
def _remove_tx_from_local_history(self, txid):
|
|
||||||
with self.lock:
|
|
||||||
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
|
|
||||||
cur_hist = self._history_local.get(addr, set())
|
|
||||||
try:
|
|
||||||
cur_hist.remove(txid)
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self._history_local[addr] = cur_hist
|
|
||||||
self._mark_address_history_changed(addr)
|
|
||||||
|
|
||||||
def _mark_address_history_changed(self, addr: str) -> None:
|
def _mark_address_history_changed(self, addr: str) -> None:
|
||||||
def set_and_clear():
|
def set_and_clear():
|
||||||
event = self._address_history_changed_events[addr]
|
event = self._address_history_changed_events[addr]
|
||||||
@@ -617,27 +617,27 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
assert self.is_mine(addr), "address needs to be is_mine to be watched"
|
assert self.is_mine(addr), "address needs to be is_mine to be watched"
|
||||||
await self._address_history_changed_events[addr].wait()
|
await self._address_history_changed_events[addr].wait()
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None:
|
def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None:
|
||||||
assert tx_height >= TX_HEIGHT_UNCONF_PARENT, f"got {tx_height=} for {tx_hash=}" # forbid local/future txs here
|
assert tx_height >= TX_HEIGHT_UNCONF_PARENT, f"got {tx_height=} for {tx_hash=}" # forbid local/future txs here
|
||||||
with self.lock:
|
if self.db.is_in_verified_tx(tx_hash):
|
||||||
if self.db.is_in_verified_tx(tx_hash):
|
if tx_height <= 0:
|
||||||
if tx_height <= 0:
|
# tx was previously SPV-verified but now in mempool (probably reorg)
|
||||||
# tx was previously SPV-verified but now in mempool (probably reorg)
|
self.db.remove_verified_tx(tx_hash)
|
||||||
self.db.remove_verified_tx(tx_hash)
|
self.unconfirmed_tx[tx_hash] = tx_height
|
||||||
self.unconfirmed_tx[tx_hash] = tx_height
|
if self.verifier:
|
||||||
if self.verifier:
|
self.verifier.remove_spv_proof_for_tx(tx_hash)
|
||||||
self.verifier.remove_spv_proof_for_tx(tx_hash)
|
else:
|
||||||
|
if tx_height > 0:
|
||||||
|
self.unverified_tx[tx_hash] = tx_height
|
||||||
else:
|
else:
|
||||||
if tx_height > 0:
|
self.unconfirmed_tx[tx_hash] = tx_height
|
||||||
self.unverified_tx[tx_hash] = tx_height
|
|
||||||
else:
|
|
||||||
self.unconfirmed_tx[tx_hash] = tx_height
|
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def remove_unverified_tx(self, tx_hash: str, tx_height: int) -> None:
|
def remove_unverified_tx(self, tx_hash: str, tx_height: int) -> None:
|
||||||
with self.lock:
|
new_height = self.unverified_tx.get(tx_hash)
|
||||||
new_height = self.unverified_tx.get(tx_hash)
|
if new_height == tx_height:
|
||||||
if new_height == tx_height:
|
self.unverified_tx.pop(tx_hash, None)
|
||||||
self.unverified_tx.pop(tx_hash, None)
|
|
||||||
|
|
||||||
def add_verified_tx(self, tx_hash: str, info: TxMinedInfo):
|
def add_verified_tx(self, tx_hash: str, info: TxMinedInfo):
|
||||||
# Remove from the unverified map and add to the verified map
|
# Remove from the unverified map and add to the verified map
|
||||||
@@ -646,10 +646,10 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
self.db.add_verified_tx(tx_hash, info)
|
self.db.add_verified_tx(tx_hash, info)
|
||||||
util.trigger_callback('adb_added_verified_tx', self, tx_hash)
|
util.trigger_callback('adb_added_verified_tx', self, tx_hash)
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def get_unverified_txs(self) -> Dict[str, int]:
|
def get_unverified_txs(self) -> Dict[str, int]:
|
||||||
'''Returns a map from tx hash to transaction height'''
|
'''Returns a map from tx hash to transaction height'''
|
||||||
with self.lock:
|
return dict(self.unverified_tx) # copy
|
||||||
return dict(self.unverified_tx) # copy
|
|
||||||
|
|
||||||
def undo_verifications(self, blockchain: Blockchain, above_height: int) -> Set[str]:
|
def undo_verifications(self, blockchain: Blockchain, above_height: int) -> Set[str]:
|
||||||
'''Used by the verifier when a reorg has happened'''
|
'''Used by the verifier when a reorg has happened'''
|
||||||
@@ -830,20 +830,20 @@ class AddressSynchronizer(Logger, EventListener):
|
|||||||
self.db.add_num_inputs_to_tx(txid, len(tx.inputs()))
|
self.db.add_num_inputs_to_tx(txid, len(tx.inputs()))
|
||||||
return fee
|
return fee
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def get_addr_io(self, address: str):
|
def get_addr_io(self, address: str):
|
||||||
with self.lock:
|
h = self.get_address_history(address).items()
|
||||||
h = self.get_address_history(address).items()
|
received = {}
|
||||||
received = {}
|
sent = {}
|
||||||
sent = {}
|
for tx_hash, height in h:
|
||||||
for tx_hash, height in h:
|
tx_mined_info = self.get_tx_height(tx_hash)
|
||||||
tx_mined_info = self.get_tx_height(tx_hash)
|
txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1
|
||||||
txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1
|
d = self.db.get_txo_addr(tx_hash, address)
|
||||||
d = self.db.get_txo_addr(tx_hash, address)
|
for n, (v, is_cb) in d.items():
|
||||||
for n, (v, is_cb) in d.items():
|
received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb)
|
||||||
received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb)
|
l = self.db.get_txi_addr(tx_hash, address)
|
||||||
l = self.db.get_txi_addr(tx_hash, address)
|
for txi, v in l:
|
||||||
for txi, v in l:
|
sent[txi] = tx_hash, height, txpos
|
||||||
sent[txi] = tx_hash, height, txpos
|
|
||||||
return received, sent
|
return received, sent
|
||||||
|
|
||||||
def get_addr_outputs(self, address: str) -> Dict[TxOutpoint, PartialTxInput]:
|
def get_addr_outputs(self, address: str) -> Dict[TxOutpoint, PartialTxInput]:
|
||||||
|
|||||||
Reference in New Issue
Block a user