From 1bf1de36cbbab3ef9bd06bf5009074f11f0d456d Mon Sep 17 00:00:00 2001 From: ThomasV Date: Thu, 5 Jun 2025 14:37:05 +0200 Subject: [PATCH] txbatcher: - add base_tx to wallet before broadcasting - remove base_tx in find_base_tx, it is local - add unit test in test_tx_batcher --- electrum/txbatcher.py | 73 +++++++++++++++--------------- tests/test_txbatcher.py | 99 ++++++++++++++++++++++++++++------------- 2 files changed, 103 insertions(+), 69 deletions(-) diff --git a/electrum/txbatcher.py b/electrum/txbatcher.py index a7fb8aca5..9dac30b98 100644 --- a/electrum/txbatcher.py +++ b/electrum/txbatcher.py @@ -294,15 +294,6 @@ class TxBatch(Logger): self.logger.info(f'add_sweep_info: {sweep_info.name} {sweep_info.txin.prevout.to_str()}') self.batch_inputs[txin.prevout] = sweep_info - def _find_confirmed_base_tx(self) -> Optional[Transaction]: - for txid in self._batch_txids: - tx_mined_status = self.wallet.adb.get_tx_height(txid) - if tx_mined_status.conf > 0: - tx = self.wallet.adb.get_transaction(txid) - tx = PartialTransaction.from_tx(tx) - tx.add_info_from_wallet(self.wallet) # needed for txid - return tx - @locked def _to_pay_after(self, tx) -> Sequence[PartialTxOutput]: if not tx: @@ -357,34 +348,40 @@ class TxBatch(Logger): return len(self.batch_inputs) == 0 and len(self.batch_payments) == 0 and len(self._batch_txids) == 0 def find_base_tx(self) -> Optional[PartialTransaction]: - if self._batch_txids: - last_txid = self._batch_txids[-1] - if self._prevout: - prev_txid, index = self._prevout.split(':') - spender_txid = self.wallet.adb.db.get_spent_outpoint(prev_txid, int(index)) - tx = self.wallet.adb.get_transaction(spender_txid) - if tx: - if spender_txid == last_txid: - if self._base_tx is None: - # log initialization - self.logger.info(f'found base_tx {last_txid}') - self._base_tx = tx - else: - self.logger.info(f'base tx was replaced by {spender_txid}') - self._new_base_tx(tx) + if not self._prevout: + return + prev_txid, index = self._prevout.split(':') + txid = self.wallet.adb.db.get_spent_outpoint(prev_txid, int(index)) + tx = self.wallet.adb.get_transaction(txid) if txid else None + if not tx: + return + tx = PartialTransaction.from_tx(tx) + tx.add_info_from_wallet(self.wallet) # this sets is_change + + if self.is_mine(txid): + if self._base_tx is None: + self.logger.info(f'found base_tx {txid}') + self._base_tx = tx + else: + self.logger.info(f'base tx was replaced by {tx.txid()}') + self._new_base_tx(tx) + # if tx is confirmed or local, we will start a new batch + tx_mined_status = self.wallet.adb.get_tx_height(txid) + if tx_mined_status.conf > 0: + self.logger.info(f'base tx confirmed {txid}') + self._clear_unconfirmed_sweeps(tx) + self._start_new_batch(tx) + elif tx_mined_status.height in [TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE]: + # fixme: adb may return TX_HEIGHT_LOCAL when not up to date + if self.wallet.adb.is_up_to_date(): + self.logger.info(f'removing local base_tx {txid}') + self.wallet.adb.remove_transaction(txid) + self._start_new_batch(None) + return self._base_tx async def run_iteration(self): - conf_tx = self._find_confirmed_base_tx() - if conf_tx: - self.logger.info(f'base tx confirmed {conf_tx.txid()}') - self._clear_unconfirmed_sweeps(conf_tx) - self._start_new_batch(conf_tx) - base_tx = self.find_base_tx() - if base_tx: - base_tx = PartialTransaction.from_tx(base_tx) - base_tx.add_info_from_wallet(self.wallet) # this sets is_change try: tx = self.create_next_transaction(base_tx) except NoDynamicFeeEstimates: @@ -413,9 +410,10 @@ class TxBatch(Logger): self.wallet.adb.remove_transaction(tx.txid()) return - if await self.wallet.network.try_broadcasting(tx, 'batch'): - self._new_base_tx(tx) - else: + # save local base_tx + self._new_base_tx(tx) + + if not await self.wallet.network.try_broadcasting(tx, 'batch'): # most likely reason is that base_tx is not replaceable # this may be the case if it has children (because we don't pay enough fees to replace them) # or if we are trying to sweep unconfirmed inputs (replacement-adds-unconfirmed error) @@ -528,13 +526,12 @@ class TxBatch(Logger): self._batch_txids.clear() self._base_tx = None self._parent_tx = tx if use_change else None + self._prevout = None @locked def _new_base_tx(self, tx: Transaction): self._prevout = tx.inputs()[0].prevout.to_str() self.storage['prevout'] = self._prevout - tx = PartialTransaction.from_tx(tx) - tx.add_info_from_wallet(self.wallet) # this sets is_change if tx.has_change(): self._batch_txids.append(tx.txid()) self._base_tx = tx diff --git a/tests/test_txbatcher.py b/tests/test_txbatcher.py index 7e1a4e882..d257e5799 100644 --- a/tests/test_txbatcher.py +++ b/tests/test_txbatcher.py @@ -2,6 +2,7 @@ import unittest import logging from unittest import mock import asyncio +from aiorpcx import timeout_after from electrum import storage, bitcoin, keystore, wallet from electrum import SimpleConfig @@ -42,7 +43,6 @@ class MockNetwork(Logger): async def try_broadcasting(self, tx, name): for w in self.wallets: w.adb.receive_tx_callback(tx, tx_height=TX_HEIGHT_UNCONFIRMED) - self._tx_queue.put_nowait(tx) return tx.txid() @@ -57,6 +57,35 @@ class MockNetwork(Logger): SWAP_FUNDING_TX = "01000000000101500e9d67647481864edfb020b5c45e1c40d90f06b0130f9faed1a5149c6d26450000000000ffffffff0226080300000000002200205059c44bf57534303ab8f090f06b7bde58f5d2522440247a1ff6b41bdca9348df312c20100000000160014021d4f3b17921d790e1c022367a5bb078ce4deb402483045022100d41331089a2031396a1db8e4dec6dda9cacefe1288644b92f8e08a23325aa19b02204159230691601f7d726e4e6e0b7124d3377620f400d699a01095f0b0a09ee26a012102d60315c72c0cefd41c6d07883c20b88be3fc37aac7912f0052722a95de0de71600000000" SWAP_CLAIM_TX = "02000000000101f9db8580febd5c0f85b6f1576c83f7739109e3a2d772743e3217e9537fea7e89000000000001000000017005030000000000160014b113a47f3718da3fd161339a6681c150fef2cfe30347304402204c6d40103589b1a8177a37a824f0c66a3a7b22bc570b14c9e07965b56f6ace8f02203a35cffe0ab10de00f3e15ecf5aafdd2c7f6c62da11edd9054a1bce7a9e1455c0120f1939b5723155713855d7ebea6e174f77d41d669269e7f138856c3de190e7a366a8201208763a914d7a62ef0270960fe23f0f351b28caadab62c21838821030bfd61153816df786036ea293edce851d3a4b9f4a1c66bdc1a17f00ffef3d6b167750334ef24b1752102fc8128f17f9e666ea281c702171ab16c1dd2a4337b71f08970f5aa10c608a93268ac00000000" +SWAPDATA = SwapData( + is_reverse=True, + locktime=2420532, + onchain_amount=198694, + lightning_amount=200000, + redeem_script=bytes.fromhex('8201208763a914d7a62ef0270960fe23f0f351b28caadab62c21838821030bfd61153816df786036ea293edce851d3a4b9f4a1c66bdc1a17f00ffef3d6b167750334ef24b1752102fc8128f17f9e666ea281c702171ab16c1dd2a4337b71f08970f5aa10c608a93268ac'), + preimage=bytes.fromhex('f1939b5723155713855d7ebea6e174f77d41d669269e7f138856c3de190e7a36'), + prepay_hash=None, + privkey=bytes.fromhex('58fd0018a9a2737d1d6b81d380df96bf0c858473a9592015508a270a7c9b1d8d'), + lockup_address='tb1q2pvugjl4w56rqw4c7zg0q6mmmev0t5jjy3qzg7sl766phh9fxjxsrtl77t', + receive_address='tb1ql0adrj58g88xgz375yct63rclhv29hv03u0mel', + funding_txid='897eea7f53e917323e7472d7a2e3099173f7836c57f1b6850f5cbdfe8085dbf9', + spending_txid=None, + is_redeemed=False, +) + +txin = PartialTxInput( + prevout=TxOutpoint(txid=bytes.fromhex(SWAPDATA.funding_txid), out_idx=0), +) +txin._trusted_value_sats = SWAPDATA.onchain_amount +txin, locktime = SwapManager.create_claim_txin(txin=txin, swap=SWAPDATA) +SWAP_SWEEP_INFO = SweepInfo( + txin=txin, + cltv_abs=locktime, + txout=None, + name='swap claim', + can_be_batched=True, +) + class TestTxBatcher(ElectrumTestCase): @@ -178,37 +207,9 @@ class TestTxBatcher(ElectrumTestCase): self.maxDiff = None # create wallet wallet = self._create_wallet() - # add swap data - swap_data = SwapData( - is_reverse=True, - locktime=2420532, - onchain_amount=198694, - lightning_amount=200000, - redeem_script=bytes.fromhex('8201208763a914d7a62ef0270960fe23f0f351b28caadab62c21838821030bfd61153816df786036ea293edce851d3a4b9f4a1c66bdc1a17f00ffef3d6b167750334ef24b1752102fc8128f17f9e666ea281c702171ab16c1dd2a4337b71f08970f5aa10c608a93268ac'), - preimage=bytes.fromhex('f1939b5723155713855d7ebea6e174f77d41d669269e7f138856c3de190e7a36'), - prepay_hash=None, - privkey=bytes.fromhex('58fd0018a9a2737d1d6b81d380df96bf0c858473a9592015508a270a7c9b1d8d'), - lockup_address='tb1q2pvugjl4w56rqw4c7zg0q6mmmev0t5jjy3qzg7sl766phh9fxjxsrtl77t', - receive_address='tb1ql0adrj58g88xgz375yct63rclhv29hv03u0mel', - funding_txid='897eea7f53e917323e7472d7a2e3099173f7836c57f1b6850f5cbdfe8085dbf9', - spending_txid=None, - is_redeemed=False, - ) - wallet.adb.db.transactions[swap_data.funding_txid] = tx = Transaction(SWAP_FUNDING_TX) + wallet.adb.db.transactions[SWAPDATA.funding_txid] = tx = Transaction(SWAP_FUNDING_TX) wallet.adb.receive_tx_callback(tx, tx_height=1) - txin = PartialTxInput( - prevout=TxOutpoint(txid=bytes.fromhex(swap_data.funding_txid), out_idx=0), - ) - txin._trusted_value_sats = swap_data.onchain_amount - txin, locktime = SwapManager.create_claim_txin(txin=txin, swap=swap_data) - sweep_info = SweepInfo( - txin=txin, - cltv_abs=locktime, - txout=None, - name='swap claim', - can_be_batched=True, - ) - wallet.txbatcher.add_sweep_input('default', sweep_info) + wallet.txbatcher.add_sweep_input('default', SWAP_SWEEP_INFO) tx = await self.network.next_tx() txid = tx.txid() self.assertEqual(SWAP_CLAIM_TX, str(tx)) @@ -220,3 +221,39 @@ class TestTxBatcher(ElectrumTestCase): # check that we batched with previous tx assert new_tx.inputs()[0].prevout == tx.inputs()[0].prevout == txin.prevout assert output1 in new_tx.outputs() + + @mock.patch.object(wallet.Abstract_Wallet, 'save_db') + async def test_remove_local_base_tx(self, mock_save_db): + """ + The swap claim tx does not get broadcast + we test that txbatcher.find_base_tx() removes the local tx + """ + self.maxDiff = None + # create wallet + wallet = self._create_wallet() + # mock is_up_to_date + wallet.adb.is_up_to_date = lambda: True + # do not broadcast, wait forever + async def do_wait(x, y): + await asyncio.sleep(100000000) + self.network.try_broadcasting = do_wait + # add swap data + wallet.adb.db.transactions[SWAPDATA.funding_txid] = tx = Transaction(SWAP_FUNDING_TX) + wallet.adb.receive_tx_callback(tx, tx_height=1) + wallet.txbatcher.add_sweep_input('default', SWAP_SWEEP_INFO) + txbatch = wallet.txbatcher.tx_batches.get('default') + base_tx = await self._wait_for_base_tx(txbatch) + self.assertEqual(base_tx.txid(), '80a8cbc42de74cb48a09644c1e438c8b39144bd3b55c574f21d89d05c85fed34') + await wallet.stop() + txbatch.batch_inputs.clear() + wallet.start_network(self.network) + base_tx = await self._wait_for_base_tx(txbatch, should_be_none=True) + self.assertEqual(base_tx, None) + + async def _wait_for_base_tx(self, txbatch, should_be_none=False): + async with timeout_after(10): + while True: + base_tx = txbatch._base_tx + if (base_tx is not None) ^ should_be_none: + return base_tx + await asyncio.sleep(0.1)