lnworker: use NamedTuple for received_mpp_htlcs. add/fix type hints
try to avoid long plain tuples
This commit is contained in:
@@ -10,7 +10,7 @@ import time
|
||||
import operator
|
||||
from enum import IntEnum
|
||||
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
|
||||
NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict)
|
||||
NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict, Callable)
|
||||
import threading
|
||||
import socket
|
||||
import aiohttp
|
||||
@@ -167,6 +167,13 @@ class PaymentInfo(NamedTuple):
|
||||
status: int
|
||||
|
||||
|
||||
class ReceivedMPPStatus(NamedTuple):
|
||||
is_expired: bool
|
||||
is_accepted: bool
|
||||
expected_msat: int
|
||||
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
|
||||
|
||||
|
||||
class ErrorAddingPeer(Exception): pass
|
||||
|
||||
|
||||
@@ -665,7 +672,7 @@ class LNWallet(LNWorker):
|
||||
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
|
||||
self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level
|
||||
self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed)
|
||||
self.received_mpp_htlcs = dict() # RHASH -> mpp_status, htlc_set
|
||||
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_secret -> ReceivedMPPStatus
|
||||
|
||||
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
|
||||
# detect inflight payments
|
||||
@@ -676,7 +683,8 @@ class LNWallet(LNWorker):
|
||||
self.trampoline_forwarding_failures = {} # todo: should be persisted
|
||||
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
|
||||
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
|
||||
self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout
|
||||
# payment_hash -> callback, timeout:
|
||||
self.hold_invoice_callbacks = {} # type: Dict[bytes, Tuple[Callable[[bytes], None], int]]
|
||||
self.payment_bundles = [] # lists of hashes. todo:persist
|
||||
|
||||
|
||||
@@ -1891,11 +1899,13 @@ class LNWallet(LNWorker):
|
||||
amount_msat, direction, status = self.payment_info[key]
|
||||
return PaymentInfo(payment_hash, amount_msat, direction, status)
|
||||
|
||||
def add_payment_info_for_hold_invoice(self, payment_hash, lightning_amount_sat):
|
||||
def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: int):
|
||||
info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
|
||||
self.save_payment_info(info, write_to_disk=False)
|
||||
|
||||
def register_callback_for_hold_invoice(self, payment_hash, cb, timeout: Optional[int] = None):
|
||||
def register_callback_for_hold_invoice(
|
||||
self, payment_hash: bytes, cb: Callable[[bytes], None], timeout: int,
|
||||
):
|
||||
expiry = int(time.time()) + timeout
|
||||
self.hold_invoice_callbacks[payment_hash] = cb, expiry
|
||||
|
||||
@@ -1907,7 +1917,12 @@ class LNWallet(LNWorker):
|
||||
if write_to_disk:
|
||||
self.wallet.save_db()
|
||||
|
||||
def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
|
||||
def check_received_htlc(
|
||||
self, payment_secret: bytes,
|
||||
short_channel_id: ShortChannelID,
|
||||
htlc: UpdateAddHtlc,
|
||||
expected_msat: int,
|
||||
) -> Optional[bool]:
|
||||
""" return MPP status: True (accepted), False (expired) or None (waiting)
|
||||
"""
|
||||
payment_hash = htlc.payment_hash
|
||||
@@ -1952,47 +1967,64 @@ class LNWallet(LNWorker):
|
||||
self.maybe_cleanup_mpp_status(payment_secret, short_channel_id, htlc)
|
||||
return True if is_accepted else (False if is_expired else None)
|
||||
|
||||
def update_mpp_with_received_htlc(self, payment_secret, short_channel_id, htlc, expected_msat):
|
||||
def update_mpp_with_received_htlc(
|
||||
self,
|
||||
payment_secret: bytes,
|
||||
short_channel_id: ShortChannelID,
|
||||
htlc: UpdateAddHtlc,
|
||||
expected_msat: int,
|
||||
):
|
||||
# add new htlc to set
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, expected_msat, set()))
|
||||
assert expected_msat == _expected_msat
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_secret)
|
||||
if mpp_status is None:
|
||||
mpp_status = ReceivedMPPStatus(
|
||||
is_expired=False,
|
||||
is_accepted=False,
|
||||
expected_msat=expected_msat,
|
||||
htlc_set=set(),
|
||||
)
|
||||
assert expected_msat == mpp_status.expected_msat
|
||||
key = (short_channel_id, htlc)
|
||||
if key not in htlc_set:
|
||||
htlc_set.add(key)
|
||||
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set
|
||||
if key not in mpp_status.htlc_set:
|
||||
mpp_status.htlc_set.add(key) # side-effecting htlc_set
|
||||
self.received_mpp_htlcs[payment_secret] = mpp_status
|
||||
|
||||
def get_mpp_status(self, payment_secret):
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret]
|
||||
return is_expired, is_accepted
|
||||
def get_mpp_status(self, payment_secret: bytes) -> Tuple[bool, bool]:
|
||||
mpp_status = self.received_mpp_htlcs[payment_secret]
|
||||
return mpp_status.is_expired, mpp_status.is_accepted
|
||||
|
||||
def set_mpp_status(self, payment_secret, is_expired, is_accepted):
|
||||
_is_expired, _is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret]
|
||||
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set
|
||||
def set_mpp_status(self, payment_secret: bytes, is_expired: bool, is_accepted: bool):
|
||||
mpp_status = self.received_mpp_htlcs[payment_secret]
|
||||
self.received_mpp_htlcs[payment_secret] = mpp_status._replace(
|
||||
is_expired=is_expired,
|
||||
is_accepted=is_accepted,
|
||||
)
|
||||
|
||||
def is_mpp_amount_reached(self, payment_secret):
|
||||
mpp = self.received_mpp_htlcs.get(payment_secret)
|
||||
if not mpp:
|
||||
def is_mpp_amount_reached(self, payment_secret: bytes) -> bool:
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_secret)
|
||||
if not mpp_status:
|
||||
return False
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = mpp
|
||||
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
|
||||
return total >= _expected_msat
|
||||
total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
|
||||
return total >= mpp_status.expected_msat
|
||||
|
||||
def get_first_timestamp_of_mpp(self, payment_secret):
|
||||
mpp = self.received_mpp_htlcs.get(payment_secret)
|
||||
if not mpp:
|
||||
def get_first_timestamp_of_mpp(self, payment_secret: bytes) -> int:
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_secret)
|
||||
if not mpp_status:
|
||||
return int(time.time())
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = mpp
|
||||
return min([_htlc.timestamp for scid, _htlc in htlc_set])
|
||||
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
|
||||
|
||||
def maybe_cleanup_mpp_status(self, payment_secret, short_channel_id, htlc):
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret]
|
||||
if not is_accepted and not is_expired:
|
||||
def maybe_cleanup_mpp_status(
|
||||
self,
|
||||
payment_secret: bytes,
|
||||
short_channel_id: ShortChannelID,
|
||||
htlc: UpdateAddHtlc,
|
||||
) -> None:
|
||||
mpp_status = self.received_mpp_htlcs[payment_secret]
|
||||
if not mpp_status.is_accepted and not mpp_status.is_expired:
|
||||
return
|
||||
key = (short_channel_id, htlc)
|
||||
htlc_set.remove(key)
|
||||
if len(htlc_set) > 0:
|
||||
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set
|
||||
elif payment_secret in self.received_mpp_htlcs:
|
||||
mpp_status.htlc_set.remove(key) # side-effecting htlc_set
|
||||
if not mpp_status.htlc_set and payment_secret in self.received_mpp_htlcs:
|
||||
self.received_mpp_htlcs.pop(payment_secret)
|
||||
|
||||
def get_payment_status(self, payment_hash: bytes) -> int:
|
||||
|
||||
Reference in New Issue
Block a user