diff --git a/electrum/txbatcher.py b/electrum/txbatcher.py index ca48d09ac..2664636ba 100644 --- a/electrum/txbatcher.py +++ b/electrum/txbatcher.py @@ -115,7 +115,7 @@ class TxBatcher(Logger): def _maybe_create_new_batch(self, key, fee_policy_descriptor: str): if key not in self.storage: - self.storage[key] = { 'fee_policy': fee_policy_descriptor, 'txids': [] } + self.storage[key] = { 'fee_policy': fee_policy_descriptor, 'txids': [], 'prevout': None } self.tx_batches[key] = TxBatch(self.wallet, self.storage[key]) elif self.storage[key]['fee_policy'] != fee_policy_descriptor: # maybe update policy? @@ -127,6 +127,11 @@ class TxBatcher(Logger): self.storage.pop(key) self.tx_batches.pop(key) + def find_batch_by_prevout(self, prevout: str) -> Optional['TxBatch']: + for k, v in self.tx_batches.items(): + if v._prevout == prevout: + return v + def find_batch_of_txid(self, txid) -> str: for k, v in self.tx_batches.items(): if v.is_mine(txid): @@ -184,22 +189,15 @@ class TxBatch(Logger): def __init__(self, wallet, storage: StoredDict): Logger.__init__(self) self.wallet = wallet + self.storage = storage self.lock = threading.RLock() self.batch_payments = [] # list of payments we need to make self.batch_inputs = {} # list of inputs we need to sweep # list of tx that were broadcast. Each tx is a RBF replacement of the previous one. Ony one can get mined. + self._prevout = storage.get('prevout') self._batch_txids = storage['txids'] self.fee_policy = FeePolicy(storage['fee_policy']) self._base_tx = None # current batch tx. last element of batch_txids - if self._batch_txids: - last_txid = self._batch_txids[-1] - tx = self.wallet.adb.get_transaction(last_txid) - if tx: - tx = PartialTransaction.from_tx(tx) - tx.add_info_from_wallet(self.wallet) # this adds input amounts - self._base_tx = tx - self.logger.info(f'found base_tx {last_txid}') - self._parent_tx = None self._unconfirmed_sweeps = set() # list of inputs we are sweeping (until spending tx is confirmed) @@ -242,9 +240,6 @@ class TxBatch(Logger): self.logger.info(f'add_sweep_info: {sweep_info.name} {sweep_info.txin.prevout.to_str()}') self.batch_inputs[txin.prevout] = sweep_info - def get_base_tx(self) -> Optional[Transaction]: - return self._base_tx - def _find_confirmed_base_tx(self) -> Optional[Transaction]: for txid in self._batch_txids: tx_mined_status = self.wallet.adb.get_tx_height(txid) @@ -301,6 +296,26 @@ class TxBatch(Logger): # todo: require more than one confirmation return len(self.batch_inputs) == 0 and len(self.batch_payments) == 0 and len(self._batch_txids) == 0 + def find_base_tx(self) -> Optional[PartialTransaction]: + if self._batch_txids: + last_txid = self._batch_txids[-1] + if self._prevout: + prev_txid, index = self._prevout.split(':') + spender_txid = self.wallet.adb.db.get_spent_outpoint(prev_txid, int(index)) + tx = self.wallet.adb.get_transaction(spender_txid) + if tx: + tx = PartialTransaction.from_tx(tx) + tx.add_info_from_wallet(self.wallet) # this adds input amounts + if spender_txid == last_txid: + if self._base_tx is None: + # log initialization + self.logger.info(f'found base_tx {last_txid}') + self._base_tx = tx + else: + self.logger.info(f'base tx was replaced by {spender_txid}') + self._new_base_tx(tx) + return self._base_tx + async def run_iteration(self, password): conf_tx = self._find_confirmed_base_tx() if conf_tx: @@ -308,8 +323,7 @@ class TxBatch(Logger): self._clear_unconfirmed_sweeps(conf_tx) self._start_new_batch(conf_tx) - base_tx = self.get_base_tx() - # if base tx has been RBF-replaced, detect it here + base_tx = self.find_base_tx() try: tx = self.create_next_transaction(base_tx, password) except NoDynamicFeeEstimates: @@ -330,12 +344,7 @@ class TxBatch(Logger): if await self.wallet.network.try_broadcasting(tx, 'batch'): self.wallet.adb.add_transaction(tx) - if tx.has_change(): - self._batch_txids.append(tx.txid()) - self._base_tx = tx - else: - self.logger.info(f'starting new batch because current base tx does not have change') - self._start_new_batch(tx) + self._new_base_tx(tx) else: # most likely reason is that base_tx is not replaceable # this may be the case if it has children (because we don't pay enough fees to replace them) @@ -345,7 +354,6 @@ class TxBatch(Logger): self.logger.info(f'starting new batch because could not broadcast') self._start_new_batch(base_tx) - def create_next_transaction(self, base_tx, password): to_pay = self._to_pay_after(base_tx) to_sweep = self._to_sweep_after(base_tx) @@ -370,6 +378,15 @@ class TxBatch(Logger): self.logger.info(f'{str(tx)}') return tx + def add_sweep_info_to_tx(self, base_tx): + for txin in base_tx.inputs(): + if sweep_info := self.batch_inputs.get(txin.prevout): + if hasattr(sweep_info.txin, 'make_witness'): + txin.make_witness = sweep_info.txin.make_witness + txin.privkey = sweep_info.txin.privkey + txin.witness_script = sweep_info.txin.witness_script + txin.script_sig = sweep_info.txin.script_sig + def _create_batch_tx(self, base_tx, to_sweep, to_pay, password): self.logger.info(f'to_sweep: {list(to_sweep.keys())}') self.logger.info(f'to_pay: {to_pay}') @@ -388,15 +405,8 @@ class TxBatch(Logger): self.logger.info(f'locktime: {locktime}') outputs += to_pay inputs += self._create_inputs_from_tx_change(self._parent_tx) if self._parent_tx else [] - # add sweep info base_tx inputs if base_tx: - for txin in base_tx.inputs(): - if sweep_info := self.batch_inputs.get(txin.prevout): - if hasattr(sweep_info.txin, 'make_witness'): - txin.make_witness = sweep_info.txin.make_witness - txin.privkey = sweep_info.txin.privkey - txin.witness_script = sweep_info.txin.witness_script - txin.script_sig = sweep_info.txin.script_sig + self.add_sweep_info_to_tx(base_tx) # create tx tx = self.wallet.make_unsigned_transaction( fee_policy=self.fee_policy, @@ -429,6 +439,17 @@ class TxBatch(Logger): self._base_tx = None self._parent_tx = tx if use_change else None + @locked + def _new_base_tx(self, tx: Transaction): + self._prevout = tx.inputs()[0].prevout.to_str() + self.storage['prevout'] = self._prevout + if tx.has_change(): + self._batch_txids.append(tx.txid()) + self._base_tx = tx + else: + self.logger.info(f'starting new batch because current base tx does not have change') + self._start_new_batch(tx) + def _create_inputs_from_tx_change(self, parent_tx): inputs = [] for o in parent_tx.get_change_outputs(): diff --git a/electrum/wallet.py b/electrum/wallet.py index 423f147c5..12f3e407e 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -2219,8 +2219,6 @@ class Abstract_Wallet(ABC, Logger, EventListener): tx.remove_signatures() if not self.can_rbf_tx(tx): raise CannotBumpFee(_('Transaction is final')) - if self.txbatcher.is_mine(tx.txid()): - raise CannotBumpFee('Transaction managed by txbatcher') new_fee_rate = quantize_feerate(new_fee_rate) # strip excess precision tx.add_info_from_wallet(self) if tx.is_missing_info_from_network(): @@ -2445,8 +2443,6 @@ class Abstract_Wallet(ABC, Logger, EventListener): # do not mutate LN funding txs, as that would change their txid if not is_dscancel and self.is_lightning_funding_tx(tx.txid()): return False - if self.txbatcher.is_mine(tx.txid()): - return False return tx.is_rbf_enabled() def cpfp(self, tx: Transaction, fee: int) -> Optional[PartialTransaction]: @@ -2683,6 +2679,12 @@ class Abstract_Wallet(ABC, Logger, EventListener): if sh_danger.needs_confirm() and not ignore_warnings: raise TransactionPotentiallyDangerousException('Not signing transaction:\n' + sh_danger.get_long_message()) + # find out if we are replacing a txbatcher transaction + prevout_str = tx.inputs()[0].prevout.to_str() + batch = self.txbatcher.find_batch_by_prevout(prevout_str) + if batch: + batch.add_sweep_info_to_tx(tx) + # sign with make_witness for i, txin in enumerate(tx.inputs()): if hasattr(txin, 'make_witness'):