1
0

Merge pull request #9988 from SomberNight/202506_txbatcher_cleanup

txbatcher: add type-hints
This commit is contained in:
ghost43
2025-06-26 14:31:17 +00:00
committed by GitHub
3 changed files with 92 additions and 75 deletions

View File

@@ -2,7 +2,7 @@
# Distributed under the MIT software license, see the accompanying
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from . import util
from .util import TxMinedInfo, BelowDustLimit
@@ -88,13 +88,13 @@ class LNWatcher(Logger, EventListener):
if chan.need_to_subscribe():
self.add_callback(address, callback)
def unwatch_channel(self, address, funding_outpoint):
def unwatch_channel(self, address: str, funding_outpoint: str) -> None:
self.logger.info(f'unwatching {funding_outpoint}')
self.remove_callback(address)
@ignore_exceptions
@log_exceptions
async def check_onchain_situation(self, address, funding_outpoint):
async def check_onchain_situation(self, address: str, funding_outpoint: str) -> None:
# early return if address has not been added yet
if not self.adb.is_mine(address):
return
@@ -142,17 +142,18 @@ class LNWatcher(Logger, EventListener):
self._pending_force_closes.discard(chan)
await self.lnworker.handle_onchain_state(chan)
async def sweep_commitment_transaction(self, funding_outpoint, closing_tx) -> bool:
async def sweep_commitment_transaction(self, funding_outpoint: str, closing_tx: Transaction) -> bool:
"""This function is called when a channel was closed. In this case
we need to check for redeemable outputs of the commitment transaction
or spenders down the line (HTLC-timeout/success transactions).
Returns whether we should continue to monitor.
Side-effécts:
Side-effects:
- sets defaults labels
- populates wallet._accounting_addresses
"""
assert closing_tx
chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan:
return False
@@ -188,7 +189,8 @@ class LNWatcher(Logger, EventListener):
self.maybe_add_accounting_address(spender_txid, sweep_info)
else:
keep_watching |= was_added
self.maybe_add_pending_forceclose(chan, spender_txid, is_local_ctx, sweep_info, was_added)
self.maybe_add_pending_forceclose(
chan=chan, spender_txid=spender_txid, is_local_ctx=is_local_ctx, sweep_info=sweep_info, was_added=was_added)
return keep_watching
def get_pending_force_closes(self):
@@ -242,7 +244,15 @@ class LNWatcher(Logger, EventListener):
txout = prev_tx.outputs()[int(prev_index)]
self.lnworker.wallet._accounting_addresses.add(txout.address)
def maybe_add_pending_forceclose(self, chan, spender_txid, is_local_ctx, sweep_info, was_added):
def maybe_add_pending_forceclose(
self,
*,
chan: 'AbstractChannel',
spender_txid: Optional[str],
is_local_ctx: bool,
sweep_info: 'SweepInfo',
was_added: bool,
):
""" we are waiting for ctx to be confirmed and there are received htlcs """
if was_added and is_local_ctx and sweep_info.name == 'received-htlc' and chan.has_anchors():
tx_mined_status = self.adb.get_tx_height(spender_txid)

View File

@@ -1211,6 +1211,7 @@ class LNWallet(LNWorker):
for chan in self.channel_backups.values():
if chan.funding_outpoint.to_str() == txo:
return chan
return None
async def handle_onchain_state(self, chan: Channel):
if self.network is None:
@@ -1511,6 +1512,7 @@ class LNWallet(LNWorker):
return chan
if chan.get_local_scid_alias() == short_channel_id:
return chan
return None
def can_pay_invoice(self, invoice: Invoice) -> bool:
assert invoice.is_lightning()

View File

@@ -1,20 +1,3 @@
import asyncio
import threading
import copy
from typing import Dict, Sequence
from . import util
from .bitcoin import dust_threshold
from .logging import Logger
from .util import log_exceptions, NotEnoughFunds, BelowDustLimit, NoDynamicFeeEstimates, OldTaskGroup
from .transaction import PartialTransaction, PartialTxOutput, Transaction
from .address_synchronizer import TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE
from .lnsweep import SweepInfo
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .wallet import Abstract_Wallet
# This class batches outgoing payments and incoming utxo sweeps.
# It ensures that we do not send a payment twice.
#
@@ -74,9 +57,23 @@ if TYPE_CHECKING:
# In order to generalize that logic to payments, callers would need to pass a unique ID along with
# the payment output, so that we can prevent paying twice.
import asyncio
import threading
import copy
from typing import Dict, Sequence, Optional, TYPE_CHECKING, Mapping, Set, List, Tuple
from . import util
from .bitcoin import dust_threshold
from .logging import Logger
from .util import log_exceptions, NotEnoughFunds, BelowDustLimit, NoDynamicFeeEstimates, OldTaskGroup
from .transaction import PartialTransaction, PartialTxOutput, Transaction, TxOutpoint, PartialTxInput
from .address_synchronizer import TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE
from .lnsweep import SweepInfo
from .json_db import locked, StoredDict
from .fee_policy import FeePolicy
if TYPE_CHECKING:
from .wallet import Abstract_Wallet
class TxBatcher(Logger):
@@ -87,21 +84,21 @@ class TxBatcher(Logger):
Logger.__init__(self)
self.lock = threading.RLock()
self.storage = wallet.db.get_stored_item("tx_batches", {})
self.tx_batches = {}
self.tx_batches = {} # type: Dict[str, TxBatch]
self.wallet = wallet
for key, item_storage in self.storage.items():
self.tx_batches[key] = TxBatch(self.wallet, item_storage)
self._legacy_htlcs = {}
self.taskgroup = None
self.password_future = None
self._legacy_htlcs = {} # type: Dict[TxOutpoint, SweepInfo]
self.taskgroup = None # type: Optional[OldTaskGroup]
self.password_future = None # type: Optional[asyncio.Future[Optional[str]]]
@locked
def add_payment_output(self, key: str, output: 'PartialTxOutput'):
def add_payment_output(self, key: str, output: 'PartialTxOutput') -> None:
batch = self._maybe_create_new_batch(key, fee_policy_name=key)
batch.add_payment_output(output)
@locked
def add_sweep_input(self, key: str, sweep_info: 'SweepInfo'):
def add_sweep_input(self, key: str, sweep_info: 'SweepInfo') -> None:
if sweep_info.txin and sweep_info.txout:
# detect legacy htlc using name and csv delay
if sweep_info.name in ['received-htlc', 'offered-htlc'] and sweep_info.csv_delay == 0:
@@ -116,7 +113,7 @@ class TxBatcher(Logger):
batch = self._maybe_create_new_batch(key, fee_policy_name)
batch.add_sweep_input(sweep_info)
def _maybe_create_new_batch(self, key, fee_policy_name: str):
def _maybe_create_new_batch(self, key: str, fee_policy_name: str) -> 'TxBatch':
assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!"
if key not in self.storage:
self.logger.info(f'creating new batch: {key}')
@@ -127,7 +124,7 @@ class TxBatcher(Logger):
return self.tx_batches[key]
@locked
def delete_batch(self, key):
def delete_batch(self, key: str) -> None:
self.logger.info(f'deleting TxBatch {key}')
self.storage.pop(key)
self.tx_batches.pop(key)
@@ -136,17 +133,19 @@ class TxBatcher(Logger):
for k, v in self.tx_batches.items():
if v._prevout == prevout:
return v
return None
def find_batch_of_txid(self, txid) -> str:
def find_batch_of_txid(self, txid: str) -> Optional[str]:
for k, v in self.tx_batches.items():
if v.is_mine(txid):
return k
return None
def is_mine(self, txid):
def is_mine(self, txid: str) -> bool:
# used to prevent GUI from interfering
return bool(self.find_batch_of_txid(txid))
async def run_batch(self, key, batch):
async def run_batch(self, key: str, batch: 'TxBatch') -> None:
await batch.run()
self.delete_batch(key)
@@ -158,13 +157,13 @@ class TxBatcher(Logger):
async with self.taskgroup as group:
await group.spawn(self.redeem_legacy_htlcs())
async def redeem_legacy_htlcs(self):
async def redeem_legacy_htlcs(self) -> None:
while True:
await asyncio.sleep(self.SLEEP_INTERVAL)
for sweep_info in self._legacy_htlcs.values():
await self._maybe_redeem_legacy_htlcs(sweep_info)
async def _maybe_redeem_legacy_htlcs(self, sweep_info):
async def _maybe_redeem_legacy_htlcs(self, sweep_info: 'SweepInfo') -> None:
assert sweep_info.csv_delay == 0
local_height = self.wallet.network.get_local_height()
wanted_height = sweep_info.cltv_abs
@@ -184,7 +183,7 @@ class TxBatcher(Logger):
if await self.wallet.network.try_broadcasting(tx, sweep_info.name):
self.wallet.adb.add_transaction(tx)
async def get_password(self, txid:str):
async def get_password(self, txid: str) -> Optional[str]:
# daemon, android have password in memory
password = self.wallet.get_unlocked_password()
if password:
@@ -194,12 +193,12 @@ class TxBatcher(Logger):
await future
except asyncio.CancelledError as e:
return
return None
password = future.result()
return password
@locked
def set_password_future(self, password: Optional[str]):
def set_password_future(self, password: Optional[str]) -> None:
if self.password_future is not None:
if password is not None:
self.password_future.set_result(password)
@@ -226,25 +225,25 @@ class TxBatch(Logger):
self.wallet = wallet
self.storage = storage
self.lock = threading.RLock()
self.batch_payments = [] # list of payments we need to make
self.batch_inputs = {} # list of inputs we need to sweep
self.batch_payments = [] # type: List[PartialTxOutput] # payments we need to make
self.batch_inputs = {} # type: Dict[TxOutpoint, SweepInfo] # inputs we need to sweep
# list of tx that were broadcast. Each tx is a RBF replacement of the previous one. Ony one can get mined.
self._prevout = storage.get('prevout')
self._batch_txids = storage['txids']
self._fee_policy_name = storage.get('fee_policy_name', 'default')
self._base_tx = None # current batch tx. last element of batch_txids
self._parent_tx = None
self._unconfirmed_sweeps = set() # list of inputs we are sweeping (until spending tx is confirmed)
self._prevout = storage.get('prevout') # type: Optional[str]
self._batch_txids = storage['txids'] # type: List[str]
self._fee_policy_name = storage.get('fee_policy_name', 'default') # type: str
self._base_tx = None # type: Optional[PartialTransaction] # current batch tx. last element of batch_txids
self._parent_tx = None # type: Optional[PartialTransaction]
self._unconfirmed_sweeps = set() # type: Set[TxOutpoint] # inputs we are sweeping (until spending tx is confirmed)
@property
def fee_policy(self):
def fee_policy(self) -> FeePolicy:
# this assumes the descriptor is in config.fee_policy
cv_name = 'fee_policy' + '.' + self._fee_policy_name
descriptor = self.wallet.config.get(cv_name, 'eta:2')
return FeePolicy(descriptor)
@log_exceptions
async def run(self):
async def run(self) -> None:
while not self.is_done():
await asyncio.sleep(self.wallet.txbatcher.SLEEP_INTERVAL)
if not (self.wallet.network and self.wallet.network.is_connected()):
@@ -255,15 +254,15 @@ class TxBatch(Logger):
self.logger.exception(f'TxBatch error: {repr(e)}')
break
def is_mine(self, txid):
def is_mine(self, txid: str) -> bool:
return txid in self._batch_txids
@locked
def add_payment_output(self, output: 'PartialTxOutput'):
def add_payment_output(self, output: 'PartialTxOutput') -> None:
# todo: maybe we should raise NotEnoughFunds here
self.batch_payments.append(output)
def is_dust(self, sweep_info):
def is_dust(self, sweep_info: SweepInfo) -> bool:
if sweep_info.is_anchor():
return False
if sweep_info.txout is not None:
@@ -276,7 +275,7 @@ class TxBatch(Logger):
return value - fee <= dust_threshold()
@locked
def add_sweep_input(self, sweep_info: 'SweepInfo'):
def add_sweep_input(self, sweep_info: 'SweepInfo') -> None:
if self.is_dust(sweep_info):
raise BelowDustLimit
txin = sweep_info.txin
@@ -295,7 +294,7 @@ class TxBatch(Logger):
self.batch_inputs[txin.prevout] = sweep_info
@locked
def _to_pay_after(self, tx) -> Sequence[PartialTxOutput]:
def _to_pay_after(self, tx: Optional[PartialTransaction]) -> Sequence[PartialTxOutput]:
if not tx:
return self.batch_payments
to_pay = []
@@ -308,7 +307,7 @@ class TxBatch(Logger):
return to_pay
@locked
def _to_sweep_after(self, tx) -> Dict[str, SweepInfo]:
def _to_sweep_after(self, tx: Optional[PartialTransaction]) -> Dict[str, SweepInfo]:
tx_prevouts = set(txin.prevout for txin in tx.inputs()) if tx else set()
result = []
for k, v in list(self.batch_inputs.items()):
@@ -331,7 +330,7 @@ class TxBatch(Logger):
result.append((k,v))
return dict(result)
def _should_bump_fee(self, base_tx) -> bool:
def _should_bump_fee(self, base_tx: Optional[PartialTransaction]) -> bool:
if base_tx is None:
return False
if not self.is_mine(base_tx.txid()):
@@ -349,12 +348,12 @@ class TxBatch(Logger):
def find_base_tx(self) -> Optional[PartialTransaction]:
if not self._prevout:
return
return None
prev_txid, index = self._prevout.split(':')
txid = self.wallet.adb.db.get_spent_outpoint(prev_txid, int(index))
tx = self.wallet.adb.get_transaction(txid) if txid else None
if not tx:
return
return None
tx = PartialTransaction.from_tx(tx)
tx.add_info_from_wallet(self.wallet) # this sets is_change
@@ -380,7 +379,7 @@ class TxBatch(Logger):
return self._base_tx
async def run_iteration(self):
async def run_iteration(self) -> None:
base_tx = self.find_base_tx()
try:
tx = self.create_next_transaction(base_tx)
@@ -424,18 +423,18 @@ class TxBatch(Logger):
self._start_new_batch(base_tx)
async def sign_transaction(self, tx):
async def sign_transaction(self, tx: PartialTransaction) -> Optional[PartialTransaction]:
tx.add_info_from_wallet(self.wallet) # this adds input amounts
self.add_sweep_info_to_tx(tx)
pw_required = self.wallet.has_keystore_encryption() and tx.requires_keystore()
password = await self.wallet.txbatcher.get_password(tx.txid()) if pw_required else None
if password is None and pw_required:
return
return None
self.wallet.sign_transaction(tx, password)
assert tx.is_complete()
return tx
def create_next_transaction(self, base_tx):
def create_next_transaction(self, base_tx: Optional[PartialTransaction]) -> Optional[PartialTransaction]:
to_pay = self._to_pay_after(base_tx)
to_sweep = self._to_sweep_after(base_tx)
to_sweep_now = {}
@@ -447,9 +446,9 @@ class TxBatch(Logger):
self.wallet.add_future_tx(v, wanted_height)
while True:
if not to_pay and not to_sweep_now and not self._should_bump_fee(base_tx):
return
return None
try:
tx = self._create_batch_tx(base_tx, to_sweep_now, to_pay)
tx = self._create_batch_tx(base_tx=base_tx, to_sweep=to_sweep_now, to_pay=to_pay)
except NotEnoughFunds:
if to_pay:
k = max(to_pay, key=lambda x: x.value)
@@ -458,7 +457,7 @@ class TxBatch(Logger):
continue
else:
self.logger.info(f'Not enough funds, waiting')
return
return None
# 100 kb max standardness rule
if tx.estimated_size() < 100_000:
break
@@ -468,7 +467,7 @@ class TxBatch(Logger):
self.logger.info(f'created tx {tx.txid()} with {len(tx.inputs())} inputs and {len(tx.outputs())} outputs')
return tx
def add_sweep_info_to_tx(self, base_tx):
def add_sweep_info_to_tx(self, base_tx: PartialTransaction) -> None:
for txin in base_tx.inputs():
if sweep_info := self.batch_inputs.get(txin.prevout):
if hasattr(sweep_info.txin, 'make_witness'):
@@ -477,11 +476,17 @@ class TxBatch(Logger):
txin.witness_script = sweep_info.txin.witness_script
txin.script_sig = sweep_info.txin.script_sig
def _create_batch_tx(self, base_tx, to_sweep, to_pay):
def _create_batch_tx(
self,
*,
base_tx: Optional[PartialTransaction],
to_sweep: Mapping[str, SweepInfo],
to_pay: Sequence[PartialTxOutput],
) -> PartialTransaction:
self.logger.info(f'to_sweep: {list(to_sweep.keys())}')
self.logger.info(f'to_pay: {to_pay}')
inputs = []
outputs = []
inputs = [] # type: List[PartialTxInput]
outputs = [] # type: List[PartialTxOutput]
locktime = base_tx.locktime if base_tx else None
# sort inputs so that txin-txout pairs are first
for sweep_info in sorted(to_sweep.values(), key=lambda x: not bool(x.txout)):
@@ -511,7 +516,7 @@ class TxBatch(Logger):
for o in outputs: assert o in tx.outputs()
return tx
def _clear_unconfirmed_sweeps(self, tx):
def _clear_unconfirmed_sweeps(self, tx: PartialTransaction) -> None:
# this ensures that we can accept an input again,
# in case the sweeping tx has been removed from the blockchain after a reorg
for txin in tx.inputs():
@@ -519,7 +524,7 @@ class TxBatch(Logger):
self._unconfirmed_sweeps.remove(txin.prevout)
@locked
def _start_new_batch(self, tx):
def _start_new_batch(self, tx: Optional[PartialTransaction]) -> None:
use_change = tx and tx.has_change() and any([txout in self.batch_payments for txout in tx.outputs()])
self.batch_payments = self._to_pay_after(tx)
self.batch_inputs = self._to_sweep_after(tx)
@@ -529,7 +534,7 @@ class TxBatch(Logger):
self._prevout = None
@locked
def _new_base_tx(self, tx: Transaction):
def _new_base_tx(self, tx: Transaction) -> None:
self._prevout = tx.inputs()[0].prevout.to_str()
self.storage['prevout'] = self._prevout
if tx.has_change():
@@ -539,7 +544,7 @@ class TxBatch(Logger):
self.logger.info(f'starting new batch because current base tx does not have change')
self._start_new_batch(tx)
def _create_inputs_from_tx_change(self, parent_tx):
def _create_inputs_from_tx_change(self, parent_tx: PartialTransaction) -> List[PartialTxInput]:
inputs = []
for o in parent_tx.get_change_outputs():
coins = self.wallet.adb.get_addr_utxo(o.address)
@@ -548,7 +553,7 @@ class TxBatch(Logger):
txin.nsequence = 0xffffffff - 2
return inputs
def _can_broadcast(self, sweep_info: 'SweepInfo', base_tx: 'Transaction'):
def _can_broadcast(self, sweep_info: 'SweepInfo', base_tx: 'Transaction') -> Tuple[bool, Optional[int]]:
prevout = sweep_info.txin.prevout.to_str()
name = sweep_info.name
prev_txid, index = prevout.split(':')