From 468f496f347860253d5700a0849927a268b3ce90 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Mon, 2 Jun 2025 17:06:08 +0000 Subject: [PATCH] submarine_swaps: make swaps dict thread-safe In general many methods of the SwapManager are called both from the asyncio thread and from the GUI, and hence must be thread-safe. closes https://github.com/spesmilo/electrum/issues/9887 --- electrum/submarine_swaps.py | 59 ++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 536b0ecb2..7d47a4b5a 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -2,6 +2,7 @@ import asyncio import json import os import ssl +import threading from typing import TYPE_CHECKING, Optional, Dict, Sequence, Tuple, Iterable, List from decimal import Decimal import math @@ -206,20 +207,22 @@ class SwapManager(Logger): self.taskgroup = OldTaskGroup() self.dummy_address = DummyAddress.SWAP - self.swaps = self.wallet.db.get_dict('submarine_swaps') # type: Dict[str, SwapData] + # note: accessing swaps dicts (besides simple lookup) needs swaps_lock + self.swaps_lock = threading.Lock() + self._swaps = self.wallet.db.get_dict('submarine_swaps') # type: Dict[str, SwapData] self._swaps_by_funding_outpoint = {} # type: Dict[TxOutpoint, SwapData] self._swaps_by_lockup_address = {} # type: Dict[str, SwapData] - for payment_hash_hex, swap in self.swaps.items(): + for payment_hash_hex, swap in self._swaps.items(): payment_hash = bytes.fromhex(payment_hash_hex) swap._payment_hash = payment_hash self._add_or_reindex_swap(swap) if not swap.is_reverse and not swap.is_redeemed: self.lnworker.register_hold_invoice(payment_hash, self.hold_invoice_callback) - self.prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash - for k, swap in self.swaps.items(): + self._prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash + for k, swap in self._swaps.items(): if swap.prepay_hash is not None: - self.prepayments[swap.prepay_hash] = bytes.fromhex(k) + self._prepayments[swap.prepay_hash] = bytes.fromhex(k) self.is_server = False # overridden by swapserver plugin if enabled self.is_initialized = asyncio.Event() self.pairs_updated = asyncio.Event() @@ -231,7 +234,9 @@ class SwapManager(Logger): return self.logger.info('start_network: starting main loop') self.network = network - for k, swap in self.swaps.items(): + with self.swaps_lock: + swaps_items = list(self._swaps.items()) + for k, swap in swaps_items: if swap.is_redeemed: continue self.add_lnwatcher_callback(swap) @@ -355,7 +360,9 @@ class SwapManager(Logger): self.lnworker.save_forwarding_failure(payment_key.hex(), failure_message=e) self.lnwatcher.remove_callback(swap.lockup_address) if not swap.is_funded(): - self.swaps.pop(swap.payment_hash.hex()) + with self.swaps_lock: + self._swaps.pop(swap.payment_hash.hex()) + # TODO clean-up other swaps dicts, i.e. undo _add_or_reindex_swap() @classmethod def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]: @@ -481,12 +488,12 @@ class SwapManager(Logger): def get_swap(self, payment_hash: bytes) -> Optional[SwapData]: # for history - swap = self.swaps.get(payment_hash.hex()) + swap = self._swaps.get(payment_hash.hex()) if swap: return swap - payment_hash = self.prepayments.get(payment_hash) + payment_hash = self._prepayments.get(payment_hash) if payment_hash: - return self.swaps.get(payment_hash.hex()) + return self._swaps.get(payment_hash.hex()) return None def add_lnwatcher_callback(self, swap: SwapData) -> None: @@ -496,7 +503,7 @@ class SwapManager(Logger): async def hold_invoice_callback(self, payment_hash: bytes) -> None: # note: this assumes the wallet has been unlocked key = payment_hash.hex() - if swap := self.swaps.get(key): + if swap := self._swaps.get(key): if not swap.is_funded(): output = self.create_funding_output(swap) self.wallet.txbatcher.add_payment_output('swaps', output) @@ -572,7 +579,7 @@ class SwapManager(Logger): min_final_cltv_expiry_delta=min_final_cltv_expiry_delta, ) self.lnworker.bundle_payments([payment_hash, prepay_hash]) - self.prepayments[prepay_hash] = payment_hash + self._prepayments[prepay_hash] = payment_hash else: prepay_invoice = None prepay_hash = None @@ -655,7 +662,7 @@ class SwapManager(Logger): spending_txid=None, ) if prepay_hash: - self.prepayments[prepay_hash] = payment_hash + self._prepayments[prepay_hash] = payment_hash swap._payment_hash = payment_hash self._add_or_reindex_swap(swap) self.add_lnwatcher_callback(swap) @@ -666,8 +673,9 @@ class SwapManager(Logger): invoice = Invoice.from_bech32(invoice) key = invoice.rhash payment_hash = bytes.fromhex(key) - assert key in self.swaps - swap = self.swaps[key] + with self.swaps_lock: + assert key in self._swaps + swap = self._swaps[key] assert swap.lightning_amount == int(invoice.get_amount_sat()) self.wallet.save_invoice(invoice) # check that we have the preimage @@ -799,7 +807,7 @@ class SwapManager(Logger): await asyncio.sleep(0.1) return swap.funding_txid - def create_funding_output(self, swap): + def create_funding_output(self, swap: SwapData) -> PartialTxOutput: return PartialTxOutput.from_address_and_value(swap.lockup_address, swap.onchain_amount) def create_funding_tx( @@ -948,11 +956,12 @@ class SwapManager(Logger): return swap.funding_txid def _add_or_reindex_swap(self, swap: SwapData) -> None: - if swap.payment_hash.hex() not in self.swaps: - self.swaps[swap.payment_hash.hex()] = swap - if swap._funding_prevout: - self._swaps_by_funding_outpoint[swap._funding_prevout] = swap - self._swaps_by_lockup_address[swap.lockup_address] = swap + with self.swaps_lock: + if swap.payment_hash.hex() not in self._swaps: + self._swaps[swap.payment_hash.hex()] = swap + if swap._funding_prevout: + self._swaps_by_funding_outpoint[swap._funding_prevout] = swap + self._swaps_by_lockup_address[swap.lockup_address] = swap def server_update_pairs(self) -> None: """ for server """ @@ -1237,7 +1246,9 @@ class SwapManager(Logger): d = {} # add info about submarine swaps settled_payments = self.lnworker.get_payments(status='settled') - for payment_hash_hex, swap in self.swaps.items(): + with self.swaps_lock: + swaps_items = list(self._swaps.items()) + for payment_hash_hex, swap in swaps_items: txid = swap.spending_txid if swap.is_reverse else swap.funding_txid if txid is None: continue @@ -1290,7 +1301,9 @@ class SwapManager(Logger): def get_pending_swaps(self) -> List[SwapData]: """Returns a list of swaps with unconfirmed funding tx (which require us to stay online).""" pending_swaps: List[SwapData] = [] - for swap in self.swaps.values(): + with self.swaps_lock: + swaps = list(self._swaps.values()) + for swap in swaps: if swap.is_redeemed: # adb data might have been removed after is_redeemed was set. # in that case lnwatcher will no longer fetch the spending tx