submarine_swaps: make swaps dict thread-safe
In general many methods of the SwapManager are called both from the asyncio thread and from the GUI, and hence must be thread-safe. closes https://github.com/spesmilo/electrum/issues/9887
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Optional, Dict, Sequence, Tuple, Iterable, List
|
||||
from decimal import Decimal
|
||||
import math
|
||||
@@ -206,20 +207,22 @@ class SwapManager(Logger):
|
||||
self.taskgroup = OldTaskGroup()
|
||||
self.dummy_address = DummyAddress.SWAP
|
||||
|
||||
self.swaps = self.wallet.db.get_dict('submarine_swaps') # type: Dict[str, SwapData]
|
||||
# note: accessing swaps dicts (besides simple lookup) needs swaps_lock
|
||||
self.swaps_lock = threading.Lock()
|
||||
self._swaps = self.wallet.db.get_dict('submarine_swaps') # type: Dict[str, SwapData]
|
||||
self._swaps_by_funding_outpoint = {} # type: Dict[TxOutpoint, SwapData]
|
||||
self._swaps_by_lockup_address = {} # type: Dict[str, SwapData]
|
||||
for payment_hash_hex, swap in self.swaps.items():
|
||||
for payment_hash_hex, swap in self._swaps.items():
|
||||
payment_hash = bytes.fromhex(payment_hash_hex)
|
||||
swap._payment_hash = payment_hash
|
||||
self._add_or_reindex_swap(swap)
|
||||
if not swap.is_reverse and not swap.is_redeemed:
|
||||
self.lnworker.register_hold_invoice(payment_hash, self.hold_invoice_callback)
|
||||
|
||||
self.prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash
|
||||
for k, swap in self.swaps.items():
|
||||
self._prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash
|
||||
for k, swap in self._swaps.items():
|
||||
if swap.prepay_hash is not None:
|
||||
self.prepayments[swap.prepay_hash] = bytes.fromhex(k)
|
||||
self._prepayments[swap.prepay_hash] = bytes.fromhex(k)
|
||||
self.is_server = False # overridden by swapserver plugin if enabled
|
||||
self.is_initialized = asyncio.Event()
|
||||
self.pairs_updated = asyncio.Event()
|
||||
@@ -231,7 +234,9 @@ class SwapManager(Logger):
|
||||
return
|
||||
self.logger.info('start_network: starting main loop')
|
||||
self.network = network
|
||||
for k, swap in self.swaps.items():
|
||||
with self.swaps_lock:
|
||||
swaps_items = list(self._swaps.items())
|
||||
for k, swap in swaps_items:
|
||||
if swap.is_redeemed:
|
||||
continue
|
||||
self.add_lnwatcher_callback(swap)
|
||||
@@ -355,7 +360,9 @@ class SwapManager(Logger):
|
||||
self.lnworker.save_forwarding_failure(payment_key.hex(), failure_message=e)
|
||||
self.lnwatcher.remove_callback(swap.lockup_address)
|
||||
if not swap.is_funded():
|
||||
self.swaps.pop(swap.payment_hash.hex())
|
||||
with self.swaps_lock:
|
||||
self._swaps.pop(swap.payment_hash.hex())
|
||||
# TODO clean-up other swaps dicts, i.e. undo _add_or_reindex_swap()
|
||||
|
||||
@classmethod
|
||||
def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]:
|
||||
@@ -481,12 +488,12 @@ class SwapManager(Logger):
|
||||
|
||||
def get_swap(self, payment_hash: bytes) -> Optional[SwapData]:
|
||||
# for history
|
||||
swap = self.swaps.get(payment_hash.hex())
|
||||
swap = self._swaps.get(payment_hash.hex())
|
||||
if swap:
|
||||
return swap
|
||||
payment_hash = self.prepayments.get(payment_hash)
|
||||
payment_hash = self._prepayments.get(payment_hash)
|
||||
if payment_hash:
|
||||
return self.swaps.get(payment_hash.hex())
|
||||
return self._swaps.get(payment_hash.hex())
|
||||
return None
|
||||
|
||||
def add_lnwatcher_callback(self, swap: SwapData) -> None:
|
||||
@@ -496,7 +503,7 @@ class SwapManager(Logger):
|
||||
async def hold_invoice_callback(self, payment_hash: bytes) -> None:
|
||||
# note: this assumes the wallet has been unlocked
|
||||
key = payment_hash.hex()
|
||||
if swap := self.swaps.get(key):
|
||||
if swap := self._swaps.get(key):
|
||||
if not swap.is_funded():
|
||||
output = self.create_funding_output(swap)
|
||||
self.wallet.txbatcher.add_payment_output('swaps', output)
|
||||
@@ -572,7 +579,7 @@ class SwapManager(Logger):
|
||||
min_final_cltv_expiry_delta=min_final_cltv_expiry_delta,
|
||||
)
|
||||
self.lnworker.bundle_payments([payment_hash, prepay_hash])
|
||||
self.prepayments[prepay_hash] = payment_hash
|
||||
self._prepayments[prepay_hash] = payment_hash
|
||||
else:
|
||||
prepay_invoice = None
|
||||
prepay_hash = None
|
||||
@@ -655,7 +662,7 @@ class SwapManager(Logger):
|
||||
spending_txid=None,
|
||||
)
|
||||
if prepay_hash:
|
||||
self.prepayments[prepay_hash] = payment_hash
|
||||
self._prepayments[prepay_hash] = payment_hash
|
||||
swap._payment_hash = payment_hash
|
||||
self._add_or_reindex_swap(swap)
|
||||
self.add_lnwatcher_callback(swap)
|
||||
@@ -666,8 +673,9 @@ class SwapManager(Logger):
|
||||
invoice = Invoice.from_bech32(invoice)
|
||||
key = invoice.rhash
|
||||
payment_hash = bytes.fromhex(key)
|
||||
assert key in self.swaps
|
||||
swap = self.swaps[key]
|
||||
with self.swaps_lock:
|
||||
assert key in self._swaps
|
||||
swap = self._swaps[key]
|
||||
assert swap.lightning_amount == int(invoice.get_amount_sat())
|
||||
self.wallet.save_invoice(invoice)
|
||||
# check that we have the preimage
|
||||
@@ -799,7 +807,7 @@ class SwapManager(Logger):
|
||||
await asyncio.sleep(0.1)
|
||||
return swap.funding_txid
|
||||
|
||||
def create_funding_output(self, swap):
|
||||
def create_funding_output(self, swap: SwapData) -> PartialTxOutput:
|
||||
return PartialTxOutput.from_address_and_value(swap.lockup_address, swap.onchain_amount)
|
||||
|
||||
def create_funding_tx(
|
||||
@@ -948,11 +956,12 @@ class SwapManager(Logger):
|
||||
return swap.funding_txid
|
||||
|
||||
def _add_or_reindex_swap(self, swap: SwapData) -> None:
|
||||
if swap.payment_hash.hex() not in self.swaps:
|
||||
self.swaps[swap.payment_hash.hex()] = swap
|
||||
if swap._funding_prevout:
|
||||
self._swaps_by_funding_outpoint[swap._funding_prevout] = swap
|
||||
self._swaps_by_lockup_address[swap.lockup_address] = swap
|
||||
with self.swaps_lock:
|
||||
if swap.payment_hash.hex() not in self._swaps:
|
||||
self._swaps[swap.payment_hash.hex()] = swap
|
||||
if swap._funding_prevout:
|
||||
self._swaps_by_funding_outpoint[swap._funding_prevout] = swap
|
||||
self._swaps_by_lockup_address[swap.lockup_address] = swap
|
||||
|
||||
def server_update_pairs(self) -> None:
|
||||
""" for server """
|
||||
@@ -1237,7 +1246,9 @@ class SwapManager(Logger):
|
||||
d = {}
|
||||
# add info about submarine swaps
|
||||
settled_payments = self.lnworker.get_payments(status='settled')
|
||||
for payment_hash_hex, swap in self.swaps.items():
|
||||
with self.swaps_lock:
|
||||
swaps_items = list(self._swaps.items())
|
||||
for payment_hash_hex, swap in swaps_items:
|
||||
txid = swap.spending_txid if swap.is_reverse else swap.funding_txid
|
||||
if txid is None:
|
||||
continue
|
||||
@@ -1290,7 +1301,9 @@ class SwapManager(Logger):
|
||||
def get_pending_swaps(self) -> List[SwapData]:
|
||||
"""Returns a list of swaps with unconfirmed funding tx (which require us to stay online)."""
|
||||
pending_swaps: List[SwapData] = []
|
||||
for swap in self.swaps.values():
|
||||
with self.swaps_lock:
|
||||
swaps = list(self._swaps.values())
|
||||
for swap in swaps:
|
||||
if swap.is_redeemed:
|
||||
# adb data might have been removed after is_redeemed was set.
|
||||
# in that case lnwatcher will no longer fetch the spending tx
|
||||
|
||||
Reference in New Issue
Block a user