1
0

submarine_swaps: make swaps dict thread-safe

In general many methods of the SwapManager are called both from the asyncio thread
and from the GUI, and hence must be thread-safe.

closes https://github.com/spesmilo/electrum/issues/9887
This commit is contained in:
SomberNight
2025-06-02 17:06:08 +00:00
parent 4539269960
commit 468f496f34

View File

@@ -2,6 +2,7 @@ import asyncio
import json
import os
import ssl
import threading
from typing import TYPE_CHECKING, Optional, Dict, Sequence, Tuple, Iterable, List
from decimal import Decimal
import math
@@ -206,20 +207,22 @@ class SwapManager(Logger):
self.taskgroup = OldTaskGroup()
self.dummy_address = DummyAddress.SWAP
self.swaps = self.wallet.db.get_dict('submarine_swaps') # type: Dict[str, SwapData]
# note: accessing swaps dicts (besides simple lookup) needs swaps_lock
self.swaps_lock = threading.Lock()
self._swaps = self.wallet.db.get_dict('submarine_swaps') # type: Dict[str, SwapData]
self._swaps_by_funding_outpoint = {} # type: Dict[TxOutpoint, SwapData]
self._swaps_by_lockup_address = {} # type: Dict[str, SwapData]
for payment_hash_hex, swap in self.swaps.items():
for payment_hash_hex, swap in self._swaps.items():
payment_hash = bytes.fromhex(payment_hash_hex)
swap._payment_hash = payment_hash
self._add_or_reindex_swap(swap)
if not swap.is_reverse and not swap.is_redeemed:
self.lnworker.register_hold_invoice(payment_hash, self.hold_invoice_callback)
self.prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash
for k, swap in self.swaps.items():
self._prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash
for k, swap in self._swaps.items():
if swap.prepay_hash is not None:
self.prepayments[swap.prepay_hash] = bytes.fromhex(k)
self._prepayments[swap.prepay_hash] = bytes.fromhex(k)
self.is_server = False # overridden by swapserver plugin if enabled
self.is_initialized = asyncio.Event()
self.pairs_updated = asyncio.Event()
@@ -231,7 +234,9 @@ class SwapManager(Logger):
return
self.logger.info('start_network: starting main loop')
self.network = network
for k, swap in self.swaps.items():
with self.swaps_lock:
swaps_items = list(self._swaps.items())
for k, swap in swaps_items:
if swap.is_redeemed:
continue
self.add_lnwatcher_callback(swap)
@@ -355,7 +360,9 @@ class SwapManager(Logger):
self.lnworker.save_forwarding_failure(payment_key.hex(), failure_message=e)
self.lnwatcher.remove_callback(swap.lockup_address)
if not swap.is_funded():
self.swaps.pop(swap.payment_hash.hex())
with self.swaps_lock:
self._swaps.pop(swap.payment_hash.hex())
# TODO clean-up other swaps dicts, i.e. undo _add_or_reindex_swap()
@classmethod
def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]:
@@ -481,12 +488,12 @@ class SwapManager(Logger):
def get_swap(self, payment_hash: bytes) -> Optional[SwapData]:
# for history
swap = self.swaps.get(payment_hash.hex())
swap = self._swaps.get(payment_hash.hex())
if swap:
return swap
payment_hash = self.prepayments.get(payment_hash)
payment_hash = self._prepayments.get(payment_hash)
if payment_hash:
return self.swaps.get(payment_hash.hex())
return self._swaps.get(payment_hash.hex())
return None
def add_lnwatcher_callback(self, swap: SwapData) -> None:
@@ -496,7 +503,7 @@ class SwapManager(Logger):
async def hold_invoice_callback(self, payment_hash: bytes) -> None:
# note: this assumes the wallet has been unlocked
key = payment_hash.hex()
if swap := self.swaps.get(key):
if swap := self._swaps.get(key):
if not swap.is_funded():
output = self.create_funding_output(swap)
self.wallet.txbatcher.add_payment_output('swaps', output)
@@ -572,7 +579,7 @@ class SwapManager(Logger):
min_final_cltv_expiry_delta=min_final_cltv_expiry_delta,
)
self.lnworker.bundle_payments([payment_hash, prepay_hash])
self.prepayments[prepay_hash] = payment_hash
self._prepayments[prepay_hash] = payment_hash
else:
prepay_invoice = None
prepay_hash = None
@@ -655,7 +662,7 @@ class SwapManager(Logger):
spending_txid=None,
)
if prepay_hash:
self.prepayments[prepay_hash] = payment_hash
self._prepayments[prepay_hash] = payment_hash
swap._payment_hash = payment_hash
self._add_or_reindex_swap(swap)
self.add_lnwatcher_callback(swap)
@@ -666,8 +673,9 @@ class SwapManager(Logger):
invoice = Invoice.from_bech32(invoice)
key = invoice.rhash
payment_hash = bytes.fromhex(key)
assert key in self.swaps
swap = self.swaps[key]
with self.swaps_lock:
assert key in self._swaps
swap = self._swaps[key]
assert swap.lightning_amount == int(invoice.get_amount_sat())
self.wallet.save_invoice(invoice)
# check that we have the preimage
@@ -799,7 +807,7 @@ class SwapManager(Logger):
await asyncio.sleep(0.1)
return swap.funding_txid
def create_funding_output(self, swap):
def create_funding_output(self, swap: SwapData) -> PartialTxOutput:
return PartialTxOutput.from_address_and_value(swap.lockup_address, swap.onchain_amount)
def create_funding_tx(
@@ -948,11 +956,12 @@ class SwapManager(Logger):
return swap.funding_txid
def _add_or_reindex_swap(self, swap: SwapData) -> None:
if swap.payment_hash.hex() not in self.swaps:
self.swaps[swap.payment_hash.hex()] = swap
if swap._funding_prevout:
self._swaps_by_funding_outpoint[swap._funding_prevout] = swap
self._swaps_by_lockup_address[swap.lockup_address] = swap
with self.swaps_lock:
if swap.payment_hash.hex() not in self._swaps:
self._swaps[swap.payment_hash.hex()] = swap
if swap._funding_prevout:
self._swaps_by_funding_outpoint[swap._funding_prevout] = swap
self._swaps_by_lockup_address[swap.lockup_address] = swap
def server_update_pairs(self) -> None:
""" for server """
@@ -1237,7 +1246,9 @@ class SwapManager(Logger):
d = {}
# add info about submarine swaps
settled_payments = self.lnworker.get_payments(status='settled')
for payment_hash_hex, swap in self.swaps.items():
with self.swaps_lock:
swaps_items = list(self._swaps.items())
for payment_hash_hex, swap in swaps_items:
txid = swap.spending_txid if swap.is_reverse else swap.funding_txid
if txid is None:
continue
@@ -1290,7 +1301,9 @@ class SwapManager(Logger):
def get_pending_swaps(self) -> List[SwapData]:
"""Returns a list of swaps with unconfirmed funding tx (which require us to stay online)."""
pending_swaps: List[SwapData] = []
for swap in self.swaps.values():
with self.swaps_lock:
swaps = list(self._swaps.values())
for swap in swaps:
if swap.is_redeemed:
# adb data might have been removed after is_redeemed was set.
# in that case lnwatcher will no longer fetch the spending tx