diff --git a/electrum/commands.py b/electrum/commands.py index 6aef50cec..6482b0b3d 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -70,7 +70,7 @@ from .wallet import ( ) from .address_synchronizer import TX_HEIGHT_LOCAL from .mnemonic import Mnemonic -from .lnutil import (channel_id_from_funding_tx, LnFeatures, SENT, MIN_FINAL_CLTV_DELTA_ACCEPTED, +from .lnutil import (channel_id_from_funding_tx, LnFeatures, SENT, RECEIVED, MIN_FINAL_CLTV_DELTA_ACCEPTED, PaymentFeeBudget, NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE) from .plugin import run_hook, DeviceMgr, Plugins from .version import ELECTRUM_VERSION @@ -1402,7 +1402,7 @@ class Commands(Logger): arg:int:min_final_cltv_expiry_delta:Optional min final cltv expiry delta (default: 294 blocks) """ assert len(payment_hash) == 64, f"Invalid payment hash length: {len(payment_hash)} != 64" - assert payment_hash not in wallet.lnworker.payment_info, "Payment hash already used!" + assert not wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED), "Payment hash already used!" assert payment_hash not in wallet.lnworker.dont_expire_htlcs, "Payment hash already used!" assert wallet.lnworker.get_preimage(bfh(payment_hash)) is None, "Already got a preimage for this payment hash!" assert MIN_FINAL_CLTV_DELTA_ACCEPTED < min_final_cltv_expiry_delta < 576, "Use a sane min_final_cltv_expiry_delta value" @@ -1417,7 +1417,7 @@ class Commands(Logger): min_final_cltv_delta=min_final_cltv_expiry_delta, exp_delay=expiry, ) - info = wallet.lnworker.get_payment_info(bfh(payment_hash)) + info = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED) lnaddr, invoice = wallet.lnworker.get_bolt11_invoice( payment_info=info, message=memo, @@ -1443,12 +1443,11 @@ class Commands(Logger): assert len(preimage) == 64, f"Invalid payment_hash length: {len(preimage)} != 64" payment_hash: str = crypto.sha256(bfh(preimage)).hex() assert payment_hash not in wallet.lnworker._preimages, f"Invoice {payment_hash=} already settled" - assert payment_hash in wallet.lnworker.payment_info, \ - f"Couldn't find lightning invoice for {payment_hash=}" + info = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED) + assert info, f"Couldn't find lightning invoice for {payment_hash=}" assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"Invoice {payment_hash=} not a hold invoice?" assert wallet.lnworker.is_complete_mpp(bfh(payment_hash)), \ f"MPP incomplete, cannot settle hold invoice {payment_hash} yet" - info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash)) assert (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) >= (info.amount_msat or 0) wallet.lnworker.save_preimage(bfh(payment_hash), bfh(preimage)) util.trigger_callback('wallet_updated', wallet) @@ -1464,13 +1463,13 @@ class Commands(Logger): arg:str:payment_hash:Payment hash in hex of the hold invoice """ - assert payment_hash in wallet.lnworker.payment_info, \ + assert wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED), \ f"Couldn't find lightning invoice for payment hash {payment_hash}" assert payment_hash not in wallet.lnworker._preimages, "Cannot cancel anymore, preimage already given." assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"{payment_hash=} not a hold invoice?" # set to PR_UNPAID so it can get deleted - wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID) - wallet.lnworker.delete_payment_info(payment_hash) + wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID, direction=RECEIVED) + wallet.lnworker.delete_payment_info(payment_hash, direction=RECEIVED) wallet.set_label(payment_hash, None) del wallet.lnworker.dont_expire_htlcs[payment_hash] while wallet.lnworker.is_complete_mpp(bfh(payment_hash)): @@ -1496,7 +1495,7 @@ class Commands(Logger): arg:str:payment_hash:Payment hash in hex of the hold invoice """ assert len(payment_hash) == 64, f"Invalid payment_hash length: {len(payment_hash)} != 64" - info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash)) + info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED) is_complete_mpp: bool = wallet.lnworker.is_complete_mpp(bfh(payment_hash)) amount_sat = (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) // 1000 result = { @@ -1518,7 +1517,7 @@ class Commands(Logger): elif wallet.lnworker.get_preimage_hex(payment_hash) is not None: result["status"] = "settled" plist = wallet.lnworker.get_payments(status='settled')[bfh(payment_hash)] - _dir, amount_msat, _fee, _ts = wallet.lnworker.get_payment_value(info, plist) + _dir, amount_msat, _fee, _ts = wallet.lnworker.get_payment_value(None, plist) result["received_amount_sat"] = amount_msat // 1000 result['preimage'] = wallet.lnworker.get_preimage_hex(payment_hash) if info is not None: diff --git a/electrum/gui/qml/qerequestdetails.py b/electrum/gui/qml/qerequestdetails.py index d0d5f29f4..288ef167c 100644 --- a/electrum/gui/qml/qerequestdetails.py +++ b/electrum/gui/qml/qerequestdetails.py @@ -8,7 +8,7 @@ from electrum.logging import get_logger from electrum.invoices import ( PR_UNPAID, PR_EXPIRED, PR_UNKNOWN, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING, PR_UNCONFIRMED, LN_EXPIRY_NEVER ) -from electrum.lnutil import MIN_FUNDING_SAT +from electrum.lnutil import MIN_FUNDING_SAT, RECEIVED from electrum.lnurl import LNURL3Data, request_lnurl_withdraw_callback, LNURLError from electrum.payment_identifier import PaymentIdentifier, PaymentIdentifierType from electrum.i18n import _ @@ -237,7 +237,7 @@ class QERequestDetails(QObject, QtEventListener): address=None, ) req = self._wallet.wallet.get_request(key) - info = self._wallet.wallet.lnworker.get_payment_info(req.payment_hash) + info = self._wallet.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED) _lnaddr, b11_invoice = self._wallet.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=req.get_message(), diff --git a/electrum/gui/qt/send_tab.py b/electrum/gui/qt/send_tab.py index 6c927b11d..2cbcdec9c 100644 --- a/electrum/gui/qt/send_tab.py +++ b/electrum/gui/qt/send_tab.py @@ -18,6 +18,7 @@ from electrum.util import ( NotEnoughFunds, NoDynamicFeeEstimates, parse_max_spend, UserCancelled, ChoiceItem, UserFacingException, ) +from electrum.lnutil import RECEIVED from electrum.invoices import PR_PAID, Invoice, PR_BROADCASTING, PR_BROADCAST from electrum.transaction import Transaction, PartialTxInput, PartialTxOutput from electrum.network import TxBroadcastError, BestEffortRequestFailed @@ -979,7 +980,7 @@ class SendTab(QWidget, MessageBoxMixin, Logger): address=None, ) req = self.wallet.get_request(key) - info = self.wallet.lnworker.get_payment_info(req.payment_hash) + info = self.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED) _lnaddr, b11_invoice = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=req.get_message(), diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 1e8da558b..591ce47ce 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2221,7 +2221,7 @@ class Peer(Logger, EventListener): outer_onion_payment_secret=payment_secret_from_onion, ) - info = self.lnworker.get_payment_info(payment_hash) + info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED) if info is None: _log_fail_reason(f"no payment_info found for RHASH {payment_hash.hex()}") raise exc_incorrect_or_unknown_pd @@ -3115,7 +3115,7 @@ class Peer(Logger, EventListener): return None, None, fwd_cb # -- from here on it's assumed this set is a payment for us (not something to forward) -- - payment_info = self.lnworker.get_payment_info(payment_hash) + payment_info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED) if payment_info is None: _log_fail_reason(f"payment info has been deleted") return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 0d30d3718..70256a74d 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1089,6 +1089,7 @@ class HTLCOwner(IntEnum): return HTLCOwner(super().__neg__()) +# part of lightning_payments db keys class Direction(IntEnum): SENT = -1 # in the context of HTLCs: "offered" HTLCs RECEIVED = 1 # in the context of HTLCs: "received" HTLCs diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d3074abab..f96640cb1 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -45,7 +45,8 @@ from .fee_policy import ( FeePolicy, FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING, FEE_LN_MINIMUM_ETA_TARGET ) -from .invoices import Invoice, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, BaseInvoice +from .invoices import (Invoice, Request, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, + BaseInvoice) from .bitcoin import COIN, opcodes, make_op_return, address_to_scripthash, DummyAddress from .bip32 import BIP32Node from .address_synchronizer import TX_HEIGHT_LOCAL @@ -120,7 +121,7 @@ class PaymentInfo: """Information required to handle incoming htlcs for a payment request""" payment_hash: bytes amount_msat: Optional[int] - direction: int + direction: lnutil.Direction status: int min_final_cltv_delta: int expiry_delay: int @@ -142,6 +143,13 @@ class PaymentInfo: def __post_init__(self): self.validate() + @property + def db_key(self) -> str: + return self.calc_db_key(payment_hash_hex=self.payment_hash.hex(), direction=self.direction) + + @classmethod + def calc_db_key(cls, *, payment_hash_hex: str, direction: lnutil.Direction) -> str: + return f"{payment_hash_hex}:{int(direction)}" SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id @@ -887,8 +895,8 @@ class LNWallet(LNWorker): LNWorker.__init__(self, self.node_keypair, features, config=self.config) self.lnwatcher = LNWatcher(self) self.lnrater: LNRater = None - # lightning_payments: RHASH -> amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts - self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int, int]] + # lightning_payments: "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts + self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int]] self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self._bolt11_cache = {} # note: this sweep_address is only used as fallback; as it might result in address-reuse @@ -1104,7 +1112,7 @@ class LNWallet(LNWorker): return out def get_payment_value( - self, info: Optional['PaymentInfo'], + self, sent_info: Optional['PaymentInfo'], plist: List[HTLCWithStatus] ) -> Tuple[PaymentDirection, int, Optional[int], int]: """ fee_msat is included in amount_msat""" @@ -1112,7 +1120,7 @@ class LNWallet(LNWorker): amount_msat = sum(int(x.direction) * x.htlc.amount_msat for x in plist) if all(x.direction == SENT for x in plist): direction = PaymentDirection.SENT - fee_msat = (- info.amount_msat - amount_msat) if info else None + fee_msat = (- sent_info.amount_msat - amount_msat) if sent_info else None elif all(x.direction == RECEIVED for x in plist): direction = PaymentDirection.RECEIVED fee_msat = None @@ -1135,12 +1143,12 @@ class LNWallet(LNWorker): if len(plist) == 0: continue key = payment_hash.hex() - info = self.get_payment_info(payment_hash) + sent_info = self.get_payment_info(payment_hash, direction=SENT) # note: just after successfully paying an invoice using MPP, amount and fee values might be shifted # temporarily: the amount only considers 'settled' htlcs (see plist above), but we might also # have some inflight htlcs still. Until all relevant htlcs settle, the amount will be lower than # expected and the fee higher (the inflight htlcs will be effectively counted as fees). - direction, amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist) + direction, amount_msat, fee_msat, timestamp = self.get_payment_value(sent_info, plist) label = self.wallet.get_label_for_rhash(key) if not label and direction == PaymentDirection.FORWARDING: label = _('Forwarding') @@ -1597,7 +1605,7 @@ class LNWallet(LNWorker): invoice_features = lnaddr.get_features() r_tags = lnaddr.get_routing_info('r') amount_to_pay = lnaddr.get_amount_msat() - status = self.get_payment_status(payment_hash) + status = self.get_payment_status(payment_hash, direction=SENT) if status == PR_PAID: raise PaymentFailure(_("This invoice has been paid already")) if status == PR_INFLIGHT: @@ -2491,13 +2499,13 @@ class LNWallet(LNWorker): preimage_bytes = self.get_preimage(bytes.fromhex(payment_hash)) or b"" return preimage_bytes.hex() or None - def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]: + def get_payment_info(self, payment_hash: bytes, *, direction: lnutil.Direction) -> Optional[PaymentInfo]: """returns None if payment_hash is a payment we are forwarding""" - key = payment_hash.hex() + key = PaymentInfo.calc_db_key(payment_hash_hex=payment_hash.hex(), direction=direction) with self.lock: if key in self.payment_info: stored_tuple = self.payment_info[key] - amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple + amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple return PaymentInfo( payment_hash=payment_hash, amount_msat=amount_msat, @@ -2542,7 +2550,7 @@ class LNWallet(LNWorker): def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: assert info.status in SAVED_PR_STATUS with self.lock: - if old_info := self.get_payment_info(payment_hash=info.payment_hash): + if old_info := self.get_payment_info(payment_hash=info.payment_hash, direction=info.direction): if info == old_info: return # already saved if info.direction == SENT: @@ -2551,8 +2559,8 @@ class LNWallet(LNWorker): if info != dataclasses.replace(old_info, status=info.status): # differs more than in status. let's fail raise Exception(f"payment_hash already in use: {info=} != {old_info=}") - key = info.payment_hash.hex() - self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0 + v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts + self.payment_info[info.db_key] = v if write_to_disk: self.wallet.save_db() @@ -2698,13 +2706,15 @@ class LNWallet(LNWorker): self.active_forwardings.pop(payment_key_hex, None) self.forwarding_failures.pop(payment_key_hex, None) - def get_payment_status(self, payment_hash: bytes) -> int: - info = self.get_payment_info(payment_hash) + def get_payment_status(self, payment_hash: bytes, *, direction: lnutil.Direction) -> int: + info = self.get_payment_info(payment_hash, direction=direction) return info.status if info else PR_UNPAID def get_invoice_status(self, invoice: BaseInvoice) -> int: invoice_id = invoice.rhash - status = self.get_payment_status(bfh(invoice_id)) + assert isinstance(invoice, (Request, Invoice)), type(invoice) + direction = RECEIVED if isinstance(invoice, Request) else SENT + status = self.get_payment_status(bfh(invoice_id), direction=direction) if status == PR_UNPAID and invoice_id in self.inflight_payments: return PR_INFLIGHT # status may be PR_FAILED @@ -2718,24 +2728,24 @@ class LNWallet(LNWorker): elif key in self.inflight_payments: self.inflight_payments.remove(key) if status in SAVED_PR_STATUS: - self.set_payment_status(bfh(key), status) + self.set_payment_status(bfh(key), status, direction=SENT) util.trigger_callback('invoice_status', self.wallet, key, status) self.logger.info(f"set_invoice_status {key}: {status}") # liquidity changed self.clear_invoices_cache() def set_request_status(self, payment_hash: bytes, status: int) -> None: - if self.get_payment_status(payment_hash) == status: + if self.get_payment_status(payment_hash, direction=RECEIVED) == status: return - self.set_payment_status(payment_hash, status) + self.set_payment_status(payment_hash, status, direction=RECEIVED) request_id = payment_hash.hex() req = self.wallet.get_request(request_id) if req is None: return util.trigger_callback('request_status', self.wallet, request_id, status) - def set_payment_status(self, payment_hash: bytes, status: int) -> None: - info = self.get_payment_info(payment_hash) + def set_payment_status(self, payment_hash: bytes, status: int, *, direction: lnutil.Direction) -> None: + info = self.get_payment_info(payment_hash, direction=direction) if info is None: # if we are forwarding return @@ -2930,14 +2940,15 @@ class LNWallet(LNWorker): cltv_delta)])) return routing_hints - def delete_payment_info(self, payment_hash_hex: str): + def delete_payment_info(self, payment_hash_hex: str, *, direction: lnutil.Direction): # This method is called when an invoice or request is deleted by the user. # The GUI only lets the user delete invoices or requests that have not been paid. # Once an invoice/request has been paid, it is part of the history, # and get_lightning_history assumes that payment_info is there. - assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID + assert self.get_payment_status(bytes.fromhex(payment_hash_hex), direction=direction) != PR_PAID with self.lock: - self.payment_info.pop(payment_hash_hex, None) + key = PaymentInfo.calc_db_key(payment_hash_hex=payment_hash_hex, direction=direction) + self.payment_info.pop(key, None) def get_balance(self, *, frozen=False) -> Decimal: with self.lock: @@ -3185,7 +3196,7 @@ class LNWallet(LNWorker): amount_msat=amount_msat, exp_delay=3600, ) - info = self.get_payment_info(payment_hash) + info = self.get_payment_info(payment_hash, direction=RECEIVED) lnaddr, invoice = self.get_bolt11_invoice( payment_info=info, message='rebalance', @@ -3804,14 +3815,13 @@ class LNWallet(LNWorker): - Alice sends htlc A->B->C, for 1 sat, with HASH1 - Bob must not release the preimage of HASH1 """ - payment_info = self.get_payment_info(payment_hash) - is_our_payreq = payment_info and payment_info.direction == RECEIVED + payment_info = self.get_payment_info(payment_hash, direction=RECEIVED) # note: If we don't have the preimage for a payment request, then it must be a hold invoice. # Hold invoices are created by other parties (e.g. a counterparty initiating a submarine swap), # and it is the other party choosing the payment_hash. If we failed HTLCs with payment_hashes colliding # with hold invoices, then a party that can make us save a hold invoice for an arbitrary hash could # also make us fail arbitrary HTLCs. - return bool(is_our_payreq and self.get_preimage(payment_hash)) + return bool(payment_info and self.get_preimage(payment_hash)) def create_onion_for_route( self, *, diff --git a/electrum/plugins/nwc/nwcserver.py b/electrum/plugins/nwc/nwcserver.py index c2cd925f5..1ed021ff8 100644 --- a/electrum/plugins/nwc/nwcserver.py +++ b/electrum/plugins/nwc/nwcserver.py @@ -42,6 +42,7 @@ from electrum.util import log_exceptions, ca_path, OldTaskGroup, get_asyncio_loo get_running_loop from electrum.invoices import Invoice, Request, PR_UNKNOWN, PR_PAID, BaseInvoice, PR_INFLIGHT from electrum import constants +from electrum.lnutil import RECEIVED if TYPE_CHECKING: from aiohttp_socks import ProxyConnector @@ -480,7 +481,7 @@ class NWCServer(Logger, EventListener): address=None ) req: Request = self.wallet.get_request(key) - info = self.wallet.lnworker.get_payment_info(req.payment_hash) + info = self.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED) try: lnaddr, b11 = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, @@ -537,7 +538,7 @@ class NWCServer(Logger, EventListener): b11 = invoice.lightning_invoice elif self.wallet.get_request(invoice.rhash): direction = "incoming" - info = self.wallet.lnworker.get_payment_info(invoice.payment_hash) + info = self.wallet.lnworker.get_payment_info(invoice.payment_hash, direction=RECEIVED) _, b11 = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=invoice.message, @@ -747,7 +748,7 @@ class NWCServer(Logger, EventListener): request: Optional[Request] = self.wallet.get_request(key) if not request or not request.is_lightning() or not status == PR_PAID: return - info = self.wallet.lnworker.get_payment_info(request.payment_hash) + info = self.wallet.lnworker.get_payment_info(request.payment_hash, direction=RECEIVED) _, b11 = self.wallet.lnworker.get_bolt11_invoice( payment_info=info, message=request.message, @@ -947,7 +948,8 @@ class NWCServer(Logger, EventListener): payments = self.wallet.lnworker.get_payments(status='settled') plist = payments.get(payment_hash) if plist: - info = self.wallet.lnworker.get_payment_info(payment_hash) + direction = plist[0].direction + info = self.wallet.lnworker.get_payment_info(payment_hash, direction=direction) if info: dir, amount, fee, ts = self.wallet.lnworker.get_payment_value(info, plist) fee = abs(fee) if fee else None diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 5d6275f8e..8ea53df59 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -401,7 +401,7 @@ class SwapManager(Logger): if not swap.is_reverse and swap.payment_hash in self.lnworker.hold_invoice_callbacks: # unregister_hold_invoice will fail pending htlcs if there is no preimage available self.lnworker.unregister_hold_invoice(swap.payment_hash) - self.lnworker.delete_payment_info(swap.payment_hash.hex()) + self.lnworker.delete_payment_info(swap.payment_hash.hex(), direction=lnutil.RECEIVED) self.lnworker.clear_invoices_cache() self.lnwatcher.remove_callback(swap.lockup_address) if not swap.is_funded(): @@ -413,9 +413,13 @@ class SwapManager(Logger): self._swaps_by_lockup_address.pop(swap.lockup_address, None) if swap.prepay_hash is not None: self._prepayments.pop(swap.prepay_hash, None) - if self.lnworker.get_payment_status(swap.prepay_hash) != PR_PAID: - self.lnworker.delete_payment_info(swap.prepay_hash.hex()) + if self.lnworker.get_payment_status(swap.prepay_hash, direction=lnutil.RECEIVED) != PR_PAID: + self.lnworker.delete_payment_info(swap.prepay_hash.hex(), direction=lnutil.RECEIVED) self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash) + if self.lnworker.get_payment_status(swap.prepay_hash, direction=lnutil.SENT) != PR_PAID: + self.lnworker.delete_payment_info(swap.prepay_hash.hex(), direction=lnutil.SENT) + if self.lnworker.get_payment_status(swap.payment_hash, direction=lnutil.SENT) != PR_PAID: + self.lnworker.delete_payment_info(swap.payment_hash.hex(), direction=lnutil.SENT) @classmethod def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]: @@ -693,7 +697,7 @@ class SwapManager(Logger): min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) - info = self.lnworker.get_payment_info(payment_hash) + info = self.lnworker.get_payment_info(payment_hash, direction=lnutil.RECEIVED) lnaddr1, invoice = self.lnworker.get_bolt11_invoice( payment_info=info, message='Submarine swap', @@ -712,7 +716,7 @@ class SwapManager(Logger): min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) - info = self.lnworker.get_payment_info(prepay_hash) + info = self.lnworker.get_payment_info(prepay_hash, direction=lnutil.RECEIVED) lnaddr2, prepay_invoice = self.lnworker.get_bolt11_invoice( payment_info=info, message='Submarine swap prepayment', diff --git a/electrum/wallet.py b/electrum/wallet.py index 32f60119b..88002e96a 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -76,7 +76,7 @@ from .invoices import BaseInvoice, Invoice, Request, PR_PAID, PR_UNPAID, PR_EXPI from .contacts import Contacts from .mnemonic import Mnemonic from .lnworker import LNWallet -from .lnutil import MIN_FUNDING_SAT +from .lnutil import MIN_FUNDING_SAT, RECEIVED, SENT from .lntransport import extract_nodeid from .descriptor import Descriptor from .txbatcher import TxBatcher @@ -3014,7 +3014,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): return '' amount_msat = req.get_amount_msat() or None assert (amount_msat is None or amount_msat > 0), amount_msat - info = self.lnworker.get_payment_info(payment_hash) + info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED) assert info.amount_msat == amount_msat, f"{info.amount_msat=} != {amount_msat=}" lnaddr, invoice = self.lnworker.get_bolt11_invoice( payment_info=info, @@ -3074,7 +3074,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): if addr := req.get_address(): self._requests_addr_to_key[addr].discard(request_id) if req.is_lightning() and self.lnworker: - self.lnworker.delete_payment_info(req.rhash) + self.lnworker.delete_payment_info(req.rhash, direction=RECEIVED) if write_to_disk: self.save_db() @@ -3084,7 +3084,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): if inv is None: return if inv.is_lightning() and self.lnworker: - self.lnworker.delete_payment_info(inv.rhash) + self.lnworker.delete_payment_info(inv.rhash, direction=SENT) if write_to_disk: self.save_db() diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index bf22f03fb..325eb36f3 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -69,7 +69,7 @@ class WalletUnfinished(WalletFileException): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 63 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 64 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -235,6 +235,7 @@ class WalletDBUpgrader(Logger): self._convert_version_61() self._convert_version_62() self._convert_version_63() + self._convert_version_64() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1269,6 +1270,24 @@ class WalletDBUpgrader(Logger): self.data['seed_version'] = 63 + def _convert_version_64(self): + """Key payment_info by "rhash:direction" instead of just rhash to allow storing a PaymentInfo + for each direction""" + if not self._is_upgrade_method_needed(63, 63): + return + + new_payment_infos = {} + old_payment_infos = self.data.get('lightning_payments', {}) + for payment_hash, old_values in old_payment_infos.items(): + amount_msat, direction, status, min_final_cltv_expiry, expiry, creation_ts = old_values + # drop direction + new_values = (amount_msat, status, min_final_cltv_expiry, expiry, creation_ts) + new_key = f"{payment_hash}:{direction}" + new_payment_infos[new_key] = new_values # save new entry + + self.data['lightning_payments'] = new_payment_infos + self.data['seed_version'] = 64 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_commands.py b/tests/test_commands.py index 3016263e1..969bfb024 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -11,6 +11,7 @@ import shutil import electrum from electrum.commands import Commands, eval_bool from electrum import storage, wallet +from electrum.lnutil import RECEIVED from electrum.lnworker import RecvMPPResolution from electrum.wallet import Abstract_Wallet from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED @@ -509,7 +510,7 @@ class TestCommandsTestnet(ElectrumTestCase): ) invoice = lndecode(invoice=result['invoice']) assert invoice.paymenthash.hex() == payment_hash - assert payment_hash in wallet.lnworker.payment_info + assert wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED) assert payment_hash in wallet.lnworker.dont_expire_htlcs assert invoice.get_amount_sat() == 10000 assert invoice.get_description() == "test" @@ -520,7 +521,7 @@ class TestCommandsTestnet(ElectrumTestCase): payment_hash=payment_hash, wallet=wallet, ) - assert payment_hash not in wallet.lnworker.payment_info + assert not wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED) assert payment_hash not in wallet.lnworker.dont_expire_htlcs assert wallet.get_label_for_rhash(rhash=invoice.paymenthash.hex()) == "" assert cancel_result['cancelled'] == payment_hash diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 3fe6a3729..57def01cb 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -865,10 +865,10 @@ class TestPeerDirect(TestPeer): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) results = {} async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await w1.pay_invoice(pay_req) if result is True: - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) results[lnaddr] = PaymentDone() else: results[lnaddr] = PaymentFailure() @@ -988,7 +988,7 @@ class TestPeerDirect(TestPeer): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) assert lnaddr.get_min_final_cltv_delta() == 400 # what the receiver expects lnaddr.tags = [tag for tag in lnaddr.tags if tag[0] != 'c'] + [['c', 144]] b11 = lnencode(lnaddr, w2.node_keypair.privkey) @@ -1079,7 +1079,7 @@ class TestPeerDirect(TestPeer): result, log = await w1.pay_invoice(pay_req) assert result is True # now pay the same invoice again, the payment should be rejected by w2 - w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID) + w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID, direction=lnutil.SENT) result, log = await w1.pay_invoice(pay_req) if not result: # w1.pay_invoice returned a payment failure as the payment got rejected by w2 @@ -1224,8 +1224,8 @@ class TestPeerDirect(TestPeer): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash)) - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash, direction=RECEIVED)) route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route p1.pay( @@ -1297,7 +1297,7 @@ class TestPeerDirect(TestPeer): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route p1.pay( @@ -1997,11 +1997,11 @@ class TestPeerDirect(TestPeer): w2.dont_settle_htlcs[pay_req.rhash] = None async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) if result is True: self.assertNotIn(pay_req.rhash, w2.dont_settle_htlcs) - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) return PaymentDone() else: self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) @@ -2067,10 +2067,10 @@ class TestPeerDirect(TestPeer): w2.dont_expire_htlcs[pay_req.rhash] = None if not test_expiry else 20 async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) if result is True: - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) return PaymentDone() else: self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) @@ -2210,12 +2210,12 @@ class TestPeerForwarding(TestPeer): return split_amount_normal(total_amount, num_parts) async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) with mock.patch('electrum.mpp_split.split_amount_normal', side_effect=mocked_split_amount_normal): result, log = await graph.workers['bob'].pay_invoice(pay_req) self.assertTrue(result) - self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) async def f(): async with OldTaskGroup() as group: @@ -2242,10 +2242,10 @@ class TestPeerForwarding(TestPeer): graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph']) peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req) self.assertTrue(result) - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) raise PaymentDone() async def f(): async with OldTaskGroup() as group: @@ -2309,10 +2309,10 @@ class TestPeerForwarding(TestPeer): graph.workers['carol'].network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE = True peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req) self.assertFalse(result) - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) raise PaymentDone() async def f(): @@ -2336,11 +2336,11 @@ class TestPeerForwarding(TestPeer): async def pay(lnaddr, pay_req): self.assertEqual(500000000000, graph.channels[('alice', 'bob')].balance(LOCAL)) self.assertEqual(500000000000, graph.channels[('dave', 'bob')].balance(LOCAL)) - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=2) self.assertEqual(2, len(log)) self.assertTrue(result) - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) self.assertEqual([graph.channels[('alice', 'carol')].short_channel_id, graph.channels[('carol', 'dave')].short_channel_id], [edge.short_channel_id for edge in log[0].route]) self.assertEqual([graph.channels[('alice', 'bob')].short_channel_id, graph.channels[('bob', 'dave')].short_channel_id], @@ -2436,11 +2436,11 @@ class TestPeerForwarding(TestPeer): amount_to_pay = 100_000_000 peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=3) self.assertTrue(result) self.assertEqual(2, len(log)) - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code) liquidity_hints = graph.workers['alice'].network.path_finder.liquidity_hints @@ -2507,14 +2507,14 @@ class TestPeerForwarding(TestPeer): assert alice_w.network.channel_db is not None lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=True, amount_msat=amount_to_pay) self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure) - self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await alice_w.pay_invoice(pay_req, attempts=attempts) if not bob_forwarding: # reset to previous state, sleep 2s so that the second htlc can time out graph.workers['bob'].enable_htlc_forwarding = True await asyncio.sleep(2) if result: - self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) # check mpp is cleaned up async with OldTaskGroup() as g: for peer in peers: @@ -2642,7 +2642,7 @@ class TestPeerForwarding(TestPeer): dest_w = graph.workers[destination_name] async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await sender_w.pay_invoice(pay_req, attempts=attempts) async with OldTaskGroup() as g: for peer in peers: @@ -2653,7 +2653,7 @@ class TestPeerForwarding(TestPeer): for peer in peers: self.assertEqual(len(peer.lnworker.active_forwardings), 0) if result: - self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) raise PaymentDone() else: raise NoPathFound() @@ -2875,7 +2875,7 @@ class TestPeerForwarding(TestPeer): peers = graph.peers.values() async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) result, log = await graph.workers['alice'].pay_invoice(pay_req) self.assertEqual(OnionFailureCode.INVALID_ONION_VERSION, log[0].failure_msg.code) self.assertFalse(result, msg=log)