Separate pay_to_node logic from pay_invoice:
- pay_to_node will be needed to forward trampoline onions. - pay_to_node either is successful or raises - pay_invoice handles invoice status
This commit is contained in:
@@ -922,12 +922,10 @@ class LNWallet(LNWorker):
|
||||
chan, funding_tx = fut.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
raise Exception(_("open_channel timed out"))
|
||||
|
||||
# at this point the channel opening was successful
|
||||
# if this is the first channel that got opened, we start gossiping
|
||||
if self.channels:
|
||||
self.network.start_gossip()
|
||||
|
||||
return chan, funding_tx
|
||||
|
||||
def get_channel_by_short_id(self, short_channel_id: bytes) -> Optional[Channel]:
|
||||
@@ -935,6 +933,15 @@ class LNWallet(LNWorker):
|
||||
if chan.short_channel_id == short_channel_id:
|
||||
return chan
|
||||
|
||||
def create_routes_from_invoice(self, amount_msat, decoded_invoice, *, full_path=None):
|
||||
return self.create_routes_for_payment(
|
||||
amount_msat=amount_msat,
|
||||
invoice_pubkey=decoded_invoice.pubkey.serialize(),
|
||||
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
|
||||
r_tags=decoded_invoice.get_routing_info('r'),
|
||||
invoice_features=decoded_invoice.get_tag('9') or 0,
|
||||
full_path=full_path)
|
||||
|
||||
@log_exceptions
|
||||
async def pay_invoice(
|
||||
self, invoice: str, *,
|
||||
@@ -943,8 +950,13 @@ class LNWallet(LNWorker):
|
||||
full_path: LNPaymentPath = None) -> Tuple[bool, List[HtlcLog]]:
|
||||
|
||||
lnaddr = self._check_invoice(invoice, amount_msat=amount_msat)
|
||||
min_cltv_expiry = lnaddr.get_min_final_cltv_expiry()
|
||||
payment_hash = lnaddr.paymenthash
|
||||
key = payment_hash.hex()
|
||||
payment_secret = lnaddr.payment_secret
|
||||
invoice_pubkey = lnaddr.pubkey.serialize()
|
||||
invoice_features = lnaddr.get_tag('9') or 0
|
||||
r_tags = lnaddr.get_routing_info('r')
|
||||
amount_to_pay = lnaddr.get_amount_msat()
|
||||
status = self.get_payment_status(payment_hash)
|
||||
if status == PR_PAID:
|
||||
@@ -954,69 +966,68 @@ class LNWallet(LNWorker):
|
||||
info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
|
||||
self.save_payment_info(info)
|
||||
self.wallet.set_label(key, lnaddr.get_description())
|
||||
self.logs[key] = log = []
|
||||
success = False
|
||||
reason = ''
|
||||
amount_inflight = 0 # what we sent in htlcs
|
||||
|
||||
self.set_invoice_status(key, PR_INFLIGHT)
|
||||
util.trigger_callback('invoice_status', self.wallet, key)
|
||||
while True:
|
||||
amount_to_send = amount_to_pay - amount_inflight
|
||||
if amount_to_send > 0:
|
||||
# 1. create a set of routes for remaining amount.
|
||||
# note: path-finding runs in a separate thread so that we don't block the asyncio loop
|
||||
# graph updates might occur during the computation
|
||||
try:
|
||||
routes = await run_in_thread(partial(self.create_routes_from_invoice, amount_to_send, lnaddr, full_path=full_path))
|
||||
except NoPathFound:
|
||||
# catch this exception because we still want to return the htlc log
|
||||
reason = 'No path found'
|
||||
break
|
||||
# 2. send htlcs
|
||||
for route, amount_msat in routes:
|
||||
await self.pay_to_route(route, amount_msat, lnaddr)
|
||||
amount_inflight += amount_msat
|
||||
util.trigger_callback('invoice_status', self.wallet, key)
|
||||
# 3. await a queue
|
||||
htlc_log = await self.sent_htlcs[payment_hash].get()
|
||||
amount_inflight -= htlc_log.amount_msat
|
||||
log.append(htlc_log)
|
||||
if htlc_log.success:
|
||||
success = True
|
||||
break
|
||||
# htlc failed
|
||||
# if we get a tmp channel failure, it might work to split the amount and try more routes
|
||||
# if we get a channel update, we might retry the same route and amount
|
||||
if len(log) >= attempts:
|
||||
reason = 'Giving up after %d attempts'%len(log)
|
||||
break
|
||||
if htlc_log.sender_idx is not None:
|
||||
# apply channel update here
|
||||
should_continue = self.handle_error_code_from_failed_htlc(htlc_log)
|
||||
if not should_continue:
|
||||
break
|
||||
else:
|
||||
# probably got "update_fail_malformed_htlc". well... who to penalise now?
|
||||
reason = 'sender idx missing'
|
||||
break
|
||||
|
||||
# MPP: should we await all the inflight htlcs, or have another state?
|
||||
try:
|
||||
await self.pay_to_node(
|
||||
invoice_pubkey, payment_hash, payment_secret, amount_to_pay,
|
||||
min_cltv_expiry, r_tags, invoice_features,
|
||||
attempts=attempts, full_path=full_path)
|
||||
success = True
|
||||
except PaymentFailure as e:
|
||||
self.logger.exception('')
|
||||
success = False
|
||||
reason = str(e)
|
||||
if success:
|
||||
self.set_invoice_status(key, PR_PAID)
|
||||
util.trigger_callback('payment_succeeded', self.wallet, key)
|
||||
else:
|
||||
self.set_invoice_status(key, PR_UNPAID)
|
||||
util.trigger_callback('payment_failed', self.wallet, key, reason)
|
||||
util.trigger_callback('invoice_status', self.wallet, key)
|
||||
log = self.logs[key]
|
||||
return success, log
|
||||
|
||||
async def pay_to_route(self, route: LNPaymentRoute, amount_msat:int, lnaddr: LnAddr):
|
||||
|
||||
async def pay_to_node(
|
||||
self, node_pubkey, payment_hash, payment_secret, amount_to_pay,
|
||||
min_cltv_expiry, r_tags, invoice_features, *, attempts: int = 1,
|
||||
full_path: LNPaymentPath = None):
|
||||
|
||||
self.logs[payment_hash.hex()] = log = []
|
||||
amount_inflight = 0 # what we sent in htlcs
|
||||
while True:
|
||||
amount_to_send = amount_to_pay - amount_inflight
|
||||
if amount_to_send > 0:
|
||||
# 1. create a set of routes for remaining amount.
|
||||
# note: path-finding runs in a separate thread so that we don't block the asyncio loop
|
||||
# graph updates might occur during the computation
|
||||
routes = await run_in_thread(partial(
|
||||
self.create_routes_for_payment, amount_to_send, node_pubkey,
|
||||
min_cltv_expiry, r_tags, invoice_features, full_path=full_path))
|
||||
# 2. send htlcs
|
||||
for route, amount_msat in routes:
|
||||
await self.pay_to_route(route, amount_msat, payment_hash, payment_secret, min_cltv_expiry)
|
||||
amount_inflight += amount_msat
|
||||
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex())
|
||||
# 3. await a queue
|
||||
htlc_log = await self.sent_htlcs[payment_hash].get()
|
||||
amount_inflight -= htlc_log.amount_msat
|
||||
log.append(htlc_log)
|
||||
if htlc_log.success:
|
||||
return
|
||||
# htlc failed
|
||||
if len(log) >= attempts:
|
||||
raise PaymentFailure('Giving up after %d attempts'%len(log))
|
||||
# if we get a tmp channel failure, it might work to split the amount and try more routes
|
||||
# if we get a channel update, we might retry the same route and amount
|
||||
self.handle_error_code_from_failed_htlc(htlc_log)
|
||||
|
||||
|
||||
async def pay_to_route(self, route: LNPaymentRoute, amount_msat:int, payment_hash:bytes, payment_secret:bytes, min_cltv_expiry:int):
|
||||
# send a single htlc
|
||||
short_channel_id = route[0].short_channel_id
|
||||
chan = self.get_channel_by_short_id(short_channel_id)
|
||||
peer = self._peers.get(route[0].node_id)
|
||||
payment_hash = lnaddr.paymenthash
|
||||
if not peer:
|
||||
raise Exception('Dropped peer')
|
||||
await peer.initialized
|
||||
@@ -1025,8 +1036,8 @@ class LNWallet(LNWorker):
|
||||
chan=chan,
|
||||
amount_msat=amount_msat,
|
||||
payment_hash=payment_hash,
|
||||
min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
|
||||
payment_secret=lnaddr.payment_secret)
|
||||
min_final_cltv_expiry=min_cltv_expiry,
|
||||
payment_secret=payment_secret)
|
||||
self.htlc_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route
|
||||
util.trigger_callback('htlc_added', chan, htlc, SENT)
|
||||
|
||||
@@ -1037,8 +1048,6 @@ class LNWallet(LNWorker):
|
||||
code, data = failure_msg.code, failure_msg.data
|
||||
self.logger.info(f"UPDATE_FAIL_HTLC {repr(code)} {data}")
|
||||
self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}")
|
||||
if code == OnionFailureCode.MPP_TIMEOUT:
|
||||
return False
|
||||
# handle some specific error codes
|
||||
failure_codes = {
|
||||
OnionFailureCode.TEMPORARY_CHANNEL_FAILURE: 0,
|
||||
@@ -1048,6 +1057,8 @@ class LNWallet(LNWorker):
|
||||
OnionFailureCode.EXPIRY_TOO_SOON: 0,
|
||||
OnionFailureCode.CHANNEL_DISABLED: 2,
|
||||
}
|
||||
blacklist = False
|
||||
update = False
|
||||
if code in failure_codes:
|
||||
offset = failure_codes[code]
|
||||
channel_update_len = int.from_bytes(data[offset:offset+2], byteorder="big")
|
||||
@@ -1058,7 +1069,6 @@ class LNWallet(LNWorker):
|
||||
blacklist = True
|
||||
else:
|
||||
r = self.channel_db.add_channel_update(payload)
|
||||
blacklist = False
|
||||
short_channel_id = ShortChannelID(payload['short_channel_id'])
|
||||
if r == UpdateStatus.GOOD:
|
||||
self.logger.info(f"applied channel update to {short_channel_id}")
|
||||
@@ -1066,11 +1076,13 @@ class LNWallet(LNWorker):
|
||||
for chan in self.channels.values():
|
||||
if chan.short_channel_id == short_channel_id:
|
||||
chan.set_remote_update(payload['raw'])
|
||||
update = True
|
||||
elif r == UpdateStatus.ORPHANED:
|
||||
# maybe it is a private channel (and data in invoice was outdated)
|
||||
self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?")
|
||||
start_node_id = route[sender_idx].node_id
|
||||
self.channel_db.add_channel_update_for_private_channel(payload, start_node_id)
|
||||
#update = True # FIXME: we need to check if we actually updated something
|
||||
elif r == UpdateStatus.EXPIRED:
|
||||
blacklist = True
|
||||
elif r == UpdateStatus.DEPRECATED:
|
||||
@@ -1080,22 +1092,25 @@ class LNWallet(LNWorker):
|
||||
blacklist = True
|
||||
else:
|
||||
blacklist = True
|
||||
# blacklist channel after reporter node
|
||||
# TODO this should depend on the error (even more granularity)
|
||||
# also, we need finer blacklisting (directed edges; nodes)
|
||||
if blacklist and sender_idx:
|
||||
|
||||
if blacklist:
|
||||
# blacklist channel after reporter node
|
||||
# TODO this should depend on the error (even more granularity)
|
||||
# also, we need finer blacklisting (directed edges; nodes)
|
||||
if htlc_log.sender_idx is None:
|
||||
raise PaymentFailure(htlc_log.failure_msg.code_name())
|
||||
try:
|
||||
short_chan_id = route[sender_idx + 1].short_channel_id
|
||||
except IndexError:
|
||||
self.logger.info("payment destination reported error")
|
||||
short_chan_id = None
|
||||
else:
|
||||
# TODO: for MPP we need to save the amount for which
|
||||
# we saw temporary channel failure
|
||||
self.logger.info(f'blacklisting channel {short_chan_id}')
|
||||
self.network.channel_blacklist.add(short_chan_id)
|
||||
return True
|
||||
return False
|
||||
raise PaymentFailure('payment destination reported error')
|
||||
# TODO: for MPP we need to save the amount for which
|
||||
# we saw temporary channel failure
|
||||
self.logger.info(f'blacklisting channel {short_chan_id}')
|
||||
self.network.channel_blacklist.add(short_chan_id)
|
||||
|
||||
# we should not continue if we did not blacklist or update anything
|
||||
if not (blacklist or update):
|
||||
raise PaymentFailure(htlc_log.failure_msg.code_name())
|
||||
|
||||
|
||||
@classmethod
|
||||
@@ -1137,16 +1152,17 @@ class LNWallet(LNWorker):
|
||||
return addr
|
||||
|
||||
@profiler
|
||||
def create_routes_from_invoice(
|
||||
def create_routes_for_payment(
|
||||
self,
|
||||
amount_msat: int,
|
||||
decoded_invoice: 'LnAddr',
|
||||
invoice_pubkey,
|
||||
min_cltv_expiry,
|
||||
r_tags,
|
||||
invoice_features,
|
||||
*, full_path: LNPaymentPath = None) -> LNPaymentRoute:
|
||||
# TODO: return multiples routes if we know that a single one will not work
|
||||
# initially, try with less htlcs
|
||||
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
||||
r_tags = decoded_invoice.get_routing_info('r')
|
||||
route = None # type: Optional[LNPaymentRoute]
|
||||
route = None
|
||||
channels = list(self.channels.values())
|
||||
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
|
||||
if chan.short_channel_id is not None}
|
||||
@@ -1201,7 +1217,7 @@ class LNWallet(LNWorker):
|
||||
node_features=node_info.features if node_info else 0))
|
||||
prev_node_id = node_pubkey
|
||||
# test sanity
|
||||
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
|
||||
if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry):
|
||||
self.logger.info(f"rejecting insane route {route}")
|
||||
route = None
|
||||
continue
|
||||
@@ -1213,14 +1229,13 @@ class LNWallet(LNWorker):
|
||||
path=full_path, my_channels=scid_to_my_channels, blacklist=blacklist)
|
||||
if not route:
|
||||
raise NoPathFound()
|
||||
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
|
||||
if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry):
|
||||
self.logger.info(f"rejecting insane route {route}")
|
||||
raise NoPathFound()
|
||||
assert len(route) > 0
|
||||
if route[-1].node_id != invoice_pubkey:
|
||||
raise LNPathInconsistent("last node_id != invoice pubkey")
|
||||
# add features from invoice
|
||||
invoice_features = decoded_invoice.get_tag('9') or 0
|
||||
route[-1].node_features |= invoice_features
|
||||
# return a list of routes
|
||||
return [(route, amount_msat)]
|
||||
@@ -1367,7 +1382,10 @@ class LNWallet(LNWorker):
|
||||
failure_message: Optional['OnionRoutingFailureMessage']):
|
||||
|
||||
route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id))
|
||||
if error_bytes and route:
|
||||
if not route:
|
||||
self.logger.info(f"received unknown htlc_failed, probably from previous session")
|
||||
return
|
||||
if error_bytes:
|
||||
self.logger.info(f" {(error_bytes, route, htlc_id)}")
|
||||
# TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
|
||||
try:
|
||||
|
||||
@@ -175,9 +175,11 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
htlc_failed = LNWallet.htlc_failed
|
||||
save_preimage = LNWallet.save_preimage
|
||||
get_preimage = LNWallet.get_preimage
|
||||
create_routes_for_payment = LNWallet.create_routes_for_payment
|
||||
create_routes_from_invoice = LNWallet.create_routes_from_invoice
|
||||
_check_invoice = staticmethod(LNWallet._check_invoice)
|
||||
pay_to_route = LNWallet.pay_to_route
|
||||
pay_to_node = LNWallet.pay_to_node
|
||||
pay_invoice = LNWallet.pay_invoice
|
||||
force_close_channel = LNWallet.force_close_channel
|
||||
try_force_closing = LNWallet.try_force_closing
|
||||
@@ -766,7 +768,11 @@ class TestPeer(ElectrumTestCase):
|
||||
# AssertionError is ok since we shouldn't use old routes, and the
|
||||
# route finding should fail when channel is closed
|
||||
async def f():
|
||||
await asyncio.gather(w1.pay_to_route(route, amount_msat, lnaddr), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
||||
min_cltv_expiry = lnaddr.get_min_final_cltv_expiry()
|
||||
payment_hash = lnaddr.paymenthash
|
||||
payment_secret = lnaddr.payment_secret
|
||||
pay = w1.pay_to_route(route, amount_msat, payment_hash, payment_secret, min_cltv_expiry)
|
||||
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
||||
with self.assertRaises(PaymentFailure):
|
||||
run(f())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user