From 0e40be5fb527832a68741034f280f4060e4070c9 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Mon, 3 Mar 2025 13:49:59 +0100 Subject: [PATCH] swaps: replace request_swap_for_tx with request_swap_for_amount, as this uses less side effects (change backported from batch_payment_manager) --- electrum/gui/qt/confirm_tx_dialog.py | 2 +- electrum/gui/qt/send_tab.py | 9 +++++---- electrum/submarine_swaps.py | 17 +++++------------ electrum/transaction.py | 9 +++++++-- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/electrum/gui/qt/confirm_tx_dialog.py b/electrum/gui/qt/confirm_tx_dialog.py index 6811a3bd2..9abed8ea1 100644 --- a/electrum/gui/qt/confirm_tx_dialog.py +++ b/electrum/gui/qt/confirm_tx_dialog.py @@ -560,7 +560,7 @@ class TxEditor(WindowModalDialog): self.error = long_warning else: messages.append(long_warning) - if self.tx.has_dummy_output(DummyAddress.SWAP): + if self.tx.get_dummy_output(DummyAddress.SWAP): messages.append(_('This transaction will send funds to a submarine swap.')) # warn if spending unconf if any((txin.block_height is not None and txin.block_height<=0) for txin in self.tx.inputs()): diff --git a/electrum/gui/qt/send_tab.py b/electrum/gui/qt/send_tab.py index 356cf90bb..ab2cd7541 100644 --- a/electrum/gui/qt/send_tab.py +++ b/electrum/gui/qt/send_tab.py @@ -339,18 +339,19 @@ class SendTab(QWidget, MessageBoxMixin, Logger): # user cancelled return - if tx.has_dummy_output(DummyAddress.SWAP): + if swap_dummy_output := tx.get_dummy_output(DummyAddress.SWAP): sm = self.wallet.lnworker.swap_manager with self.window.create_sm_transport() as transport: if not self.window.initialize_swap_manager(transport): return - coro = sm.request_swap_for_tx(transport, tx) + coro = sm.request_swap_for_amount(transport, swap_dummy_output.value) try: - swap, invoice, tx = self.window.run_coroutine_dialog(coro, _('Requesting swap invoice...')) + swap, invoice = self.window.run_coroutine_dialog(coro, _('Requesting swap invoice...')) except SwapServerError as e: self.show_error(str(e)) return - assert not tx.has_dummy_output(DummyAddress.SWAP) + tx.replace_output_address(DummyAddress.SWAP, swap.lockup_address) + assert tx.get_dummy_output(DummyAddress.SWAP) is None tx.swap_invoice = invoice tx.swap_payment_hash = swap.payment_hash diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 09159ebb2..9ea5f9d04 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -826,21 +826,14 @@ class SwapManager(Logger): return tx @log_exceptions - async def request_swap_for_tx(self, transport, tx: 'PartialTransaction') -> Optional[Tuple[SwapData, str, PartialTransaction]]: - for o in tx.outputs(): - if o.address == self.dummy_address: - change_amount = o.value - break - else: - return + async def request_swap_for_amount(self, transport, onchain_amount) -> Optional[Tuple[SwapData, str]]: await self.is_initialized.wait() - lightning_amount_sat = self.get_recv_amount(change_amount, is_reverse=False) + lightning_amount_sat = self.get_recv_amount(onchain_amount, is_reverse=False) swap, invoice = await self.request_normal_swap( transport, - lightning_amount_sat = lightning_amount_sat, - expected_onchain_amount_sat=change_amount) - tx.replace_output_address(DummyAddress.SWAP, swap.lockup_address) - return swap, invoice, tx + lightning_amount_sat=lightning_amount_sat, + expected_onchain_amount_sat=onchain_amount) + return swap, invoice @log_exceptions async def broadcast_funding_tx(self, swap: SwapData, tx: Transaction) -> None: diff --git a/electrum/transaction.py b/electrum/transaction.py index 4aebe140c..fcc5f338f 100644 --- a/electrum/transaction.py +++ b/electrum/transaction.py @@ -1247,8 +1247,13 @@ class Transaction: def get_change_outputs(self): return [o for o in self._outputs if o.is_change] - def has_dummy_output(self, dummy_addr: str) -> bool: - return len(self.get_output_idxs_from_address(dummy_addr)) == 1 + def get_dummy_output(self, dummy_addr: str) -> Optional['PartialTxOutput']: + idxs = self.get_output_idxs_from_address(dummy_addr) + if not idxs: + return + assert len(idxs) == 1 + for i in idxs: + return self.outputs()[i] def output_value_for_address(self, addr): # assumes exactly one output has that address