diff --git a/electrum/lnworker.py b/electrum/lnworker.py index c66843ee0..63c6132a2 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -20,6 +20,7 @@ import concurrent from concurrent import futures import urllib.parse import itertools +import dataclasses import aiohttp import dns.asyncresolver @@ -106,12 +107,23 @@ class PaymentDirection(IntEnum): FORWARDING = 3 -class PaymentInfo(NamedTuple): +@dataclasses.dataclass(frozen=True, kw_only=True) +class PaymentInfo: + """Information required to handle incoming htlcs for a payment request""" payment_hash: bytes amount_msat: Optional[int] direction: int status: int + def validate(self): + assert isinstance(self.payment_hash, bytes) and len(self.payment_hash) == 32 + assert self.amount_msat is None or isinstance(self.amount_msat, int) + assert isinstance(self.direction, int) + assert isinstance(self.status, int) + + def __post_init__(self): + self.validate() + # Note: these states are persisted in the wallet file. # Do not modify them without performing a wallet db upgrade @@ -1567,7 +1579,12 @@ class LNWallet(LNWorker): raise PaymentFailure(_("A payment was already initiated for this invoice")) if payment_hash in self.get_payments(status='inflight'): raise PaymentFailure(_("A previous attempt to pay this invoice did not clear")) - info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID) + info = PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount_to_pay, + direction=SENT, + status=PR_UNPAID, + ) self.save_payment_info(info) self.wallet.set_label(key, lnaddr.get_description()) self.set_invoice_status(key, PR_INFLIGHT) @@ -2302,7 +2319,12 @@ class LNWallet(LNWorker): def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes: payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) - info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID) + info = PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount_msat, + direction=RECEIVED, + status=PR_UNPAID, + ) self.save_preimage(payment_hash, payment_preimage, write_to_disk=False) self.save_payment_info(info, write_to_disk=False) if write_to_disk: @@ -2376,12 +2398,22 @@ class LNWallet(LNWorker): with self.lock: if key in self.payment_info: amount_msat, direction, status = self.payment_info[key] - return PaymentInfo(payment_hash, amount_msat, direction, status) + return PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount_msat, + direction=direction, + status=status, + ) return None def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: Optional[int]): amount = lightning_amount_sat * 1000 if lightning_amount_sat else None - info = PaymentInfo(payment_hash, amount, RECEIVED, PR_UNPAID) + info = PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount, + direction=RECEIVED, + status=PR_UNPAID, + ) self.save_payment_info(info, write_to_disk=False) def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]): @@ -2396,11 +2428,11 @@ class LNWallet(LNWorker): if old_info := self.get_payment_info(payment_hash=info.payment_hash): if info == old_info: return # already saved - if info != old_info._replace(status=info.status): + if info != dataclasses.replace(old_info, status=info.status): # differs more than in status. let's fail raise Exception("payment_hash already in use") key = info.payment_hash.hex() - self.payment_info[key] = info.amount_msat, info.direction, info.status + self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0 if write_to_disk: self.wallet.save_db() @@ -2577,7 +2609,7 @@ class LNWallet(LNWorker): if info is None: # if we are forwarding return - info = info._replace(status=status) + info = dataclasses.replace(info, status=status) self.save_payment_info(info) def is_forwarded_htlc(self, htlc_key) -> Optional[str]: diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index b1bc6c5af..3eb4313f6 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -559,7 +559,12 @@ class TestPeer(ElectrumTestCase): payment_preimage = os.urandom(32) if payment_hash is None: payment_hash = sha256(payment_preimage) - info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID) + info = PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount_msat, + direction=RECEIVED, + status=PR_UNPAID, + ) if payment_preimage: w2.save_preimage(payment_hash, payment_preimage) w2.save_payment_info(info)