From 125a921cc4b52b8e6e4d3e63aa6562cebe6d3baf Mon Sep 17 00:00:00 2001 From: f321x Date: Tue, 9 Dec 2025 14:31:12 +0100 Subject: [PATCH] lnworker: add invoice features to PaymentInfo class Adds the invoice features to the `PaymentInfo` class so we can check if the sender respects our requested features (e.g. if they tried to send mpp if we requested no mpp). --- electrum/lnworker.py | 51 ++++++++++++++++++++++++++----------------- electrum/wallet_db.py | 19 +++++++++++++++- tests/test_lnpeer.py | 4 ++++ 3 files changed, 53 insertions(+), 21 deletions(-) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 1db9d0de0..938b05ba8 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -134,6 +134,7 @@ class PaymentInfo: min_final_cltv_delta: int expiry_delay: int creation_ts: int = dataclasses.field(default_factory=lambda: int(time.time())) + invoice_features: LnFeatures @property def expiration_ts(self): @@ -147,6 +148,7 @@ class PaymentInfo: assert isinstance(self.min_final_cltv_delta, int) assert isinstance(self.expiry_delay, int) and self.expiry_delay > 0 assert isinstance(self.creation_ts, int) + assert isinstance(self.invoice_features, LnFeatures) def __post_init__(self): self.validate() @@ -903,8 +905,8 @@ class LNWallet(LNWorker): LNWorker.__init__(self, self.node_keypair, features, config=self.config) self.lnwatcher = LNWatcher(self) self.lnrater: LNRater = None - # lightning_payments: "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts - self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int]] + # "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features + self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int, int]] self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self._bolt11_cache = {} # note: this sweep_address is only used as fallback; as it might result in address-reuse @@ -1628,6 +1630,7 @@ class LNWallet(LNWorker): status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, expiry_delay=LN_EXPIRY_NEVER, + invoice_features=invoice_features, ) self.save_payment_info(info) self.wallet.set_label(key, lnaddr.get_description()) @@ -2332,6 +2335,16 @@ class LNWallet(LNWorker): route[-1].node_features |= invoice_features return route + def _get_invoice_features(self, amount_msat: Optional[int]) -> LnFeatures: + invoice_features = self.features.for_invoice() + if not self.uses_trampoline(): + invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM + needs_jit: bool = self.receive_requires_jit_channel(amount_msat) + if needs_jit: + # jit only works with single htlcs, mpp will cause LSP to open channels for each htlc + invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ + return invoice_features + def clear_invoices_cache(self): self._bolt11_cache.clear() @@ -2351,15 +2364,8 @@ class LNWallet(LNWorker): assert amount_msat is None or amount_msat > 0 timestamp = int(time.time()) - needs_jit: bool = self.receive_requires_jit_channel(amount_msat) - routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels, needs_jit=needs_jit) - self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, jit: {needs_jit}, sat: {(amount_msat or 0) // 1000}") - invoice_features = self.features.for_invoice() - if not self.uses_trampoline(): - invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM - if needs_jit: - # jit only works with single htlcs, mpp will cause LSP to open channels for each htlc - invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ + routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels) + self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, sat: {(amount_msat or 0) // 1000}") payment_secret = self.get_payment_secret(payment_info.payment_hash) amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None min_final_cltv_delta = payment_info.min_final_cltv_delta + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE @@ -2370,7 +2376,7 @@ class LNWallet(LNWorker): ('d', message), ('c', min_final_cltv_delta), ('x', payment_info.expiry_delay), - ('9', invoice_features), + ('9', payment_info.invoice_features), ('f', fallback_address), ] + routing_hints, date=timestamp, @@ -2401,13 +2407,15 @@ class LNWallet(LNWorker): payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) min_final_cltv_delta = min_final_cltv_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED + invoice_features = self._get_invoice_features(amount_msat) info = PaymentInfo( payment_hash=payment_hash, amount_msat=amount_msat, direction=RECEIVED, status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, - expiry_delay=exp_delay + expiry_delay=exp_delay, + invoice_features=invoice_features, ) self.save_preimage(payment_hash, payment_preimage, write_to_disk=False) self.save_payment_info(info, write_to_disk=False) @@ -2514,7 +2522,7 @@ class LNWallet(LNWorker): with self.lock: if key in self.payment_info: stored_tuple = self.payment_info[key] - amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple + amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features = stored_tuple return PaymentInfo( payment_hash=payment_hash, amount_msat=amount_msat, @@ -2523,6 +2531,7 @@ class LNWallet(LNWorker): min_final_cltv_delta=min_final_cltv_delta, expiry_delay=expiry_delay, creation_ts=creation_ts, + invoice_features=LnFeatures(invoice_features), ) return None @@ -2533,14 +2542,15 @@ class LNWallet(LNWorker): min_final_cltv_delta: int, exp_delay: int, ): - amount = lightning_amount_sat * 1000 if lightning_amount_sat else None + amount_msat = lightning_amount_sat * 1000 if lightning_amount_sat else None info = PaymentInfo( payment_hash=payment_hash, - amount_msat=amount, + amount_msat=amount_msat, direction=RECEIVED, status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, expiry_delay=exp_delay, + invoice_features=self._get_invoice_features(amount_msat), ) self.save_payment_info(info, write_to_disk=False) @@ -2574,7 +2584,7 @@ class LNWallet(LNWorker): if info != dataclasses.replace(old_info, status=info.status): # differs more than in status. let's fail raise Exception(f"payment_hash already in use: {info=} != {old_info=}") - v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts + v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts, int(info.invoice_features) self.payment_info[info.db_key] = v if write_to_disk: self.wallet.save_db() @@ -2907,10 +2917,11 @@ class LNWallet(LNWorker): else: self.logger.info(f'htlc_failed: waiting for other htlcs to fail (phash={payment_hash.hex()})') - def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None, needs_jit=False): + def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None): """calculate routing hints (BOLT-11 'r' field)""" routing_hints = [] - if needs_jit: + if self.receive_requires_jit_channel(amount_msat): + self.logger.debug(f"will request just-in-time channel") node_id, rest = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE) alias_or_scid = self.get_static_jit_scid_alias() routing_hints.append(('r', [(node_id, alias_or_scid, 0, 0, 144)])) @@ -3093,7 +3104,7 @@ class LNWallet(LNWorker): # check if zeroconf is accepted and client has trusted zeroconf node configured return False try: - node_id = extract_nodeid(self.wallet.config.ZEROCONF_TRUSTED_NODE)[0] + node_id = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)[0] except ConnStringFormatError: # invalid connection string return False diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 068499c2d..688ab4c17 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -69,7 +69,7 @@ class WalletUnfinished(WalletFileException): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 65 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 66 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -237,6 +237,7 @@ class WalletDBUpgrader(Logger): self._convert_version_63() self._convert_version_64() self._convert_version_65() + self._convert_version_66() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1315,6 +1316,22 @@ class WalletDBUpgrader(Logger): self.data['received_mpp_htlcs'] = new_mpp_sets self.data['seed_version'] = 65 + def _convert_version_66(self): + """Add invoice features to PaymentInfo""" + if not self._is_upgrade_method_needed(65, 65): + return + + new_payment_infos = {} + old_payment_infos = self.data.get('lightning_payments', {}) + for key, old_v in old_payment_infos.items(): + amount_msat, status, min_final_cltv_expiry, expiry, creation_ts = old_v + invoice_features = 147712 # + new_v = (amount_msat, status, min_final_cltv_expiry, expiry, creation_ts, invoice_features) + new_payment_infos[key] = new_v + + self.data['lightning_payments'] = new_payment_infos + self.data['seed_version'] = 66 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 57def01cb..68e4d129e 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -359,6 +359,9 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): is_payment_bundle_complete = LNWallet.is_payment_bundle_complete delete_payment_bundle = LNWallet.delete_payment_bundle _process_htlc_log = LNWallet._process_htlc_log + _get_invoice_features = LNWallet._get_invoice_features + receive_requires_jit_channel = LNWallet.receive_requires_jit_channel + can_get_zeroconf_channel = LNWallet.can_get_zeroconf_channel class MockTransport: @@ -594,6 +597,7 @@ class TestPeer(ElectrumTestCase): status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, expiry_delay=expiry or LN_EXPIRY_NEVER, + invoice_features=invoice_features, ) w2.save_payment_info(info) lnaddr1 = LnAddr(