1
0

lnworker: differentiate PaymentInfo by direction

Allows storing two different payment info of the same payment hash by
including the direction into the db key.
We create and store PaymentInfo for sending attempts and for requests (receiving),
if we try to pay ourself (e.g. through a channel rebalance) the checks
in `save_payment_info` would prevent this and throw an exception.
By storing the PaymentInfos of outgoing and incoming payments separately in
the db this collision is avoided and it makes it easier to reason about
which PaymentInfo belongs where.
This commit is contained in:
f321x
2025-11-28 16:22:22 +01:00
parent 828fc569c9
commit 923d48f9db
12 changed files with 125 additions and 88 deletions

View File

@@ -70,7 +70,7 @@ from .wallet import (
)
from .address_synchronizer import TX_HEIGHT_LOCAL
from .mnemonic import Mnemonic
from .lnutil import (channel_id_from_funding_tx, LnFeatures, SENT, MIN_FINAL_CLTV_DELTA_ACCEPTED,
from .lnutil import (channel_id_from_funding_tx, LnFeatures, SENT, RECEIVED, MIN_FINAL_CLTV_DELTA_ACCEPTED,
PaymentFeeBudget, NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE)
from .plugin import run_hook, DeviceMgr, Plugins
from .version import ELECTRUM_VERSION
@@ -1402,7 +1402,7 @@ class Commands(Logger):
arg:int:min_final_cltv_expiry_delta:Optional min final cltv expiry delta (default: 294 blocks)
"""
assert len(payment_hash) == 64, f"Invalid payment hash length: {len(payment_hash)} != 64"
assert payment_hash not in wallet.lnworker.payment_info, "Payment hash already used!"
assert not wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED), "Payment hash already used!"
assert payment_hash not in wallet.lnworker.dont_expire_htlcs, "Payment hash already used!"
assert wallet.lnworker.get_preimage(bfh(payment_hash)) is None, "Already got a preimage for this payment hash!"
assert MIN_FINAL_CLTV_DELTA_ACCEPTED < min_final_cltv_expiry_delta < 576, "Use a sane min_final_cltv_expiry_delta value"
@@ -1417,7 +1417,7 @@ class Commands(Logger):
min_final_cltv_delta=min_final_cltv_expiry_delta,
exp_delay=expiry,
)
info = wallet.lnworker.get_payment_info(bfh(payment_hash))
info = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED)
lnaddr, invoice = wallet.lnworker.get_bolt11_invoice(
payment_info=info,
message=memo,
@@ -1443,12 +1443,11 @@ class Commands(Logger):
assert len(preimage) == 64, f"Invalid payment_hash length: {len(preimage)} != 64"
payment_hash: str = crypto.sha256(bfh(preimage)).hex()
assert payment_hash not in wallet.lnworker._preimages, f"Invoice {payment_hash=} already settled"
assert payment_hash in wallet.lnworker.payment_info, \
f"Couldn't find lightning invoice for {payment_hash=}"
info = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED)
assert info, f"Couldn't find lightning invoice for {payment_hash=}"
assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"Invoice {payment_hash=} not a hold invoice?"
assert wallet.lnworker.is_complete_mpp(bfh(payment_hash)), \
f"MPP incomplete, cannot settle hold invoice {payment_hash} yet"
info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash))
assert (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) >= (info.amount_msat or 0)
wallet.lnworker.save_preimage(bfh(payment_hash), bfh(preimage))
util.trigger_callback('wallet_updated', wallet)
@@ -1464,13 +1463,13 @@ class Commands(Logger):
arg:str:payment_hash:Payment hash in hex of the hold invoice
"""
assert payment_hash in wallet.lnworker.payment_info, \
assert wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED), \
f"Couldn't find lightning invoice for payment hash {payment_hash}"
assert payment_hash not in wallet.lnworker._preimages, "Cannot cancel anymore, preimage already given."
assert payment_hash in wallet.lnworker.dont_expire_htlcs, f"{payment_hash=} not a hold invoice?"
# set to PR_UNPAID so it can get deleted
wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID)
wallet.lnworker.delete_payment_info(payment_hash)
wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID, direction=RECEIVED)
wallet.lnworker.delete_payment_info(payment_hash, direction=RECEIVED)
wallet.set_label(payment_hash, None)
del wallet.lnworker.dont_expire_htlcs[payment_hash]
while wallet.lnworker.is_complete_mpp(bfh(payment_hash)):
@@ -1496,7 +1495,7 @@ class Commands(Logger):
arg:str:payment_hash:Payment hash in hex of the hold invoice
"""
assert len(payment_hash) == 64, f"Invalid payment_hash length: {len(payment_hash)} != 64"
info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash))
info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash), direction=RECEIVED)
is_complete_mpp: bool = wallet.lnworker.is_complete_mpp(bfh(payment_hash))
amount_sat = (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) // 1000
result = {
@@ -1518,7 +1517,7 @@ class Commands(Logger):
elif wallet.lnworker.get_preimage_hex(payment_hash) is not None:
result["status"] = "settled"
plist = wallet.lnworker.get_payments(status='settled')[bfh(payment_hash)]
_dir, amount_msat, _fee, _ts = wallet.lnworker.get_payment_value(info, plist)
_dir, amount_msat, _fee, _ts = wallet.lnworker.get_payment_value(None, plist)
result["received_amount_sat"] = amount_msat // 1000
result['preimage'] = wallet.lnworker.get_preimage_hex(payment_hash)
if info is not None:

View File

@@ -8,7 +8,7 @@ from electrum.logging import get_logger
from electrum.invoices import (
PR_UNPAID, PR_EXPIRED, PR_UNKNOWN, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING, PR_UNCONFIRMED, LN_EXPIRY_NEVER
)
from electrum.lnutil import MIN_FUNDING_SAT
from electrum.lnutil import MIN_FUNDING_SAT, RECEIVED
from electrum.lnurl import LNURL3Data, request_lnurl_withdraw_callback, LNURLError
from electrum.payment_identifier import PaymentIdentifier, PaymentIdentifierType
from electrum.i18n import _
@@ -237,7 +237,7 @@ class QERequestDetails(QObject, QtEventListener):
address=None,
)
req = self._wallet.wallet.get_request(key)
info = self._wallet.wallet.lnworker.get_payment_info(req.payment_hash)
info = self._wallet.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED)
_lnaddr, b11_invoice = self._wallet.wallet.lnworker.get_bolt11_invoice(
payment_info=info,
message=req.get_message(),

View File

@@ -18,6 +18,7 @@ from electrum.util import (
NotEnoughFunds, NoDynamicFeeEstimates, parse_max_spend, UserCancelled, ChoiceItem,
UserFacingException,
)
from electrum.lnutil import RECEIVED
from electrum.invoices import PR_PAID, Invoice, PR_BROADCASTING, PR_BROADCAST
from electrum.transaction import Transaction, PartialTxInput, PartialTxOutput
from electrum.network import TxBroadcastError, BestEffortRequestFailed
@@ -979,7 +980,7 @@ class SendTab(QWidget, MessageBoxMixin, Logger):
address=None,
)
req = self.wallet.get_request(key)
info = self.wallet.lnworker.get_payment_info(req.payment_hash)
info = self.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED)
_lnaddr, b11_invoice = self.wallet.lnworker.get_bolt11_invoice(
payment_info=info,
message=req.get_message(),

View File

@@ -2221,7 +2221,7 @@ class Peer(Logger, EventListener):
outer_onion_payment_secret=payment_secret_from_onion,
)
info = self.lnworker.get_payment_info(payment_hash)
info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED)
if info is None:
_log_fail_reason(f"no payment_info found for RHASH {payment_hash.hex()}")
raise exc_incorrect_or_unknown_pd
@@ -3115,7 +3115,7 @@ class Peer(Logger, EventListener):
return None, None, fwd_cb
# -- from here on it's assumed this set is a payment for us (not something to forward) --
payment_info = self.lnworker.get_payment_info(payment_hash)
payment_info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED)
if payment_info is None:
_log_fail_reason(f"payment info has been deleted")
return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None

View File

@@ -1089,6 +1089,7 @@ class HTLCOwner(IntEnum):
return HTLCOwner(super().__neg__())
# part of lightning_payments db keys
class Direction(IntEnum):
SENT = -1 # in the context of HTLCs: "offered" HTLCs
RECEIVED = 1 # in the context of HTLCs: "received" HTLCs

View File

@@ -45,7 +45,8 @@ from .fee_policy import (
FeePolicy, FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET,
FEERATE_PER_KW_MIN_RELAY_LIGHTNING, FEE_LN_MINIMUM_ETA_TARGET
)
from .invoices import Invoice, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER, BaseInvoice
from .invoices import (Invoice, Request, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, LN_EXPIRY_NEVER,
BaseInvoice)
from .bitcoin import COIN, opcodes, make_op_return, address_to_scripthash, DummyAddress
from .bip32 import BIP32Node
from .address_synchronizer import TX_HEIGHT_LOCAL
@@ -120,7 +121,7 @@ class PaymentInfo:
"""Information required to handle incoming htlcs for a payment request"""
payment_hash: bytes
amount_msat: Optional[int]
direction: int
direction: lnutil.Direction
status: int
min_final_cltv_delta: int
expiry_delay: int
@@ -142,6 +143,13 @@ class PaymentInfo:
def __post_init__(self):
self.validate()
@property
def db_key(self) -> str:
return self.calc_db_key(payment_hash_hex=self.payment_hash.hex(), direction=self.direction)
@classmethod
def calc_db_key(cls, *, payment_hash_hex: str, direction: lnutil.Direction) -> str:
return f"{payment_hash_hex}:{int(direction)}"
SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id
@@ -887,8 +895,8 @@ class LNWallet(LNWorker):
LNWorker.__init__(self, self.node_keypair, features, config=self.config)
self.lnwatcher = LNWatcher(self)
self.lnrater: LNRater = None
# lightning_payments: RHASH -> amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts
self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int, int]]
# lightning_payments: "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts
self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int]]
self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
self._bolt11_cache = {}
# note: this sweep_address is only used as fallback; as it might result in address-reuse
@@ -1104,7 +1112,7 @@ class LNWallet(LNWorker):
return out
def get_payment_value(
self, info: Optional['PaymentInfo'],
self, sent_info: Optional['PaymentInfo'],
plist: List[HTLCWithStatus]
) -> Tuple[PaymentDirection, int, Optional[int], int]:
""" fee_msat is included in amount_msat"""
@@ -1112,7 +1120,7 @@ class LNWallet(LNWorker):
amount_msat = sum(int(x.direction) * x.htlc.amount_msat for x in plist)
if all(x.direction == SENT for x in plist):
direction = PaymentDirection.SENT
fee_msat = (- info.amount_msat - amount_msat) if info else None
fee_msat = (- sent_info.amount_msat - amount_msat) if sent_info else None
elif all(x.direction == RECEIVED for x in plist):
direction = PaymentDirection.RECEIVED
fee_msat = None
@@ -1135,12 +1143,12 @@ class LNWallet(LNWorker):
if len(plist) == 0:
continue
key = payment_hash.hex()
info = self.get_payment_info(payment_hash)
sent_info = self.get_payment_info(payment_hash, direction=SENT)
# note: just after successfully paying an invoice using MPP, amount and fee values might be shifted
# temporarily: the amount only considers 'settled' htlcs (see plist above), but we might also
# have some inflight htlcs still. Until all relevant htlcs settle, the amount will be lower than
# expected and the fee higher (the inflight htlcs will be effectively counted as fees).
direction, amount_msat, fee_msat, timestamp = self.get_payment_value(info, plist)
direction, amount_msat, fee_msat, timestamp = self.get_payment_value(sent_info, plist)
label = self.wallet.get_label_for_rhash(key)
if not label and direction == PaymentDirection.FORWARDING:
label = _('Forwarding')
@@ -1597,7 +1605,7 @@ class LNWallet(LNWorker):
invoice_features = lnaddr.get_features()
r_tags = lnaddr.get_routing_info('r')
amount_to_pay = lnaddr.get_amount_msat()
status = self.get_payment_status(payment_hash)
status = self.get_payment_status(payment_hash, direction=SENT)
if status == PR_PAID:
raise PaymentFailure(_("This invoice has been paid already"))
if status == PR_INFLIGHT:
@@ -2491,13 +2499,13 @@ class LNWallet(LNWorker):
preimage_bytes = self.get_preimage(bytes.fromhex(payment_hash)) or b""
return preimage_bytes.hex() or None
def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]:
def get_payment_info(self, payment_hash: bytes, *, direction: lnutil.Direction) -> Optional[PaymentInfo]:
"""returns None if payment_hash is a payment we are forwarding"""
key = payment_hash.hex()
key = PaymentInfo.calc_db_key(payment_hash_hex=payment_hash.hex(), direction=direction)
with self.lock:
if key in self.payment_info:
stored_tuple = self.payment_info[key]
amount_msat, direction, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple
amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple
return PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_msat,
@@ -2542,7 +2550,7 @@ class LNWallet(LNWorker):
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
assert info.status in SAVED_PR_STATUS
with self.lock:
if old_info := self.get_payment_info(payment_hash=info.payment_hash):
if old_info := self.get_payment_info(payment_hash=info.payment_hash, direction=info.direction):
if info == old_info:
return # already saved
if info.direction == SENT:
@@ -2551,8 +2559,8 @@ class LNWallet(LNWorker):
if info != dataclasses.replace(old_info, status=info.status):
# differs more than in status. let's fail
raise Exception(f"payment_hash already in use: {info=} != {old_info=}")
key = info.payment_hash.hex()
self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0
v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts
self.payment_info[info.db_key] = v
if write_to_disk:
self.wallet.save_db()
@@ -2698,13 +2706,15 @@ class LNWallet(LNWorker):
self.active_forwardings.pop(payment_key_hex, None)
self.forwarding_failures.pop(payment_key_hex, None)
def get_payment_status(self, payment_hash: bytes) -> int:
info = self.get_payment_info(payment_hash)
def get_payment_status(self, payment_hash: bytes, *, direction: lnutil.Direction) -> int:
info = self.get_payment_info(payment_hash, direction=direction)
return info.status if info else PR_UNPAID
def get_invoice_status(self, invoice: BaseInvoice) -> int:
invoice_id = invoice.rhash
status = self.get_payment_status(bfh(invoice_id))
assert isinstance(invoice, (Request, Invoice)), type(invoice)
direction = RECEIVED if isinstance(invoice, Request) else SENT
status = self.get_payment_status(bfh(invoice_id), direction=direction)
if status == PR_UNPAID and invoice_id in self.inflight_payments:
return PR_INFLIGHT
# status may be PR_FAILED
@@ -2718,24 +2728,24 @@ class LNWallet(LNWorker):
elif key in self.inflight_payments:
self.inflight_payments.remove(key)
if status in SAVED_PR_STATUS:
self.set_payment_status(bfh(key), status)
self.set_payment_status(bfh(key), status, direction=SENT)
util.trigger_callback('invoice_status', self.wallet, key, status)
self.logger.info(f"set_invoice_status {key}: {status}")
# liquidity changed
self.clear_invoices_cache()
def set_request_status(self, payment_hash: bytes, status: int) -> None:
if self.get_payment_status(payment_hash) == status:
if self.get_payment_status(payment_hash, direction=RECEIVED) == status:
return
self.set_payment_status(payment_hash, status)
self.set_payment_status(payment_hash, status, direction=RECEIVED)
request_id = payment_hash.hex()
req = self.wallet.get_request(request_id)
if req is None:
return
util.trigger_callback('request_status', self.wallet, request_id, status)
def set_payment_status(self, payment_hash: bytes, status: int) -> None:
info = self.get_payment_info(payment_hash)
def set_payment_status(self, payment_hash: bytes, status: int, *, direction: lnutil.Direction) -> None:
info = self.get_payment_info(payment_hash, direction=direction)
if info is None:
# if we are forwarding
return
@@ -2930,14 +2940,15 @@ class LNWallet(LNWorker):
cltv_delta)]))
return routing_hints
def delete_payment_info(self, payment_hash_hex: str):
def delete_payment_info(self, payment_hash_hex: str, *, direction: lnutil.Direction):
# This method is called when an invoice or request is deleted by the user.
# The GUI only lets the user delete invoices or requests that have not been paid.
# Once an invoice/request has been paid, it is part of the history,
# and get_lightning_history assumes that payment_info is there.
assert self.get_payment_status(bytes.fromhex(payment_hash_hex)) != PR_PAID
assert self.get_payment_status(bytes.fromhex(payment_hash_hex), direction=direction) != PR_PAID
with self.lock:
self.payment_info.pop(payment_hash_hex, None)
key = PaymentInfo.calc_db_key(payment_hash_hex=payment_hash_hex, direction=direction)
self.payment_info.pop(key, None)
def get_balance(self, *, frozen=False) -> Decimal:
with self.lock:
@@ -3185,7 +3196,7 @@ class LNWallet(LNWorker):
amount_msat=amount_msat,
exp_delay=3600,
)
info = self.get_payment_info(payment_hash)
info = self.get_payment_info(payment_hash, direction=RECEIVED)
lnaddr, invoice = self.get_bolt11_invoice(
payment_info=info,
message='rebalance',
@@ -3804,14 +3815,13 @@ class LNWallet(LNWorker):
- Alice sends htlc A->B->C, for 1 sat, with HASH1
- Bob must not release the preimage of HASH1
"""
payment_info = self.get_payment_info(payment_hash)
is_our_payreq = payment_info and payment_info.direction == RECEIVED
payment_info = self.get_payment_info(payment_hash, direction=RECEIVED)
# note: If we don't have the preimage for a payment request, then it must be a hold invoice.
# Hold invoices are created by other parties (e.g. a counterparty initiating a submarine swap),
# and it is the other party choosing the payment_hash. If we failed HTLCs with payment_hashes colliding
# with hold invoices, then a party that can make us save a hold invoice for an arbitrary hash could
# also make us fail arbitrary HTLCs.
return bool(is_our_payreq and self.get_preimage(payment_hash))
return bool(payment_info and self.get_preimage(payment_hash))
def create_onion_for_route(
self, *,

View File

@@ -42,6 +42,7 @@ from electrum.util import log_exceptions, ca_path, OldTaskGroup, get_asyncio_loo
get_running_loop
from electrum.invoices import Invoice, Request, PR_UNKNOWN, PR_PAID, BaseInvoice, PR_INFLIGHT
from electrum import constants
from electrum.lnutil import RECEIVED
if TYPE_CHECKING:
from aiohttp_socks import ProxyConnector
@@ -480,7 +481,7 @@ class NWCServer(Logger, EventListener):
address=None
)
req: Request = self.wallet.get_request(key)
info = self.wallet.lnworker.get_payment_info(req.payment_hash)
info = self.wallet.lnworker.get_payment_info(req.payment_hash, direction=RECEIVED)
try:
lnaddr, b11 = self.wallet.lnworker.get_bolt11_invoice(
payment_info=info,
@@ -537,7 +538,7 @@ class NWCServer(Logger, EventListener):
b11 = invoice.lightning_invoice
elif self.wallet.get_request(invoice.rhash):
direction = "incoming"
info = self.wallet.lnworker.get_payment_info(invoice.payment_hash)
info = self.wallet.lnworker.get_payment_info(invoice.payment_hash, direction=RECEIVED)
_, b11 = self.wallet.lnworker.get_bolt11_invoice(
payment_info=info,
message=invoice.message,
@@ -747,7 +748,7 @@ class NWCServer(Logger, EventListener):
request: Optional[Request] = self.wallet.get_request(key)
if not request or not request.is_lightning() or not status == PR_PAID:
return
info = self.wallet.lnworker.get_payment_info(request.payment_hash)
info = self.wallet.lnworker.get_payment_info(request.payment_hash, direction=RECEIVED)
_, b11 = self.wallet.lnworker.get_bolt11_invoice(
payment_info=info,
message=request.message,
@@ -947,7 +948,8 @@ class NWCServer(Logger, EventListener):
payments = self.wallet.lnworker.get_payments(status='settled')
plist = payments.get(payment_hash)
if plist:
info = self.wallet.lnworker.get_payment_info(payment_hash)
direction = plist[0].direction
info = self.wallet.lnworker.get_payment_info(payment_hash, direction=direction)
if info:
dir, amount, fee, ts = self.wallet.lnworker.get_payment_value(info, plist)
fee = abs(fee) if fee else None

View File

@@ -401,7 +401,7 @@ class SwapManager(Logger):
if not swap.is_reverse and swap.payment_hash in self.lnworker.hold_invoice_callbacks:
# unregister_hold_invoice will fail pending htlcs if there is no preimage available
self.lnworker.unregister_hold_invoice(swap.payment_hash)
self.lnworker.delete_payment_info(swap.payment_hash.hex())
self.lnworker.delete_payment_info(swap.payment_hash.hex(), direction=lnutil.RECEIVED)
self.lnworker.clear_invoices_cache()
self.lnwatcher.remove_callback(swap.lockup_address)
if not swap.is_funded():
@@ -413,9 +413,13 @@ class SwapManager(Logger):
self._swaps_by_lockup_address.pop(swap.lockup_address, None)
if swap.prepay_hash is not None:
self._prepayments.pop(swap.prepay_hash, None)
if self.lnworker.get_payment_status(swap.prepay_hash) != PR_PAID:
self.lnworker.delete_payment_info(swap.prepay_hash.hex())
if self.lnworker.get_payment_status(swap.prepay_hash, direction=lnutil.RECEIVED) != PR_PAID:
self.lnworker.delete_payment_info(swap.prepay_hash.hex(), direction=lnutil.RECEIVED)
self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash)
if self.lnworker.get_payment_status(swap.prepay_hash, direction=lnutil.SENT) != PR_PAID:
self.lnworker.delete_payment_info(swap.prepay_hash.hex(), direction=lnutil.SENT)
if self.lnworker.get_payment_status(swap.payment_hash, direction=lnutil.SENT) != PR_PAID:
self.lnworker.delete_payment_info(swap.payment_hash.hex(), direction=lnutil.SENT)
@classmethod
def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]:
@@ -693,7 +697,7 @@ class SwapManager(Logger):
min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED,
exp_delay=300,
)
info = self.lnworker.get_payment_info(payment_hash)
info = self.lnworker.get_payment_info(payment_hash, direction=lnutil.RECEIVED)
lnaddr1, invoice = self.lnworker.get_bolt11_invoice(
payment_info=info,
message='Submarine swap',
@@ -712,7 +716,7 @@ class SwapManager(Logger):
min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED,
exp_delay=300,
)
info = self.lnworker.get_payment_info(prepay_hash)
info = self.lnworker.get_payment_info(prepay_hash, direction=lnutil.RECEIVED)
lnaddr2, prepay_invoice = self.lnworker.get_bolt11_invoice(
payment_info=info,
message='Submarine swap prepayment',

View File

@@ -76,7 +76,7 @@ from .invoices import BaseInvoice, Invoice, Request, PR_PAID, PR_UNPAID, PR_EXPI
from .contacts import Contacts
from .mnemonic import Mnemonic
from .lnworker import LNWallet
from .lnutil import MIN_FUNDING_SAT
from .lnutil import MIN_FUNDING_SAT, RECEIVED, SENT
from .lntransport import extract_nodeid
from .descriptor import Descriptor
from .txbatcher import TxBatcher
@@ -3014,7 +3014,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
return ''
amount_msat = req.get_amount_msat() or None
assert (amount_msat is None or amount_msat > 0), amount_msat
info = self.lnworker.get_payment_info(payment_hash)
info = self.lnworker.get_payment_info(payment_hash, direction=RECEIVED)
assert info.amount_msat == amount_msat, f"{info.amount_msat=} != {amount_msat=}"
lnaddr, invoice = self.lnworker.get_bolt11_invoice(
payment_info=info,
@@ -3074,7 +3074,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
if addr := req.get_address():
self._requests_addr_to_key[addr].discard(request_id)
if req.is_lightning() and self.lnworker:
self.lnworker.delete_payment_info(req.rhash)
self.lnworker.delete_payment_info(req.rhash, direction=RECEIVED)
if write_to_disk:
self.save_db()
@@ -3084,7 +3084,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
if inv is None:
return
if inv.is_lightning() and self.lnworker:
self.lnworker.delete_payment_info(inv.rhash)
self.lnworker.delete_payment_info(inv.rhash, direction=SENT)
if write_to_disk:
self.save_db()

View File

@@ -69,7 +69,7 @@ class WalletUnfinished(WalletFileException):
# seed_version is now used for the version of the wallet file
OLD_SEED_VERSION = 4 # electrum versions < 2.0
NEW_SEED_VERSION = 11 # electrum versions >= 2.0
FINAL_SEED_VERSION = 63 # electrum >= 2.7 will set this to prevent
FINAL_SEED_VERSION = 64 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format
@@ -235,6 +235,7 @@ class WalletDBUpgrader(Logger):
self._convert_version_61()
self._convert_version_62()
self._convert_version_63()
self._convert_version_64()
self.put('seed_version', FINAL_SEED_VERSION) # just to be sure
def _convert_wallet_type(self):
@@ -1269,6 +1270,24 @@ class WalletDBUpgrader(Logger):
self.data['seed_version'] = 63
def _convert_version_64(self):
"""Key payment_info by "rhash:direction" instead of just rhash to allow storing a PaymentInfo
for each direction"""
if not self._is_upgrade_method_needed(63, 63):
return
new_payment_infos = {}
old_payment_infos = self.data.get('lightning_payments', {})
for payment_hash, old_values in old_payment_infos.items():
amount_msat, direction, status, min_final_cltv_expiry, expiry, creation_ts = old_values
# drop direction
new_values = (amount_msat, status, min_final_cltv_expiry, expiry, creation_ts)
new_key = f"{payment_hash}:{direction}"
new_payment_infos[new_key] = new_values # save new entry
self.data['lightning_payments'] = new_payment_infos
self.data['seed_version'] = 64
def _convert_imported(self):
if not self._is_upgrade_method_needed(0, 13):
return

View File

@@ -11,6 +11,7 @@ import shutil
import electrum
from electrum.commands import Commands, eval_bool
from electrum import storage, wallet
from electrum.lnutil import RECEIVED
from electrum.lnworker import RecvMPPResolution
from electrum.wallet import Abstract_Wallet
from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED
@@ -509,7 +510,7 @@ class TestCommandsTestnet(ElectrumTestCase):
)
invoice = lndecode(invoice=result['invoice'])
assert invoice.paymenthash.hex() == payment_hash
assert payment_hash in wallet.lnworker.payment_info
assert wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED)
assert payment_hash in wallet.lnworker.dont_expire_htlcs
assert invoice.get_amount_sat() == 10000
assert invoice.get_description() == "test"
@@ -520,7 +521,7 @@ class TestCommandsTestnet(ElectrumTestCase):
payment_hash=payment_hash,
wallet=wallet,
)
assert payment_hash not in wallet.lnworker.payment_info
assert not wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED)
assert payment_hash not in wallet.lnworker.dont_expire_htlcs
assert wallet.get_label_for_rhash(rhash=invoice.paymenthash.hex()) == ""
assert cancel_result['cancelled'] == payment_hash

View File

@@ -865,10 +865,10 @@ class TestPeerDirect(TestPeer):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
results = {}
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await w1.pay_invoice(pay_req)
if result is True:
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
results[lnaddr] = PaymentDone()
else:
results[lnaddr] = PaymentFailure()
@@ -988,7 +988,7 @@ class TestPeerDirect(TestPeer):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
assert lnaddr.get_min_final_cltv_delta() == 400 # what the receiver expects
lnaddr.tags = [tag for tag in lnaddr.tags if tag[0] != 'c'] + [['c', 144]]
b11 = lnencode(lnaddr, w2.node_keypair.privkey)
@@ -1079,7 +1079,7 @@ class TestPeerDirect(TestPeer):
result, log = await w1.pay_invoice(pay_req)
assert result is True
# now pay the same invoice again, the payment should be rejected by w2
w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID)
w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID, direction=lnutil.SENT)
result, log = await w1.pay_invoice(pay_req)
if not result:
# w1.pay_invoice returned a payment failure as the payment got rejected by w2
@@ -1224,8 +1224,8 @@ class TestPeerDirect(TestPeer):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def pay():
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash, direction=RECEIVED))
route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route
p1.pay(
@@ -1297,7 +1297,7 @@ class TestPeerDirect(TestPeer):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def pay():
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED))
route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route
p1.pay(
@@ -1997,11 +1997,11 @@ class TestPeerDirect(TestPeer):
w2.dont_settle_htlcs[pay_req.rhash] = None
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3)
if result is True:
self.assertNotIn(pay_req.rhash, w2.dont_settle_htlcs)
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
return PaymentDone()
else:
self.assertIsNone(w2.get_preimage(lnaddr.paymenthash))
@@ -2067,10 +2067,10 @@ class TestPeerDirect(TestPeer):
w2.dont_expire_htlcs[pay_req.rhash] = None if not test_expiry else 20
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3)
if result is True:
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
return PaymentDone()
else:
self.assertIsNone(w2.get_preimage(lnaddr.paymenthash))
@@ -2210,12 +2210,12 @@ class TestPeerForwarding(TestPeer):
return split_amount_normal(total_amount, num_parts)
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
with mock.patch('electrum.mpp_split.split_amount_normal',
side_effect=mocked_split_amount_normal):
result, log = await graph.workers['bob'].pay_invoice(pay_req)
self.assertTrue(result)
self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
async def f():
async with OldTaskGroup() as group:
@@ -2242,10 +2242,10 @@ class TestPeerForwarding(TestPeer):
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph'])
peers = graph.peers.values()
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req)
self.assertTrue(result)
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
raise PaymentDone()
async def f():
async with OldTaskGroup() as group:
@@ -2309,10 +2309,10 @@ class TestPeerForwarding(TestPeer):
graph.workers['carol'].network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE = True
peers = graph.peers.values()
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req)
self.assertFalse(result)
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code)
raise PaymentDone()
async def f():
@@ -2336,11 +2336,11 @@ class TestPeerForwarding(TestPeer):
async def pay(lnaddr, pay_req):
self.assertEqual(500000000000, graph.channels[('alice', 'bob')].balance(LOCAL))
self.assertEqual(500000000000, graph.channels[('dave', 'bob')].balance(LOCAL))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=2)
self.assertEqual(2, len(log))
self.assertTrue(result)
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
self.assertEqual([graph.channels[('alice', 'carol')].short_channel_id, graph.channels[('carol', 'dave')].short_channel_id],
[edge.short_channel_id for edge in log[0].route])
self.assertEqual([graph.channels[('alice', 'bob')].short_channel_id, graph.channels[('bob', 'dave')].short_channel_id],
@@ -2436,11 +2436,11 @@ class TestPeerForwarding(TestPeer):
amount_to_pay = 100_000_000
peers = graph.peers.values()
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=3)
self.assertTrue(result)
self.assertEqual(2, len(log))
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code)
liquidity_hints = graph.workers['alice'].network.path_finder.liquidity_hints
@@ -2507,14 +2507,14 @@ class TestPeerForwarding(TestPeer):
assert alice_w.network.channel_db is not None
lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=True, amount_msat=amount_to_pay)
self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure)
self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await alice_w.pay_invoice(pay_req, attempts=attempts)
if not bob_forwarding:
# reset to previous state, sleep 2s so that the second htlc can time out
graph.workers['bob'].enable_htlc_forwarding = True
await asyncio.sleep(2)
if result:
self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
# check mpp is cleaned up
async with OldTaskGroup() as g:
for peer in peers:
@@ -2642,7 +2642,7 @@ class TestPeerForwarding(TestPeer):
dest_w = graph.workers[destination_name]
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await sender_w.pay_invoice(pay_req, attempts=attempts)
async with OldTaskGroup() as g:
for peer in peers:
@@ -2653,7 +2653,7 @@ class TestPeerForwarding(TestPeer):
for peer in peers:
self.assertEqual(len(peer.lnworker.active_forwardings), 0)
if result:
self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
raise PaymentDone()
else:
raise NoPathFound()
@@ -2875,7 +2875,7 @@ class TestPeerForwarding(TestPeer):
peers = graph.peers.values()
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req)
self.assertEqual(OnionFailureCode.INVALID_ONION_VERSION, log[0].failure_msg.code)
self.assertFalse(result, msg=log)