lnworker: move sent_buckets into PaySession
This commit is contained in:
@@ -667,6 +667,7 @@ class PaySession(Logger):
|
||||
min_cltv_expiry: int,
|
||||
amount_to_pay: int, # total payment amount final receiver will get
|
||||
invoice_pubkey: bytes,
|
||||
uses_trampoline: bool, # whether sender uses trampoline or gossip
|
||||
):
|
||||
assert payment_hash
|
||||
assert payment_secret
|
||||
@@ -684,9 +685,11 @@ class PaySession(Logger):
|
||||
self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog]
|
||||
self.start_time = time.time()
|
||||
|
||||
self.uses_trampoline = uses_trampoline
|
||||
self.trampoline_fee_level = initial_trampoline_fee_level
|
||||
self.failed_trampoline_routes = []
|
||||
self.use_two_trampolines = True
|
||||
self._sent_buckets = dict() # psecret_bucket -> (amount_sent, amount_failed)
|
||||
|
||||
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
|
||||
self._nhtlcs_inflight = 0
|
||||
@@ -742,13 +745,36 @@ class PaySession(Logger):
|
||||
raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
|
||||
return htlc_log
|
||||
|
||||
def add_new_htlc(self, sent_htlc_info: SentHtlcInfo) -> SentHtlcInfo:
|
||||
def add_new_htlc(self, sent_htlc_info: SentHtlcInfo):
|
||||
self._nhtlcs_inflight += 1
|
||||
self._amount_inflight += sent_htlc_info.amount_receiver_msat
|
||||
if self._amount_inflight > self.amount_to_pay: # safety belts
|
||||
raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
|
||||
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level)
|
||||
return sent_htlc_info
|
||||
shi = sent_htlc_info
|
||||
bkey = shi.payment_secret_bucket
|
||||
# if we sent MPP to a trampoline, add item to sent_buckets
|
||||
if self.uses_trampoline and shi.amount_msat != shi.bucket_msat:
|
||||
if bkey not in self._sent_buckets:
|
||||
self._sent_buckets[bkey] = (0, 0)
|
||||
amount_sent, amount_failed = self._sent_buckets[bkey]
|
||||
amount_sent += shi.amount_receiver_msat
|
||||
self._sent_buckets[bkey] = amount_sent, amount_failed
|
||||
|
||||
def on_htlc_fail_get_fail_amt_to_propagate(self, sent_htlc_info: SentHtlcInfo) -> Optional[int]:
|
||||
shi = sent_htlc_info
|
||||
# check sent_buckets if we use trampoline
|
||||
bkey = shi.payment_secret_bucket
|
||||
if self.uses_trampoline and bkey in self._sent_buckets:
|
||||
amount_sent, amount_failed = self._sent_buckets[bkey]
|
||||
amount_failed += shi.amount_receiver_msat
|
||||
self._sent_buckets[bkey] = amount_sent, amount_failed
|
||||
if amount_sent != amount_failed:
|
||||
self.logger.info('bucket still active...')
|
||||
return None
|
||||
self.logger.info('bucket failed')
|
||||
return amount_sent
|
||||
# not using trampoline buckets
|
||||
return shi.amount_receiver_msat
|
||||
|
||||
def get_outstanding_amount_to_send(self) -> int:
|
||||
return self.amount_to_pay - self._amount_inflight
|
||||
@@ -795,7 +821,6 @@ class LNWallet(LNWorker):
|
||||
|
||||
self._paysessions = dict() # type: Dict[bytes, PaySession]
|
||||
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
|
||||
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) # TODO move into PaySession
|
||||
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
|
||||
|
||||
# detect inflight payments
|
||||
@@ -1397,6 +1422,7 @@ class LNWallet(LNWorker):
|
||||
min_cltv_expiry=min_cltv_expiry,
|
||||
amount_to_pay=amount_to_pay,
|
||||
invoice_pubkey=node_pubkey,
|
||||
uses_trampoline=self.uses_trampoline(),
|
||||
)
|
||||
self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding)
|
||||
|
||||
@@ -1417,10 +1443,9 @@ class LNWallet(LNWorker):
|
||||
)
|
||||
# 2. send htlcs
|
||||
async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
|
||||
sent_htlc_info = paysession.add_new_htlc(sent_htlc_info)
|
||||
await self.pay_to_route(
|
||||
paysession=paysession,
|
||||
sent_htlc_info=sent_htlc_info,
|
||||
payment_hash=payment_hash,
|
||||
min_cltv_expiry=cltv_delta,
|
||||
trampoline_onion=trampoline_onion,
|
||||
)
|
||||
@@ -1466,8 +1491,8 @@ class LNWallet(LNWorker):
|
||||
|
||||
async def pay_to_route(
|
||||
self, *,
|
||||
paysession: PaySession,
|
||||
sent_htlc_info: SentHtlcInfo,
|
||||
payment_hash: bytes,
|
||||
min_cltv_expiry: int,
|
||||
trampoline_onion: bytes = None,
|
||||
) -> None:
|
||||
@@ -1486,21 +1511,14 @@ class LNWallet(LNWorker):
|
||||
chan=chan,
|
||||
amount_msat=shi.amount_msat,
|
||||
total_msat=shi.bucket_msat,
|
||||
payment_hash=payment_hash,
|
||||
payment_hash=paysession.payment_hash,
|
||||
min_final_cltv_expiry=min_cltv_expiry,
|
||||
payment_secret=shi.payment_secret_bucket,
|
||||
trampoline_onion=trampoline_onion)
|
||||
|
||||
key = (payment_hash, short_channel_id, htlc.htlc_id)
|
||||
key = (paysession.payment_hash, short_channel_id, htlc.htlc_id)
|
||||
self.sent_htlcs_info[key] = shi
|
||||
payment_key = payment_hash + shi.payment_secret_bucket
|
||||
# if we sent MPP to a trampoline, add item to sent_buckets
|
||||
if self.uses_trampoline() and shi.amount_msat != shi.bucket_msat:
|
||||
if payment_key not in self.sent_buckets:
|
||||
self.sent_buckets[payment_key] = (0, 0)
|
||||
amount_sent, amount_failed = self.sent_buckets[payment_key]
|
||||
amount_sent += shi.amount_receiver_msat
|
||||
self.sent_buckets[payment_key] = amount_sent, amount_failed
|
||||
paysession.add_new_htlc(shi)
|
||||
if self.network.path_finder:
|
||||
# add inflight htlcs to liquidity hints
|
||||
self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True)
|
||||
@@ -1807,7 +1825,7 @@ class LNWallet(LNWorker):
|
||||
amount_msat=part_amount_msat_with_fees,
|
||||
bucket_msat=per_trampoline_amount_with_fees,
|
||||
amount_receiver_msat=part_amount_msat,
|
||||
trampoline_fee_level=None,
|
||||
trampoline_fee_level=paysession.trampoline_fee_level,
|
||||
trampoline_route=trampoline_route,
|
||||
)
|
||||
routes.append((shi, per_trampoline_cltv_delta, trampoline_onion))
|
||||
@@ -2232,7 +2250,6 @@ class LNWallet(LNWorker):
|
||||
# detect if it is part of a bucket
|
||||
# if yes, wait until the bucket completely failed
|
||||
shi = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)]
|
||||
amount_receiver_msat = shi.amount_receiver_msat
|
||||
route = shi.route
|
||||
if error_bytes:
|
||||
# TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
|
||||
@@ -2247,18 +2264,9 @@ class LNWallet(LNWorker):
|
||||
sender_idx = None
|
||||
self.logger.info(f"htlc_failed {failure_message}")
|
||||
|
||||
# check sent_buckets if we use trampoline
|
||||
payment_bkey = payment_hash + shi.payment_secret_bucket
|
||||
if self.uses_trampoline() and payment_bkey in self.sent_buckets:
|
||||
amount_sent, amount_failed = self.sent_buckets[payment_bkey]
|
||||
amount_failed += amount_receiver_msat
|
||||
self.sent_buckets[payment_bkey] = amount_sent, amount_failed
|
||||
if amount_sent != amount_failed:
|
||||
self.logger.info('bucket still active...')
|
||||
return
|
||||
self.logger.info('bucket failed')
|
||||
amount_receiver_msat = amount_sent
|
||||
|
||||
amount_receiver_msat = paysession.on_htlc_fail_get_fail_amt_to_propagate(shi)
|
||||
if amount_receiver_msat is None:
|
||||
return
|
||||
if shi.trampoline_route:
|
||||
route = shi.trampoline_route
|
||||
htlc_log = HtlcLog(
|
||||
|
||||
@@ -241,6 +241,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
|
||||
amount_to_pay=amount_msat,
|
||||
invoice_pubkey=decoded_invoice.pubkey.serialize(),
|
||||
uses_trampoline=False,
|
||||
)
|
||||
paysession.use_two_trampolines = False
|
||||
payment_key = decoded_invoice.paymenthash + decoded_invoice.payment_secret
|
||||
@@ -861,6 +862,7 @@ class TestPeer(ElectrumTestCase):
|
||||
# alice sends htlc BUT NOT COMMITMENT_SIGNED
|
||||
p1.maybe_send_commitment = lambda x: None
|
||||
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route
|
||||
paysession1 = w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret]
|
||||
shi1 = SentHtlcInfo(
|
||||
route=route1,
|
||||
payment_secret_orig=lnaddr2.payment_secret,
|
||||
@@ -873,13 +875,14 @@ class TestPeer(ElectrumTestCase):
|
||||
)
|
||||
await w1.pay_to_route(
|
||||
sent_htlc_info=shi1,
|
||||
payment_hash=lnaddr2.paymenthash,
|
||||
paysession=paysession1,
|
||||
min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(),
|
||||
)
|
||||
p1.maybe_send_commitment = _maybe_send_commitment1
|
||||
# bob sends htlc BUT NOT COMMITMENT_SIGNED
|
||||
p2.maybe_send_commitment = lambda x: None
|
||||
route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0].route
|
||||
paysession2 = w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret]
|
||||
shi2 = SentHtlcInfo(
|
||||
route=route2,
|
||||
payment_secret_orig=lnaddr1.payment_secret,
|
||||
@@ -892,7 +895,7 @@ class TestPeer(ElectrumTestCase):
|
||||
)
|
||||
await w2.pay_to_route(
|
||||
sent_htlc_info=shi2,
|
||||
payment_hash=lnaddr1.paymenthash,
|
||||
paysession=paysession2,
|
||||
min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
|
||||
)
|
||||
p2.maybe_send_commitment = _maybe_send_commitment2
|
||||
@@ -902,9 +905,9 @@ class TestPeer(ElectrumTestCase):
|
||||
p1.maybe_send_commitment(alice_channel)
|
||||
p2.maybe_send_commitment(bob_channel)
|
||||
|
||||
htlc_log1 = await w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret].sent_htlcs_q.get()
|
||||
htlc_log1 = await paysession1.sent_htlcs_q.get()
|
||||
self.assertTrue(htlc_log1.success)
|
||||
htlc_log2 = await w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret].sent_htlcs_q.get()
|
||||
htlc_log2 = await paysession2.sent_htlcs_q.get()
|
||||
self.assertTrue(htlc_log2.success)
|
||||
raise PaymentDone()
|
||||
|
||||
@@ -1603,9 +1606,10 @@ class TestPeer(ElectrumTestCase):
|
||||
trampoline_fee_level=None,
|
||||
trampoline_route=None,
|
||||
)
|
||||
paysession = w1._paysessions[lnaddr.paymenthash + lnaddr.payment_secret]
|
||||
pay = w1.pay_to_route(
|
||||
sent_htlc_info=shi,
|
||||
payment_hash=lnaddr.paymenthash,
|
||||
paysession=paysession,
|
||||
min_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
|
||||
)
|
||||
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
||||
|
||||
Reference in New Issue
Block a user