diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index 46f3a438a..8b55b27e8 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -2,7 +2,7 @@ # Distributed under the MIT software license, see the accompanying # file LICENCE or http://www.opensource.org/licenses/mit-license.php -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from . import util from .util import TxMinedInfo, BelowDustLimit @@ -88,13 +88,13 @@ class LNWatcher(Logger, EventListener): if chan.need_to_subscribe(): self.add_callback(address, callback) - def unwatch_channel(self, address, funding_outpoint): + def unwatch_channel(self, address: str, funding_outpoint: str) -> None: self.logger.info(f'unwatching {funding_outpoint}') self.remove_callback(address) @ignore_exceptions @log_exceptions - async def check_onchain_situation(self, address, funding_outpoint): + async def check_onchain_situation(self, address: str, funding_outpoint: str) -> None: # early return if address has not been added yet if not self.adb.is_mine(address): return @@ -142,17 +142,18 @@ class LNWatcher(Logger, EventListener): self._pending_force_closes.discard(chan) await self.lnworker.handle_onchain_state(chan) - async def sweep_commitment_transaction(self, funding_outpoint, closing_tx) -> bool: + async def sweep_commitment_transaction(self, funding_outpoint: str, closing_tx: Transaction) -> bool: """This function is called when a channel was closed. In this case we need to check for redeemable outputs of the commitment transaction or spenders down the line (HTLC-timeout/success transactions). Returns whether we should continue to monitor. - Side-effécts: + Side-effects: - sets defaults labels - populates wallet._accounting_addresses """ + assert closing_tx chan = self.lnworker.channel_by_txo(funding_outpoint) if not chan: return False @@ -188,7 +189,8 @@ class LNWatcher(Logger, EventListener): self.maybe_add_accounting_address(spender_txid, sweep_info) else: keep_watching |= was_added - self.maybe_add_pending_forceclose(chan, spender_txid, is_local_ctx, sweep_info, was_added) + self.maybe_add_pending_forceclose( + chan=chan, spender_txid=spender_txid, is_local_ctx=is_local_ctx, sweep_info=sweep_info, was_added=was_added) return keep_watching def get_pending_force_closes(self): @@ -242,7 +244,15 @@ class LNWatcher(Logger, EventListener): txout = prev_tx.outputs()[int(prev_index)] self.lnworker.wallet._accounting_addresses.add(txout.address) - def maybe_add_pending_forceclose(self, chan, spender_txid, is_local_ctx, sweep_info, was_added): + def maybe_add_pending_forceclose( + self, + *, + chan: 'AbstractChannel', + spender_txid: Optional[str], + is_local_ctx: bool, + sweep_info: 'SweepInfo', + was_added: bool, + ): """ we are waiting for ctx to be confirmed and there are received htlcs """ if was_added and is_local_ctx and sweep_info.name == 'received-htlc' and chan.has_anchors(): tx_mined_status = self.adb.get_tx_height(spender_txid) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 4c78719b1..1333a4914 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1211,6 +1211,7 @@ class LNWallet(LNWorker): for chan in self.channel_backups.values(): if chan.funding_outpoint.to_str() == txo: return chan + return None async def handle_onchain_state(self, chan: Channel): if self.network is None: @@ -1511,6 +1512,7 @@ class LNWallet(LNWorker): return chan if chan.get_local_scid_alias() == short_channel_id: return chan + return None def can_pay_invoice(self, invoice: Invoice) -> bool: assert invoice.is_lightning() diff --git a/electrum/txbatcher.py b/electrum/txbatcher.py index 96dd8367e..2cb3b0e86 100644 --- a/electrum/txbatcher.py +++ b/electrum/txbatcher.py @@ -1,20 +1,3 @@ -import asyncio -import threading -import copy - -from typing import Dict, Sequence -from . import util -from .bitcoin import dust_threshold -from .logging import Logger -from .util import log_exceptions, NotEnoughFunds, BelowDustLimit, NoDynamicFeeEstimates, OldTaskGroup -from .transaction import PartialTransaction, PartialTxOutput, Transaction -from .address_synchronizer import TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE -from .lnsweep import SweepInfo -from typing import Optional, TYPE_CHECKING - -if TYPE_CHECKING: - from .wallet import Abstract_Wallet - # This class batches outgoing payments and incoming utxo sweeps. # It ensures that we do not send a payment twice. # @@ -74,9 +57,23 @@ if TYPE_CHECKING: # In order to generalize that logic to payments, callers would need to pass a unique ID along with # the payment output, so that we can prevent paying twice. +import asyncio +import threading +import copy +from typing import Dict, Sequence, Optional, TYPE_CHECKING, Mapping, Set, List, Tuple + +from . import util +from .bitcoin import dust_threshold +from .logging import Logger +from .util import log_exceptions, NotEnoughFunds, BelowDustLimit, NoDynamicFeeEstimates, OldTaskGroup +from .transaction import PartialTransaction, PartialTxOutput, Transaction, TxOutpoint, PartialTxInput +from .address_synchronizer import TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE +from .lnsweep import SweepInfo from .json_db import locked, StoredDict from .fee_policy import FeePolicy +if TYPE_CHECKING: + from .wallet import Abstract_Wallet class TxBatcher(Logger): @@ -87,21 +84,21 @@ class TxBatcher(Logger): Logger.__init__(self) self.lock = threading.RLock() self.storage = wallet.db.get_stored_item("tx_batches", {}) - self.tx_batches = {} + self.tx_batches = {} # type: Dict[str, TxBatch] self.wallet = wallet for key, item_storage in self.storage.items(): self.tx_batches[key] = TxBatch(self.wallet, item_storage) - self._legacy_htlcs = {} - self.taskgroup = None - self.password_future = None + self._legacy_htlcs = {} # type: Dict[TxOutpoint, SweepInfo] + self.taskgroup = None # type: Optional[OldTaskGroup] + self.password_future = None # type: Optional[asyncio.Future[Optional[str]]] @locked - def add_payment_output(self, key: str, output: 'PartialTxOutput'): + def add_payment_output(self, key: str, output: 'PartialTxOutput') -> None: batch = self._maybe_create_new_batch(key, fee_policy_name=key) batch.add_payment_output(output) @locked - def add_sweep_input(self, key: str, sweep_info: 'SweepInfo'): + def add_sweep_input(self, key: str, sweep_info: 'SweepInfo') -> None: if sweep_info.txin and sweep_info.txout: # detect legacy htlc using name and csv delay if sweep_info.name in ['received-htlc', 'offered-htlc'] and sweep_info.csv_delay == 0: @@ -116,7 +113,7 @@ class TxBatcher(Logger): batch = self._maybe_create_new_batch(key, fee_policy_name) batch.add_sweep_input(sweep_info) - def _maybe_create_new_batch(self, key, fee_policy_name: str): + def _maybe_create_new_batch(self, key: str, fee_policy_name: str) -> 'TxBatch': assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!" if key not in self.storage: self.logger.info(f'creating new batch: {key}') @@ -127,7 +124,7 @@ class TxBatcher(Logger): return self.tx_batches[key] @locked - def delete_batch(self, key): + def delete_batch(self, key: str) -> None: self.logger.info(f'deleting TxBatch {key}') self.storage.pop(key) self.tx_batches.pop(key) @@ -136,17 +133,19 @@ class TxBatcher(Logger): for k, v in self.tx_batches.items(): if v._prevout == prevout: return v + return None - def find_batch_of_txid(self, txid) -> str: + def find_batch_of_txid(self, txid: str) -> Optional[str]: for k, v in self.tx_batches.items(): if v.is_mine(txid): return k + return None - def is_mine(self, txid): + def is_mine(self, txid: str) -> bool: # used to prevent GUI from interfering return bool(self.find_batch_of_txid(txid)) - async def run_batch(self, key, batch): + async def run_batch(self, key: str, batch: 'TxBatch') -> None: await batch.run() self.delete_batch(key) @@ -158,13 +157,13 @@ class TxBatcher(Logger): async with self.taskgroup as group: await group.spawn(self.redeem_legacy_htlcs()) - async def redeem_legacy_htlcs(self): + async def redeem_legacy_htlcs(self) -> None: while True: await asyncio.sleep(self.SLEEP_INTERVAL) for sweep_info in self._legacy_htlcs.values(): await self._maybe_redeem_legacy_htlcs(sweep_info) - async def _maybe_redeem_legacy_htlcs(self, sweep_info): + async def _maybe_redeem_legacy_htlcs(self, sweep_info: 'SweepInfo') -> None: assert sweep_info.csv_delay == 0 local_height = self.wallet.network.get_local_height() wanted_height = sweep_info.cltv_abs @@ -184,7 +183,7 @@ class TxBatcher(Logger): if await self.wallet.network.try_broadcasting(tx, sweep_info.name): self.wallet.adb.add_transaction(tx) - async def get_password(self, txid:str): + async def get_password(self, txid: str) -> Optional[str]: # daemon, android have password in memory password = self.wallet.get_unlocked_password() if password: @@ -194,12 +193,12 @@ class TxBatcher(Logger): await future except asyncio.CancelledError as e: - return + return None password = future.result() return password @locked - def set_password_future(self, password: Optional[str]): + def set_password_future(self, password: Optional[str]) -> None: if self.password_future is not None: if password is not None: self.password_future.set_result(password) @@ -226,25 +225,25 @@ class TxBatch(Logger): self.wallet = wallet self.storage = storage self.lock = threading.RLock() - self.batch_payments = [] # list of payments we need to make - self.batch_inputs = {} # list of inputs we need to sweep + self.batch_payments = [] # type: List[PartialTxOutput] # payments we need to make + self.batch_inputs = {} # type: Dict[TxOutpoint, SweepInfo] # inputs we need to sweep # list of tx that were broadcast. Each tx is a RBF replacement of the previous one. Ony one can get mined. - self._prevout = storage.get('prevout') - self._batch_txids = storage['txids'] - self._fee_policy_name = storage.get('fee_policy_name', 'default') - self._base_tx = None # current batch tx. last element of batch_txids - self._parent_tx = None - self._unconfirmed_sweeps = set() # list of inputs we are sweeping (until spending tx is confirmed) + self._prevout = storage.get('prevout') # type: Optional[str] + self._batch_txids = storage['txids'] # type: List[str] + self._fee_policy_name = storage.get('fee_policy_name', 'default') # type: str + self._base_tx = None # type: Optional[PartialTransaction] # current batch tx. last element of batch_txids + self._parent_tx = None # type: Optional[PartialTransaction] + self._unconfirmed_sweeps = set() # type: Set[TxOutpoint] # inputs we are sweeping (until spending tx is confirmed) @property - def fee_policy(self): + def fee_policy(self) -> FeePolicy: # this assumes the descriptor is in config.fee_policy cv_name = 'fee_policy' + '.' + self._fee_policy_name descriptor = self.wallet.config.get(cv_name, 'eta:2') return FeePolicy(descriptor) @log_exceptions - async def run(self): + async def run(self) -> None: while not self.is_done(): await asyncio.sleep(self.wallet.txbatcher.SLEEP_INTERVAL) if not (self.wallet.network and self.wallet.network.is_connected()): @@ -255,15 +254,15 @@ class TxBatch(Logger): self.logger.exception(f'TxBatch error: {repr(e)}') break - def is_mine(self, txid): + def is_mine(self, txid: str) -> bool: return txid in self._batch_txids @locked - def add_payment_output(self, output: 'PartialTxOutput'): + def add_payment_output(self, output: 'PartialTxOutput') -> None: # todo: maybe we should raise NotEnoughFunds here self.batch_payments.append(output) - def is_dust(self, sweep_info): + def is_dust(self, sweep_info: SweepInfo) -> bool: if sweep_info.is_anchor(): return False if sweep_info.txout is not None: @@ -276,7 +275,7 @@ class TxBatch(Logger): return value - fee <= dust_threshold() @locked - def add_sweep_input(self, sweep_info: 'SweepInfo'): + def add_sweep_input(self, sweep_info: 'SweepInfo') -> None: if self.is_dust(sweep_info): raise BelowDustLimit txin = sweep_info.txin @@ -295,7 +294,7 @@ class TxBatch(Logger): self.batch_inputs[txin.prevout] = sweep_info @locked - def _to_pay_after(self, tx) -> Sequence[PartialTxOutput]: + def _to_pay_after(self, tx: Optional[PartialTransaction]) -> Sequence[PartialTxOutput]: if not tx: return self.batch_payments to_pay = [] @@ -308,7 +307,7 @@ class TxBatch(Logger): return to_pay @locked - def _to_sweep_after(self, tx) -> Dict[str, SweepInfo]: + def _to_sweep_after(self, tx: Optional[PartialTransaction]) -> Dict[str, SweepInfo]: tx_prevouts = set(txin.prevout for txin in tx.inputs()) if tx else set() result = [] for k, v in list(self.batch_inputs.items()): @@ -331,7 +330,7 @@ class TxBatch(Logger): result.append((k,v)) return dict(result) - def _should_bump_fee(self, base_tx) -> bool: + def _should_bump_fee(self, base_tx: Optional[PartialTransaction]) -> bool: if base_tx is None: return False if not self.is_mine(base_tx.txid()): @@ -349,12 +348,12 @@ class TxBatch(Logger): def find_base_tx(self) -> Optional[PartialTransaction]: if not self._prevout: - return + return None prev_txid, index = self._prevout.split(':') txid = self.wallet.adb.db.get_spent_outpoint(prev_txid, int(index)) tx = self.wallet.adb.get_transaction(txid) if txid else None if not tx: - return + return None tx = PartialTransaction.from_tx(tx) tx.add_info_from_wallet(self.wallet) # this sets is_change @@ -380,7 +379,7 @@ class TxBatch(Logger): return self._base_tx - async def run_iteration(self): + async def run_iteration(self) -> None: base_tx = self.find_base_tx() try: tx = self.create_next_transaction(base_tx) @@ -424,18 +423,18 @@ class TxBatch(Logger): self._start_new_batch(base_tx) - async def sign_transaction(self, tx): + async def sign_transaction(self, tx: PartialTransaction) -> Optional[PartialTransaction]: tx.add_info_from_wallet(self.wallet) # this adds input amounts self.add_sweep_info_to_tx(tx) pw_required = self.wallet.has_keystore_encryption() and tx.requires_keystore() password = await self.wallet.txbatcher.get_password(tx.txid()) if pw_required else None if password is None and pw_required: - return + return None self.wallet.sign_transaction(tx, password) assert tx.is_complete() return tx - def create_next_transaction(self, base_tx): + def create_next_transaction(self, base_tx: Optional[PartialTransaction]) -> Optional[PartialTransaction]: to_pay = self._to_pay_after(base_tx) to_sweep = self._to_sweep_after(base_tx) to_sweep_now = {} @@ -447,9 +446,9 @@ class TxBatch(Logger): self.wallet.add_future_tx(v, wanted_height) while True: if not to_pay and not to_sweep_now and not self._should_bump_fee(base_tx): - return + return None try: - tx = self._create_batch_tx(base_tx, to_sweep_now, to_pay) + tx = self._create_batch_tx(base_tx=base_tx, to_sweep=to_sweep_now, to_pay=to_pay) except NotEnoughFunds: if to_pay: k = max(to_pay, key=lambda x: x.value) @@ -458,7 +457,7 @@ class TxBatch(Logger): continue else: self.logger.info(f'Not enough funds, waiting') - return + return None # 100 kb max standardness rule if tx.estimated_size() < 100_000: break @@ -468,7 +467,7 @@ class TxBatch(Logger): self.logger.info(f'created tx {tx.txid()} with {len(tx.inputs())} inputs and {len(tx.outputs())} outputs') return tx - def add_sweep_info_to_tx(self, base_tx): + def add_sweep_info_to_tx(self, base_tx: PartialTransaction) -> None: for txin in base_tx.inputs(): if sweep_info := self.batch_inputs.get(txin.prevout): if hasattr(sweep_info.txin, 'make_witness'): @@ -477,11 +476,17 @@ class TxBatch(Logger): txin.witness_script = sweep_info.txin.witness_script txin.script_sig = sweep_info.txin.script_sig - def _create_batch_tx(self, base_tx, to_sweep, to_pay): + def _create_batch_tx( + self, + *, + base_tx: Optional[PartialTransaction], + to_sweep: Mapping[str, SweepInfo], + to_pay: Sequence[PartialTxOutput], + ) -> PartialTransaction: self.logger.info(f'to_sweep: {list(to_sweep.keys())}') self.logger.info(f'to_pay: {to_pay}') - inputs = [] - outputs = [] + inputs = [] # type: List[PartialTxInput] + outputs = [] # type: List[PartialTxOutput] locktime = base_tx.locktime if base_tx else None # sort inputs so that txin-txout pairs are first for sweep_info in sorted(to_sweep.values(), key=lambda x: not bool(x.txout)): @@ -511,7 +516,7 @@ class TxBatch(Logger): for o in outputs: assert o in tx.outputs() return tx - def _clear_unconfirmed_sweeps(self, tx): + def _clear_unconfirmed_sweeps(self, tx: PartialTransaction) -> None: # this ensures that we can accept an input again, # in case the sweeping tx has been removed from the blockchain after a reorg for txin in tx.inputs(): @@ -519,7 +524,7 @@ class TxBatch(Logger): self._unconfirmed_sweeps.remove(txin.prevout) @locked - def _start_new_batch(self, tx): + def _start_new_batch(self, tx: Optional[PartialTransaction]) -> None: use_change = tx and tx.has_change() and any([txout in self.batch_payments for txout in tx.outputs()]) self.batch_payments = self._to_pay_after(tx) self.batch_inputs = self._to_sweep_after(tx) @@ -529,7 +534,7 @@ class TxBatch(Logger): self._prevout = None @locked - def _new_base_tx(self, tx: Transaction): + def _new_base_tx(self, tx: Transaction) -> None: self._prevout = tx.inputs()[0].prevout.to_str() self.storage['prevout'] = self._prevout if tx.has_change(): @@ -539,7 +544,7 @@ class TxBatch(Logger): self.logger.info(f'starting new batch because current base tx does not have change') self._start_new_batch(tx) - def _create_inputs_from_tx_change(self, parent_tx): + def _create_inputs_from_tx_change(self, parent_tx: PartialTransaction) -> List[PartialTxInput]: inputs = [] for o in parent_tx.get_change_outputs(): coins = self.wallet.adb.get_addr_utxo(o.address) @@ -548,7 +553,7 @@ class TxBatch(Logger): txin.nsequence = 0xffffffff - 2 return inputs - def _can_broadcast(self, sweep_info: 'SweepInfo', base_tx: 'Transaction'): + def _can_broadcast(self, sweep_info: 'SweepInfo', base_tx: 'Transaction') -> Tuple[bool, Optional[int]]: prevout = sweep_info.txin.prevout.to_str() name = sweep_info.name prev_txid, index = prevout.split(':')