lnworker: keep invoice status INFLIGHT as long as HTLCs are inflight
This commit is contained in:
@@ -657,8 +657,8 @@ class LNWallet(LNWorker):
|
||||
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
|
||||
|
||||
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
|
||||
self.sent_htlcs_routes = dict() # (RHASH, scid, htlc_id) -> route
|
||||
self.received_htlcs = dict() # RHASH -> mpp_status, htlc_set
|
||||
self.htlc_routes = dict()
|
||||
|
||||
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
|
||||
# detect inflight payments
|
||||
@@ -939,14 +939,13 @@ class LNWallet(LNWorker):
|
||||
|
||||
@log_exceptions
|
||||
async def _open_channel_coroutine(
|
||||
self,
|
||||
*,
|
||||
self, *,
|
||||
connect_str: str,
|
||||
funding_tx: PartialTransaction,
|
||||
funding_sat: int,
|
||||
push_sat: int,
|
||||
password: Optional[str],
|
||||
) -> Tuple[Channel, PartialTransaction]:
|
||||
password: Optional[str]) -> Tuple[Channel, PartialTransaction]:
|
||||
|
||||
peer = await self.add_peer(connect_str)
|
||||
coro = peer.channel_establishment_flow(
|
||||
funding_tx=funding_tx,
|
||||
@@ -1053,7 +1052,6 @@ class LNWallet(LNWorker):
|
||||
random.shuffle(self.trampoline2_list)
|
||||
|
||||
self.set_invoice_status(key, PR_INFLIGHT)
|
||||
util.trigger_callback('invoice_status', self.wallet, key)
|
||||
try:
|
||||
await self.pay_to_node(
|
||||
node_pubkey=invoice_pubkey,
|
||||
@@ -1071,6 +1069,11 @@ class LNWallet(LNWorker):
|
||||
self.logger.exception('')
|
||||
success = False
|
||||
reason = str(e)
|
||||
# keep invoice status INFLIGHT as long as HTLCs are inflight
|
||||
# maybe we could add an extra state for the waiting time.
|
||||
while payment_hash in self.get_payments(status='inflight'):
|
||||
self.logger.info('waiting for inflight HTLCs...')
|
||||
await self.sent_htlcs[payment_hash].get()
|
||||
if success:
|
||||
self.set_invoice_status(key, PR_PAID)
|
||||
util.trigger_callback('payment_succeeded', self.wallet, key)
|
||||
@@ -1081,8 +1084,7 @@ class LNWallet(LNWorker):
|
||||
return success, log
|
||||
|
||||
async def pay_to_node(
|
||||
self,
|
||||
*,
|
||||
self, *,
|
||||
node_pubkey: bytes,
|
||||
payment_hash: bytes,
|
||||
payment_secret: Optional[bytes],
|
||||
@@ -1095,8 +1097,7 @@ class LNWallet(LNWorker):
|
||||
full_path: LNPaymentPath = None,
|
||||
trampoline_onion=None,
|
||||
trampoline_fee=None,
|
||||
trampoline_cltv_delta=None,
|
||||
) -> None:
|
||||
trampoline_cltv_delta=None) -> None:
|
||||
|
||||
if trampoline_onion:
|
||||
# todo: compare to the fee of the actual route we found
|
||||
@@ -1119,7 +1120,7 @@ class LNWallet(LNWorker):
|
||||
# 2. send htlcs
|
||||
for route, amount_msat in routes:
|
||||
await self.pay_to_route(
|
||||
route,
|
||||
route=route,
|
||||
amount_msat=amount_msat,
|
||||
total_msat=amount_to_pay,
|
||||
payment_hash=payment_hash,
|
||||
@@ -1142,16 +1143,15 @@ class LNWallet(LNWorker):
|
||||
self.handle_error_code_from_failed_htlc(htlc_log)
|
||||
|
||||
async def pay_to_route(
|
||||
self,
|
||||
self, *,
|
||||
route: LNPaymentRoute,
|
||||
*,
|
||||
amount_msat: int,
|
||||
total_msat: int,
|
||||
payment_hash: bytes,
|
||||
payment_secret: Optional[bytes],
|
||||
min_cltv_expiry: int,
|
||||
trampoline_onion: bytes = None,
|
||||
) -> None:
|
||||
trampoline_onion: bytes = None) -> None:
|
||||
|
||||
# send a single htlc
|
||||
short_channel_id = route[0].short_channel_id
|
||||
chan = self.get_channel_by_short_id(short_channel_id)
|
||||
@@ -1168,7 +1168,7 @@ class LNWallet(LNWorker):
|
||||
min_final_cltv_expiry=min_cltv_expiry,
|
||||
payment_secret=payment_secret,
|
||||
fwd_trampoline_onion=trampoline_onion)
|
||||
self.htlc_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route
|
||||
self.sent_htlcs_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route
|
||||
util.trigger_callback('htlc_added', chan, htlc, SENT)
|
||||
|
||||
def handle_error_code_from_failed_htlc(self, htlc_log):
|
||||
@@ -1729,6 +1729,7 @@ class LNWallet(LNWorker):
|
||||
self.inflight_payments.remove(key)
|
||||
if status in SAVED_PR_STATUS:
|
||||
self.set_payment_status(bfh(key), status)
|
||||
util.trigger_callback('invoice_status', self.wallet, key)
|
||||
|
||||
def set_payment_status(self, payment_hash: bytes, status):
|
||||
info = self.get_payment_info(payment_hash)
|
||||
@@ -1739,54 +1740,60 @@ class LNWallet(LNWorker):
|
||||
self.save_payment_info(info)
|
||||
|
||||
def htlc_fulfilled(self, chan, payment_hash: bytes, htlc_id:int, amount_msat:int):
|
||||
route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id))
|
||||
htlc_log = HtlcLog(
|
||||
success=True,
|
||||
route=route,
|
||||
amount_msat=amount_msat)
|
||||
q = self.sent_htlcs[payment_hash]
|
||||
q.put_nowait(htlc_log)
|
||||
util.trigger_callback('htlc_fulfilled', payment_hash, chan.channel_id)
|
||||
q = self.sent_htlcs.get(payment_hash)
|
||||
if q:
|
||||
route = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)]
|
||||
htlc_log = HtlcLog(
|
||||
success=True,
|
||||
route=route,
|
||||
amount_msat=amount_msat)
|
||||
q.put_nowait(htlc_log)
|
||||
else:
|
||||
if payment_hash not in self.get_payments(status='inflight'):
|
||||
key = payment_hash.hex()
|
||||
self.set_invoice_status(key, PR_PAID)
|
||||
util.trigger_callback('payment_succeeded', self.wallet, key)
|
||||
|
||||
def htlc_failed(
|
||||
self,
|
||||
chan,
|
||||
chan: Channel,
|
||||
payment_hash: bytes,
|
||||
htlc_id: int,
|
||||
amount_msat:int,
|
||||
error_bytes: Optional[bytes],
|
||||
failure_message: Optional['OnionRoutingFailure']):
|
||||
|
||||
route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id))
|
||||
if not route:
|
||||
self.logger.info(f"received unknown htlc_failed, probably from previous session")
|
||||
return
|
||||
if error_bytes:
|
||||
self.logger.info(f" {(error_bytes, route, htlc_id)}")
|
||||
# TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
|
||||
try:
|
||||
failure_message, sender_idx = chan.decode_onion_error(error_bytes, route, htlc_id)
|
||||
except Exception as e:
|
||||
sender_idx = None
|
||||
failure_message = OnionRoutingFailure(-1, str(e))
|
||||
else:
|
||||
# probably got "update_fail_malformed_htlc". well... who to penalise now?
|
||||
assert failure_message is not None
|
||||
sender_idx = None
|
||||
|
||||
htlc_log = HtlcLog(
|
||||
success=False,
|
||||
route=route,
|
||||
amount_msat=amount_msat,
|
||||
error_bytes=error_bytes,
|
||||
failure_msg=failure_message,
|
||||
sender_idx=sender_idx)
|
||||
|
||||
q = self.sent_htlcs[payment_hash]
|
||||
q.put_nowait(htlc_log)
|
||||
util.trigger_callback('htlc_failed', payment_hash, chan.channel_id)
|
||||
|
||||
|
||||
q = self.sent_htlcs.get(payment_hash)
|
||||
if q:
|
||||
route = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)]
|
||||
if error_bytes:
|
||||
# TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
|
||||
try:
|
||||
failure_message, sender_idx = chan.decode_onion_error(error_bytes, route, htlc_id)
|
||||
except Exception as e:
|
||||
sender_idx = None
|
||||
failure_message = OnionRoutingFailure(-1, str(e))
|
||||
else:
|
||||
# probably got "update_fail_malformed_htlc". well... who to penalise now?
|
||||
assert failure_message is not None
|
||||
sender_idx = None
|
||||
self.logger.info(f"htlc_failed {failure_message}")
|
||||
htlc_log = HtlcLog(
|
||||
success=False,
|
||||
route=route,
|
||||
amount_msat=amount_msat,
|
||||
error_bytes=error_bytes,
|
||||
failure_msg=failure_message,
|
||||
sender_idx=sender_idx)
|
||||
q.put_nowait(htlc_log)
|
||||
else:
|
||||
self.logger.info(f"received unknown htlc_failed, probably from previous session")
|
||||
if payment_hash not in self.get_payments(status='inflight'):
|
||||
key = payment_hash.hex()
|
||||
self.set_invoice_status(key, PR_UNPAID)
|
||||
util.trigger_callback('payment_failed', self.wallet, key, '')
|
||||
|
||||
async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
|
||||
"""calculate routing hints (BOLT-11 'r' field)"""
|
||||
|
||||
@@ -165,6 +165,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
|
||||
inflight_payments = set()
|
||||
preimages = {}
|
||||
get_payments = LNWallet.get_payments
|
||||
get_payment_info = LNWallet.get_payment_info
|
||||
save_payment_info = LNWallet.save_payment_info
|
||||
set_invoice_status = LNWallet.set_invoice_status
|
||||
@@ -776,7 +777,7 @@ class TestPeer(ElectrumTestCase):
|
||||
payment_hash = lnaddr.paymenthash
|
||||
payment_secret = lnaddr.payment_secret
|
||||
pay = w1.pay_to_route(
|
||||
route,
|
||||
route=route,
|
||||
amount_msat=amount_msat,
|
||||
total_msat=amount_msat,
|
||||
payment_hash=payment_hash,
|
||||
|
||||
Reference in New Issue
Block a user