lnworker: make PaymentInfo dataclass
Move PaymentInfo from NamedTuple to dataclass to allow for easier handling e.g. using dataclasses.astuple etc.
This commit is contained in:
@@ -20,6 +20,7 @@ import concurrent
|
|||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import itertools
|
import itertools
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import dns.asyncresolver
|
import dns.asyncresolver
|
||||||
@@ -106,12 +107,23 @@ class PaymentDirection(IntEnum):
|
|||||||
FORWARDING = 3
|
FORWARDING = 3
|
||||||
|
|
||||||
|
|
||||||
class PaymentInfo(NamedTuple):
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||||
|
class PaymentInfo:
|
||||||
|
"""Information required to handle incoming htlcs for a payment request"""
|
||||||
payment_hash: bytes
|
payment_hash: bytes
|
||||||
amount_msat: Optional[int]
|
amount_msat: Optional[int]
|
||||||
direction: int
|
direction: int
|
||||||
status: int
|
status: int
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
assert isinstance(self.payment_hash, bytes) and len(self.payment_hash) == 32
|
||||||
|
assert self.amount_msat is None or isinstance(self.amount_msat, int)
|
||||||
|
assert isinstance(self.direction, int)
|
||||||
|
assert isinstance(self.status, int)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
|
||||||
# Note: these states are persisted in the wallet file.
|
# Note: these states are persisted in the wallet file.
|
||||||
# Do not modify them without performing a wallet db upgrade
|
# Do not modify them without performing a wallet db upgrade
|
||||||
@@ -1567,7 +1579,12 @@ class LNWallet(LNWorker):
|
|||||||
raise PaymentFailure(_("A payment was already initiated for this invoice"))
|
raise PaymentFailure(_("A payment was already initiated for this invoice"))
|
||||||
if payment_hash in self.get_payments(status='inflight'):
|
if payment_hash in self.get_payments(status='inflight'):
|
||||||
raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
|
raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
|
||||||
info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
|
info = PaymentInfo(
|
||||||
|
payment_hash=payment_hash,
|
||||||
|
amount_msat=amount_to_pay,
|
||||||
|
direction=SENT,
|
||||||
|
status=PR_UNPAID,
|
||||||
|
)
|
||||||
self.save_payment_info(info)
|
self.save_payment_info(info)
|
||||||
self.wallet.set_label(key, lnaddr.get_description())
|
self.wallet.set_label(key, lnaddr.get_description())
|
||||||
self.set_invoice_status(key, PR_INFLIGHT)
|
self.set_invoice_status(key, PR_INFLIGHT)
|
||||||
@@ -2302,7 +2319,12 @@ class LNWallet(LNWorker):
|
|||||||
def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
|
def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
|
||||||
payment_preimage = os.urandom(32)
|
payment_preimage = os.urandom(32)
|
||||||
payment_hash = sha256(payment_preimage)
|
payment_hash = sha256(payment_preimage)
|
||||||
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
|
info = PaymentInfo(
|
||||||
|
payment_hash=payment_hash,
|
||||||
|
amount_msat=amount_msat,
|
||||||
|
direction=RECEIVED,
|
||||||
|
status=PR_UNPAID,
|
||||||
|
)
|
||||||
self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
|
self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
|
||||||
self.save_payment_info(info, write_to_disk=False)
|
self.save_payment_info(info, write_to_disk=False)
|
||||||
if write_to_disk:
|
if write_to_disk:
|
||||||
@@ -2376,12 +2398,22 @@ class LNWallet(LNWorker):
|
|||||||
with self.lock:
|
with self.lock:
|
||||||
if key in self.payment_info:
|
if key in self.payment_info:
|
||||||
amount_msat, direction, status = self.payment_info[key]
|
amount_msat, direction, status = self.payment_info[key]
|
||||||
return PaymentInfo(payment_hash, amount_msat, direction, status)
|
return PaymentInfo(
|
||||||
|
payment_hash=payment_hash,
|
||||||
|
amount_msat=amount_msat,
|
||||||
|
direction=direction,
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: Optional[int]):
|
def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: Optional[int]):
|
||||||
amount = lightning_amount_sat * 1000 if lightning_amount_sat else None
|
amount = lightning_amount_sat * 1000 if lightning_amount_sat else None
|
||||||
info = PaymentInfo(payment_hash, amount, RECEIVED, PR_UNPAID)
|
info = PaymentInfo(
|
||||||
|
payment_hash=payment_hash,
|
||||||
|
amount_msat=amount,
|
||||||
|
direction=RECEIVED,
|
||||||
|
status=PR_UNPAID,
|
||||||
|
)
|
||||||
self.save_payment_info(info, write_to_disk=False)
|
self.save_payment_info(info, write_to_disk=False)
|
||||||
|
|
||||||
def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
|
def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
|
||||||
@@ -2396,11 +2428,11 @@ class LNWallet(LNWorker):
|
|||||||
if old_info := self.get_payment_info(payment_hash=info.payment_hash):
|
if old_info := self.get_payment_info(payment_hash=info.payment_hash):
|
||||||
if info == old_info:
|
if info == old_info:
|
||||||
return # already saved
|
return # already saved
|
||||||
if info != old_info._replace(status=info.status):
|
if info != dataclasses.replace(old_info, status=info.status):
|
||||||
# differs more than in status. let's fail
|
# differs more than in status. let's fail
|
||||||
raise Exception("payment_hash already in use")
|
raise Exception("payment_hash already in use")
|
||||||
key = info.payment_hash.hex()
|
key = info.payment_hash.hex()
|
||||||
self.payment_info[key] = info.amount_msat, info.direction, info.status
|
self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0
|
||||||
if write_to_disk:
|
if write_to_disk:
|
||||||
self.wallet.save_db()
|
self.wallet.save_db()
|
||||||
|
|
||||||
@@ -2577,7 +2609,7 @@ class LNWallet(LNWorker):
|
|||||||
if info is None:
|
if info is None:
|
||||||
# if we are forwarding
|
# if we are forwarding
|
||||||
return
|
return
|
||||||
info = info._replace(status=status)
|
info = dataclasses.replace(info, status=status)
|
||||||
self.save_payment_info(info)
|
self.save_payment_info(info)
|
||||||
|
|
||||||
def is_forwarded_htlc(self, htlc_key) -> Optional[str]:
|
def is_forwarded_htlc(self, htlc_key) -> Optional[str]:
|
||||||
|
|||||||
@@ -559,7 +559,12 @@ class TestPeer(ElectrumTestCase):
|
|||||||
payment_preimage = os.urandom(32)
|
payment_preimage = os.urandom(32)
|
||||||
if payment_hash is None:
|
if payment_hash is None:
|
||||||
payment_hash = sha256(payment_preimage)
|
payment_hash = sha256(payment_preimage)
|
||||||
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
|
info = PaymentInfo(
|
||||||
|
payment_hash=payment_hash,
|
||||||
|
amount_msat=amount_msat,
|
||||||
|
direction=RECEIVED,
|
||||||
|
status=PR_UNPAID,
|
||||||
|
)
|
||||||
if payment_preimage:
|
if payment_preimage:
|
||||||
w2.save_preimage(payment_hash, payment_preimage)
|
w2.save_preimage(payment_hash, payment_preimage)
|
||||||
w2.save_payment_info(info)
|
w2.save_payment_info(info)
|
||||||
|
|||||||
Reference in New Issue
Block a user