lnworker: add support for hold invoices
(invoices for which we do not have the preimage) Callbacks and timeouts are registered with lnworker. If the preimage is not known after the timeout has expired, the payment is failed with MPP_TIMEOUT.
This commit is contained in:
@@ -676,6 +676,7 @@ class LNWallet(LNWorker):
|
||||
self.trampoline_forwarding_failures = {} # todo: should be persisted
|
||||
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
|
||||
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
|
||||
self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout
|
||||
|
||||
def has_deterministic_node_id(self) -> bool:
|
||||
return bool(self.db.get('lightning_xprv'))
|
||||
@@ -1880,6 +1881,14 @@ class LNWallet(LNWorker):
|
||||
amount_msat, direction, status = self.payment_info[key]
|
||||
return PaymentInfo(payment_hash, amount_msat, direction, status)
|
||||
|
||||
def add_payment_info_for_hold_invoice(self, payment_hash, lightning_amount_sat):
|
||||
info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
|
||||
self.save_payment_info(info, write_to_disk=False)
|
||||
|
||||
def register_callback_for_hold_invoice(self, payment_hash, cb, timeout: Optional[int] = None):
|
||||
expiry = int(time.time()) + timeout
|
||||
self.hold_invoice_callbacks[payment_hash] = cb, expiry
|
||||
|
||||
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
|
||||
key = info.payment_hash.hex()
|
||||
assert info.status in SAVED_PR_STATUS
|
||||
@@ -1891,13 +1900,22 @@ class LNWallet(LNWorker):
|
||||
def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
|
||||
""" return MPP status: True (accepted), False (expired) or None (waiting)
|
||||
"""
|
||||
payment_hash = htlc.payment_hash
|
||||
preimage = self.get_preimage(payment_hash)
|
||||
callback = self.hold_invoice_callbacks.get(payment_hash)
|
||||
if not preimage and callback:
|
||||
cb, timeout = callback
|
||||
if int(time.time()) < timeout:
|
||||
cb(payment_hash)
|
||||
return None
|
||||
else:
|
||||
return False
|
||||
|
||||
amt_to_forward = htlc.amount_msat # check this
|
||||
if amt_to_forward >= expected_msat:
|
||||
# not multi-part
|
||||
return True
|
||||
|
||||
payment_hash = htlc.payment_hash
|
||||
is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set()))
|
||||
if self.get_payment_status(payment_hash) == PR_PAID:
|
||||
# payment_status is persisted
|
||||
|
||||
@@ -173,6 +173,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
self.preimages = {}
|
||||
self.stopping_soon = False
|
||||
self.downstream_htlc_to_upstream_peer_map = {}
|
||||
self.hold_invoice_callbacks = {}
|
||||
|
||||
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
|
||||
|
||||
@@ -275,7 +276,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
_on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved
|
||||
_force_close_channel = LNWallet._force_close_channel
|
||||
suggest_splits = LNWallet.suggest_splits
|
||||
|
||||
register_callback_for_hold_invoice = LNWallet.register_callback_for_hold_invoice
|
||||
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self, name):
|
||||
@@ -731,10 +733,12 @@ class TestPeer(ElectrumTestCase):
|
||||
async def pay(lnaddr, pay_req):
|
||||
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
|
||||
result, log = await w1.pay_invoice(pay_req)
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
|
||||
raise PaymentDone()
|
||||
async def f():
|
||||
if result is True:
|
||||
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
|
||||
raise PaymentDone()
|
||||
else:
|
||||
raise PaymentFailure()
|
||||
async def f(test_hold_invoice=False, test_timeout=False):
|
||||
if trampoline:
|
||||
await turn_on_trampoline_alice()
|
||||
async with OldTaskGroup() as group:
|
||||
@@ -744,6 +748,14 @@ class TestPeer(ElectrumTestCase):
|
||||
await group.spawn(p2.htlc_switch())
|
||||
await asyncio.sleep(0.01)
|
||||
lnaddr, pay_req = self.prepare_invoice(w2)
|
||||
if test_hold_invoice:
|
||||
payment_hash = lnaddr.paymenthash
|
||||
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
|
||||
def cb(payment_hash):
|
||||
if not test_timeout:
|
||||
w2.save_preimage(payment_hash, preimage)
|
||||
timeout = 1 if test_timeout else 60
|
||||
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
|
||||
invoice_features = lnaddr.get_features()
|
||||
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
|
||||
await group.spawn(pay(lnaddr, pay_req))
|
||||
@@ -752,7 +764,11 @@ class TestPeer(ElectrumTestCase):
|
||||
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
||||
}
|
||||
with self.assertRaises(PaymentDone):
|
||||
await f()
|
||||
await f(test_hold_invoice=False)
|
||||
with self.assertRaises(PaymentDone):
|
||||
await f(test_hold_invoice=True, test_timeout=False)
|
||||
with self.assertRaises(PaymentFailure):
|
||||
await f(test_hold_invoice=True, test_timeout=True)
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
async def test_simple_payment(self):
|
||||
|
||||
Reference in New Issue
Block a user