lnworker: introduce PaySession cls, refactor pay_to_node
This commit is contained in:
@@ -17,7 +17,7 @@ import socket
|
||||
import aiohttp
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from functools import partial, cached_property
|
||||
from collections import defaultdict
|
||||
import concurrent
|
||||
from concurrent import futures
|
||||
@@ -655,6 +655,105 @@ class LNGossip(LNWorker):
|
||||
self.logger.debug(f'process_gossip: {len(categorized_chan_upds.good)}/{len(chan_upds)}')
|
||||
|
||||
|
||||
class PaySession(Logger):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payment_hash: bytes,
|
||||
payment_secret: bytes,
|
||||
initial_trampoline_fee_level: int,
|
||||
invoice_features: int,
|
||||
r_tags,
|
||||
min_cltv_expiry: int,
|
||||
amount_to_pay: int, # total payment amount final receiver will get
|
||||
invoice_pubkey: bytes,
|
||||
):
|
||||
assert payment_hash
|
||||
assert payment_secret
|
||||
self.payment_hash = payment_hash
|
||||
self.payment_secret = payment_secret
|
||||
self.payment_key = payment_hash + payment_secret
|
||||
Logger.__init__(self)
|
||||
|
||||
self.invoice_features = LnFeatures(invoice_features)
|
||||
self.r_tags = r_tags
|
||||
self.min_cltv_expiry = min_cltv_expiry
|
||||
self.amount_to_pay = amount_to_pay
|
||||
self.invoice_pubkey = invoice_pubkey
|
||||
|
||||
self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog]
|
||||
self.start_time = time.time()
|
||||
|
||||
self.trampoline_fee_level = initial_trampoline_fee_level
|
||||
self.failed_trampoline_routes = []
|
||||
self.use_two_trampolines = True
|
||||
|
||||
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
|
||||
self._nhtlcs_inflight = 0
|
||||
|
||||
def diagnostic_name(self):
|
||||
pkey = sha256(self.payment_key)
|
||||
return f"{self.payment_hash[:4].hex()}-{pkey[:2].hex()}"
|
||||
|
||||
def maybe_raise_trampoline_fee(self, htlc_log: HtlcLog):
|
||||
if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
|
||||
self.trampoline_fee_level += 1
|
||||
self.failed_trampoline_routes = []
|
||||
self.logger.info(f'raising trampoline fee level {self.trampoline_fee_level}')
|
||||
else:
|
||||
self.logger.info(f'NOT raising trampoline fee level, already at {self.trampoline_fee_level}')
|
||||
|
||||
def handle_failed_trampoline_htlc(self, *, htlc_log: HtlcLog, failure_msg: OnionRoutingFailure):
|
||||
# FIXME The trampoline nodes in the path are chosen randomly.
|
||||
# Some of the errors might depend on how we have chosen them.
|
||||
# Having more attempts is currently useful in part because of the randomness,
|
||||
# instead we should give feedback to create_routes_for_payment.
|
||||
# Sometimes the trampoline node fails to send a payment and returns
|
||||
# TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
|
||||
if failure_msg.code in (
|
||||
OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
|
||||
OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
|
||||
OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
|
||||
# TODO: parse the node policy here (not returned by eclair yet)
|
||||
# TODO: erring node is always the first trampoline even if second
|
||||
# trampoline demands more fees, we can't influence this
|
||||
self.maybe_raise_trampoline_fee(htlc_log)
|
||||
elif self.use_two_trampolines:
|
||||
self.use_two_trampolines = False
|
||||
elif failure_msg.code in (
|
||||
OnionFailureCode.UNKNOWN_NEXT_PEER,
|
||||
OnionFailureCode.TEMPORARY_NODE_FAILURE):
|
||||
trampoline_route = htlc_log.route
|
||||
r = [hop.end_node.hex() for hop in trampoline_route]
|
||||
self.logger.info(f'failed trampoline route: {r}')
|
||||
if r not in self.failed_trampoline_routes:
|
||||
self.failed_trampoline_routes.append(r)
|
||||
else:
|
||||
pass # maybe the route was reused between different MPP parts
|
||||
else:
|
||||
raise PaymentFailure(failure_msg.code_name())
|
||||
|
||||
async def wait_for_one_htlc_to_resolve(self) -> HtlcLog:
|
||||
self.logger.info(f"waiting... amount_inflight={self._amount_inflight}. nhtlcs_inflight={self._nhtlcs_inflight}")
|
||||
htlc_log = await self.sent_htlcs_q.get()
|
||||
self._amount_inflight -= htlc_log.amount_msat
|
||||
self._nhtlcs_inflight -= 1
|
||||
if self._amount_inflight < 0 or self._nhtlcs_inflight < 0:
|
||||
raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
|
||||
return htlc_log
|
||||
|
||||
def add_new_htlc(self, sent_htlc_info: SentHtlcInfo) -> SentHtlcInfo:
|
||||
self._nhtlcs_inflight += 1
|
||||
self._amount_inflight += sent_htlc_info.amount_receiver_msat
|
||||
if self._amount_inflight > self.amount_to_pay: # safety belts
|
||||
raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
|
||||
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level)
|
||||
return sent_htlc_info
|
||||
|
||||
def get_outstanding_amount_to_send(self) -> int:
|
||||
return self.amount_to_pay - self._amount_inflight
|
||||
|
||||
|
||||
class LNWallet(LNWorker):
|
||||
|
||||
lnwatcher: Optional['LNWalletWatcher']
|
||||
@@ -694,9 +793,9 @@ class LNWallet(LNWorker):
|
||||
for channel_id, storage in channel_backups.items():
|
||||
self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self)
|
||||
|
||||
self.sent_htlcs_q = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
|
||||
self._paysessions = dict() # type: Dict[bytes, PaySession]
|
||||
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
|
||||
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed)
|
||||
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) # TODO move into PaySession
|
||||
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
|
||||
|
||||
# detect inflight payments
|
||||
@@ -1274,9 +1373,9 @@ class LNWallet(LNWorker):
|
||||
invoice_features: int,
|
||||
attempts: int = None,
|
||||
full_path: LNPaymentPath = None,
|
||||
fwd_trampoline_onion=None,
|
||||
fwd_trampoline_fee=None,
|
||||
fwd_trampoline_cltv_delta=None,
|
||||
fwd_trampoline_onion: OnionPacket = None,
|
||||
fwd_trampoline_fee: int = None,
|
||||
fwd_trampoline_cltv_delta: int = None,
|
||||
channels: Optional[Sequence[Channel]] = None,
|
||||
) -> None:
|
||||
|
||||
@@ -1288,46 +1387,37 @@ class LNWallet(LNWorker):
|
||||
raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
|
||||
|
||||
payment_key = payment_hash + payment_secret
|
||||
#assert payment_key not in self._paysessions # FIXME
|
||||
self._paysessions[payment_key] = paysession = PaySession(
|
||||
payment_hash=payment_hash,
|
||||
payment_secret=payment_secret,
|
||||
initial_trampoline_fee_level=self.INITIAL_TRAMPOLINE_FEE_LEVEL,
|
||||
invoice_features=invoice_features,
|
||||
r_tags=r_tags,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
amount_to_pay=amount_to_pay,
|
||||
invoice_pubkey=node_pubkey,
|
||||
)
|
||||
self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding)
|
||||
|
||||
# when encountering trampoline forwarding difficulties in the legacy case, we
|
||||
# sometimes need to fall back to a single trampoline forwarder, at the expense
|
||||
# of privacy
|
||||
use_two_trampolines = True
|
||||
trampoline_fee_level = self.INITIAL_TRAMPOLINE_FEE_LEVEL
|
||||
failed_trampoline_routes = []
|
||||
start_time = time.time()
|
||||
amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
|
||||
nhtlcs_inflight = 0
|
||||
while True:
|
||||
amount_to_send = amount_to_pay - amount_inflight
|
||||
if amount_to_send > 0:
|
||||
if (amount_to_send := paysession.get_outstanding_amount_to_send()) > 0:
|
||||
# 1. create a set of routes for remaining amount.
|
||||
# note: path-finding runs in a separate thread so that we don't block the asyncio loop
|
||||
# graph updates might occur during the computation
|
||||
routes = self.create_routes_for_payment(
|
||||
paysession=paysession,
|
||||
amount_msat=amount_to_send,
|
||||
final_total_msat=amount_to_pay,
|
||||
invoice_pubkey=node_pubkey,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
r_tags=r_tags,
|
||||
invoice_features=invoice_features,
|
||||
full_path=full_path,
|
||||
payment_hash=payment_hash,
|
||||
payment_secret=payment_secret,
|
||||
trampoline_fee_level=trampoline_fee_level,
|
||||
failed_trampoline_routes=failed_trampoline_routes,
|
||||
use_two_trampolines=use_two_trampolines,
|
||||
fwd_trampoline_onion=fwd_trampoline_onion,
|
||||
channels=channels,
|
||||
)
|
||||
# 2. send htlcs
|
||||
async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
|
||||
nhtlcs_inflight += 1
|
||||
amount_inflight += sent_htlc_info.amount_receiver_msat
|
||||
if amount_inflight > amount_to_pay: # safety belts
|
||||
raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}")
|
||||
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=trampoline_fee_level)
|
||||
sent_htlc_info = paysession.add_new_htlc(sent_htlc_info)
|
||||
await self.pay_to_route(
|
||||
sent_htlc_info=sent_htlc_info,
|
||||
payment_hash=payment_hash,
|
||||
@@ -1339,12 +1429,7 @@ class LNWallet(LNWorker):
|
||||
# (e.g. attempt counter)
|
||||
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT)
|
||||
# 3. await a queue
|
||||
self.logger.info(f"paysession for RHASH {payment_hash.hex()} waiting... {amount_inflight=}. {nhtlcs_inflight=}")
|
||||
htlc_log = await self.sent_htlcs_q[payment_key].get() # TODO maybe wait a bit, more failures might come
|
||||
amount_inflight -= htlc_log.amount_msat
|
||||
nhtlcs_inflight -= 1
|
||||
if amount_inflight < 0 or nhtlcs_inflight < 0:
|
||||
raise Exception(f"{amount_inflight=}, {nhtlcs_inflight=}. both should be >= 0 !")
|
||||
htlc_log = await paysession.wait_for_one_htlc_to_resolve() # TODO maybe wait a bit, more failures might come
|
||||
log.append(htlc_log)
|
||||
if htlc_log.success:
|
||||
if self.network.path_finder:
|
||||
@@ -1357,7 +1442,7 @@ class LNWallet(LNWorker):
|
||||
self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False)
|
||||
return
|
||||
# htlc failed
|
||||
if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - start_time > self.PAYMENT_TIMEOUT):
|
||||
if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - paysession.start_time > self.PAYMENT_TIMEOUT):
|
||||
raise PaymentFailure('Giving up after %d attempts'%len(log))
|
||||
# if we get a tmp channel failure, it might work to split the amount and try more routes
|
||||
# if we get a channel update, we might retry the same route and amount
|
||||
@@ -1373,45 +1458,8 @@ class LNWallet(LNWorker):
|
||||
raise PaymentFailure(failure_msg.code_name())
|
||||
# trampoline
|
||||
if self.uses_trampoline():
|
||||
def maybe_raise_trampoline_fee(htlc_log):
|
||||
nonlocal trampoline_fee_level
|
||||
nonlocal failed_trampoline_routes
|
||||
if htlc_log.trampoline_fee_level == trampoline_fee_level:
|
||||
trampoline_fee_level += 1
|
||||
failed_trampoline_routes = []
|
||||
self.logger.info(f'raising trampoline fee level {trampoline_fee_level}')
|
||||
else:
|
||||
self.logger.info(f'NOT raising trampoline fee level, already at {trampoline_fee_level}')
|
||||
# FIXME The trampoline nodes in the path are chosen randomly.
|
||||
# Some of the errors might depend on how we have chosen them.
|
||||
# Having more attempts is currently useful in part because of the randomness,
|
||||
# instead we should give feedback to create_routes_for_payment.
|
||||
# Sometimes the trampoline node fails to send a payment and returns
|
||||
# TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
|
||||
if code in (
|
||||
OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
|
||||
OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
|
||||
OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
|
||||
# TODO: parse the node policy here (not returned by eclair yet)
|
||||
# TODO: erring node is always the first trampoline even if second
|
||||
# trampoline demands more fees, we can't influence this
|
||||
maybe_raise_trampoline_fee(htlc_log)
|
||||
continue
|
||||
elif use_two_trampolines:
|
||||
use_two_trampolines = False
|
||||
elif code in (
|
||||
OnionFailureCode.UNKNOWN_NEXT_PEER,
|
||||
OnionFailureCode.TEMPORARY_NODE_FAILURE):
|
||||
trampoline_route = htlc_log.route
|
||||
r = [hop.end_node.hex() for hop in trampoline_route]
|
||||
self.logger.info(f'failed trampoline route: {r}')
|
||||
if r not in failed_trampoline_routes:
|
||||
failed_trampoline_routes.append(r)
|
||||
else:
|
||||
pass # maybe the route was reused between different MPP parts
|
||||
continue
|
||||
else:
|
||||
raise PaymentFailure(failure_msg.code_name())
|
||||
paysession.handle_failed_trampoline_htlc(
|
||||
htlc_log=htlc_log, failure_msg=failure_msg)
|
||||
else:
|
||||
self.handle_error_code_from_failed_htlc(
|
||||
route=route, sender_idx=sender_idx, failure_msg=failure_msg, amount=htlc_log.amount_msat)
|
||||
@@ -1654,18 +1702,9 @@ class LNWallet(LNWorker):
|
||||
|
||||
async def create_routes_for_payment(
|
||||
self, *,
|
||||
paysession: PaySession,
|
||||
amount_msat: int, # part of payment amount we want routes for now
|
||||
final_total_msat: int, # total payment amount final receiver will get
|
||||
invoice_pubkey,
|
||||
min_cltv_expiry,
|
||||
r_tags,
|
||||
invoice_features: int,
|
||||
payment_hash: bytes,
|
||||
payment_secret: bytes,
|
||||
trampoline_fee_level: int,
|
||||
failed_trampoline_routes: Iterable[Sequence[str]],
|
||||
use_two_trampolines: bool,
|
||||
fwd_trampoline_onion=None,
|
||||
fwd_trampoline_onion: OnionPacket = None,
|
||||
full_path: LNPaymentPath = None,
|
||||
channels: Optional[Sequence[Channel]] = None,
|
||||
) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]:
|
||||
@@ -1675,7 +1714,6 @@ class LNWallet(LNWorker):
|
||||
|
||||
We first try to conduct the payment over a single channel. If that fails
|
||||
and mpp is supported by the receiver, we will split the payment."""
|
||||
invoice_features = LnFeatures(invoice_features)
|
||||
trampoline_features = LnFeatures.VAR_ONION_OPT
|
||||
local_height = self.network.get_local_height()
|
||||
if channels:
|
||||
@@ -1688,15 +1726,15 @@ class LNWallet(LNWorker):
|
||||
random.shuffle(my_active_channels)
|
||||
split_configurations = self.suggest_splits(
|
||||
amount_msat=amount_msat,
|
||||
final_total_msat=final_total_msat,
|
||||
final_total_msat=paysession.amount_to_pay,
|
||||
my_active_channels=my_active_channels,
|
||||
invoice_features=invoice_features,
|
||||
r_tags=r_tags,
|
||||
invoice_features=paysession.invoice_features,
|
||||
r_tags=paysession.r_tags,
|
||||
)
|
||||
for sc in split_configurations:
|
||||
is_multichan_mpp = len(sc.config.items()) > 1
|
||||
is_mpp = sum(len(x) for x in list(sc.config.values())) > 1
|
||||
if is_mpp and not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
|
||||
if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
|
||||
continue
|
||||
if not is_mpp and self.config.TEST_FORCE_MPP:
|
||||
continue
|
||||
@@ -1715,33 +1753,33 @@ class LNWallet(LNWorker):
|
||||
# for each trampoline forwarder, construct mpp trampoline
|
||||
for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items():
|
||||
per_trampoline_amount = sum([x[1] for x in trampoline_parts])
|
||||
if trampoline_node_id == invoice_pubkey:
|
||||
if trampoline_node_id == paysession.invoice_pubkey:
|
||||
trampoline_route = None
|
||||
trampoline_onion = None
|
||||
per_trampoline_secret = payment_secret
|
||||
per_trampoline_secret = paysession.payment_secret
|
||||
per_trampoline_amount_with_fees = amount_msat
|
||||
per_trampoline_cltv_delta = min_cltv_expiry
|
||||
per_trampoline_cltv_delta = paysession.min_cltv_expiry
|
||||
per_trampoline_fees = 0
|
||||
else:
|
||||
trampoline_route, trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion(
|
||||
amount_msat=per_trampoline_amount,
|
||||
total_msat=final_total_msat,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
total_msat=paysession.amount_to_pay,
|
||||
min_cltv_expiry=paysession.min_cltv_expiry,
|
||||
my_pubkey=self.node_keypair.pubkey,
|
||||
invoice_pubkey=invoice_pubkey,
|
||||
invoice_features=invoice_features,
|
||||
invoice_pubkey=paysession.invoice_pubkey,
|
||||
invoice_features=paysession.invoice_features,
|
||||
node_id=trampoline_node_id,
|
||||
r_tags=r_tags,
|
||||
payment_hash=payment_hash,
|
||||
payment_secret=payment_secret,
|
||||
r_tags=paysession.r_tags,
|
||||
payment_hash=paysession.payment_hash,
|
||||
payment_secret=paysession.payment_secret,
|
||||
local_height=local_height,
|
||||
trampoline_fee_level=trampoline_fee_level,
|
||||
use_two_trampolines=use_two_trampolines,
|
||||
failed_routes=failed_trampoline_routes)
|
||||
trampoline_fee_level=paysession.trampoline_fee_level,
|
||||
use_two_trampolines=paysession.use_two_trampolines,
|
||||
failed_routes=paysession.failed_trampoline_routes)
|
||||
# node_features is only used to determine is_tlv
|
||||
per_trampoline_secret = os.urandom(32)
|
||||
per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount
|
||||
self.logger.info(f'created route with trampoline fee level={trampoline_fee_level}')
|
||||
self.logger.info(f'created route with trampoline fee level={paysession.trampoline_fee_level}')
|
||||
self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}')
|
||||
self.logger.info(f'per trampoline fees: {per_trampoline_fees}')
|
||||
for chan_id, part_amount_msat in trampoline_parts:
|
||||
@@ -1764,7 +1802,7 @@ class LNWallet(LNWorker):
|
||||
self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
|
||||
shi = SentHtlcInfo(
|
||||
route=route,
|
||||
payment_secret_orig=payment_secret,
|
||||
payment_secret_orig=paysession.payment_secret,
|
||||
payment_secret_bucket=per_trampoline_secret,
|
||||
amount_msat=part_amount_msat_with_fees,
|
||||
bucket_msat=per_trampoline_amount_with_fees,
|
||||
@@ -1786,25 +1824,25 @@ class LNWallet(LNWorker):
|
||||
partial(
|
||||
self.create_route_for_payment,
|
||||
amount_msat=part_amount_msat,
|
||||
invoice_pubkey=invoice_pubkey,
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
r_tags=r_tags,
|
||||
invoice_features=invoice_features,
|
||||
invoice_pubkey=paysession.invoice_pubkey,
|
||||
min_cltv_expiry=paysession.min_cltv_expiry,
|
||||
r_tags=paysession.r_tags,
|
||||
invoice_features=paysession.invoice_features,
|
||||
my_sending_channels=[channel] if is_multichan_mpp else my_active_channels,
|
||||
full_path=full_path,
|
||||
)
|
||||
)
|
||||
shi = SentHtlcInfo(
|
||||
route=route,
|
||||
payment_secret_orig=payment_secret,
|
||||
payment_secret_bucket=payment_secret,
|
||||
payment_secret_orig=paysession.payment_secret,
|
||||
payment_secret_bucket=paysession.payment_secret,
|
||||
amount_msat=part_amount_msat,
|
||||
bucket_msat=final_total_msat,
|
||||
bucket_msat=paysession.amount_to_pay,
|
||||
amount_receiver_msat=part_amount_msat,
|
||||
trampoline_fee_level=None,
|
||||
trampoline_route=None,
|
||||
)
|
||||
routes.append((shi, min_cltv_expiry, fwd_trampoline_onion))
|
||||
routes.append((shi, paysession.min_cltv_expiry, fwd_trampoline_onion))
|
||||
except NoPathFound:
|
||||
continue
|
||||
for route in routes:
|
||||
@@ -2159,7 +2197,9 @@ class LNWallet(LNWorker):
|
||||
q = None
|
||||
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
|
||||
payment_key = payment_hash + shi.payment_secret_orig
|
||||
q = self.sent_htlcs_q.get(payment_key)
|
||||
paysession = self._paysessions.get(payment_key)
|
||||
if paysession:
|
||||
q = paysession.sent_htlcs_q
|
||||
if q:
|
||||
htlc_log = HtlcLog(
|
||||
success=True,
|
||||
@@ -2185,7 +2225,9 @@ class LNWallet(LNWorker):
|
||||
q = None
|
||||
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
|
||||
payment_okey = payment_hash + shi.payment_secret_orig
|
||||
q = self.sent_htlcs_q.get(payment_okey)
|
||||
paysession = self._paysessions.get(payment_okey)
|
||||
if paysession:
|
||||
q = paysession.sent_htlcs_q
|
||||
if q:
|
||||
# detect if it is part of a bucket
|
||||
# if yes, wait until the bucket completely failed
|
||||
|
||||
@@ -31,7 +31,7 @@ from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
|
||||
from electrum.lnchannel import ChannelState, PeerState, Channel
|
||||
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
|
||||
from electrum.channel_db import ChannelDB
|
||||
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo
|
||||
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession
|
||||
from electrum.lnmsg import encode_msg, decode_msg
|
||||
from electrum import lnmsg
|
||||
from electrum.logging import console_stderr_handler, Logger
|
||||
@@ -168,7 +168,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
self.enable_htlc_settle = True
|
||||
self.enable_htlc_forwarding = True
|
||||
self.received_mpp_htlcs = dict()
|
||||
self.sent_htlcs_q = defaultdict(asyncio.Queue)
|
||||
self._paysessions = dict()
|
||||
self.sent_htlcs_info = dict()
|
||||
self.sent_buckets = defaultdict(set)
|
||||
self.final_onion_forwardings = set()
|
||||
@@ -232,18 +232,22 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
await self.channel_db.stopped_event.wait()
|
||||
|
||||
async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
|
||||
return [r async for r in self.create_routes_for_payment(
|
||||
amount_msat=amount_msat,
|
||||
final_total_msat=amount_msat,
|
||||
invoice_pubkey=decoded_invoice.pubkey.serialize(),
|
||||
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
|
||||
r_tags=decoded_invoice.get_routing_info('r'),
|
||||
invoice_features=decoded_invoice.get_features(),
|
||||
trampoline_fee_level=0,
|
||||
failed_trampoline_routes=[],
|
||||
use_two_trampolines=False,
|
||||
paysession = PaySession(
|
||||
payment_hash=decoded_invoice.paymenthash,
|
||||
payment_secret=decoded_invoice.payment_secret,
|
||||
initial_trampoline_fee_level=0,
|
||||
invoice_features=decoded_invoice.get_features(),
|
||||
r_tags=decoded_invoice.get_routing_info('r'),
|
||||
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
|
||||
amount_to_pay=amount_msat,
|
||||
invoice_pubkey=decoded_invoice.pubkey.serialize(),
|
||||
)
|
||||
paysession.use_two_trampolines = False
|
||||
payment_key = decoded_invoice.paymenthash + decoded_invoice.payment_secret
|
||||
self._paysessions[payment_key] = paysession
|
||||
return [r async for r in self.create_routes_for_payment(
|
||||
amount_msat=amount_msat,
|
||||
paysession=paysession,
|
||||
full_path=full_path)]
|
||||
|
||||
get_payments = LNWallet.get_payments
|
||||
@@ -854,9 +858,6 @@ class TestPeer(ElectrumTestCase):
|
||||
_maybe_send_commitment2 = p2.maybe_send_commitment
|
||||
lnaddr2, pay_req2 = self.prepare_invoice(w2)
|
||||
lnaddr1, pay_req1 = self.prepare_invoice(w1)
|
||||
# create the htlc queues now (side-effecting defaultdict)
|
||||
q1 = w1.sent_htlcs_q[lnaddr2.paymenthash + lnaddr2.payment_secret]
|
||||
q2 = w2.sent_htlcs_q[lnaddr1.paymenthash + lnaddr1.payment_secret]
|
||||
# alice sends htlc BUT NOT COMMITMENT_SIGNED
|
||||
p1.maybe_send_commitment = lambda x: None
|
||||
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route
|
||||
@@ -901,9 +902,9 @@ class TestPeer(ElectrumTestCase):
|
||||
p1.maybe_send_commitment(alice_channel)
|
||||
p2.maybe_send_commitment(bob_channel)
|
||||
|
||||
htlc_log1 = await q1.get()
|
||||
htlc_log1 = await w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret].sent_htlcs_q.get()
|
||||
self.assertTrue(htlc_log1.success)
|
||||
htlc_log2 = await q2.get()
|
||||
htlc_log2 = await w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret].sent_htlcs_q.get()
|
||||
self.assertTrue(htlc_log2.success)
|
||||
raise PaymentDone()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user