lnworker: add support for hold invoices
(invoices for which we do not have the preimage) Callbacks and timeouts are registered with lnworker. If the preimage is not known after the timeout has expired, the payment is failed with MPP_TIMEOUT.
This commit is contained in:
@@ -676,6 +676,7 @@ class LNWallet(LNWorker):
|
|||||||
self.trampoline_forwarding_failures = {} # todo: should be persisted
|
self.trampoline_forwarding_failures = {} # todo: should be persisted
|
||||||
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
|
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
|
||||||
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
|
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
|
||||||
|
self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout
|
||||||
|
|
||||||
def has_deterministic_node_id(self) -> bool:
|
def has_deterministic_node_id(self) -> bool:
|
||||||
return bool(self.db.get('lightning_xprv'))
|
return bool(self.db.get('lightning_xprv'))
|
||||||
@@ -1880,6 +1881,14 @@ class LNWallet(LNWorker):
|
|||||||
amount_msat, direction, status = self.payment_info[key]
|
amount_msat, direction, status = self.payment_info[key]
|
||||||
return PaymentInfo(payment_hash, amount_msat, direction, status)
|
return PaymentInfo(payment_hash, amount_msat, direction, status)
|
||||||
|
|
||||||
|
def add_payment_info_for_hold_invoice(self, payment_hash, lightning_amount_sat):
|
||||||
|
info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
|
||||||
|
self.save_payment_info(info, write_to_disk=False)
|
||||||
|
|
||||||
|
def register_callback_for_hold_invoice(self, payment_hash, cb, timeout: Optional[int] = None):
|
||||||
|
expiry = int(time.time()) + timeout
|
||||||
|
self.hold_invoice_callbacks[payment_hash] = cb, expiry
|
||||||
|
|
||||||
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
|
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
|
||||||
key = info.payment_hash.hex()
|
key = info.payment_hash.hex()
|
||||||
assert info.status in SAVED_PR_STATUS
|
assert info.status in SAVED_PR_STATUS
|
||||||
@@ -1891,13 +1900,22 @@ class LNWallet(LNWorker):
|
|||||||
def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
|
def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
|
||||||
""" return MPP status: True (accepted), False (expired) or None (waiting)
|
""" return MPP status: True (accepted), False (expired) or None (waiting)
|
||||||
"""
|
"""
|
||||||
|
payment_hash = htlc.payment_hash
|
||||||
|
preimage = self.get_preimage(payment_hash)
|
||||||
|
callback = self.hold_invoice_callbacks.get(payment_hash)
|
||||||
|
if not preimage and callback:
|
||||||
|
cb, timeout = callback
|
||||||
|
if int(time.time()) < timeout:
|
||||||
|
cb(payment_hash)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
amt_to_forward = htlc.amount_msat # check this
|
amt_to_forward = htlc.amount_msat # check this
|
||||||
if amt_to_forward >= expected_msat:
|
if amt_to_forward >= expected_msat:
|
||||||
# not multi-part
|
# not multi-part
|
||||||
return True
|
return True
|
||||||
|
|
||||||
payment_hash = htlc.payment_hash
|
|
||||||
is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set()))
|
is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set()))
|
||||||
if self.get_payment_status(payment_hash) == PR_PAID:
|
if self.get_payment_status(payment_hash) == PR_PAID:
|
||||||
# payment_status is persisted
|
# payment_status is persisted
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
|||||||
self.preimages = {}
|
self.preimages = {}
|
||||||
self.stopping_soon = False
|
self.stopping_soon = False
|
||||||
self.downstream_htlc_to_upstream_peer_map = {}
|
self.downstream_htlc_to_upstream_peer_map = {}
|
||||||
|
self.hold_invoice_callbacks = {}
|
||||||
|
|
||||||
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
|
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
|
||||||
|
|
||||||
@@ -275,7 +276,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
|||||||
_on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved
|
_on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved
|
||||||
_force_close_channel = LNWallet._force_close_channel
|
_force_close_channel = LNWallet._force_close_channel
|
||||||
suggest_splits = LNWallet.suggest_splits
|
suggest_splits = LNWallet.suggest_splits
|
||||||
|
register_callback_for_hold_invoice = LNWallet.register_callback_for_hold_invoice
|
||||||
|
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
|
||||||
|
|
||||||
class MockTransport:
|
class MockTransport:
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
@@ -731,10 +733,12 @@ class TestPeer(ElectrumTestCase):
|
|||||||
async def pay(lnaddr, pay_req):
|
async def pay(lnaddr, pay_req):
|
||||||
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
|
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
|
||||||
result, log = await w1.pay_invoice(pay_req)
|
result, log = await w1.pay_invoice(pay_req)
|
||||||
self.assertTrue(result)
|
if result is True:
|
||||||
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
|
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
|
||||||
raise PaymentDone()
|
raise PaymentDone()
|
||||||
async def f():
|
else:
|
||||||
|
raise PaymentFailure()
|
||||||
|
async def f(test_hold_invoice=False, test_timeout=False):
|
||||||
if trampoline:
|
if trampoline:
|
||||||
await turn_on_trampoline_alice()
|
await turn_on_trampoline_alice()
|
||||||
async with OldTaskGroup() as group:
|
async with OldTaskGroup() as group:
|
||||||
@@ -744,6 +748,14 @@ class TestPeer(ElectrumTestCase):
|
|||||||
await group.spawn(p2.htlc_switch())
|
await group.spawn(p2.htlc_switch())
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
lnaddr, pay_req = self.prepare_invoice(w2)
|
lnaddr, pay_req = self.prepare_invoice(w2)
|
||||||
|
if test_hold_invoice:
|
||||||
|
payment_hash = lnaddr.paymenthash
|
||||||
|
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
|
||||||
|
def cb(payment_hash):
|
||||||
|
if not test_timeout:
|
||||||
|
w2.save_preimage(payment_hash, preimage)
|
||||||
|
timeout = 1 if test_timeout else 60
|
||||||
|
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
|
||||||
invoice_features = lnaddr.get_features()
|
invoice_features = lnaddr.get_features()
|
||||||
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
|
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
|
||||||
await group.spawn(pay(lnaddr, pay_req))
|
await group.spawn(pay(lnaddr, pay_req))
|
||||||
@@ -752,7 +764,11 @@ class TestPeer(ElectrumTestCase):
|
|||||||
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
||||||
}
|
}
|
||||||
with self.assertRaises(PaymentDone):
|
with self.assertRaises(PaymentDone):
|
||||||
await f()
|
await f(test_hold_invoice=False)
|
||||||
|
with self.assertRaises(PaymentDone):
|
||||||
|
await f(test_hold_invoice=True, test_timeout=False)
|
||||||
|
with self.assertRaises(PaymentFailure):
|
||||||
|
await f(test_hold_invoice=True, test_timeout=True)
|
||||||
|
|
||||||
@needs_test_with_all_chacha20_implementations
|
@needs_test_with_all_chacha20_implementations
|
||||||
async def test_simple_payment(self):
|
async def test_simple_payment(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user