lnworker: async gen create_routes_for_payments
This commit is contained in:
@@ -8,7 +8,7 @@ from decimal import Decimal
|
||||
import random
|
||||
import time
|
||||
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
|
||||
NamedTuple, Union, Mapping, Any, Iterable)
|
||||
NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator)
|
||||
import threading
|
||||
import socket
|
||||
import aiohttp
|
||||
@@ -1073,20 +1073,6 @@ class LNWallet(LNWorker):
|
||||
if chan.short_channel_id == short_channel_id:
|
||||
return chan
|
||||
|
||||
def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
|
||||
return self.create_routes_for_payment(
|
||||
amount_msat=amount_msat,
|
||||
final_total_msat=amount_msat,
|
||||
invoice_pubkey=decoded_invoice.pubkey.serialize(),
|
||||
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
|
||||
r_tags=decoded_invoice.get_routing_info('r'),
|
||||
invoice_features=decoded_invoice.get_features(),
|
||||
trampoline_fee_level=0,
|
||||
use_two_trampolines=False,
|
||||
payment_hash=decoded_invoice.paymenthash,
|
||||
payment_secret=decoded_invoice.payment_secret,
|
||||
full_path=full_path)
|
||||
|
||||
@log_exceptions
|
||||
async def pay_invoice(
|
||||
self, invoice: str, *,
|
||||
@@ -1173,8 +1159,7 @@ class LNWallet(LNWorker):
|
||||
# 1. create a set of routes for remaining amount.
|
||||
# note: path-finding runs in a separate thread so that we don't block the asyncio loop
|
||||
# graph updates might occur during the computation
|
||||
routes = await run_in_thread(partial(
|
||||
self.create_routes_for_payment,
|
||||
routes = self.create_routes_for_payment(
|
||||
amount_msat=amount_to_send,
|
||||
final_total_msat=amount_to_pay,
|
||||
invoice_pubkey=node_pubkey,
|
||||
@@ -1186,9 +1171,10 @@ class LNWallet(LNWorker):
|
||||
payment_secret=payment_secret,
|
||||
trampoline_fee_level=trampoline_fee_level,
|
||||
use_two_trampolines=use_two_trampolines,
|
||||
fwd_trampoline_onion=fwd_trampoline_onion))
|
||||
fwd_trampoline_onion=fwd_trampoline_onion
|
||||
)
|
||||
# 2. send htlcs
|
||||
for route, amount_msat, total_msat, amount_receiver_msat, cltv_delta, bucket_payment_secret, trampoline_onion in routes:
|
||||
async for route, amount_msat, total_msat, amount_receiver_msat, cltv_delta, bucket_payment_secret, trampoline_onion in routes:
|
||||
amount_inflight += amount_receiver_msat
|
||||
if amount_inflight > amount_to_pay: # safety belts
|
||||
raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}")
|
||||
@@ -1432,8 +1418,7 @@ class LNWallet(LNWorker):
|
||||
else:
|
||||
return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
|
||||
|
||||
@profiler
|
||||
def create_routes_for_payment(
|
||||
async def create_routes_for_payment(
|
||||
self, *,
|
||||
amount_msat: int, # part of payment amount we want routes for now
|
||||
final_total_msat: int, # total payment amount final receiver will get
|
||||
@@ -1446,7 +1431,7 @@ class LNWallet(LNWorker):
|
||||
trampoline_fee_level: int,
|
||||
use_two_trampolines: bool,
|
||||
fwd_trampoline_onion = None,
|
||||
full_path: LNPaymentPath = None) -> Sequence[Tuple[LNPaymentRoute, int]]:
|
||||
full_path: LNPaymentPath = None) -> AsyncGenerator[Tuple[LNPaymentRoute, int], None]:
|
||||
|
||||
"""Creates multiple routes for splitting a payment over the available
|
||||
private channels.
|
||||
@@ -1502,20 +1487,24 @@ class LNWallet(LNWorker):
|
||||
cltv_expiry_delta=0,
|
||||
node_features=trampoline_features)
|
||||
]
|
||||
routes = [(route, amount_with_fees, trampoline_total_msat, amount_msat, cltv_delta, trampoline_payment_secret, trampoline_onion)]
|
||||
yield route, amount_with_fees, trampoline_total_msat, amount_msat, cltv_delta, trampoline_payment_secret, trampoline_onion
|
||||
break
|
||||
else:
|
||||
raise NoPathFound()
|
||||
else:
|
||||
route = self.create_route_for_payment(
|
||||
amount_msat=amount_msat,
|
||||
invoice_pubkey=invoice_pubkey,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
r_tags=r_tags,
|
||||
invoice_features=invoice_features,
|
||||
channels=active_channels,
|
||||
full_path=full_path)
|
||||
routes = [(route, amount_msat, final_total_msat, amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion)]
|
||||
route = await run_in_thread(
|
||||
partial(
|
||||
self.create_route_for_payment,
|
||||
amount_msat=amount_msat,
|
||||
invoice_pubkey=invoice_pubkey,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
r_tags=r_tags,
|
||||
invoice_features=invoice_features,
|
||||
channels=active_channels,
|
||||
full_path=full_path
|
||||
)
|
||||
)
|
||||
yield route, amount_msat, final_total_msat, amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion
|
||||
except NoPathFound:
|
||||
if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
|
||||
raise
|
||||
@@ -1532,7 +1521,6 @@ class LNWallet(LNWorker):
|
||||
|
||||
for s in split_configurations:
|
||||
self.logger.info(f"trying split configuration: {s[0].values()} rating: {s[1]}")
|
||||
routes = []
|
||||
try:
|
||||
if not self.channel_db:
|
||||
buckets = defaultdict(list)
|
||||
@@ -1577,7 +1565,7 @@ class LNWallet(LNWorker):
|
||||
node_features=trampoline_features)
|
||||
]
|
||||
self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
|
||||
routes.append((route, part_amount_msat_with_fees, bucket_amount_with_fees, part_amount_msat, bucket_cltv_delta, bucket_payment_secret, trampoline_onion))
|
||||
yield route, part_amount_msat_with_fees, bucket_amount_with_fees, part_amount_msat, bucket_cltv_delta, bucket_payment_secret, trampoline_onion
|
||||
if bucket_fees != 0:
|
||||
self.logger.info('not enough margin to pay trampoline fee')
|
||||
raise NoPathFound()
|
||||
@@ -1585,23 +1573,27 @@ class LNWallet(LNWorker):
|
||||
for (chan_id, _), part_amount_msat in s[0].items():
|
||||
if part_amount_msat:
|
||||
channel = self.channels[chan_id]
|
||||
route = self.create_route_for_payment(
|
||||
amount_msat=part_amount_msat,
|
||||
invoice_pubkey=invoice_pubkey,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
r_tags=r_tags,
|
||||
invoice_features=invoice_features,
|
||||
channels=[channel],
|
||||
full_path=None)
|
||||
routes.append((route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion))
|
||||
route = await run_in_thread(
|
||||
partial(
|
||||
self.create_route_for_payment,
|
||||
amount_msat=part_amount_msat,
|
||||
invoice_pubkey=invoice_pubkey,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
r_tags=r_tags,
|
||||
invoice_features=invoice_features,
|
||||
channels=[channel],
|
||||
full_path=None
|
||||
)
|
||||
)
|
||||
yield route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion
|
||||
self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}")
|
||||
break
|
||||
except NoPathFound:
|
||||
continue
|
||||
else:
|
||||
raise NoPathFound()
|
||||
return routes
|
||||
|
||||
@profiler
|
||||
def create_route_for_payment(
|
||||
self, *,
|
||||
amount_msat: int,
|
||||
@@ -1610,7 +1602,7 @@ class LNWallet(LNWorker):
|
||||
r_tags,
|
||||
invoice_features: int,
|
||||
channels: List[Channel],
|
||||
full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]:
|
||||
full_path: Optional[LNPaymentPath]) -> LNPaymentRoute:
|
||||
|
||||
scid_to_my_channels = {
|
||||
chan.short_channel_id: chan for chan in channels
|
||||
|
||||
@@ -192,6 +192,20 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
self.channel_db.stop()
|
||||
await self.channel_db.stopped_event.wait()
|
||||
|
||||
async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
|
||||
return [r async for r in self.create_routes_for_payment(
|
||||
amount_msat=amount_msat,
|
||||
final_total_msat=amount_msat,
|
||||
invoice_pubkey=decoded_invoice.pubkey.serialize(),
|
||||
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
|
||||
r_tags=decoded_invoice.get_routing_info('r'),
|
||||
invoice_features=decoded_invoice.get_features(),
|
||||
trampoline_fee_level=0,
|
||||
use_two_trampolines=False,
|
||||
payment_hash=decoded_invoice.paymenthash,
|
||||
payment_secret=decoded_invoice.payment_secret,
|
||||
full_path=full_path)]
|
||||
|
||||
get_payments = LNWallet.get_payments
|
||||
get_payment_info = LNWallet.get_payment_info
|
||||
save_payment_info = LNWallet.save_payment_info
|
||||
@@ -206,7 +220,6 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
get_preimage = LNWallet.get_preimage
|
||||
create_route_for_payment = LNWallet.create_route_for_payment
|
||||
create_routes_for_payment = LNWallet.create_routes_for_payment
|
||||
create_routes_from_invoice = LNWallet.create_routes_from_invoice
|
||||
_check_invoice = staticmethod(LNWallet._check_invoice)
|
||||
pay_to_route = LNWallet.pay_to_route
|
||||
pay_to_node = LNWallet.pay_to_node
|
||||
@@ -598,7 +611,7 @@ class TestPeer(TestCaseForTestnet):
|
||||
q2 = w2.sent_htlcs[lnaddr1.paymenthash]
|
||||
# alice sends htlc BUT NOT COMMITMENT_SIGNED
|
||||
p1.maybe_send_commitment = lambda x: None
|
||||
route1 = w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2)[0][0]
|
||||
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0]
|
||||
amount_msat = lnaddr2.get_amount_msat()
|
||||
await w1.pay_to_route(
|
||||
route=route1,
|
||||
@@ -612,7 +625,7 @@ class TestPeer(TestCaseForTestnet):
|
||||
p1.maybe_send_commitment = _maybe_send_commitment1
|
||||
# bob sends htlc BUT NOT COMMITMENT_SIGNED
|
||||
p2.maybe_send_commitment = lambda x: None
|
||||
route2 = w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1)[0][0]
|
||||
route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0]
|
||||
amount_msat = lnaddr1.get_amount_msat()
|
||||
await w2.pay_to_route(
|
||||
route=route2,
|
||||
@@ -982,14 +995,14 @@ class TestPeer(TestCaseForTestnet):
|
||||
await asyncio.wait_for(p1.initialized, 1)
|
||||
await asyncio.wait_for(p2.initialized, 1)
|
||||
# alice sends htlc
|
||||
route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2]
|
||||
htlc = p1.pay(route=route,
|
||||
chan=alice_channel,
|
||||
amount_msat=lnaddr.get_amount_msat(),
|
||||
total_msat=lnaddr.get_amount_msat(),
|
||||
payment_hash=lnaddr.paymenthash,
|
||||
min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
|
||||
payment_secret=lnaddr.payment_secret)
|
||||
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
|
||||
p1.pay(route=route,
|
||||
chan=alice_channel,
|
||||
amount_msat=lnaddr.get_amount_msat(),
|
||||
total_msat=lnaddr.get_amount_msat(),
|
||||
payment_hash=lnaddr.paymenthash,
|
||||
min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
|
||||
payment_secret=lnaddr.payment_secret)
|
||||
# alice closes
|
||||
await p1.close_channel(alice_channel.channel_id)
|
||||
gath.cancel()
|
||||
@@ -1078,7 +1091,7 @@ class TestPeer(TestCaseForTestnet):
|
||||
lnaddr, pay_req = run(self.prepare_invoice(w2))
|
||||
|
||||
lnaddr = w1._check_invoice(pay_req)
|
||||
route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2]
|
||||
route, amount_msat = run(w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
|
||||
assert amount_msat == lnaddr.get_amount_msat()
|
||||
|
||||
run(w1.force_close_channel(alice_channel.channel_id))
|
||||
@@ -1086,7 +1099,7 @@ class TestPeer(TestCaseForTestnet):
|
||||
assert q1.qsize() == 1
|
||||
|
||||
with self.assertRaises(NoPathFound) as e:
|
||||
w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)
|
||||
run(w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))
|
||||
|
||||
peer = w1.peers[route[0].node_id]
|
||||
# AssertionError is ok since we shouldn't use old routes, and the
|
||||
|
||||
Reference in New Issue
Block a user