adb/wallet: merge transaction_lock and lock
The distinction was no longer clear.
This commit is contained in:
@@ -89,9 +89,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
# verifier (SPV) and synchronizer are started in start_network
|
||||
self.synchronizer = None
|
||||
self.verifier = None
|
||||
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
|
||||
self.lock = threading.RLock()
|
||||
self.transaction_lock = threading.RLock()
|
||||
self.future_tx = {} # type: Dict[str, int] # txid -> wanted (abs) height
|
||||
# Txs the server claims are mined but still pending verification:
|
||||
self.unverified_tx = defaultdict(int) # type: Dict[str, int] # txid -> height. Access with self.lock.
|
||||
@@ -107,12 +105,6 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
def diagnostic_name(self):
|
||||
return self.name or ""
|
||||
|
||||
def with_transaction_lock(func):
|
||||
def func_wrapper(self: 'AddressSynchronizer', *args, **kwargs):
|
||||
with self.transaction_lock:
|
||||
return func(self, *args, **kwargs)
|
||||
return func_wrapper
|
||||
|
||||
def load_and_cleanup(self):
|
||||
self.load_local_history()
|
||||
self.check_history()
|
||||
@@ -143,9 +135,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
so that only includes txns the server sees.
|
||||
"""
|
||||
h = {}
|
||||
# we need self.transaction_lock but get_tx_height will take self.lock
|
||||
# so we need to take that too here, to enforce order of locks
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
related_txns = self._history_local.get(addr, set())
|
||||
for tx_hash in related_txns:
|
||||
tx_height = self.get_tx_height(tx_hash).height
|
||||
@@ -240,7 +230,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
conflict (if already in wallet history)
|
||||
"""
|
||||
conflicting_txns = set()
|
||||
with self.transaction_lock:
|
||||
with self.lock:
|
||||
for txin in tx.inputs():
|
||||
if txin.is_coinbase_input():
|
||||
continue
|
||||
@@ -287,9 +277,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
raise Exception("cannot add tx without txid to wallet history")
|
||||
# For sanity, try to serialize and deserialize tx early:
|
||||
tx_from_any(str(tx)) # see if raises (no-side-effects)
|
||||
# we need self.transaction_lock but get_tx_height will take self.lock
|
||||
# so we need to take that too here, to enforce order of locks
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
# NOTE: returning if tx in self.transactions might seem like a good idea
|
||||
# BUT we track is_mine inputs in a txn, and during subsequent calls
|
||||
# of add_transaction tx, we might learn of more-and-more inputs of
|
||||
@@ -377,7 +365,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
"""Removes a transaction AND all its dependents/children
|
||||
from the wallet history.
|
||||
"""
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
to_remove = {tx_hash}
|
||||
to_remove |= self.get_depending_transactions(tx_hash)
|
||||
for txid in to_remove:
|
||||
@@ -404,7 +392,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
if spending_txid == tx_hash:
|
||||
self.db.remove_spent_outpoint(prevout_hash, prevout_n)
|
||||
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
self.logger.info(f"removing tx from history {tx_hash}")
|
||||
tx = self.db.remove_transaction(tx_hash)
|
||||
remove_from_spent_outpoints()
|
||||
@@ -426,7 +414,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
|
||||
def get_depending_transactions(self, tx_hash: str) -> Set[str]:
|
||||
"""Returns all (grand-)children of tx_hash in this wallet."""
|
||||
with self.transaction_lock:
|
||||
with self.lock:
|
||||
children = set()
|
||||
for n in self.db.get_spent_outpoints(tx_hash):
|
||||
other_hash = self.db.get_spent_outpoint(tx_hash, n)
|
||||
@@ -502,10 +490,9 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
|
||||
def clear_history(self):
|
||||
with self.lock:
|
||||
with self.transaction_lock:
|
||||
self.db.clear_history()
|
||||
self._history_local.clear()
|
||||
self._get_balance_cache.clear() # invalidate cache
|
||||
self.db.clear_history()
|
||||
self._history_local.clear()
|
||||
self._get_balance_cache.clear() # invalidate cache
|
||||
|
||||
def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]:
|
||||
"""Returns a key to be used for sorting txs."""
|
||||
@@ -544,7 +531,6 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
return f
|
||||
|
||||
@with_lock
|
||||
@with_transaction_lock
|
||||
@with_local_height_cached
|
||||
def get_history(self, domain) -> Sequence[HistoryItem]:
|
||||
domain = set(domain)
|
||||
@@ -582,7 +568,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
return h2
|
||||
|
||||
def _add_tx_to_local_history(self, txid):
|
||||
with self.transaction_lock:
|
||||
with self.lock:
|
||||
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
|
||||
cur_hist = self._history_local.get(addr, set())
|
||||
cur_hist.add(txid)
|
||||
@@ -590,7 +576,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
self._mark_address_history_changed(addr)
|
||||
|
||||
def _remove_tx_from_local_history(self, txid):
|
||||
with self.transaction_lock:
|
||||
with self.lock:
|
||||
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
|
||||
cur_hist = self._history_local.get(addr, set())
|
||||
try:
|
||||
@@ -621,16 +607,15 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
await self._address_history_changed_events[addr].wait()
|
||||
|
||||
def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None:
|
||||
if self.db.is_in_verified_tx(tx_hash):
|
||||
if tx_height <= 0:
|
||||
# tx was previously SPV-verified but now in mempool (probably reorg)
|
||||
with self.lock:
|
||||
with self.lock:
|
||||
if self.db.is_in_verified_tx(tx_hash):
|
||||
if tx_height <= 0:
|
||||
# tx was previously SPV-verified but now in mempool (probably reorg)
|
||||
self.db.remove_verified_tx(tx_hash)
|
||||
self.unconfirmed_tx[tx_hash] = tx_height
|
||||
if self.verifier:
|
||||
self.verifier.remove_spv_proof_for_tx(tx_hash)
|
||||
else:
|
||||
with self.lock:
|
||||
if self.verifier:
|
||||
self.verifier.remove_spv_proof_for_tx(tx_hash)
|
||||
else:
|
||||
if tx_height > 0:
|
||||
self.unverified_tx[tx_hash] = tx_height
|
||||
else:
|
||||
@@ -750,7 +735,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
nans += n2
|
||||
return nsent, nans
|
||||
|
||||
@with_transaction_lock
|
||||
@with_lock
|
||||
def get_tx_delta(self, tx_hash: str, address: str) -> int:
|
||||
"""effect of tx on address"""
|
||||
delta = 0
|
||||
@@ -799,7 +784,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
return None
|
||||
# compute fee if possible
|
||||
v_in = v_out = 0
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
for txin in tx.inputs():
|
||||
addr = self.get_txin_address(txin)
|
||||
value = self.get_txin_value(txin, address=addr)
|
||||
@@ -819,7 +804,7 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
return fee
|
||||
|
||||
def get_addr_io(self, address: str):
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
h = self.get_address_history(address).items()
|
||||
received = {}
|
||||
sent = {}
|
||||
@@ -868,7 +853,6 @@ class AddressSynchronizer(Logger, EventListener):
|
||||
return sum([value for height, pos, value, is_cb in received.values()])
|
||||
|
||||
@with_lock
|
||||
@with_transaction_lock
|
||||
@with_local_height_cached
|
||||
def get_balance(self, domain, *, excluded_addresses: Set[str] = None,
|
||||
excluded_coins: Set[str] = None) -> Tuple[int, int, int]:
|
||||
|
||||
@@ -401,7 +401,6 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
for addr in self.get_addresses():
|
||||
self.adb.add_address(addr)
|
||||
self.lock = self.adb.lock
|
||||
self.transaction_lock = self.adb.transaction_lock
|
||||
self._last_full_history = None
|
||||
self._tx_parents_cache = {}
|
||||
self._default_labels = {}
|
||||
@@ -568,7 +567,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
return is_mine
|
||||
|
||||
def clear_tx_parents_cache(self):
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
self._tx_parents_cache.clear()
|
||||
self._num_parents.clear()
|
||||
self._last_full_history = None
|
||||
@@ -877,7 +876,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
is_relevant = False # "related to wallet?"
|
||||
num_input_ismine = 0
|
||||
v_in = v_in_mine = v_out = v_out_mine = 0
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
for txin in tx.inputs():
|
||||
addr = self.adb.get_txin_address(txin)
|
||||
value = self.adb.get_txin_value(txin, address=addr)
|
||||
@@ -1015,7 +1014,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
returns a flat dict:
|
||||
txid -> list of parent txids
|
||||
"""
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
if self._last_full_history is None:
|
||||
self._last_full_history = self.get_onchain_history()
|
||||
# populate cache in chronological order (confirmed tx only)
|
||||
@@ -1252,7 +1251,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
if not invoice.is_lightning():
|
||||
if self.is_onchain_invoice_paid(invoice)[0]:
|
||||
_logger.info("saving invoice... but it is already paid!")
|
||||
with self.transaction_lock:
|
||||
with self.lock:
|
||||
for txout in invoice.get_outputs():
|
||||
self._invoices_from_scriptpubkey_map[txout.scriptpubkey].add(key)
|
||||
self._invoices[key] = invoice
|
||||
@@ -1362,7 +1361,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
relevant_txs = set()
|
||||
is_paid = True
|
||||
conf_needed = None # type: Optional[int]
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
for invoice_scriptpubkey, invoice_amt in invoice_amounts.items():
|
||||
scripthash = bitcoin.script_to_scripthash(invoice_scriptpubkey)
|
||||
prevouts_and_values = self.db.get_prevouts_by_scripthash(scripthash)
|
||||
@@ -2879,7 +2878,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
def get_invoices_and_requests_touched_by_tx(self, tx):
|
||||
request_keys = set()
|
||||
invoice_keys = set()
|
||||
with self.lock, self.transaction_lock:
|
||||
with self.lock:
|
||||
for txo in tx.outputs():
|
||||
addr = txo.address
|
||||
if request := self.get_request_by_addr(addr):
|
||||
|
||||
Reference in New Issue
Block a user