1
0

lnworker: make PaymentInfo dataclass

Move PaymentInfo from NamedTuple to dataclass to allow for easier
handling e.g. using dataclasses.astuple etc.
This commit is contained in:
f321x
2025-09-26 16:02:50 +02:00
committed by SomberNight
parent acd52da764
commit 4c0155c072
2 changed files with 46 additions and 9 deletions

View File

@@ -20,6 +20,7 @@ import concurrent
from concurrent import futures
import urllib.parse
import itertools
import dataclasses
import aiohttp
import dns.asyncresolver
@@ -106,12 +107,23 @@ class PaymentDirection(IntEnum):
FORWARDING = 3
class PaymentInfo(NamedTuple):
@dataclasses.dataclass(frozen=True, kw_only=True)
class PaymentInfo:
"""Information required to handle incoming htlcs for a payment request"""
payment_hash: bytes
amount_msat: Optional[int]
direction: int
status: int
def validate(self):
assert isinstance(self.payment_hash, bytes) and len(self.payment_hash) == 32
assert self.amount_msat is None or isinstance(self.amount_msat, int)
assert isinstance(self.direction, int)
assert isinstance(self.status, int)
def __post_init__(self):
self.validate()
# Note: these states are persisted in the wallet file.
# Do not modify them without performing a wallet db upgrade
@@ -1567,7 +1579,12 @@ class LNWallet(LNWorker):
raise PaymentFailure(_("A payment was already initiated for this invoice"))
if payment_hash in self.get_payments(status='inflight'):
raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_to_pay,
direction=SENT,
status=PR_UNPAID,
)
self.save_payment_info(info)
self.wallet.set_label(key, lnaddr.get_description())
self.set_invoice_status(key, PR_INFLIGHT)
@@ -2302,7 +2319,12 @@ class LNWallet(LNWorker):
def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
payment_preimage = os.urandom(32)
payment_hash = sha256(payment_preimage)
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_msat,
direction=RECEIVED,
status=PR_UNPAID,
)
self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
self.save_payment_info(info, write_to_disk=False)
if write_to_disk:
@@ -2376,12 +2398,22 @@ class LNWallet(LNWorker):
with self.lock:
if key in self.payment_info:
amount_msat, direction, status = self.payment_info[key]
return PaymentInfo(payment_hash, amount_msat, direction, status)
return PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_msat,
direction=direction,
status=status,
)
return None
def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: Optional[int]):
amount = lightning_amount_sat * 1000 if lightning_amount_sat else None
info = PaymentInfo(payment_hash, amount, RECEIVED, PR_UNPAID)
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount,
direction=RECEIVED,
status=PR_UNPAID,
)
self.save_payment_info(info, write_to_disk=False)
def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
@@ -2396,11 +2428,11 @@ class LNWallet(LNWorker):
if old_info := self.get_payment_info(payment_hash=info.payment_hash):
if info == old_info:
return # already saved
if info != old_info._replace(status=info.status):
if info != dataclasses.replace(old_info, status=info.status):
# differs more than in status. let's fail
raise Exception("payment_hash already in use")
key = info.payment_hash.hex()
self.payment_info[key] = info.amount_msat, info.direction, info.status
self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0
if write_to_disk:
self.wallet.save_db()
@@ -2577,7 +2609,7 @@ class LNWallet(LNWorker):
if info is None:
# if we are forwarding
return
info = info._replace(status=status)
info = dataclasses.replace(info, status=status)
self.save_payment_info(info)
def is_forwarded_htlc(self, htlc_key) -> Optional[str]:

View File

@@ -559,7 +559,12 @@ class TestPeer(ElectrumTestCase):
payment_preimage = os.urandom(32)
if payment_hash is None:
payment_hash = sha256(payment_preimage)
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_msat,
direction=RECEIVED,
status=PR_UNPAID,
)
if payment_preimage:
w2.save_preimage(payment_hash, payment_preimage)
w2.save_payment_info(info)