1
0

Merge pull request #10147 from SomberNight/202508_typehint_callables

type-hint some Callables
This commit is contained in:
ghost43
2025-08-18 16:12:36 +00:00
committed by GitHub
5 changed files with 47 additions and 23 deletions

View File

@@ -90,7 +90,7 @@ class ScoredCandidate(NamedTuple):
buckets: List[Bucket] buckets: List[Bucket]
def strip_unneeded(bkts: List[Bucket], sufficient_funds) -> List[Bucket]: def strip_unneeded(bkts: List[Bucket], sufficient_funds: Callable) -> List[Bucket]:
'''Remove buckets that are unnecessary in achieving the spend amount''' '''Remove buckets that are unnecessary in achieving the spend amount'''
if sufficient_funds([], bucket_value_sum=0): if sufficient_funds([], bucket_value_sum=0):
# none of the buckets are needed # none of the buckets are needed
@@ -113,7 +113,12 @@ class CoinChooserBase(Logger):
def keys(self, coins: Sequence[PartialTxInput]) -> Sequence[str]: def keys(self, coins: Sequence[PartialTxInput]) -> Sequence[str]:
raise NotImplementedError raise NotImplementedError
def bucketize_coins(self, coins: Sequence[PartialTxInput], *, fee_estimator_vb): def bucketize_coins(
self,
coins: Sequence[PartialTxInput],
*,
fee_estimator_vb: Callable[[int | float | Decimal], int],
):
keys = self.keys(coins) keys = self.keys(coins)
buckets = defaultdict(list) # type: Dict[str, List[PartialTxInput]] buckets = defaultdict(list) # type: Dict[str, List[PartialTxInput]]
for key, coin in zip(keys, coins): for key, coin in zip(keys, coins):
@@ -151,9 +156,12 @@ class CoinChooserBase(Logger):
return list(map(make_Bucket, buckets.keys(), buckets.values())) return list(map(make_Bucket, buckets.keys(), buckets.values()))
def penalty_func(self, base_tx, *, def penalty_func(
tx_from_buckets: Callable[[List[Bucket]], Tuple[PartialTransaction, List[PartialTxOutput]]]) \ self,
-> Callable[[List[Bucket]], ScoredCandidate]: base_tx: Transaction,
*,
tx_from_buckets: Callable[[List[Bucket]], Tuple[PartialTransaction, List[PartialTxOutput]]],
) -> Callable[[List[Bucket]], ScoredCandidate]:
raise NotImplementedError raise NotImplementedError
def _change_amounts(self, tx: PartialTransaction, count: int, fee_estimator_numchange) -> List[int]: def _change_amounts(self, tx: PartialTransaction, count: int, fee_estimator_numchange) -> List[int]:
@@ -282,7 +290,7 @@ class CoinChooserBase(Logger):
inputs: List[PartialTxInput], inputs: List[PartialTxInput],
outputs: List[PartialTxOutput], outputs: List[PartialTxOutput],
change_addrs: Sequence[str], change_addrs: Sequence[str],
fee_estimator_vb: Callable, fee_estimator_vb: Callable[[int | float | Decimal], int],
dust_threshold: int, dust_threshold: int,
BIP69_sort: bool = True, BIP69_sort: bool = True,
) -> PartialTransaction: ) -> PartialTransaction:
@@ -322,7 +330,7 @@ class CoinChooserBase(Logger):
def fee_estimator_w(weight): def fee_estimator_w(weight):
return fee_estimator_vb(Transaction.virtual_size_from_weight(weight)) return fee_estimator_vb(Transaction.virtual_size_from_weight(weight))
def sufficient_funds(buckets, *, bucket_value_sum): def sufficient_funds(buckets: List[Bucket], *, bucket_value_sum: int) -> bool:
'''Given a list of buckets, return True if it has enough '''Given a list of buckets, return True if it has enough
value to pay for the transaction''' value to pay for the transaction'''
# assert bucket_value_sum == sum(bucket.value for bucket in buckets) # expensive! # assert bucket_value_sum == sum(bucket.value for bucket in buckets) # expensive!
@@ -373,7 +381,11 @@ class CoinChooserBase(Logger):
class CoinChooserRandom(CoinChooserBase): class CoinChooserRandom(CoinChooserBase):
def bucket_candidates_any(self, buckets: List[Bucket], sufficient_funds) -> List[List[Bucket]]: def bucket_candidates_any(
self,
buckets: List[Bucket],
sufficient_funds: Callable,
) -> List[List[Bucket]]:
'''Returns a list of bucket sets.''' '''Returns a list of bucket sets.'''
if not buckets: if not buckets:
if sufficient_funds([], bucket_value_sum=0): if sufficient_funds([], bucket_value_sum=0):
@@ -411,8 +423,11 @@ class CoinChooserRandom(CoinChooserBase):
candidates = [[buckets[n] for n in c] for c in candidates] candidates = [[buckets[n] for n in c] for c in candidates]
return [strip_unneeded(c, sufficient_funds) for c in candidates] return [strip_unneeded(c, sufficient_funds) for c in candidates]
def bucket_candidates_prefer_confirmed(self, buckets: List[Bucket], def bucket_candidates_prefer_confirmed(
sufficient_funds) -> List[List[Bucket]]: self,
buckets: List[Bucket],
sufficient_funds: Callable,
) -> List[List[Bucket]]:
"""Returns a list of bucket sets preferring confirmed coins. """Returns a list of bucket sets preferring confirmed coins.
Any bucket can be: Any bucket can be:
@@ -435,7 +450,7 @@ class CoinChooserRandom(CoinChooserBase):
for bkts_choose_from in bucket_sets: for bkts_choose_from in bucket_sets:
try: try:
def sfunds( def sfunds(
bkts, *, bucket_value_sum, bkts: List[Bucket], *, bucket_value_sum: int,
already_selected_buckets_value_sum=already_selected_buckets_value_sum, already_selected_buckets_value_sum=already_selected_buckets_value_sum,
already_selected_buckets=already_selected_buckets, already_selected_buckets=already_selected_buckets,
): ):

View File

@@ -2352,8 +2352,8 @@ class Peer(Logger, EventListener):
chan: Channel, chan: Channel,
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket, processed_onion: ProcessedOnionPacket,
log_fail_reason: Callable, log_fail_reason: Callable[[str], None],
): ) -> tuple[bytes, int, int, OnionRoutingFailure]:
""" """
Perform checks that are invariant (results do not depend on height, network conditions, etc). Perform checks that are invariant (results do not depend on height, network conditions, etc).
May raise OnionRoutingFailure May raise OnionRoutingFailure
@@ -2379,13 +2379,13 @@ class Peer(Logger, EventListener):
code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY, code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY,
data=htlc.cltv_abs.to_bytes(4, byteorder="big")) data=htlc.cltv_abs.to_bytes(4, byteorder="big"))
try: try:
total_msat = processed_onion.hop_data.payload["payment_data"]["total_msat"] total_msat = processed_onion.hop_data.payload["payment_data"]["total_msat"] # type: int
except Exception: except Exception:
log_fail_reason(f"'total_msat' missing from onion") log_fail_reason(f"'total_msat' missing from onion")
raise exc_incorrect_or_unknown_pd raise exc_incorrect_or_unknown_pd
if chan.opening_fee: if chan.opening_fee:
channel_opening_fee = chan.opening_fee['channel_opening_fee'] channel_opening_fee = chan.opening_fee['channel_opening_fee'] # type: int
total_msat -= channel_opening_fee total_msat -= channel_opening_fee
amt_to_forward -= channel_opening_fee amt_to_forward -= channel_opening_fee
else: else:
@@ -2405,7 +2405,16 @@ class Peer(Logger, EventListener):
return payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd return payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd
def check_mpp_is_waiting(self, *, payment_secret, short_channel_id, htlc, expected_msat, exc_incorrect_or_unknown_pd, log_fail_reason) -> bool: def check_mpp_is_waiting(
self,
*,
payment_secret: bytes,
short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc,
expected_msat: int,
exc_incorrect_or_unknown_pd: OnionRoutingFailure,
log_fail_reason: Callable[[str], None],
) -> bool:
from .lnworker import RecvMPPResolution from .lnworker import RecvMPPResolution
mpp_resolution = self.lnworker.check_mpp_status( mpp_resolution = self.lnworker.check_mpp_status(
payment_secret=payment_secret, payment_secret=payment_secret,

View File

@@ -275,8 +275,8 @@ class CosignerWallet(Logger):
tx: Union['Transaction', 'PartialTransaction'], tx: Union['Transaction', 'PartialTransaction'],
*, *,
label: str = None, label: str = None,
on_failure: Callable = None, on_failure: Callable[[str], None] = None,
on_success: Callable = None on_success: Callable[[], None] = None
) -> None: ) -> None:
try: try:
# TODO: adding tx should be handled more gracefully here: # TODO: adding tx should be handled more gracefully here:

View File

@@ -29,7 +29,7 @@ import struct
import io import io
import base64 import base64
from typing import ( from typing import (
Sequence, Union, NamedTuple, Tuple, Optional, Iterable, Callable, List, Dict, Set, TYPE_CHECKING, Mapping Sequence, Union, NamedTuple, Tuple, Optional, Iterable, Callable, List, Dict, Set, TYPE_CHECKING, Mapping, Any
) )
from collections import defaultdict from collections import defaultdict
from enum import IntEnum from enum import IntEnum
@@ -703,7 +703,7 @@ def script_GetOp(_bytes : bytes):
class OPPushDataGeneric: class OPPushDataGeneric:
def __init__(self, pushlen: Callable=None): def __init__(self, pushlen: Callable[[int], bool] | None = None):
if pushlen is not None: if pushlen is not None:
self.check_data_len = pushlen self.check_data_len = pushlen
@@ -721,7 +721,7 @@ class OPPushDataGeneric:
class OPGeneric: class OPGeneric:
def __init__(self, matcher: Callable = None): def __init__(self, matcher: Callable[[Any], bool] | None = None):
if matcher is not None: if matcher is not None:
self.matcher = matcher self.matcher = matcher
@@ -729,7 +729,7 @@ class OPGeneric:
return self.matcher(op) return self.matcher(op)
@classmethod @classmethod
def is_instance(cls, item): def is_instance(cls, item) -> bool:
# accept objects that are instances of this class # accept objects that are instances of this class
# or other classes that are subclasses # or other classes that are subclasses
return isinstance(item, cls) \ return isinstance(item, cls) \

View File

@@ -1742,7 +1742,7 @@ def _set_custom_task_factory(loop: asyncio.AbstractEventLoop):
loop.set_task_factory(factory) loop.set_task_factory(factory)
def run_sync_function_on_asyncio_thread(func: Callable, *, block: bool) -> None: def run_sync_function_on_asyncio_thread(func: Callable[[], Any], *, block: bool) -> None:
"""Run a non-async fn on the asyncio thread. Can be called from any thread. """Run a non-async fn on the asyncio thread. Can be called from any thread.
If the current thread is already the asyncio thread, func is guaranteed If the current thread is already the asyncio thread, func is guaranteed