1
0

lnworker: add invoice features to PaymentInfo class

Adds the invoice features to the `PaymentInfo` class so we can check if
the sender respects our requested features (e.g. if they tried to send
mpp if we requested no mpp).
This commit is contained in:
f321x
2025-12-09 14:31:12 +01:00
parent 5be598b808
commit 125a921cc4
3 changed files with 53 additions and 21 deletions

View File

@@ -134,6 +134,7 @@ class PaymentInfo:
min_final_cltv_delta: int
expiry_delay: int
creation_ts: int = dataclasses.field(default_factory=lambda: int(time.time()))
invoice_features: LnFeatures
@property
def expiration_ts(self):
@@ -147,6 +148,7 @@ class PaymentInfo:
assert isinstance(self.min_final_cltv_delta, int)
assert isinstance(self.expiry_delay, int) and self.expiry_delay > 0
assert isinstance(self.creation_ts, int)
assert isinstance(self.invoice_features, LnFeatures)
def __post_init__(self):
self.validate()
@@ -903,8 +905,8 @@ class LNWallet(LNWorker):
LNWorker.__init__(self, self.node_keypair, features, config=self.config)
self.lnwatcher = LNWatcher(self)
self.lnrater: LNRater = None
# lightning_payments: "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts
self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int]]
# "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features
self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int, int]]
self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
self._bolt11_cache = {}
# note: this sweep_address is only used as fallback; as it might result in address-reuse
@@ -1628,6 +1630,7 @@ class LNWallet(LNWorker):
status=PR_UNPAID,
min_final_cltv_delta=min_final_cltv_delta,
expiry_delay=LN_EXPIRY_NEVER,
invoice_features=invoice_features,
)
self.save_payment_info(info)
self.wallet.set_label(key, lnaddr.get_description())
@@ -2332,6 +2335,16 @@ class LNWallet(LNWorker):
route[-1].node_features |= invoice_features
return route
def _get_invoice_features(self, amount_msat: Optional[int]) -> LnFeatures:
invoice_features = self.features.for_invoice()
if not self.uses_trampoline():
invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
needs_jit: bool = self.receive_requires_jit_channel(amount_msat)
if needs_jit:
# jit only works with single htlcs, mpp will cause LSP to open channels for each htlc
invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ
return invoice_features
def clear_invoices_cache(self):
self._bolt11_cache.clear()
@@ -2351,15 +2364,8 @@ class LNWallet(LNWorker):
assert amount_msat is None or amount_msat > 0
timestamp = int(time.time())
needs_jit: bool = self.receive_requires_jit_channel(amount_msat)
routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels, needs_jit=needs_jit)
self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, jit: {needs_jit}, sat: {(amount_msat or 0) // 1000}")
invoice_features = self.features.for_invoice()
if not self.uses_trampoline():
invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
if needs_jit:
# jit only works with single htlcs, mpp will cause LSP to open channels for each htlc
invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ
routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels)
self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, sat: {(amount_msat or 0) // 1000}")
payment_secret = self.get_payment_secret(payment_info.payment_hash)
amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None
min_final_cltv_delta = payment_info.min_final_cltv_delta + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE
@@ -2370,7 +2376,7 @@ class LNWallet(LNWorker):
('d', message),
('c', min_final_cltv_delta),
('x', payment_info.expiry_delay),
('9', invoice_features),
('9', payment_info.invoice_features),
('f', fallback_address),
] + routing_hints,
date=timestamp,
@@ -2401,13 +2407,15 @@ class LNWallet(LNWorker):
payment_preimage = os.urandom(32)
payment_hash = sha256(payment_preimage)
min_final_cltv_delta = min_final_cltv_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED
invoice_features = self._get_invoice_features(amount_msat)
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_msat,
direction=RECEIVED,
status=PR_UNPAID,
min_final_cltv_delta=min_final_cltv_delta,
expiry_delay=exp_delay
expiry_delay=exp_delay,
invoice_features=invoice_features,
)
self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
self.save_payment_info(info, write_to_disk=False)
@@ -2514,7 +2522,7 @@ class LNWallet(LNWorker):
with self.lock:
if key in self.payment_info:
stored_tuple = self.payment_info[key]
amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple
amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features = stored_tuple
return PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_msat,
@@ -2523,6 +2531,7 @@ class LNWallet(LNWorker):
min_final_cltv_delta=min_final_cltv_delta,
expiry_delay=expiry_delay,
creation_ts=creation_ts,
invoice_features=LnFeatures(invoice_features),
)
return None
@@ -2533,14 +2542,15 @@ class LNWallet(LNWorker):
min_final_cltv_delta: int,
exp_delay: int,
):
amount = lightning_amount_sat * 1000 if lightning_amount_sat else None
amount_msat = lightning_amount_sat * 1000 if lightning_amount_sat else None
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount,
amount_msat=amount_msat,
direction=RECEIVED,
status=PR_UNPAID,
min_final_cltv_delta=min_final_cltv_delta,
expiry_delay=exp_delay,
invoice_features=self._get_invoice_features(amount_msat),
)
self.save_payment_info(info, write_to_disk=False)
@@ -2574,7 +2584,7 @@ class LNWallet(LNWorker):
if info != dataclasses.replace(old_info, status=info.status):
# differs more than in status. let's fail
raise Exception(f"payment_hash already in use: {info=} != {old_info=}")
v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts
v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts, int(info.invoice_features)
self.payment_info[info.db_key] = v
if write_to_disk:
self.wallet.save_db()
@@ -2907,10 +2917,11 @@ class LNWallet(LNWorker):
else:
self.logger.info(f'htlc_failed: waiting for other htlcs to fail (phash={payment_hash.hex()})')
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None, needs_jit=False):
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None):
"""calculate routing hints (BOLT-11 'r' field)"""
routing_hints = []
if needs_jit:
if self.receive_requires_jit_channel(amount_msat):
self.logger.debug(f"will request just-in-time channel")
node_id, rest = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)
alias_or_scid = self.get_static_jit_scid_alias()
routing_hints.append(('r', [(node_id, alias_or_scid, 0, 0, 144)]))
@@ -3093,7 +3104,7 @@ class LNWallet(LNWorker):
# check if zeroconf is accepted and client has trusted zeroconf node configured
return False
try:
node_id = extract_nodeid(self.wallet.config.ZEROCONF_TRUSTED_NODE)[0]
node_id = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)[0]
except ConnStringFormatError:
# invalid connection string
return False

View File

@@ -69,7 +69,7 @@ class WalletUnfinished(WalletFileException):
# seed_version is now used for the version of the wallet file
OLD_SEED_VERSION = 4 # electrum versions < 2.0
NEW_SEED_VERSION = 11 # electrum versions >= 2.0
FINAL_SEED_VERSION = 65 # electrum >= 2.7 will set this to prevent
FINAL_SEED_VERSION = 66 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format
@@ -237,6 +237,7 @@ class WalletDBUpgrader(Logger):
self._convert_version_63()
self._convert_version_64()
self._convert_version_65()
self._convert_version_66()
self.put('seed_version', FINAL_SEED_VERSION) # just to be sure
def _convert_wallet_type(self):
@@ -1315,6 +1316,22 @@ class WalletDBUpgrader(Logger):
self.data['received_mpp_htlcs'] = new_mpp_sets
self.data['seed_version'] = 65
def _convert_version_66(self):
"""Add invoice features to PaymentInfo"""
if not self._is_upgrade_method_needed(65, 65):
return
new_payment_infos = {}
old_payment_infos = self.data.get('lightning_payments', {})
for key, old_v in old_payment_infos.items():
amount_msat, status, min_final_cltv_expiry, expiry, creation_ts = old_v
invoice_features = 147712 # <VAR_ONION_REQ|PAYMENT_SECRET_REQ|BASIC_MPP_OPT: 0x24100>
new_v = (amount_msat, status, min_final_cltv_expiry, expiry, creation_ts, invoice_features)
new_payment_infos[key] = new_v
self.data['lightning_payments'] = new_payment_infos
self.data['seed_version'] = 66
def _convert_imported(self):
if not self._is_upgrade_method_needed(0, 13):
return

View File

@@ -359,6 +359,9 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
is_payment_bundle_complete = LNWallet.is_payment_bundle_complete
delete_payment_bundle = LNWallet.delete_payment_bundle
_process_htlc_log = LNWallet._process_htlc_log
_get_invoice_features = LNWallet._get_invoice_features
receive_requires_jit_channel = LNWallet.receive_requires_jit_channel
can_get_zeroconf_channel = LNWallet.can_get_zeroconf_channel
class MockTransport:
@@ -594,6 +597,7 @@ class TestPeer(ElectrumTestCase):
status=PR_UNPAID,
min_final_cltv_delta=min_final_cltv_delta,
expiry_delay=expiry or LN_EXPIRY_NEVER,
invoice_features=invoice_features,
)
w2.save_payment_info(info)
lnaddr1 = LnAddr(