basic_mpp: receive multi-part payments
This commit is contained in:
@@ -969,10 +969,6 @@ class Channel(AbstractChannel):
|
||||
raise Exception("refusing to revoke as remote sig does not fit")
|
||||
with self.db_lock:
|
||||
self.hm.send_rev()
|
||||
if self.lnworker:
|
||||
received = self.hm.received_in_ctn(new_ctn)
|
||||
for htlc in received:
|
||||
self.lnworker.payment_received(self, htlc.payment_hash)
|
||||
last_secret, last_point = self.get_secret_and_point(LOCAL, new_ctn - 1)
|
||||
next_secret, next_point = self.get_secret_and_point(LOCAL, new_ctn + 1)
|
||||
return RevokeAndAck(last_secret, next_point)
|
||||
@@ -1054,7 +1050,7 @@ class Channel(AbstractChannel):
|
||||
if is_sent:
|
||||
self.lnworker.payment_sent(self, payment_hash)
|
||||
else:
|
||||
self.lnworker.payment_received(self, payment_hash)
|
||||
self.lnworker.payment_received(payment_hash)
|
||||
|
||||
def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int:
|
||||
assert type(whose) is HTLCOwner
|
||||
|
||||
@@ -498,6 +498,7 @@ class OnionFailureCode(IntEnum):
|
||||
CHANNEL_DISABLED = UPDATE | 20
|
||||
EXPIRY_TOO_FAR = 21
|
||||
INVALID_ONION_PAYLOAD = PERM | 22
|
||||
MPP_TIMEOUT = 23
|
||||
|
||||
|
||||
# don't use these elsewhere, the names are ambiguous without context
|
||||
|
||||
@@ -1389,10 +1389,6 @@ class Peer(Logger):
|
||||
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
|
||||
return None, reason
|
||||
expected_received_msat = info.amount_msat
|
||||
if expected_received_msat is not None and \
|
||||
not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat):
|
||||
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
|
||||
return None, reason
|
||||
# Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height.
|
||||
# We should not release the preimage for an HTLC that its sender could already time out as
|
||||
# then they might try to force-close and it becomes a race.
|
||||
@@ -1415,20 +1411,34 @@ class Peer(Logger):
|
||||
data=htlc.cltv_expiry.to_bytes(4, byteorder="big"))
|
||||
return None, reason
|
||||
try:
|
||||
amount_from_onion = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]
|
||||
amt_to_forward = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]
|
||||
except:
|
||||
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
|
||||
return None, reason
|
||||
try:
|
||||
amount_from_onion = processed_onion.hop_data.payload["payment_data"]["total_msat"]
|
||||
total_msat = processed_onion.hop_data.payload["payment_data"]["total_msat"]
|
||||
except:
|
||||
pass # fall back to "amt_to_forward"
|
||||
if amount_from_onion > htlc.amount_msat:
|
||||
reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT,
|
||||
data=htlc.amount_msat.to_bytes(8, byteorder="big"))
|
||||
total_msat = amt_to_forward # fall back to "amt_to_forward"
|
||||
|
||||
if amt_to_forward != htlc.amount_msat:
|
||||
reason = OnionRoutingFailureMessage(
|
||||
code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT,
|
||||
data=total_msat.to_bytes(8, byteorder="big"))
|
||||
return None, reason
|
||||
# all good
|
||||
return preimage, None
|
||||
if expected_received_msat is None:
|
||||
return preimage, None
|
||||
if not (expected_received_msat <= total_msat <= 2 * expected_received_msat):
|
||||
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
|
||||
return None, reason
|
||||
accepted, expired = self.lnworker.htlc_received(chan.short_channel_id, htlc, expected_received_msat)
|
||||
if accepted:
|
||||
return preimage, None
|
||||
elif expired:
|
||||
reason = OnionRoutingFailureMessage(code=OnionFailureCode.MPP_TIMEOUT)
|
||||
return None, reason
|
||||
else:
|
||||
# waiting for more htlcs
|
||||
return None, None
|
||||
|
||||
def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
|
||||
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
|
||||
@@ -1669,7 +1679,7 @@ class Peer(Logger):
|
||||
for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items():
|
||||
if not chan.hm.is_add_htlc_irrevocably_committed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
|
||||
continue
|
||||
chan.logger.info(f'found unfulfilled htlc: {htlc_id}')
|
||||
#chan.logger.info(f'found unfulfilled htlc: {htlc_id}')
|
||||
htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id)
|
||||
payment_hash = htlc.payment_hash
|
||||
error_reason = None # type: Optional[OnionRoutingFailureMessage]
|
||||
@@ -1694,7 +1704,6 @@ class Peer(Logger):
|
||||
error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_VERSION, data=sha256(onion_packet_bytes))
|
||||
if self.network.config.get('test_fail_htlcs_with_temp_node_failure'):
|
||||
error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
|
||||
|
||||
if not error_reason:
|
||||
if processed_onion.are_we_final:
|
||||
preimage, error_reason = self.maybe_fulfill_htlc(
|
||||
|
||||
@@ -86,6 +86,7 @@ SAVED_PR_STATUS = [PR_PAID, PR_UNPAID] # status that are persisted
|
||||
|
||||
|
||||
NUM_PEERS_TARGET = 4
|
||||
MPP_EXPIRY = 120
|
||||
|
||||
|
||||
FALLBACK_NODE_LIST_TESTNET = (
|
||||
@@ -164,7 +165,8 @@ BASE_FEATURES = LnFeatures(0)\
|
||||
LNWALLET_FEATURES = BASE_FEATURES\
|
||||
| LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ\
|
||||
| LnFeatures.OPTION_STATIC_REMOTEKEY_REQ\
|
||||
| LnFeatures.GOSSIP_QUERIES_REQ
|
||||
| LnFeatures.GOSSIP_QUERIES_REQ\
|
||||
| LnFeatures.BASIC_MPP_OPT
|
||||
|
||||
LNGOSSIP_FEATURES = BASE_FEATURES\
|
||||
| LnFeatures.GOSSIP_QUERIES_OPT\
|
||||
@@ -581,6 +583,7 @@ class LNWallet(LNWorker):
|
||||
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
|
||||
|
||||
self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]]
|
||||
self.pending_htlcs = defaultdict(set) # type: Dict[bytes, set]
|
||||
|
||||
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
|
||||
# detect inflight payments
|
||||
@@ -1284,6 +1287,24 @@ class LNWallet(LNWorker):
|
||||
self.payments[key] = info.amount_msat, info.direction, info.status
|
||||
self.wallet.save_db()
|
||||
|
||||
def htlc_received(self, short_channel_id, htlc, expected_msat):
|
||||
status = self.get_payment_status(htlc.payment_hash)
|
||||
if status == PR_PAID:
|
||||
return True, None
|
||||
s = self.pending_htlcs[htlc.payment_hash]
|
||||
if (short_channel_id, htlc) not in s:
|
||||
s.add((short_channel_id, htlc))
|
||||
total = sum([htlc.amount_msat for scid, htlc in s])
|
||||
first_timestamp = min([htlc.timestamp for scid, htlc in s])
|
||||
expired = time.time() - first_timestamp > MPP_EXPIRY
|
||||
if total >= expected_msat and not expired:
|
||||
# status must be persisted
|
||||
self.payment_received(htlc.payment_hash)
|
||||
return True, None
|
||||
if expired:
|
||||
return None, True
|
||||
return None, None
|
||||
|
||||
def get_payment_status(self, payment_hash):
|
||||
info = self.get_payment_info(payment_hash)
|
||||
return info.status if info else PR_UNPAID
|
||||
@@ -1359,10 +1380,10 @@ class LNWallet(LNWorker):
|
||||
util.trigger_callback('payment_succeeded', self.wallet, key)
|
||||
util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
|
||||
|
||||
def payment_received(self, chan, payment_hash: bytes):
|
||||
def payment_received(self, payment_hash: bytes):
|
||||
self.set_payment_status(payment_hash, PR_PAID)
|
||||
util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID)
|
||||
util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
|
||||
#util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
|
||||
|
||||
async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
|
||||
"""calculate routing hints (BOLT-11 'r' field)"""
|
||||
|
||||
@@ -132,6 +132,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
# used in tests
|
||||
self.enable_htlc_settle = asyncio.Event()
|
||||
self.enable_htlc_settle.set()
|
||||
self.pending_htlcs = defaultdict(set)
|
||||
|
||||
def get_invoice_status(self, key):
|
||||
pass
|
||||
@@ -167,6 +168,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
set_invoice_status = LNWallet.set_invoice_status
|
||||
set_payment_status = LNWallet.set_payment_status
|
||||
get_payment_status = LNWallet.get_payment_status
|
||||
htlc_received = LNWallet.htlc_received
|
||||
await_payment = LNWallet.await_payment
|
||||
payment_received = LNWallet.payment_received
|
||||
payment_sent = LNWallet.payment_sent
|
||||
|
||||
Reference in New Issue
Block a user