1
0

txbatcher: be careful when removing local transactions

1. Do not remove local transaction in find_base_tx.

This logic was intended to cleanup claim transactions that are
never broadcast (for example, if the counterparty gets a refund)
(see 1bf1de36cb)

However, this code is too unspecific and may result in fund loss,
because the transaction being removed may contain outgoing payments.
For example, if the electrum server is not responsive, the tx will
be seen as local and deleted. In that case, another payment will
be attempted, thus paying twice.

2. Do not remove tx after try_broadcasting returns False.

The server might be lying to us. We can only remove the local tx
if there is a base_tx, because the next tx we create will try to
spend the same output.
This commit is contained in:
ThomasV
2025-08-18 14:43:47 +02:00
parent e85a3f2d3f
commit 5f30f2a0c0
2 changed files with 30 additions and 41 deletions

View File

@@ -222,34 +222,6 @@ class TestTxBatcher(ElectrumTestCase):
assert new_tx.inputs()[0].prevout == tx.inputs()[0].prevout == txin.prevout
assert output1 in new_tx.outputs()
@mock.patch.object(wallet.Abstract_Wallet, 'save_db')
async def test_remove_local_base_tx(self, mock_save_db):
"""
The swap claim tx does not get broadcast
we test that txbatcher.find_base_tx() removes the local tx
"""
self.maxDiff = None
# create wallet
wallet = self._create_wallet()
# mock is_up_to_date
wallet.adb.is_up_to_date = lambda: True
# do not broadcast, wait forever
async def do_wait(x, y):
await asyncio.sleep(100000000)
self.network.try_broadcasting = do_wait
# add swap data
wallet.adb.db.transactions[SWAPDATA.funding_txid] = tx = Transaction(SWAP_FUNDING_TX)
wallet.adb.receive_tx_callback(tx, tx_height=1)
wallet.txbatcher.add_sweep_input('default', SWAP_SWEEP_INFO)
txbatch = wallet.txbatcher.tx_batches.get('default')
base_tx = await self._wait_for_base_tx(txbatch)
self.assertEqual(base_tx.txid(), '80a8cbc42de74cb48a09644c1e438c8b39144bd3b55c574f21d89d05c85fed34')
await wallet.stop()
txbatch.batch_inputs.clear()
wallet.start_network(self.network)
base_tx = await self._wait_for_base_tx(txbatch, should_be_none=True)
self.assertEqual(base_tx, None)
async def _wait_for_base_tx(self, txbatch, should_be_none=False):
async with timeout_after(10):
while True: