From 923d48f9db0ca87bed40c3a145875d5ffe4303a3 Mon Sep 17 00:00:00 2001 From: f321x Date: Fri, 28 Nov 2025 16:22:22 +0100 Subject: [PATCH] lnworker: differentiate PaymentInfo by direction Allows storing two different payment info of the same payment hash by including the direction into the db key. We create and store PaymentInfo for sending attempts and for requests (receiving), if we try to pay ourself (e.g. through a channel rebalance) the checks in `save_payment_info` would prevent this and throw an exception. By storing the PaymentInfos of outgoing and incoming payments separately in the db this collision is avoided and it makes it easier to reason about which PaymentInfo belongs where. --- electrum/commands.py | 21 ++++----- electrum/gui/qml/qerequestdetails.py | 4 +- electrum/gui/qt/send_tab.py | 3 +- electrum/lnpeer.py | 4 +- electrum/lnutil.py | 1 + electrum/lnworker.py | 70 ++++++++++++++++------------ electrum/plugins/nwc/nwcserver.py | 10 ++-- electrum/submarine_swaps.py | 14 ++++-- electrum/wallet.py | 8 ++-- electrum/wallet_db.py | 21 ++++++++- tests/test_commands.py | 5 +- tests/test_lnpeer.py | 52 ++++++++++----------- 12 files changed, 125 insertions(+), 88 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index 6aef50cec..6482b0b3d 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -70,7 +70,7 @@ from .wallet import ( ) from .address_synchronizer import TX_HEIGHT_LOCAL from .mnemonic import Mnemonic -from .lnutil import (channel_id_from_funding_tx, LnFeatures, SENT, MIN_FINAL_CLTV_DELTA_ACCEPTED, +from .lnutil import (channel_id_from_funding_tx, LnFeatures, SENT, RECEIVED, MIN_FINAL_CLTV_DELTA_ACCEPTED, PaymentFeeBudget, NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE) from .plugin import run_hook, DeviceMgr, Plugins from .version import ELECTRUM_VERSION @@ -1402,7 +1402,7 @@ class Commands(Logger): arg:int:min_final_cltv_expiry_delta:Optional min final cltv expiry delta (default: 294 blocks) """ assert len(payment_hash) == 64, f"Invalid payment hash length: {len(payment_hash)} != 64" - assert payment_hash not in wallet.lnworker.payment_info, "Payment hash already used!" + assert not wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED), "Payment hash already used!" assert payment_hash not in wallet.lnworker.dont_expire_htlcs, "Payment hash already used!" assert wallet.lnworker.get_preimage(bfh(payment_hash)) is None, "Already got a preimage for this payment hash!" assert MIN_FINAL_CLTV_DELTA_ACCEPTED < min_final_cltv_expiry_delta < 576, "Use a sane min_final_cltv_expiry_delta value" @@ -1417,7 +1417,7 @@ class Commands(Logger): min_final_cltv_delta=min_final_cltv_expiry_delta, exp_delay=expiry, ) - info = wallet.lnworker.get_payment_info(bfh(payment_hash)) + info = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED) lnaddr, invoice = wallet.lnworker.get_bolt11_invoice( payment_info=info, message=memo, @@ -1443,12 +1443,11 @@ class Commands(Logger): assert len(preimage) == 64, f"Invalid payment_hash length: {len(preimage)} != 64" payment_hash: str = crypto.sha256(bfh(preimage)).hex() assert payment_hash not in wallet.lnworker._preimages, f"Invoice {payment_hash=} already settled" - assert payment_hash in wallet.lnworker.payment_info, \ - f"Couldn't find lightning invoice for {payment_hash=}" + info = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED) + assert info, f"Couldn't find lightning invoice for {payment_hash=}" assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"Invoice {payment_hash=} not a hold invoice?" assert wallet.lnworker.is_complete_mpp(bfh(payment_hash)), \ f"MPP incomplete, cannot settle hold invoice {payment_hash} yet" - info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash)) assert (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) >= (info.amount_msat or 0) wallet.lnworker.save_preimage(bfh(payment_hash), bfh(preimage)) util.trigger_callback('wallet_updated', wallet) @@ -1464,13 +1463,13 @@ class Commands(Logger): arg:str:payment_hash:Payment hash in hex of the hold invoice """ - assert payment_hash in wallet.lnworker.payment_info, \ + assert wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED), \ f"Couldn't find lightning invoice for payment hash {payment_hash}" assert payment_hash not in wallet.lnworker._preimages, "Cannot cancel anymore, preimage already given." assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"{payment_hash=} not a hold invoice?" # set to PR_UNPAID so it can get deleted - wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID) - wallet.lnworker.delete_payment_info(payment_hash) + wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID, direction=RECEIVED) + wallet.lnworker.delete_payment_info(payment_hash, direction=RECEIVED) wallet.set_label(payment_hash, None) del wallet.lnworker.dont_expire_htlcs[payment_hash] while wallet.lnworker.is_complete_mpp(bfh(payment_hash)): @@ -1496,7 +1495,7 @@ class Commands(Logger): arg:str:payment_hash:Payment hash in hex of the hold invoice """ assert len(payment_hash) == 64, f"Invalid payment_hash length: {len(payment_hash)} != 64" - info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash)) + info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED) is_complete_mpp: bool = wallet.lnworker.is_complete_mpp(bfh(payment_hash)) amount_sat = (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) // 1000 result = { @@ -1518,7 +1517,7 @@ class Commands(Logger): elif wallet.lnworker.get_preimage_hex(payment_hash) is not None: result["status"] = "settled" plist = wallet.lnworker.get_payments(status='settled')[bfh(payment_hash)] - _dir, amount_msat, _fee, _ts = wallet.lnworker.get_payment_value(info, plist) + _dir, amount_msat, _fee, _ts = wallet.lnworker.get_payment_value(None, plist) result["received_amount_sat"] = amount_msat // 1000 result['preimage'] = wallet.lnworker.get_preimage_hex(payment_hash) if info is not None: diff --git a/electrum/gui/qml/qerequestdetails.py b/electrum/gui/qml/qerequestdetails.py index d0d5f29f4..288ef167c 100644 --- a/electrum/gui/qml/qerequestdetails.py +++ b/electrum/gui/qml/qerequestdetails.py @@ -8,7 +8,7 @@ from electrum.logging import get_logger from electrum.invoices import ( PR_UNPAID, PR_EXPIRED, PR_UNKNOWN, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING, PR_UNCONFIRMED, LN_EXPIRY_NEVER ) -from electrum.lnutil import MIN_FUNDING_SAT +from electrum.lnutil import MIN_FUNDING_SAT, RECEIVED from electrum.lnurl import LNURL3Data, request_lnurl_withdraw_callback, LNURLError from electrum.payment_identifier import PaymentIdentifier, PaymentIdentifierType from electrum.i18n import _ @@ -237,7 +237,7 @@ class QERequestDetails(QObject, QtEventListener): address=None, ) req = self._wallet.wallet.get_request(key) - info = self._wallet.wallet.lnworker.get_payment_info(req.payment_hash) + info = self._wallet.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED) _lnaddr, b11_invoice = self._wallet.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=req.get_message(), diff --git a/electrum/gui/qt/send_tab.py b/electrum/gui/qt/send_tab.py index 6c927b11d..2cbcdec9c 100644 --- a/electrum/gui/qt/send_tab.py +++ b/electrum/gui/qt/send_tab.py @@ -18,6 +18,7 @@ from electrum.util import ( NotEnoughFunds, NoDynamicFeeEstimates, parse_max_spend, UserCancelled, ChoiceItem, UserFacingException, ) +from electrum.lnutil import RECEIVED from electrum.invoices import PR_PAID, Invoice, PR_BROADCASTING, PR_BROADCAST from electrum.transaction import Transaction, PartialTxInput, PartialTxOutput from electrum.network import TxBroadcastError, BestEffortRequestFailed @@ -979,7 +980,7 @@ class SendTab(QWidget, MessageBoxMixin, Logger): address=None, ) req = self.wallet.get_request(key) - info = self.wallet.lnworker.get_payment_info(req.payment_hash) + info = self.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED) _lnaddr, b11_invoice = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=req.get_message(), diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 1e8da558b..591ce47ce 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2221,7 +2221,7 @@ class Peer(Logger, EventListener): outer_onion_payment_secret=payment_secret_from_onion, ) - info = self.lnworker.get_payment_info(payment_hash) + info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED) if info is None: _log_fail_reason(f"no payment_info found for RHASH {payment_hash.hex()}") raise exc_incorrect_or_unknown_pd @@ -3115,7 +3115,7 @@ class Peer(Logger, EventListener): return None, None, fwd_cb # -- from here on it's assumed this set is a payment for us (not something to forward) -- - payment_info = self.lnworker.get_payment_info(payment_hash) + payment_info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED) if payment_info is None: _log_fail_reason(f"payment info has been deleted") return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 0d30d3718..70256a74d 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1089,6 +1089,7 @@ class HTLCOwner(IntEnum): return HTLCOwner(super().__neg__()) +# part of lightning_payments db keys class Direction(IntEnum): SENT = -1 # in the context of HTLCs: "offered" HTLCs RECEIVED = 1 # in the context of HTLCs: "received" HTLCs diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d3074abab..f96640cb1 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -45,7 +45,8 @@ from .fee_policy import ( FeePolicy, FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING, FEE_LN_MINIMUM_ETA_TARGET ) -from .invoices import Invoice, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, BaseInvoice +from .invoices import (Invoice, Request, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, + BaseInvoice) from .bitcoin import COIN, opcodes, make_op_return, address_to_scripthash, DummyAddress from .bip32 import BIP32Node from .address_synchronizer import TX_HEIGHT_LOCAL @@ -120,7 +121,7 @@ class PaymentInfo: """Information required to handle incoming htlcs for a payment request""" payment_hash: bytes amount_msat: Optional[int] - direction: int + direction: lnutil.Direction status: int min_final_cltv_delta: int expiry_delay: int @@ -142,6 +143,13 @@ class PaymentInfo: def __post_init__(self): self.validate() + @property + def db_key(self) -> str: + return self.calc_db_key(payment_hash_hex=self.payment_hash.hex(), direction=self.direction) + + @classmethod + def calc_db_key(cls, *, payment_hash_hex: str, direction: lnutil.Direction) -> str: + return f"{payment_hash_hex}:{int(direction)}" SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id @@ -887,8 +895,8 @@ class LNWallet(LNWorker): LNWorker.__init__(self, self.node_keypair, features, config=self.config) self.lnwatcher = LNWatcher(self) self.lnrater: LNRater = None - # lightning_payments: RHASH -> amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts - self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int, int]] + # lightning_payments: "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts + self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int]] self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self._bolt11_cache = {} # note: this sweep_address is only used as fallback; as it might result in address-reuse @@ -1104,7 +1112,7 @@ class LNWallet(LNWorker): return out def get_payment_value( - self, info: Optional['PaymentInfo'], + self, sent_info: Optional['PaymentInfo'], plist: List[HTLCWithStatus] ) -> Tuple[PaymentDirection, int, Optional[int], int]: """ fee_msat is included in amount_msat""" @@ -1112,7 +1120,7 @@ class LNWallet(LNWorker): amount_msat = sum(int(x.direction) * x.htlc.amount_msat for x in plist) if all(x.direction == SENT for x in plist): direction = PaymentDirection.SENT - fee_msat = (- info.amount_msat - amount_msat) if info else None + fee_msat = (- sent_info.amount_msat - amount_msat) if sent_info else None elif all(x.direction == RECEIVED for x in plist): direction = PaymentDirection.RECEIVED fee_msat = None @@ -1135,12 +1143,12 @@ class LNWallet(LNWorker): if len(plist) == 0: continue key = payment_hash.hex() - info = self.get_payment_info(payment_hash) + sent_info = self.get_payment_info(payment_hash, direction=SENT) # note: just after successfully paying an invoice using MPP, amount and fee values might be shifted # temporarily: the amount only considers 'settled' htlcs (see plist above), but we might also # have some inflight htlcs still. Until all relevant htlcs settle, the amount will be lower than # expected and the fee higher (the inflight htlcs will be effectively counted as fees). - direction, amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist) + direction, amount_msat, fee_msat, timestamp = self.get_payment_value(sent_info, plist) label = self.wallet.get_label_for_rhash(key) if not label and direction == PaymentDirection.FORWARDING: label = _('Forwarding') @@ -1597,7 +1605,7 @@ class LNWallet(LNWorker): invoice_features = lnaddr.get_features() r_tags = lnaddr.get_routing_info('r') amount_to_pay = lnaddr.get_amount_msat() - status = self.get_payment_status(payment_hash) + status = self.get_payment_status(payment_hash, direction=SENT) if status == PR_PAID: raise PaymentFailure(_("This invoice has been paid already")) if status == PR_INFLIGHT: @@ -2491,13 +2499,13 @@ class LNWallet(LNWorker): preimage_bytes = self.get_preimage(bytes.fromhex(payment_hash)) or b"" return preimage_bytes.hex() or None - def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]: + def get_payment_info(self, payment_hash: bytes, *, direction: lnutil.Direction) -> Optional[PaymentInfo]: """returns None if payment_hash is a payment we are forwarding""" - key = payment_hash.hex() + key = PaymentInfo.calc_db_key(payment_hash_hex=payment_hash.hex(), direction=direction) with self.lock: if key in self.payment_info: stored_tuple = self.payment_info[key] - amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple + amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple return PaymentInfo( payment_hash=payment_hash, amount_msat=amount_msat, @@ -2542,7 +2550,7 @@ class LNWallet(LNWorker): def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: assert info.status in SAVED_PR_STATUS with self.lock: - if old_info := self.get_payment_info(payment_hash=info.payment_hash): + if old_info := self.get_payment_info(payment_hash=info.payment_hash, direction=info.direction): if info == old_info: return # already saved if info.direction == SENT: @@ -2551,8 +2559,8 @@ class LNWallet(LNWorker): if info != dataclasses.replace(old_info, status=info.status): # differs more than in status. let's fail raise Exception(f"payment_hash already in use: {info=} != {old_info=}") - key = info.payment_hash.hex() - self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0 + v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts + self.payment_info[info.db_key] = v if write_to_disk: self.wallet.save_db() @@ -2698,13 +2706,15 @@ class LNWallet(LNWorker): self.active_forwardings.pop(payment_key_hex, None) self.forwarding_failures.pop(payment_key_hex, None) - def get_payment_status(self, payment_hash: bytes) -> int: - info = self.get_payment_info(payment_hash) + def get_payment_status(self, payment_hash: bytes, *, direction: lnutil.Direction) -> int: + info = self.get_payment_info(payment_hash, direction=direction) return info.status if info else PR_UNPAID def get_invoice_status(self, invoice: BaseInvoice) -> int: invoice_id = invoice.rhash - status = self.get_payment_status(bfh(invoice_id)) + assert isinstance(invoice, (Request, Invoice)), type(invoice) + direction = RECEIVED if isinstance(invoice, Request) else SENT + status = self.get_payment_status(bfh(invoice_id), direction=direction) if status == PR_UNPAID and invoice_id in self.inflight_payments: return PR_INFLIGHT # status may be PR_FAILED @@ -2718,24 +2728,24 @@ class LNWallet(LNWorker): elif key in self.inflight_payments: self.inflight_payments.remove(key) if status in SAVED_PR_STATUS: - self.set_payment_status(bfh(key), status) + self.set_payment_status(bfh(key), status, direction=SENT) util.trigger_callback('invoice_status', self.wallet, key, status) self.logger.info(f"set_invoice_status {key}: {status}") # liquidity changed self.clear_invoices_cache() def set_request_status(self, payment_hash: bytes, status: int) -> None: - if self.get_payment_status(payment_hash) == status: + if self.get_payment_status(payment_hash, direction=RECEIVED) == status: return - self.set_payment_status(payment_hash, status) + self.set_payment_status(payment_hash, status, direction=RECEIVED) request_id = payment_hash.hex() req = self.wallet.get_request(request_id) if req is None: return util.trigger_callback('request_status', self.wallet, request_id, status) - def set_payment_status(self, payment_hash: bytes, status: int) -> None: - info = self.get_payment_info(payment_hash) + def set_payment_status(self, payment_hash: bytes, status: int, *, direction: lnutil.Direction) -> None: + info = self.get_payment_info(payment_hash, direction=direction) if info is None: # if we are forwarding return @@ -2930,14 +2940,15 @@ class LNWallet(LNWorker): cltv_delta)])) return routing_hints - def delete_payment_info(self, payment_hash_hex: str): + def delete_payment_info(self, payment_hash_hex: str, *, direction: lnutil.Direction): # This method is called when an invoice or request is deleted by the user. # The GUI only lets the user delete invoices or requests that have not been paid. # Once an invoice/request has been paid, it is part of the history, # and get_lightning_history assumes that payment_info is there. - assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID + assert self.get_payment_status(bytes.fromhex(payment_hash_hex), direction=direction) != PR_PAID with self.lock: - self.payment_info.pop(payment_hash_hex, None) + key = PaymentInfo.calc_db_key(payment_hash_hex=payment_hash_hex, direction=direction) + self.payment_info.pop(key, None) def get_balance(self, *, frozen=False) -> Decimal: with self.lock: @@ -3185,7 +3196,7 @@ class LNWallet(LNWorker): amount_msat=amount_msat, exp_delay=3600, ) - info = self.get_payment_info(payment_hash) + info = self.get_payment_info(payment_hash, direction=RECEIVED) lnaddr, invoice = self.get_bolt11_invoice( payment_info=info, message='rebalance', @@ -3804,14 +3815,13 @@ class LNWallet(LNWorker): - Alice sends htlc A->B->C, for 1 sat, with HASH1 - Bob must not release the preimage of HASH1 """ - payment_info = self.get_payment_info(payment_hash) - is_our_payreq = payment_info and payment_info.direction == RECEIVED + payment_info = self.get_payment_info(payment_hash, direction=RECEIVED) # note: If we don't have the preimage for a payment request, then it must be a hold invoice. # Hold invoices are created by other parties (e.g. a counterparty initiating a submarine swap), # and it is the other party choosing the payment_hash. If we failed HTLCs with payment_hashes colliding # with hold invoices, then a party that can make us save a hold invoice for an arbitrary hash could # also make us fail arbitrary HTLCs. - return bool(is_our_payreq and self.get_preimage(payment_hash)) + return bool(payment_info and self.get_preimage(payment_hash)) def create_onion_for_route( self, *, diff --git a/electrum/plugins/nwc/nwcserver.py b/electrum/plugins/nwc/nwcserver.py index c2cd925f5..1ed021ff8 100644 --- a/electrum/plugins/nwc/nwcserver.py +++ b/electrum/plugins/nwc/nwcserver.py @@ -42,6 +42,7 @@ from electrum.util import log_exceptions, ca_path, OldTaskGroup, get_asyncio_loo get_running_loop from electrum.invoices import Invoice, Request, PR_UNKNOWN, PR_PAID, BaseInvoice, PR_INFLIGHT from electrum import constants +from electrum.lnutil import RECEIVED if TYPE_CHECKING: from aiohttp_socks import ProxyConnector @@ -480,7 +481,7 @@ class NWCServer(Logger, EventListener): address=None ) req: Request = self.wallet.get_request(key) - info = self.wallet.lnworker.get_payment_info(req.payment_hash) + info = self.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED) try: lnaddr, b11 = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, @@ -537,7 +538,7 @@ class NWCServer(Logger, EventListener): b11 = invoice.lightning_invoice elif self.wallet.get_request(invoice.rhash): direction = "incoming" - info = self.wallet.lnworker.get_payment_info(invoice.payment_hash) + info = self.wallet.lnworker.get_payment_info(invoice.payment_hash, direction=RECEIVED) _, b11 = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=invoice.message, @@ -747,7 +748,7 @@ class NWCServer(Logger, EventListener): request: Optional[Request] = self.wallet.get_request(key) if not request or not request.is_lightning() or not status == PR_PAID: return - info = self.wallet.lnworker.get_payment_info(request.payment_hash) + info = self.wallet.lnworker.get_payment_info(request.payment_hash, direction=RECEIVED) _, b11 = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=request.message, @@ -947,7 +948,8 @@ class NWCServer(Logger, EventListener): payments = self.wallet.lnworker.get_payments(status='settled') plist = payments.get(payment_hash) if plist: - info = self.wallet.lnworker.get_payment_info(payment_hash) + direction = plist[0].direction + info = self.wallet.lnworker.get_payment_info(payment_hash, direction=direction) if info: dir, amount, fee, ts = self.wallet.lnworker.get_payment_value(info, plist) fee = abs(fee) if fee else None diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 5d6275f8e..8ea53df59 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -401,7 +401,7 @@ class SwapManager(Logger): if not swap.is_reverse and swap.payment_hash in self.lnworker.hold_invoice_callbacks: # unregister_hold_invoice will fail pending htlcs if there is no preimage available self.lnworker.unregister_hold_invoice(swap.payment_hash) - self.lnworker.delete_payment_info(swap.payment_hash.hex()) + self.lnworker.delete_payment_info(swap.payment_hash.hex(), direction=lnutil.RECEIVED) self.lnworker.clear_invoices_cache() self.lnwatcher.remove_callback(swap.lockup_address) if not swap.is_funded(): @@ -413,9 +413,13 @@ class SwapManager(Logger): self._swaps_by_lockup_address.pop(swap.lockup_address, None) if swap.prepay_hash is not None: self._prepayments.pop(swap.prepay_hash, None) - if self.lnworker.get_payment_status(swap.prepay_hash) != PR_PAID: - self.lnworker.delete_payment_info(swap.prepay_hash.hex()) + if self.lnworker.get_payment_status(swap.prepay_hash, direction=lnutil.RECEIVED) != PR_PAID: + self.lnworker.delete_payment_info(swap.prepay_hash.hex(), direction=lnutil.RECEIVED) self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash) + if self.lnworker.get_payment_status(swap.prepay_hash, direction=lnutil.SENT) != PR_PAID: + self.lnworker.delete_payment_info(swap.prepay_hash.hex(), direction=lnutil.SENT) + if self.lnworker.get_payment_status(swap.payment_hash, direction=lnutil.SENT) != PR_PAID: + self.lnworker.delete_payment_info(swap.payment_hash.hex(), direction=lnutil.SENT) @classmethod def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]: @@ -693,7 +697,7 @@ class SwapManager(Logger): min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) - info = self.lnworker.get_payment_info(payment_hash) + info = self.lnworker.get_payment_info(payment_hash, direction=lnutil.RECEIVED) lnaddr1, invoice = self.lnworker.get_bolt11_invoice( payment_info=info, message='Submarine swap', @@ -712,7 +716,7 @@ class SwapManager(Logger): min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) - info = self.lnworker.get_payment_info(prepay_hash) + info = self.lnworker.get_payment_info(prepay_hash, direction=lnutil.RECEIVED) lnaddr2, prepay_invoice = self.lnworker.get_bolt11_invoice( payment_info=info, message='Submarine swap prepayment', diff --git a/electrum/wallet.py b/electrum/wallet.py index 32f60119b..88002e96a 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -76,7 +76,7 @@ from .invoices import BaseInvoice, Invoice, Request, PR_PAID, PR_UNPAID, PR_EXPI from .contacts import Contacts from .mnemonic import Mnemonic from .lnworker import LNWallet -from .lnutil import MIN_FUNDING_SAT +from .lnutil import MIN_FUNDING_SAT, RECEIVED, SENT from .lntransport import extract_nodeid from .descriptor import Descriptor from .txbatcher import TxBatcher @@ -3014,7 +3014,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): return '' amount_msat = req.get_amount_msat() or None assert (amount_msat is None or amount_msat > 0), amount_msat - info = self.lnworker.get_payment_info(payment_hash) + info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED) assert info.amount_msat == amount_msat, f"{info.amount_msat=} != {amount_msat=}" lnaddr, invoice = self.lnworker.get_bolt11_invoice( payment_info=info, @@ -3074,7 +3074,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): if addr := req.get_address(): self._requests_addr_to_key[addr].discard(request_id) if req.is_lightning() and self.lnworker: - self.lnworker.delete_payment_info(req.rhash) + self.lnworker.delete_payment_info(req.rhash, direction=RECEIVED) if write_to_disk: self.save_db() @@ -3084,7 +3084,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): if inv is None: return if inv.is_lightning() and self.lnworker: - self.lnworker.delete_payment_info(inv.rhash) + self.lnworker.delete_payment_info(inv.rhash, direction=SENT) if write_to_disk: self.save_db() diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index bf22f03fb..325eb36f3 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -69,7 +69,7 @@ class WalletUnfinished(WalletFileException): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 63 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 64 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -235,6 +235,7 @@ class WalletDBUpgrader(Logger): self._convert_version_61() self._convert_version_62() self._convert_version_63() + self._convert_version_64() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1269,6 +1270,24 @@ class WalletDBUpgrader(Logger): self.data['seed_version'] = 63 + def _convert_version_64(self): + """Key payment_info by "rhash:direction" instead of just rhash to allow storing a PaymentInfo + for each direction""" + if not self._is_upgrade_method_needed(63, 63): + return + + new_payment_infos = {} + old_payment_infos = self.data.get('lightning_payments', {}) + for payment_hash, old_values in old_payment_infos.items(): + amount_msat, direction, status, min_final_cltv_expiry, expiry, creation_ts = old_values + # drop direction + new_values = (amount_msat, status, min_final_cltv_expiry, expiry, creation_ts) + new_key = f"{payment_hash}:{direction}" + new_payment_infos[new_key] = new_values # save new entry + + self.data['lightning_payments'] = new_payment_infos + self.data['seed_version'] = 64 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_commands.py b/tests/test_commands.py index 3016263e1..969bfb024 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -11,6 +11,7 @@ import shutil import electrum from electrum.commands import Commands, eval_bool from electrum import storage, wallet +from electrum.lnutil import RECEIVED from electrum.lnworker import RecvMPPResolution from electrum.wallet import Abstract_Wallet from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED @@ -509,7 +510,7 @@ class TestCommandsTestnet(ElectrumTestCase): ) invoice = lndecode(invoice=result['invoice']) assert invoice.paymenthash.hex() == payment_hash - assert payment_hash in wallet.lnworker.payment_info + assert wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED) assert payment_hash in wallet.lnworker.dont_expire_htlcs assert invoice.get_amount_sat() == 10000 assert invoice.get_description() == "test" @@ -520,7 +521,7 @@ class TestCommandsTestnet(ElectrumTestCase): payment_hash=payment_hash, wallet=wallet, ) - assert payment_hash not in wallet.lnworker.payment_info + assert not wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED) assert payment_hash not in wallet.lnworker.dont_expire_htlcs assert wallet.get_label_for_rhash(rhash=invoice.paymenthash.hex()) == "" assert cancel_result['cancelled'] == payment_hash diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 3fe6a3729..57def01cb 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -865,10 +865,10 @@ class TestPeerDirect(TestPeer): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) results = {} async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await w1.pay_invoice(pay_req) if result is True: - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) results[lnaddr] = PaymentDone() else: results[lnaddr] = PaymentFailure() @@ -988,7 +988,7 @@ class TestPeerDirect(TestPeer): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) assert lnaddr.get_min_final_cltv_delta() == 400 # what the receiver expects lnaddr.tags = [tag for tag in lnaddr.tags if tag[0] != 'c'] + [['c', 144]] b11 = lnencode(lnaddr, w2.node_keypair.privkey) @@ -1079,7 +1079,7 @@ class TestPeerDirect(TestPeer): result, log = await w1.pay_invoice(pay_req) assert result is True # now pay the same invoice again, the payment should be rejected by w2 - w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID) + w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID, direction=lnutil.SENT) result, log = await w1.pay_invoice(pay_req) if not result: # w1.pay_invoice returned a payment failure as the payment got rejected by w2 @@ -1224,8 +1224,8 @@ class TestPeerDirect(TestPeer): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash)) - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash, direction=RECEIVED)) route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route p1.pay( @@ -1297,7 +1297,7 @@ class TestPeerDirect(TestPeer): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route p1.pay( @@ -1997,11 +1997,11 @@ class TestPeerDirect(TestPeer): w2.dont_settle_htlcs[pay_req.rhash] = None async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) if result is True: self.assertNotIn(pay_req.rhash, w2.dont_settle_htlcs) - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) return PaymentDone() else: self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) @@ -2067,10 +2067,10 @@ class TestPeerDirect(TestPeer): w2.dont_expire_htlcs[pay_req.rhash] = None if not test_expiry else 20 async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) if result is True: - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) return PaymentDone() else: self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) @@ -2210,12 +2210,12 @@ class TestPeerForwarding(TestPeer): return split_amount_normal(total_amount, num_parts) async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) with mock.patch('electrum.mpp_split.split_amount_normal', side_effect=mocked_split_amount_normal): result, log = await graph.workers['bob'].pay_invoice(pay_req) self.assertTrue(result) - self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) async def f(): async with OldTaskGroup() as group: @@ -2242,10 +2242,10 @@ class TestPeerForwarding(TestPeer): graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph']) peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req) self.assertTrue(result) - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) raise PaymentDone() async def f(): async with OldTaskGroup() as group: @@ -2309,10 +2309,10 @@ class TestPeerForwarding(TestPeer): graph.workers['carol'].network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE = True peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req) self.assertFalse(result) - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) raise PaymentDone() async def f(): @@ -2336,11 +2336,11 @@ class TestPeerForwarding(TestPeer): async def pay(lnaddr, pay_req): self.assertEqual(500000000000, graph.channels[('alice', 'bob')].balance(LOCAL)) self.assertEqual(500000000000, graph.channels[('dave', 'bob')].balance(LOCAL)) - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=2) self.assertEqual(2, len(log)) self.assertTrue(result) - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) self.assertEqual([graph.channels[('alice', 'carol')].short_channel_id, graph.channels[('carol', 'dave')].short_channel_id], [edge.short_channel_id for edge in log[0].route]) self.assertEqual([graph.channels[('alice', 'bob')].short_channel_id, graph.channels[('bob', 'dave')].short_channel_id], @@ -2436,11 +2436,11 @@ class TestPeerForwarding(TestPeer): amount_to_pay = 100_000_000 peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=3) self.assertTrue(result) self.assertEqual(2, len(log)) - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code) liquidity_hints = graph.workers['alice'].network.path_finder.liquidity_hints @@ -2507,14 +2507,14 @@ class TestPeerForwarding(TestPeer): assert alice_w.network.channel_db is not None lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=True, amount_msat=amount_to_pay) self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure) - self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await alice_w.pay_invoice(pay_req, attempts=attempts) if not bob_forwarding: # reset to previous state, sleep 2s so that the second htlc can time out graph.workers['bob'].enable_htlc_forwarding = True await asyncio.sleep(2) if result: - self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) # check mpp is cleaned up async with OldTaskGroup() as g: for peer in peers: @@ -2642,7 +2642,7 @@ class TestPeerForwarding(TestPeer): dest_w = graph.workers[destination_name] async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await sender_w.pay_invoice(pay_req, attempts=attempts) async with OldTaskGroup() as g: for peer in peers: @@ -2653,7 +2653,7 @@ class TestPeerForwarding(TestPeer): for peer in peers: self.assertEqual(len(peer.lnworker.active_forwardings), 0) if result: - self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) raise PaymentDone() else: raise NoPathFound() @@ -2875,7 +2875,7 @@ class TestPeerForwarding(TestPeer): peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req) self.assertEqual(OnionFailureCode.INVALID_ONION_VERSION, log[0].failure_msg.code) self.assertFalse(result, msg=log)