1
0

Merge pull request #10166 from spesmilo/txbatcher_fix_to_sweep_now

txbatcher: to_sweep_now should be a list
This commit is contained in:
ghost43
2025-08-21 21:03:49 +00:00
committed by GitHub

View File

@@ -315,19 +315,19 @@ class TxBatch(Logger):
return to_pay return to_pay
@locked @locked
def _to_sweep_after(self, tx: Optional[PartialTransaction]) -> Dict[str, SweepInfo]: def _to_sweep_after(self, tx: Optional[PartialTransaction]) -> Dict[TxOutpoint, SweepInfo]:
tx_prevouts = set(txin.prevout for txin in tx.inputs()) if tx else set() tx_prevouts = set(txin.prevout for txin in tx.inputs()) if tx else set()
result = [] result = [] # type: list[tuple[TxOutpoint, SweepInfo]]
for k, v in list(self.batch_inputs.items()): for prevout, sweep_info in list(self.batch_inputs.items()):
prevout = v.txin.prevout assert prevout == sweep_info.txin.prevout
prev_txid, index = prevout.to_str().split(':') prev_txid, index = prevout.to_str().split(':')
if not self.wallet.adb.db.get_transaction(prev_txid): if not self.wallet.adb.db.get_transaction(prev_txid):
continue continue
if v.is_anchor(): if sweep_info.is_anchor():
prev_tx_mined_status = self.wallet.adb.get_tx_height(prev_txid) prev_tx_mined_status = self.wallet.adb.get_tx_height(prev_txid)
if prev_tx_mined_status.conf > 0: if prev_tx_mined_status.conf > 0:
self.logger.info(f"anchor not needed {k}") self.logger.info(f"anchor not needed {prevout}")
self.batch_inputs.pop(k) # note: if the input is already in a batch tx, this will trigger assert error self.batch_inputs.pop(prevout) # note: if the input is already in a batch tx, this will trigger assert error
continue continue
if spender_txid := self.wallet.adb.db.get_spent_outpoint(prev_txid, int(index)): if spender_txid := self.wallet.adb.db.get_spent_outpoint(prev_txid, int(index)):
tx_mined_status = self.wallet.adb.get_tx_height(spender_txid) tx_mined_status = self.wallet.adb.get_tx_height(spender_txid)
@@ -335,7 +335,7 @@ class TxBatch(Logger):
continue continue
if prevout in tx_prevouts: if prevout in tx_prevouts:
continue continue
result.append((k,v)) result.append((prevout, sweep_info))
return dict(result) return dict(result)
def _should_bump_fee(self, base_tx: Optional[PartialTransaction]) -> bool: def _should_bump_fee(self, base_tx: Optional[PartialTransaction]) -> bool:
@@ -462,11 +462,11 @@ class TxBatch(Logger):
def create_next_transaction(self, base_tx: Optional[PartialTransaction]) -> Optional[PartialTransaction]: def create_next_transaction(self, base_tx: Optional[PartialTransaction]) -> Optional[PartialTransaction]:
to_pay = self._to_pay_after(base_tx) to_pay = self._to_pay_after(base_tx)
to_sweep = self._to_sweep_after(base_tx) to_sweep = self._to_sweep_after(base_tx)
to_sweep_now = {} to_sweep_now = [] # type: list[SweepInfo]
for k, v in to_sweep.items(): for k, v in to_sweep.items():
can_broadcast, wanted_height = self._can_broadcast(v, base_tx) can_broadcast, wanted_height = self._can_broadcast(v, base_tx)
if can_broadcast: if can_broadcast:
to_sweep_now[k] = v to_sweep_now.append(v)
else: else:
self.wallet.add_future_tx(v, wanted_height) self.wallet.add_future_tx(v, wanted_height)
while True: while True:
@@ -505,16 +505,16 @@ class TxBatch(Logger):
self, self,
*, *,
base_tx: Optional[PartialTransaction], base_tx: Optional[PartialTransaction],
to_sweep: Mapping[str, SweepInfo], to_sweep: Sequence[SweepInfo],
to_pay: Sequence[PartialTxOutput], to_pay: Sequence[PartialTxOutput],
) -> PartialTransaction: ) -> PartialTransaction:
self.logger.info(f'to_sweep: {list(to_sweep.keys())}') self.logger.info(f'to_sweep: {[x.txin.prevout.to_str() for x in to_sweep]}')
self.logger.info(f'to_pay: {to_pay}') self.logger.info(f'to_pay: {to_pay}')
inputs = [] # type: List[PartialTxInput] inputs = [] # type: List[PartialTxInput]
outputs = [] # type: List[PartialTxOutput] outputs = [] # type: List[PartialTxOutput]
locktime = base_tx.locktime if base_tx else None locktime = base_tx.locktime if base_tx else None
# sort inputs so that txin-txout pairs are first # sort inputs so that txin-txout pairs are first
for sweep_info in sorted(to_sweep.values(), key=lambda x: not bool(x.txout)): for sweep_info in sorted(to_sweep, key=lambda x: not bool(x.txout)):
if sweep_info.cltv_abs is not None: if sweep_info.cltv_abs is not None:
if locktime is None or locktime < sweep_info.cltv_abs: if locktime is None or locktime < sweep_info.cltv_abs:
# nLockTime must be greater than or equal to the stack operand. # nLockTime must be greater than or equal to the stack operand.