1
0

txbatcher: add type-hints

This commit is contained in:
SomberNight
2025-06-26 13:52:07 +00:00
parent 56684c049a
commit 2b92e8a97a
3 changed files with 92 additions and 75 deletions

View File

@@ -2,7 +2,7 @@
# Distributed under the MIT software license, see the accompanying # Distributed under the MIT software license, see the accompanying
# file LICENCE or http://www.opensource.org/licenses/mit-license.php # file LICENCE or http://www.opensource.org/licenses/mit-license.php
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
from . import util from . import util
from .util import TxMinedInfo, BelowDustLimit from .util import TxMinedInfo, BelowDustLimit
@@ -88,13 +88,13 @@ class LNWatcher(Logger, EventListener):
if chan.need_to_subscribe(): if chan.need_to_subscribe():
self.add_callback(address, callback) self.add_callback(address, callback)
def unwatch_channel(self, address, funding_outpoint): def unwatch_channel(self, address: str, funding_outpoint: str) -> None:
self.logger.info(f'unwatching {funding_outpoint}') self.logger.info(f'unwatching {funding_outpoint}')
self.remove_callback(address) self.remove_callback(address)
@ignore_exceptions @ignore_exceptions
@log_exceptions @log_exceptions
async def check_onchain_situation(self, address, funding_outpoint): async def check_onchain_situation(self, address: str, funding_outpoint: str) -> None:
# early return if address has not been added yet # early return if address has not been added yet
if not self.adb.is_mine(address): if not self.adb.is_mine(address):
return return
@@ -142,17 +142,18 @@ class LNWatcher(Logger, EventListener):
self._pending_force_closes.discard(chan) self._pending_force_closes.discard(chan)
await self.lnworker.handle_onchain_state(chan) await self.lnworker.handle_onchain_state(chan)
async def sweep_commitment_transaction(self, funding_outpoint, closing_tx) -> bool: async def sweep_commitment_transaction(self, funding_outpoint: str, closing_tx: Transaction) -> bool:
"""This function is called when a channel was closed. In this case """This function is called when a channel was closed. In this case
we need to check for redeemable outputs of the commitment transaction we need to check for redeemable outputs of the commitment transaction
or spenders down the line (HTLC-timeout/success transactions). or spenders down the line (HTLC-timeout/success transactions).
Returns whether we should continue to monitor. Returns whether we should continue to monitor.
Side-effécts: Side-effects:
- sets defaults labels - sets defaults labels
- populates wallet._accounting_addresses - populates wallet._accounting_addresses
""" """
assert closing_tx
chan = self.lnworker.channel_by_txo(funding_outpoint) chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan: if not chan:
return False return False
@@ -188,7 +189,8 @@ class LNWatcher(Logger, EventListener):
self.maybe_add_accounting_address(spender_txid, sweep_info) self.maybe_add_accounting_address(spender_txid, sweep_info)
else: else:
keep_watching |= was_added keep_watching |= was_added
self.maybe_add_pending_forceclose(chan, spender_txid, is_local_ctx, sweep_info, was_added) self.maybe_add_pending_forceclose(
chan=chan, spender_txid=spender_txid, is_local_ctx=is_local_ctx, sweep_info=sweep_info, was_added=was_added)
return keep_watching return keep_watching
def get_pending_force_closes(self): def get_pending_force_closes(self):
@@ -242,7 +244,15 @@ class LNWatcher(Logger, EventListener):
txout = prev_tx.outputs()[int(prev_index)] txout = prev_tx.outputs()[int(prev_index)]
self.lnworker.wallet._accounting_addresses.add(txout.address) self.lnworker.wallet._accounting_addresses.add(txout.address)
def maybe_add_pending_forceclose(self, chan, spender_txid, is_local_ctx, sweep_info, was_added): def maybe_add_pending_forceclose(
self,
*,
chan: 'AbstractChannel',
spender_txid: Optional[str],
is_local_ctx: bool,
sweep_info: 'SweepInfo',
was_added: bool,
):
""" we are waiting for ctx to be confirmed and there are received htlcs """ """ we are waiting for ctx to be confirmed and there are received htlcs """
if was_added and is_local_ctx and sweep_info.name == 'received-htlc' and chan.has_anchors(): if was_added and is_local_ctx and sweep_info.name == 'received-htlc' and chan.has_anchors():
tx_mined_status = self.adb.get_tx_height(spender_txid) tx_mined_status = self.adb.get_tx_height(spender_txid)

View File

@@ -1211,6 +1211,7 @@ class LNWallet(LNWorker):
for chan in self.channel_backups.values(): for chan in self.channel_backups.values():
if chan.funding_outpoint.to_str() == txo: if chan.funding_outpoint.to_str() == txo:
return chan return chan
return None
async def handle_onchain_state(self, chan: Channel): async def handle_onchain_state(self, chan: Channel):
if self.network is None: if self.network is None:
@@ -1511,6 +1512,7 @@ class LNWallet(LNWorker):
return chan return chan
if chan.get_local_scid_alias() == short_channel_id: if chan.get_local_scid_alias() == short_channel_id:
return chan return chan
return None
def can_pay_invoice(self, invoice: Invoice) -> bool: def can_pay_invoice(self, invoice: Invoice) -> bool:
assert invoice.is_lightning() assert invoice.is_lightning()

View File

@@ -1,20 +1,3 @@
import asyncio
import threading
import copy
from typing import Dict, Sequence
from . import util
from .bitcoin import dust_threshold
from .logging import Logger
from .util import log_exceptions, NotEnoughFunds, BelowDustLimit, NoDynamicFeeEstimates, OldTaskGroup
from .transaction import PartialTransaction, PartialTxOutput, Transaction
from .address_synchronizer import TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE
from .lnsweep import SweepInfo
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .wallet import Abstract_Wallet
# This class batches outgoing payments and incoming utxo sweeps. # This class batches outgoing payments and incoming utxo sweeps.
# It ensures that we do not send a payment twice. # It ensures that we do not send a payment twice.
# #
@@ -74,9 +57,23 @@ if TYPE_CHECKING:
# In order to generalize that logic to payments, callers would need to pass a unique ID along with # In order to generalize that logic to payments, callers would need to pass a unique ID along with
# the payment output, so that we can prevent paying twice. # the payment output, so that we can prevent paying twice.
import asyncio
import threading
import copy
from typing import Dict, Sequence, Optional, TYPE_CHECKING, Mapping, Set, List, Tuple
from . import util
from .bitcoin import dust_threshold
from .logging import Logger
from .util import log_exceptions, NotEnoughFunds, BelowDustLimit, NoDynamicFeeEstimates, OldTaskGroup
from .transaction import PartialTransaction, PartialTxOutput, Transaction, TxOutpoint, PartialTxInput
from .address_synchronizer import TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE
from .lnsweep import SweepInfo
from .json_db import locked, StoredDict from .json_db import locked, StoredDict
from .fee_policy import FeePolicy from .fee_policy import FeePolicy
if TYPE_CHECKING:
from .wallet import Abstract_Wallet
class TxBatcher(Logger): class TxBatcher(Logger):
@@ -87,21 +84,21 @@ class TxBatcher(Logger):
Logger.__init__(self) Logger.__init__(self)
self.lock = threading.RLock() self.lock = threading.RLock()
self.storage = wallet.db.get_stored_item("tx_batches", {}) self.storage = wallet.db.get_stored_item("tx_batches", {})
self.tx_batches = {} self.tx_batches = {} # type: Dict[str, TxBatch]
self.wallet = wallet self.wallet = wallet
for key, item_storage in self.storage.items(): for key, item_storage in self.storage.items():
self.tx_batches[key] = TxBatch(self.wallet, item_storage) self.tx_batches[key] = TxBatch(self.wallet, item_storage)
self._legacy_htlcs = {} self._legacy_htlcs = {} # type: Dict[TxOutpoint, SweepInfo]
self.taskgroup = None self.taskgroup = None # type: Optional[OldTaskGroup]
self.password_future = None self.password_future = None # type: Optional[asyncio.Future[Optional[str]]]
@locked @locked
def add_payment_output(self, key: str, output: 'PartialTxOutput'): def add_payment_output(self, key: str, output: 'PartialTxOutput') -> None:
batch = self._maybe_create_new_batch(key, fee_policy_name=key) batch = self._maybe_create_new_batch(key, fee_policy_name=key)
batch.add_payment_output(output) batch.add_payment_output(output)
@locked @locked
def add_sweep_input(self, key: str, sweep_info: 'SweepInfo'): def add_sweep_input(self, key: str, sweep_info: 'SweepInfo') -> None:
if sweep_info.txin and sweep_info.txout: if sweep_info.txin and sweep_info.txout:
# detect legacy htlc using name and csv delay # detect legacy htlc using name and csv delay
if sweep_info.name in ['received-htlc', 'offered-htlc'] and sweep_info.csv_delay == 0: if sweep_info.name in ['received-htlc', 'offered-htlc'] and sweep_info.csv_delay == 0:
@@ -116,7 +113,7 @@ class TxBatcher(Logger):
batch = self._maybe_create_new_batch(key, fee_policy_name) batch = self._maybe_create_new_batch(key, fee_policy_name)
batch.add_sweep_input(sweep_info) batch.add_sweep_input(sweep_info)
def _maybe_create_new_batch(self, key, fee_policy_name: str): def _maybe_create_new_batch(self, key: str, fee_policy_name: str) -> 'TxBatch':
assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!" assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!"
if key not in self.storage: if key not in self.storage:
self.logger.info(f'creating new batch: {key}') self.logger.info(f'creating new batch: {key}')
@@ -127,7 +124,7 @@ class TxBatcher(Logger):
return self.tx_batches[key] return self.tx_batches[key]
@locked @locked
def delete_batch(self, key): def delete_batch(self, key: str) -> None:
self.logger.info(f'deleting TxBatch {key}') self.logger.info(f'deleting TxBatch {key}')
self.storage.pop(key) self.storage.pop(key)
self.tx_batches.pop(key) self.tx_batches.pop(key)
@@ -136,17 +133,19 @@ class TxBatcher(Logger):
for k, v in self.tx_batches.items(): for k, v in self.tx_batches.items():
if v._prevout == prevout: if v._prevout == prevout:
return v return v
return None
def find_batch_of_txid(self, txid) -> str: def find_batch_of_txid(self, txid: str) -> Optional[str]:
for k, v in self.tx_batches.items(): for k, v in self.tx_batches.items():
if v.is_mine(txid): if v.is_mine(txid):
return k return k
return None
def is_mine(self, txid): def is_mine(self, txid: str) -> bool:
# used to prevent GUI from interfering # used to prevent GUI from interfering
return bool(self.find_batch_of_txid(txid)) return bool(self.find_batch_of_txid(txid))
async def run_batch(self, key, batch): async def run_batch(self, key: str, batch: 'TxBatch') -> None:
await batch.run() await batch.run()
self.delete_batch(key) self.delete_batch(key)
@@ -158,13 +157,13 @@ class TxBatcher(Logger):
async with self.taskgroup as group: async with self.taskgroup as group:
await group.spawn(self.redeem_legacy_htlcs()) await group.spawn(self.redeem_legacy_htlcs())
async def redeem_legacy_htlcs(self): async def redeem_legacy_htlcs(self) -> None:
while True: while True:
await asyncio.sleep(self.SLEEP_INTERVAL) await asyncio.sleep(self.SLEEP_INTERVAL)
for sweep_info in self._legacy_htlcs.values(): for sweep_info in self._legacy_htlcs.values():
await self._maybe_redeem_legacy_htlcs(sweep_info) await self._maybe_redeem_legacy_htlcs(sweep_info)
async def _maybe_redeem_legacy_htlcs(self, sweep_info): async def _maybe_redeem_legacy_htlcs(self, sweep_info: 'SweepInfo') -> None:
assert sweep_info.csv_delay == 0 assert sweep_info.csv_delay == 0
local_height = self.wallet.network.get_local_height() local_height = self.wallet.network.get_local_height()
wanted_height = sweep_info.cltv_abs wanted_height = sweep_info.cltv_abs
@@ -184,7 +183,7 @@ class TxBatcher(Logger):
if await self.wallet.network.try_broadcasting(tx, sweep_info.name): if await self.wallet.network.try_broadcasting(tx, sweep_info.name):
self.wallet.adb.add_transaction(tx) self.wallet.adb.add_transaction(tx)
async def get_password(self, txid:str): async def get_password(self, txid: str) -> Optional[str]:
# daemon, android have password in memory # daemon, android have password in memory
password = self.wallet.get_unlocked_password() password = self.wallet.get_unlocked_password()
if password: if password:
@@ -194,12 +193,12 @@ class TxBatcher(Logger):
await future await future
except asyncio.CancelledError as e: except asyncio.CancelledError as e:
return return None
password = future.result() password = future.result()
return password return password
@locked @locked
def set_password_future(self, password: Optional[str]): def set_password_future(self, password: Optional[str]) -> None:
if self.password_future is not None: if self.password_future is not None:
if password is not None: if password is not None:
self.password_future.set_result(password) self.password_future.set_result(password)
@@ -226,25 +225,25 @@ class TxBatch(Logger):
self.wallet = wallet self.wallet = wallet
self.storage = storage self.storage = storage
self.lock = threading.RLock() self.lock = threading.RLock()
self.batch_payments = [] # list of payments we need to make self.batch_payments = [] # type: List[PartialTxOutput] # payments we need to make
self.batch_inputs = {} # list of inputs we need to sweep self.batch_inputs = {} # type: Dict[TxOutpoint, SweepInfo] # inputs we need to sweep
# list of tx that were broadcast. Each tx is a RBF replacement of the previous one. Ony one can get mined. # list of tx that were broadcast. Each tx is a RBF replacement of the previous one. Ony one can get mined.
self._prevout = storage.get('prevout') self._prevout = storage.get('prevout') # type: Optional[str]
self._batch_txids = storage['txids'] self._batch_txids = storage['txids'] # type: List[str]
self._fee_policy_name = storage.get('fee_policy_name', 'default') self._fee_policy_name = storage.get('fee_policy_name', 'default') # type: str
self._base_tx = None # current batch tx. last element of batch_txids self._base_tx = None # type: Optional[PartialTransaction] # current batch tx. last element of batch_txids
self._parent_tx = None self._parent_tx = None # type: Optional[PartialTransaction]
self._unconfirmed_sweeps = set() # list of inputs we are sweeping (until spending tx is confirmed) self._unconfirmed_sweeps = set() # type: Set[TxOutpoint] # inputs we are sweeping (until spending tx is confirmed)
@property @property
def fee_policy(self): def fee_policy(self) -> FeePolicy:
# this assumes the descriptor is in config.fee_policy # this assumes the descriptor is in config.fee_policy
cv_name = 'fee_policy' + '.' + self._fee_policy_name cv_name = 'fee_policy' + '.' + self._fee_policy_name
descriptor = self.wallet.config.get(cv_name, 'eta:2') descriptor = self.wallet.config.get(cv_name, 'eta:2')
return FeePolicy(descriptor) return FeePolicy(descriptor)
@log_exceptions @log_exceptions
async def run(self): async def run(self) -> None:
while not self.is_done(): while not self.is_done():
await asyncio.sleep(self.wallet.txbatcher.SLEEP_INTERVAL) await asyncio.sleep(self.wallet.txbatcher.SLEEP_INTERVAL)
if not (self.wallet.network and self.wallet.network.is_connected()): if not (self.wallet.network and self.wallet.network.is_connected()):
@@ -255,15 +254,15 @@ class TxBatch(Logger):
self.logger.exception(f'TxBatch error: {repr(e)}') self.logger.exception(f'TxBatch error: {repr(e)}')
break break
def is_mine(self, txid): def is_mine(self, txid: str) -> bool:
return txid in self._batch_txids return txid in self._batch_txids
@locked @locked
def add_payment_output(self, output: 'PartialTxOutput'): def add_payment_output(self, output: 'PartialTxOutput') -> None:
# todo: maybe we should raise NotEnoughFunds here # todo: maybe we should raise NotEnoughFunds here
self.batch_payments.append(output) self.batch_payments.append(output)
def is_dust(self, sweep_info): def is_dust(self, sweep_info: SweepInfo) -> bool:
if sweep_info.is_anchor(): if sweep_info.is_anchor():
return False return False
if sweep_info.txout is not None: if sweep_info.txout is not None:
@@ -276,7 +275,7 @@ class TxBatch(Logger):
return value - fee <= dust_threshold() return value - fee <= dust_threshold()
@locked @locked
def add_sweep_input(self, sweep_info: 'SweepInfo'): def add_sweep_input(self, sweep_info: 'SweepInfo') -> None:
if self.is_dust(sweep_info): if self.is_dust(sweep_info):
raise BelowDustLimit raise BelowDustLimit
txin = sweep_info.txin txin = sweep_info.txin
@@ -295,7 +294,7 @@ class TxBatch(Logger):
self.batch_inputs[txin.prevout] = sweep_info self.batch_inputs[txin.prevout] = sweep_info
@locked @locked
def _to_pay_after(self, tx) -> Sequence[PartialTxOutput]: def _to_pay_after(self, tx: Optional[PartialTransaction]) -> Sequence[PartialTxOutput]:
if not tx: if not tx:
return self.batch_payments return self.batch_payments
to_pay = [] to_pay = []
@@ -308,7 +307,7 @@ class TxBatch(Logger):
return to_pay return to_pay
@locked @locked
def _to_sweep_after(self, tx) -> Dict[str, SweepInfo]: def _to_sweep_after(self, tx: Optional[PartialTransaction]) -> Dict[str, SweepInfo]:
tx_prevouts = set(txin.prevout for txin in tx.inputs()) if tx else set() tx_prevouts = set(txin.prevout for txin in tx.inputs()) if tx else set()
result = [] result = []
for k, v in list(self.batch_inputs.items()): for k, v in list(self.batch_inputs.items()):
@@ -331,7 +330,7 @@ class TxBatch(Logger):
result.append((k,v)) result.append((k,v))
return dict(result) return dict(result)
def _should_bump_fee(self, base_tx) -> bool: def _should_bump_fee(self, base_tx: Optional[PartialTransaction]) -> bool:
if base_tx is None: if base_tx is None:
return False return False
if not self.is_mine(base_tx.txid()): if not self.is_mine(base_tx.txid()):
@@ -349,12 +348,12 @@ class TxBatch(Logger):
def find_base_tx(self) -> Optional[PartialTransaction]: def find_base_tx(self) -> Optional[PartialTransaction]:
if not self._prevout: if not self._prevout:
return return None
prev_txid, index = self._prevout.split(':') prev_txid, index = self._prevout.split(':')
txid = self.wallet.adb.db.get_spent_outpoint(prev_txid, int(index)) txid = self.wallet.adb.db.get_spent_outpoint(prev_txid, int(index))
tx = self.wallet.adb.get_transaction(txid) if txid else None tx = self.wallet.adb.get_transaction(txid) if txid else None
if not tx: if not tx:
return return None
tx = PartialTransaction.from_tx(tx) tx = PartialTransaction.from_tx(tx)
tx.add_info_from_wallet(self.wallet) # this sets is_change tx.add_info_from_wallet(self.wallet) # this sets is_change
@@ -380,7 +379,7 @@ class TxBatch(Logger):
return self._base_tx return self._base_tx
async def run_iteration(self): async def run_iteration(self) -> None:
base_tx = self.find_base_tx() base_tx = self.find_base_tx()
try: try:
tx = self.create_next_transaction(base_tx) tx = self.create_next_transaction(base_tx)
@@ -424,18 +423,18 @@ class TxBatch(Logger):
self._start_new_batch(base_tx) self._start_new_batch(base_tx)
async def sign_transaction(self, tx): async def sign_transaction(self, tx: PartialTransaction) -> Optional[PartialTransaction]:
tx.add_info_from_wallet(self.wallet) # this adds input amounts tx.add_info_from_wallet(self.wallet) # this adds input amounts
self.add_sweep_info_to_tx(tx) self.add_sweep_info_to_tx(tx)
pw_required = self.wallet.has_keystore_encryption() and tx.requires_keystore() pw_required = self.wallet.has_keystore_encryption() and tx.requires_keystore()
password = await self.wallet.txbatcher.get_password(tx.txid()) if pw_required else None password = await self.wallet.txbatcher.get_password(tx.txid()) if pw_required else None
if password is None and pw_required: if password is None and pw_required:
return return None
self.wallet.sign_transaction(tx, password) self.wallet.sign_transaction(tx, password)
assert tx.is_complete() assert tx.is_complete()
return tx return tx
def create_next_transaction(self, base_tx): def create_next_transaction(self, base_tx: Optional[PartialTransaction]) -> Optional[PartialTransaction]:
to_pay = self._to_pay_after(base_tx) to_pay = self._to_pay_after(base_tx)
to_sweep = self._to_sweep_after(base_tx) to_sweep = self._to_sweep_after(base_tx)
to_sweep_now = {} to_sweep_now = {}
@@ -447,9 +446,9 @@ class TxBatch(Logger):
self.wallet.add_future_tx(v, wanted_height) self.wallet.add_future_tx(v, wanted_height)
while True: while True:
if not to_pay and not to_sweep_now and not self._should_bump_fee(base_tx): if not to_pay and not to_sweep_now and not self._should_bump_fee(base_tx):
return return None
try: try:
tx = self._create_batch_tx(base_tx, to_sweep_now, to_pay) tx = self._create_batch_tx(base_tx=base_tx, to_sweep=to_sweep_now, to_pay=to_pay)
except NotEnoughFunds: except NotEnoughFunds:
if to_pay: if to_pay:
k = max(to_pay, key=lambda x: x.value) k = max(to_pay, key=lambda x: x.value)
@@ -458,7 +457,7 @@ class TxBatch(Logger):
continue continue
else: else:
self.logger.info(f'Not enough funds, waiting') self.logger.info(f'Not enough funds, waiting')
return return None
# 100 kb max standardness rule # 100 kb max standardness rule
if tx.estimated_size() < 100_000: if tx.estimated_size() < 100_000:
break break
@@ -468,7 +467,7 @@ class TxBatch(Logger):
self.logger.info(f'created tx {tx.txid()} with {len(tx.inputs())} inputs and {len(tx.outputs())} outputs') self.logger.info(f'created tx {tx.txid()} with {len(tx.inputs())} inputs and {len(tx.outputs())} outputs')
return tx return tx
def add_sweep_info_to_tx(self, base_tx): def add_sweep_info_to_tx(self, base_tx: PartialTransaction) -> None:
for txin in base_tx.inputs(): for txin in base_tx.inputs():
if sweep_info := self.batch_inputs.get(txin.prevout): if sweep_info := self.batch_inputs.get(txin.prevout):
if hasattr(sweep_info.txin, 'make_witness'): if hasattr(sweep_info.txin, 'make_witness'):
@@ -477,11 +476,17 @@ class TxBatch(Logger):
txin.witness_script = sweep_info.txin.witness_script txin.witness_script = sweep_info.txin.witness_script
txin.script_sig = sweep_info.txin.script_sig txin.script_sig = sweep_info.txin.script_sig
def _create_batch_tx(self, base_tx, to_sweep, to_pay): def _create_batch_tx(
self,
*,
base_tx: Optional[PartialTransaction],
to_sweep: Mapping[str, SweepInfo],
to_pay: Sequence[PartialTxOutput],
) -> PartialTransaction:
self.logger.info(f'to_sweep: {list(to_sweep.keys())}') self.logger.info(f'to_sweep: {list(to_sweep.keys())}')
self.logger.info(f'to_pay: {to_pay}') self.logger.info(f'to_pay: {to_pay}')
inputs = [] inputs = [] # type: List[PartialTxInput]
outputs = [] outputs = [] # type: List[PartialTxOutput]
locktime = base_tx.locktime if base_tx else None locktime = base_tx.locktime if base_tx else None
# sort inputs so that txin-txout pairs are first # sort inputs so that txin-txout pairs are first
for sweep_info in sorted(to_sweep.values(), key=lambda x: not bool(x.txout)): for sweep_info in sorted(to_sweep.values(), key=lambda x: not bool(x.txout)):
@@ -511,7 +516,7 @@ class TxBatch(Logger):
for o in outputs: assert o in tx.outputs() for o in outputs: assert o in tx.outputs()
return tx return tx
def _clear_unconfirmed_sweeps(self, tx): def _clear_unconfirmed_sweeps(self, tx: PartialTransaction) -> None:
# this ensures that we can accept an input again, # this ensures that we can accept an input again,
# in case the sweeping tx has been removed from the blockchain after a reorg # in case the sweeping tx has been removed from the blockchain after a reorg
for txin in tx.inputs(): for txin in tx.inputs():
@@ -519,7 +524,7 @@ class TxBatch(Logger):
self._unconfirmed_sweeps.remove(txin.prevout) self._unconfirmed_sweeps.remove(txin.prevout)
@locked @locked
def _start_new_batch(self, tx): def _start_new_batch(self, tx: Optional[PartialTransaction]) -> None:
use_change = tx and tx.has_change() and any([txout in self.batch_payments for txout in tx.outputs()]) use_change = tx and tx.has_change() and any([txout in self.batch_payments for txout in tx.outputs()])
self.batch_payments = self._to_pay_after(tx) self.batch_payments = self._to_pay_after(tx)
self.batch_inputs = self._to_sweep_after(tx) self.batch_inputs = self._to_sweep_after(tx)
@@ -529,7 +534,7 @@ class TxBatch(Logger):
self._prevout = None self._prevout = None
@locked @locked
def _new_base_tx(self, tx: Transaction): def _new_base_tx(self, tx: Transaction) -> None:
self._prevout = tx.inputs()[0].prevout.to_str() self._prevout = tx.inputs()[0].prevout.to_str()
self.storage['prevout'] = self._prevout self.storage['prevout'] = self._prevout
if tx.has_change(): if tx.has_change():
@@ -539,7 +544,7 @@ class TxBatch(Logger):
self.logger.info(f'starting new batch because current base tx does not have change') self.logger.info(f'starting new batch because current base tx does not have change')
self._start_new_batch(tx) self._start_new_batch(tx)
def _create_inputs_from_tx_change(self, parent_tx): def _create_inputs_from_tx_change(self, parent_tx: PartialTransaction) -> List[PartialTxInput]:
inputs = [] inputs = []
for o in parent_tx.get_change_outputs(): for o in parent_tx.get_change_outputs():
coins = self.wallet.adb.get_addr_utxo(o.address) coins = self.wallet.adb.get_addr_utxo(o.address)
@@ -548,7 +553,7 @@ class TxBatch(Logger):
txin.nsequence = 0xffffffff - 2 txin.nsequence = 0xffffffff - 2
return inputs return inputs
def _can_broadcast(self, sweep_info: 'SweepInfo', base_tx: 'Transaction'): def _can_broadcast(self, sweep_info: 'SweepInfo', base_tx: 'Transaction') -> Tuple[bool, Optional[int]]:
prevout = sweep_info.txin.prevout.to_str() prevout = sweep_info.txin.prevout.to_str()
name = sweep_info.name name = sweep_info.name
prev_txid, index = prevout.split(':') prev_txid, index = prevout.split(':')