lnworker: make PaymentInfo dataclass
Move PaymentInfo from NamedTuple to dataclass to allow for easier handling e.g. using dataclasses.astuple etc.
This commit is contained in:
@@ -20,6 +20,7 @@ import concurrent
|
||||
from concurrent import futures
|
||||
import urllib.parse
|
||||
import itertools
|
||||
import dataclasses
|
||||
|
||||
import aiohttp
|
||||
import dns.asyncresolver
|
||||
@@ -106,12 +107,23 @@ class PaymentDirection(IntEnum):
|
||||
FORWARDING = 3
|
||||
|
||||
|
||||
class PaymentInfo(NamedTuple):
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class PaymentInfo:
|
||||
"""Information required to handle incoming htlcs for a payment request"""
|
||||
payment_hash: bytes
|
||||
amount_msat: Optional[int]
|
||||
direction: int
|
||||
status: int
|
||||
|
||||
def validate(self):
|
||||
assert isinstance(self.payment_hash, bytes) and len(self.payment_hash) == 32
|
||||
assert self.amount_msat is None or isinstance(self.amount_msat, int)
|
||||
assert isinstance(self.direction, int)
|
||||
assert isinstance(self.status, int)
|
||||
|
||||
def __post_init__(self):
|
||||
self.validate()
|
||||
|
||||
|
||||
# Note: these states are persisted in the wallet file.
|
||||
# Do not modify them without performing a wallet db upgrade
|
||||
@@ -1567,7 +1579,12 @@ class LNWallet(LNWorker):
|
||||
raise PaymentFailure(_("A payment was already initiated for this invoice"))
|
||||
if payment_hash in self.get_payments(status='inflight'):
|
||||
raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
|
||||
info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
|
||||
info = PaymentInfo(
|
||||
payment_hash=payment_hash,
|
||||
amount_msat=amount_to_pay,
|
||||
direction=SENT,
|
||||
status=PR_UNPAID,
|
||||
)
|
||||
self.save_payment_info(info)
|
||||
self.wallet.set_label(key, lnaddr.get_description())
|
||||
self.set_invoice_status(key, PR_INFLIGHT)
|
||||
@@ -2302,7 +2319,12 @@ class LNWallet(LNWorker):
|
||||
def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
|
||||
payment_preimage = os.urandom(32)
|
||||
payment_hash = sha256(payment_preimage)
|
||||
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
|
||||
info = PaymentInfo(
|
||||
payment_hash=payment_hash,
|
||||
amount_msat=amount_msat,
|
||||
direction=RECEIVED,
|
||||
status=PR_UNPAID,
|
||||
)
|
||||
self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
|
||||
self.save_payment_info(info, write_to_disk=False)
|
||||
if write_to_disk:
|
||||
@@ -2376,12 +2398,22 @@ class LNWallet(LNWorker):
|
||||
with self.lock:
|
||||
if key in self.payment_info:
|
||||
amount_msat, direction, status = self.payment_info[key]
|
||||
return PaymentInfo(payment_hash, amount_msat, direction, status)
|
||||
return PaymentInfo(
|
||||
payment_hash=payment_hash,
|
||||
amount_msat=amount_msat,
|
||||
direction=direction,
|
||||
status=status,
|
||||
)
|
||||
return None
|
||||
|
||||
def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: Optional[int]):
|
||||
amount = lightning_amount_sat * 1000 if lightning_amount_sat else None
|
||||
info = PaymentInfo(payment_hash, amount, RECEIVED, PR_UNPAID)
|
||||
info = PaymentInfo(
|
||||
payment_hash=payment_hash,
|
||||
amount_msat=amount,
|
||||
direction=RECEIVED,
|
||||
status=PR_UNPAID,
|
||||
)
|
||||
self.save_payment_info(info, write_to_disk=False)
|
||||
|
||||
def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
|
||||
@@ -2396,11 +2428,11 @@ class LNWallet(LNWorker):
|
||||
if old_info := self.get_payment_info(payment_hash=info.payment_hash):
|
||||
if info == old_info:
|
||||
return # already saved
|
||||
if info != old_info._replace(status=info.status):
|
||||
if info != dataclasses.replace(old_info, status=info.status):
|
||||
# differs more than in status. let's fail
|
||||
raise Exception("payment_hash already in use")
|
||||
key = info.payment_hash.hex()
|
||||
self.payment_info[key] = info.amount_msat, info.direction, info.status
|
||||
self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0
|
||||
if write_to_disk:
|
||||
self.wallet.save_db()
|
||||
|
||||
@@ -2577,7 +2609,7 @@ class LNWallet(LNWorker):
|
||||
if info is None:
|
||||
# if we are forwarding
|
||||
return
|
||||
info = info._replace(status=status)
|
||||
info = dataclasses.replace(info, status=status)
|
||||
self.save_payment_info(info)
|
||||
|
||||
def is_forwarded_htlc(self, htlc_key) -> Optional[str]:
|
||||
|
||||
@@ -559,7 +559,12 @@ class TestPeer(ElectrumTestCase):
|
||||
payment_preimage = os.urandom(32)
|
||||
if payment_hash is None:
|
||||
payment_hash = sha256(payment_preimage)
|
||||
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
|
||||
info = PaymentInfo(
|
||||
payment_hash=payment_hash,
|
||||
amount_msat=amount_msat,
|
||||
direction=RECEIVED,
|
||||
status=PR_UNPAID,
|
||||
)
|
||||
if payment_preimage:
|
||||
w2.save_preimage(payment_hash, payment_preimage)
|
||||
w2.save_payment_info(info)
|
||||
|
||||
Reference in New Issue
Block a user