Merge pull request #10351 from f321x/jit_htlc_switch_fixes
lnpeer/lnutil: fail mpp if we didn't signal mpp in invoice
This commit is contained in:
@@ -2152,7 +2152,7 @@ class Peer(Logger, EventListener):
|
||||
Does additional checks on the incoming htlc and return the payment key if the tests pass,
|
||||
otherwise raises OnionRoutingError which will get the htlc failed.
|
||||
"""
|
||||
_log_fail_reason = self._log_htlc_fail_reason_cb(chan.short_channel_id, htlc, processed_onion.hop_data.payload)
|
||||
_log_fail_reason = self._log_htlc_fail_reason_cb(chan.channel_id, htlc, processed_onion.hop_data.payload)
|
||||
|
||||
# Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height.
|
||||
# We should not release the preimage for an HTLC that its sender could already time out as
|
||||
@@ -2242,6 +2242,10 @@ class Peer(Logger, EventListener):
|
||||
elif htlc.timestamp > info.expiration_ts: # the set will get failed too if now > exp_ts
|
||||
_log_fail_reason(f"not accepting htlc for expired invoice")
|
||||
raise exc_incorrect_or_unknown_pd
|
||||
elif not info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and total_msat > htlc.amount_msat:
|
||||
# in _check_unfulfilled_htlc_set we check the count to prevent mpp through overpayment
|
||||
_log_fail_reason(f"got mpp but we requested no mpp in the invoice: {total_msat=} > {htlc.amount_msat=}")
|
||||
raise exc_incorrect_or_unknown_pd
|
||||
|
||||
expected_payment_secret = self.lnworker.get_payment_secret(payment_hash)
|
||||
if not util.constant_time_compare(payment_secret_from_onion, expected_payment_secret):
|
||||
@@ -2271,7 +2275,7 @@ class Peer(Logger, EventListener):
|
||||
self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) # htlcs wont get expired anymore
|
||||
for mpp_htlc in list(htlc_set.htlcs):
|
||||
htlc_id = mpp_htlc.htlc.htlc_id
|
||||
chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid)
|
||||
chan = self.lnworker.channels[mpp_htlc.channel_id]
|
||||
if chan.channel_id not in self.channels:
|
||||
# this htlc belongs to another peer and has to be settled in their htlc_switch
|
||||
continue
|
||||
@@ -2313,7 +2317,7 @@ class Peer(Logger, EventListener):
|
||||
self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None)
|
||||
self.lnworker.dont_settle_htlcs.pop(payment_hash.hex(), None) # already failed
|
||||
for mpp_htlc in list(htlc_set.htlcs):
|
||||
chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid)
|
||||
chan = self.lnworker.channels[mpp_htlc.channel_id]
|
||||
htlc_id = mpp_htlc.htlc.htlc_id
|
||||
if chan.channel_id not in self.channels:
|
||||
# this htlc belongs to another peer and has to be settled in their htlc_switch
|
||||
@@ -2856,7 +2860,7 @@ class Peer(Logger, EventListener):
|
||||
)
|
||||
self.lnworker.update_or_create_mpp_with_received_htlc(
|
||||
payment_key=payment_key,
|
||||
scid=chan.short_channel_id,
|
||||
channel_id=chan.channel_id,
|
||||
htlc=htlc,
|
||||
unprocessed_onion_packet=onion_packet_hex, # outer onion if trampoline
|
||||
)
|
||||
@@ -2940,11 +2944,12 @@ class Peer(Logger, EventListener):
|
||||
|
||||
def _log_htlc_fail_reason_cb(
|
||||
self,
|
||||
scid: ShortChannelID,
|
||||
channel_id: bytes,
|
||||
htlc: UpdateAddHtlc,
|
||||
onion_payload: dict
|
||||
) -> Callable[[str], None]:
|
||||
def _log_fail_reason(reason: str) -> None:
|
||||
scid = self.lnworker.channels[channel_id].short_channel_id
|
||||
self.logger.info(f"will FAIL HTLC: {str(scid)=}. {reason=}. {str(htlc)=}. {onion_payload=}")
|
||||
return _log_fail_reason
|
||||
|
||||
@@ -2962,7 +2967,7 @@ class Peer(Logger, EventListener):
|
||||
onion_payload = {}
|
||||
|
||||
self._log_htlc_fail_reason_cb(
|
||||
mpp_htlc.scid,
|
||||
mpp_htlc.channel_id,
|
||||
mpp_htlc.htlc,
|
||||
onion_payload,
|
||||
)(f"mpp set {id(mpp_set)} failed: {reason}")
|
||||
@@ -3075,8 +3080,9 @@ class Peer(Logger, EventListener):
|
||||
return OnionFailureCode.MPP_TIMEOUT, None, None
|
||||
|
||||
if mpp_set.resolution == RecvMPPResolution.WAITING:
|
||||
# calculate the sum of just in time channel opening fees
|
||||
htlc_channels = [self.lnworker.get_channel_by_short_id(scid) for scid in set(h.scid for h in mpp_set.htlcs)]
|
||||
# calculate the sum of just in time channel opening fees, note jit only supports
|
||||
# single part payments for now, this is enforced by checking against the invoice features
|
||||
htlc_channels = [self.lnworker.channels[channel_id] for channel_id in set(h.channel_id for h in mpp_set.htlcs)]
|
||||
jit_opening_fees_msat = sum((c.jit_opening_fee or 0) for c in htlc_channels)
|
||||
|
||||
# check if set is first stage multi-trampoline payment to us
|
||||
@@ -3096,15 +3102,21 @@ class Peer(Logger, EventListener):
|
||||
trampoline_payment_key = (payment_hash + trampoline_payment_secret).hex()
|
||||
|
||||
if trampoline_payment_key and trampoline_payment_key != payment_key:
|
||||
if jit_opening_fees_msat:
|
||||
# for jit openings we only accept a single htlc
|
||||
expected_amount_first_stage = any_trampoline_onion.total_msat - jit_opening_fees_msat
|
||||
else:
|
||||
expected_amount_first_stage = any_trampoline_onion.amt_to_forward
|
||||
|
||||
# first stage of trampoline payment, the first stage must never get set COMPLETE
|
||||
if amount_msat >= (any_trampoline_onion.amt_to_forward - jit_opening_fees_msat):
|
||||
if amount_msat >= expected_amount_first_stage:
|
||||
# setting the parent key will mark the htlcs to be moved to the parent set
|
||||
self.logger.debug(f"trampoline part complete. {len(mpp_set.htlcs)=}, "
|
||||
f"{amount_msat=}. setting parent key: {trampoline_payment_key}")
|
||||
self.lnworker.received_mpp_htlcs[payment_key] = mpp_set._replace(
|
||||
parent_set_key=trampoline_payment_key,
|
||||
)
|
||||
elif amount_msat >= (total_msat - jit_opening_fees_msat):
|
||||
elif amount_msat >= (total_msat - jit_opening_fees_msat): # regular mpp or 2nd stage trampoline
|
||||
# set mpp_set as completed as we have received the full total_msat
|
||||
mpp_set = self.lnworker.set_mpp_resolution(
|
||||
payment_key=payment_key,
|
||||
@@ -3128,6 +3140,11 @@ class Peer(Logger, EventListener):
|
||||
if payment_info is None:
|
||||
_log_fail_reason(f"payment info has been deleted")
|
||||
return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None
|
||||
elif not payment_info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and len(mpp_set.htlcs) > 1:
|
||||
# in _check_unfulfilled_htlc we already check amount == total_amount, however someone could
|
||||
# send us multiple htlcs that all pay the full amount, so we also check the htlc count
|
||||
_log_fail_reason(f"got mpp but we requested no mpp in the invoice: {len(mpp_set.htlcs)=}")
|
||||
return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None
|
||||
|
||||
# check invoice expiry, fail set if the invoice has expired before it was completed
|
||||
if mpp_set.resolution == RecvMPPResolution.WAITING:
|
||||
|
||||
@@ -1968,18 +1968,18 @@ del r
|
||||
|
||||
|
||||
class ReceivedMPPHtlc(NamedTuple):
|
||||
scid: ShortChannelID
|
||||
channel_id: bytes
|
||||
htlc: UpdateAddHtlc
|
||||
unprocessed_onion: str
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.scid}, {self.htlc=}, {self.unprocessed_onion[:15]=}..."
|
||||
return f"chan_id={self.channel_id.hex()}, {self.htlc=}, {self.unprocessed_onion[:15]=}..."
|
||||
|
||||
@staticmethod
|
||||
def from_tuple(scid, htlc, unprocessed_onion) -> 'ReceivedMPPHtlc':
|
||||
assert is_hex_str(unprocessed_onion) and is_hex_str(scid)
|
||||
def from_tuple(channel_id, htlc, unprocessed_onion) -> 'ReceivedMPPHtlc':
|
||||
assert is_hex_str(unprocessed_onion) and is_hex_str(channel_id)
|
||||
return ReceivedMPPHtlc(
|
||||
scid=ShortChannelID(bytes.fromhex(scid)),
|
||||
channel_id=bytes.fromhex(channel_id),
|
||||
htlc=UpdateAddHtlc.from_tuple(*htlc),
|
||||
unprocessed_onion=unprocessed_onion,
|
||||
)
|
||||
|
||||
@@ -134,6 +134,7 @@ class PaymentInfo:
|
||||
min_final_cltv_delta: int
|
||||
expiry_delay: int
|
||||
creation_ts: int = dataclasses.field(default_factory=lambda: int(time.time()))
|
||||
invoice_features: LnFeatures
|
||||
|
||||
@property
|
||||
def expiration_ts(self):
|
||||
@@ -147,6 +148,7 @@ class PaymentInfo:
|
||||
assert isinstance(self.min_final_cltv_delta, int)
|
||||
assert isinstance(self.expiry_delay, int) and self.expiry_delay > 0
|
||||
assert isinstance(self.creation_ts, int)
|
||||
assert isinstance(self.invoice_features, LnFeatures)
|
||||
|
||||
def __post_init__(self):
|
||||
self.validate()
|
||||
@@ -903,8 +905,8 @@ class LNWallet(LNWorker):
|
||||
LNWorker.__init__(self, self.node_keypair, features, config=self.config)
|
||||
self.lnwatcher = LNWatcher(self)
|
||||
self.lnrater: LNRater = None
|
||||
# lightning_payments: "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts
|
||||
self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int]]
|
||||
# "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features
|
||||
self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int, int]]
|
||||
self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
|
||||
self._bolt11_cache = {}
|
||||
# note: this sweep_address is only used as fallback; as it might result in address-reuse
|
||||
@@ -1574,6 +1576,7 @@ class LNWallet(LNWorker):
|
||||
return chan, funding_tx
|
||||
|
||||
def get_channel_by_short_id(self, short_channel_id: bytes) -> Optional[Channel]:
|
||||
assert short_channel_id and isinstance(short_channel_id, bytes), repr(short_channel_id)
|
||||
# First check against *real* SCIDs.
|
||||
# This e.g. protects against maliciously chosen SCID aliases, and accidental collisions.
|
||||
for chan in self.channels.values():
|
||||
@@ -1627,6 +1630,7 @@ class LNWallet(LNWorker):
|
||||
status=PR_UNPAID,
|
||||
min_final_cltv_delta=min_final_cltv_delta,
|
||||
expiry_delay=LN_EXPIRY_NEVER,
|
||||
invoice_features=invoice_features,
|
||||
)
|
||||
self.save_payment_info(info)
|
||||
self.wallet.set_label(key, lnaddr.get_description())
|
||||
@@ -2331,6 +2335,16 @@ class LNWallet(LNWorker):
|
||||
route[-1].node_features |= invoice_features
|
||||
return route
|
||||
|
||||
def _get_invoice_features(self, amount_msat: Optional[int]) -> LnFeatures:
|
||||
invoice_features = self.features.for_invoice()
|
||||
if not self.uses_trampoline():
|
||||
invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
|
||||
needs_jit: bool = self.receive_requires_jit_channel(amount_msat)
|
||||
if needs_jit:
|
||||
# jit only works with single htlcs, mpp will cause LSP to open channels for each htlc
|
||||
invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ
|
||||
return invoice_features
|
||||
|
||||
def clear_invoices_cache(self):
|
||||
self._bolt11_cache.clear()
|
||||
|
||||
@@ -2350,15 +2364,8 @@ class LNWallet(LNWorker):
|
||||
|
||||
assert amount_msat is None or amount_msat > 0
|
||||
timestamp = int(time.time())
|
||||
needs_jit: bool = self.receive_requires_jit_channel(amount_msat)
|
||||
routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels, needs_jit=needs_jit)
|
||||
self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, jit: {needs_jit}, sat: {(amount_msat or 0) // 1000}")
|
||||
invoice_features = self.features.for_invoice()
|
||||
if not self.uses_trampoline():
|
||||
invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
|
||||
if needs_jit:
|
||||
# jit only works with single htlcs, mpp will cause LSP to open channels for each htlc
|
||||
invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ
|
||||
routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels)
|
||||
self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, sat: {(amount_msat or 0) // 1000}")
|
||||
payment_secret = self.get_payment_secret(payment_info.payment_hash)
|
||||
amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None
|
||||
min_final_cltv_delta = payment_info.min_final_cltv_delta + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE
|
||||
@@ -2369,7 +2376,7 @@ class LNWallet(LNWorker):
|
||||
('d', message),
|
||||
('c', min_final_cltv_delta),
|
||||
('x', payment_info.expiry_delay),
|
||||
('9', invoice_features),
|
||||
('9', payment_info.invoice_features),
|
||||
('f', fallback_address),
|
||||
] + routing_hints,
|
||||
date=timestamp,
|
||||
@@ -2400,13 +2407,15 @@ class LNWallet(LNWorker):
|
||||
payment_preimage = os.urandom(32)
|
||||
payment_hash = sha256(payment_preimage)
|
||||
min_final_cltv_delta = min_final_cltv_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED
|
||||
invoice_features = self._get_invoice_features(amount_msat)
|
||||
info = PaymentInfo(
|
||||
payment_hash=payment_hash,
|
||||
amount_msat=amount_msat,
|
||||
direction=RECEIVED,
|
||||
status=PR_UNPAID,
|
||||
min_final_cltv_delta=min_final_cltv_delta,
|
||||
expiry_delay=exp_delay
|
||||
expiry_delay=exp_delay,
|
||||
invoice_features=invoice_features,
|
||||
)
|
||||
self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
|
||||
self.save_payment_info(info, write_to_disk=False)
|
||||
@@ -2513,7 +2522,7 @@ class LNWallet(LNWorker):
|
||||
with self.lock:
|
||||
if key in self.payment_info:
|
||||
stored_tuple = self.payment_info[key]
|
||||
amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple
|
||||
amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features = stored_tuple
|
||||
return PaymentInfo(
|
||||
payment_hash=payment_hash,
|
||||
amount_msat=amount_msat,
|
||||
@@ -2522,6 +2531,7 @@ class LNWallet(LNWorker):
|
||||
min_final_cltv_delta=min_final_cltv_delta,
|
||||
expiry_delay=expiry_delay,
|
||||
creation_ts=creation_ts,
|
||||
invoice_features=LnFeatures(invoice_features),
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -2532,14 +2542,15 @@ class LNWallet(LNWorker):
|
||||
min_final_cltv_delta: int,
|
||||
exp_delay: int,
|
||||
):
|
||||
amount = lightning_amount_sat * 1000 if lightning_amount_sat else None
|
||||
amount_msat = lightning_amount_sat * 1000 if lightning_amount_sat else None
|
||||
info = PaymentInfo(
|
||||
payment_hash=payment_hash,
|
||||
amount_msat=amount,
|
||||
amount_msat=amount_msat,
|
||||
direction=RECEIVED,
|
||||
status=PR_UNPAID,
|
||||
min_final_cltv_delta=min_final_cltv_delta,
|
||||
expiry_delay=exp_delay,
|
||||
invoice_features=self._get_invoice_features(amount_msat),
|
||||
)
|
||||
self.save_payment_info(info, write_to_disk=False)
|
||||
|
||||
@@ -2573,7 +2584,7 @@ class LNWallet(LNWorker):
|
||||
if info != dataclasses.replace(old_info, status=info.status):
|
||||
# differs more than in status. let's fail
|
||||
raise Exception(f"payment_hash already in use: {info=} != {old_info=}")
|
||||
v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts
|
||||
v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts, int(info.invoice_features)
|
||||
self.payment_info[info.db_key] = v
|
||||
if write_to_disk:
|
||||
self.wallet.save_db()
|
||||
@@ -2582,7 +2593,7 @@ class LNWallet(LNWorker):
|
||||
self,
|
||||
*,
|
||||
payment_key: str,
|
||||
scid: ShortChannelID,
|
||||
channel_id: bytes,
|
||||
htlc: UpdateAddHtlc,
|
||||
unprocessed_onion_packet: str,
|
||||
):
|
||||
@@ -2609,7 +2620,7 @@ class LNWallet(LNWorker):
|
||||
|
||||
if mpp_status.resolution > RecvMPPResolution.WAITING:
|
||||
# we are getting a htlc for a set that is not in WAITING state, it cannot be safely added
|
||||
self.logger.info(f"htlc set cannot accept htlc, failing htlc: {scid=} {htlc.htlc_id=}")
|
||||
self.logger.info(f"htlc set cannot accept htlc, failing htlc: {channel_id=} {htlc.htlc_id=}")
|
||||
if mpp_status == RecvMPPResolution.EXPIRED:
|
||||
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
|
||||
raise OnionRoutingFailure(
|
||||
@@ -2618,7 +2629,7 @@ class LNWallet(LNWorker):
|
||||
)
|
||||
|
||||
new_htlc = ReceivedMPPHtlc(
|
||||
scid=scid,
|
||||
channel_id=channel_id,
|
||||
htlc=htlc,
|
||||
unprocessed_onion=unprocessed_onion_packet,
|
||||
)
|
||||
@@ -2706,7 +2717,7 @@ class LNWallet(LNWorker):
|
||||
# only cleanup when channel is REDEEMED as mpp set is still required for lnsweep
|
||||
assert chan._state == ChannelState.REDEEMED
|
||||
for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
|
||||
htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.scid == chan.short_channel_id]
|
||||
htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.channel_id == chan.channel_id]
|
||||
for stale_mpp_htlc in htlcs_to_remove:
|
||||
assert mpp_status.resolution != RecvMPPResolution.WAITING
|
||||
self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
|
||||
@@ -2906,10 +2917,11 @@ class LNWallet(LNWorker):
|
||||
else:
|
||||
self.logger.info(f'htlc_failed: waiting for other htlcs to fail (phash={payment_hash.hex()})')
|
||||
|
||||
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None, needs_jit=False):
|
||||
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None):
|
||||
"""calculate routing hints (BOLT-11 'r' field)"""
|
||||
routing_hints = []
|
||||
if needs_jit:
|
||||
if self.receive_requires_jit_channel(amount_msat):
|
||||
self.logger.debug(f"will request just-in-time channel")
|
||||
node_id, rest = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)
|
||||
alias_or_scid = self.get_static_jit_scid_alias()
|
||||
routing_hints.append(('r', [(node_id, alias_or_scid, 0, 0, 144)]))
|
||||
@@ -3092,7 +3104,7 @@ class LNWallet(LNWorker):
|
||||
# check if zeroconf is accepted and client has trusted zeroconf node configured
|
||||
return False
|
||||
try:
|
||||
node_id = extract_nodeid(self.wallet.config.ZEROCONF_TRUSTED_NODE)[0]
|
||||
node_id = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)[0]
|
||||
except ConnStringFormatError:
|
||||
# invalid connection string
|
||||
return False
|
||||
@@ -3546,7 +3558,7 @@ class LNWallet(LNWorker):
|
||||
assert not any_outer_onion.are_we_final
|
||||
assert len(processed_htlc_set) == 1, processed_htlc_set
|
||||
forward_htlc = any_mpp_htlc.htlc
|
||||
incoming_chan = self.get_channel_by_short_id(any_mpp_htlc.scid)
|
||||
incoming_chan = self.channels[any_mpp_htlc.channel_id]
|
||||
next_htlc = await self._maybe_forward_htlc(
|
||||
incoming_chan=incoming_chan,
|
||||
htlc=forward_htlc,
|
||||
|
||||
@@ -69,7 +69,7 @@ class WalletUnfinished(WalletFileException):
|
||||
# seed_version is now used for the version of the wallet file
|
||||
OLD_SEED_VERSION = 4 # electrum versions < 2.0
|
||||
NEW_SEED_VERSION = 11 # electrum versions >= 2.0
|
||||
FINAL_SEED_VERSION = 64 # electrum >= 2.7 will set this to prevent
|
||||
FINAL_SEED_VERSION = 66 # electrum >= 2.7 will set this to prevent
|
||||
# old versions from overwriting new format
|
||||
|
||||
|
||||
@@ -236,6 +236,8 @@ class WalletDBUpgrader(Logger):
|
||||
self._convert_version_62()
|
||||
self._convert_version_63()
|
||||
self._convert_version_64()
|
||||
self._convert_version_65()
|
||||
self._convert_version_66()
|
||||
self.put('seed_version', FINAL_SEED_VERSION) # just to be sure
|
||||
|
||||
def _convert_wallet_type(self):
|
||||
@@ -1288,6 +1290,48 @@ class WalletDBUpgrader(Logger):
|
||||
self.data['lightning_payments'] = new_payment_infos
|
||||
self.data['seed_version'] = 64
|
||||
|
||||
def _convert_version_65(self):
|
||||
"""Store channel_id instead of short_channel_id in ReceivedMPPHtlc"""
|
||||
if not self._is_upgrade_method_needed(64, 64):
|
||||
return
|
||||
|
||||
channels = self.data.get('channels', {})
|
||||
def scid_to_channel_id(scid):
|
||||
for channel_id, channel_data in channels.items():
|
||||
if scid == channel_data.get('short_channel_id'):
|
||||
return channel_id
|
||||
raise KeyError(f"missing {scid=} in channels")
|
||||
|
||||
mpp_sets = self.data.get('received_mpp_htlcs', {})
|
||||
new_mpp_sets = {}
|
||||
for payment_key, mpp_set in mpp_sets.items():
|
||||
resolution, htlc_list, parent_set_key = mpp_set
|
||||
new_htlc_list = []
|
||||
for htlc_data_tuple in htlc_list:
|
||||
scid, update_add_htlc, onion = htlc_data_tuple
|
||||
channel_id = scid_to_channel_id(scid)
|
||||
new_htlc_list.append((channel_id, update_add_htlc, onion))
|
||||
new_mpp_sets[payment_key] = (resolution, new_htlc_list, parent_set_key)
|
||||
|
||||
self.data['received_mpp_htlcs'] = new_mpp_sets
|
||||
self.data['seed_version'] = 65
|
||||
|
||||
def _convert_version_66(self):
|
||||
"""Add invoice features to PaymentInfo"""
|
||||
if not self._is_upgrade_method_needed(65, 65):
|
||||
return
|
||||
|
||||
new_payment_infos = {}
|
||||
old_payment_infos = self.data.get('lightning_payments', {})
|
||||
for key, old_v in old_payment_infos.items():
|
||||
amount_msat, status, min_final_cltv_expiry, expiry, creation_ts = old_v
|
||||
invoice_features = 0x24100 # <VAR_ONION_REQ|PAYMENT_SECRET_REQ|BASIC_MPP_OPT>
|
||||
new_v = (amount_msat, status, min_final_cltv_expiry, expiry, creation_ts, invoice_features)
|
||||
new_payment_infos[key] = new_v
|
||||
|
||||
self.data['lightning_payments'] = new_payment_infos
|
||||
self.data['seed_version'] = 66
|
||||
|
||||
def _convert_imported(self):
|
||||
if not self._is_upgrade_method_needed(0, 13):
|
||||
return
|
||||
|
||||
@@ -359,6 +359,9 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
is_payment_bundle_complete = LNWallet.is_payment_bundle_complete
|
||||
delete_payment_bundle = LNWallet.delete_payment_bundle
|
||||
_process_htlc_log = LNWallet._process_htlc_log
|
||||
_get_invoice_features = LNWallet._get_invoice_features
|
||||
receive_requires_jit_channel = LNWallet.receive_requires_jit_channel
|
||||
can_get_zeroconf_channel = LNWallet.can_get_zeroconf_channel
|
||||
|
||||
|
||||
class MockTransport:
|
||||
@@ -594,6 +597,7 @@ class TestPeer(ElectrumTestCase):
|
||||
status=PR_UNPAID,
|
||||
min_final_cltv_delta=min_final_cltv_delta,
|
||||
expiry_delay=expiry or LN_EXPIRY_NEVER,
|
||||
invoice_features=invoice_features,
|
||||
)
|
||||
w2.save_payment_info(info)
|
||||
lnaddr1 = LnAddr(
|
||||
@@ -1067,6 +1071,46 @@ class TestPeerDirect(TestPeer):
|
||||
for _test_trampoline in [False, True]:
|
||||
await run_test(_test_trampoline)
|
||||
|
||||
async def test_reject_mpp_for_non_mpp_invoice(self):
|
||||
"""Test that we reject a payment if it is mpp and we didn't signal support for mpp in the invoice"""
|
||||
async def run_test(test_trampoline):
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
w1.config.TEST_FORCE_MPP = True # force alice to send mpp
|
||||
|
||||
if test_trampoline:
|
||||
await self._activate_trampoline(w1)
|
||||
await self._activate_trampoline(w2)
|
||||
# declare bob as trampoline node
|
||||
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
|
||||
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
||||
}
|
||||
|
||||
lnaddr, pay_req = self.prepare_invoice(w2)
|
||||
self.assertFalse(lnaddr.get_features().supports(LnFeatures.BASIC_MPP_OPT))
|
||||
self.assertFalse(lnaddr.get_features().supports(LnFeatures.BASIC_MPP_REQ))
|
||||
|
||||
async def try_pay_invoice_with_mpp(pay_req: Invoice, w1=w1):
|
||||
result, log = await w1.pay_invoice(pay_req)
|
||||
if not result:
|
||||
raise PaymentFailure()
|
||||
raise PaymentDone()
|
||||
|
||||
async def f():
|
||||
async with OldTaskGroup() as group:
|
||||
await group.spawn(p1._message_loop())
|
||||
await group.spawn(p1.htlc_switch())
|
||||
await group.spawn(p2._message_loop())
|
||||
await group.spawn(p2.htlc_switch())
|
||||
await asyncio.sleep(0.01)
|
||||
await group.spawn(try_pay_invoice_with_mpp(pay_req))
|
||||
|
||||
with self.assertRaises(PaymentFailure):
|
||||
await f()
|
||||
|
||||
for _test_trampoline in [False, True]:
|
||||
await run_test(_test_trampoline)
|
||||
|
||||
async def test_reject_multiple_payments_of_same_invoice(self):
|
||||
"""Tests that new htlcs paying an invoice that has already been paid will get rejected."""
|
||||
async def run_test(test_trampoline):
|
||||
@@ -1448,6 +1492,7 @@ class TestPeerDirect(TestPeer):
|
||||
async def run_test(test_trampoline: bool):
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
alice_peer, bob_peer, alice_wallet, bob_wallet, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
bob_wallet.features |= LnFeatures.BASIC_MPP_OPT
|
||||
lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000)
|
||||
|
||||
if test_trampoline:
|
||||
|
||||
Reference in New Issue
Block a user