1
0

refactor lnworker.pay_invoice to accept Invoice object instead of bolt11 string

rename lnworker._check_invoice to lnworker._check_bolt11_invoice
This commit is contained in:
Sander van Grieken
2025-02-20 16:05:45 +01:00
parent 4d4453821a
commit 6fdb6c93f7
8 changed files with 34 additions and 30 deletions

View File

@@ -1264,10 +1264,11 @@ class Commands(Logger):
@command('wnpl') @command('wnpl')
async def lnpay(self, invoice, timeout=120, password=None, wallet: Abstract_Wallet = None): async def lnpay(self, invoice, timeout=120, password=None, wallet: Abstract_Wallet = None):
lnworker = wallet.lnworker lnworker = wallet.lnworker
lnaddr = lnworker._check_invoice(invoice) lnaddr = lnworker._check_bolt11_invoice(invoice)
payment_hash = lnaddr.paymenthash payment_hash = lnaddr.paymenthash
wallet.save_invoice(Invoice.from_bech32(invoice)) invoice_obj = Invoice.from_bech32(invoice)
success, log = await lnworker.pay_invoice(invoice) wallet.save_invoice(invoice_obj)
success, log = await lnworker.pay_invoice(invoice_obj)
return { return {
'payment_hash': payment_hash.hex(), 'payment_hash': payment_hash.hex(),
'success': success, 'success': success,

View File

@@ -7,14 +7,15 @@ from PyQt6.QtCore import pyqtProperty, pyqtSignal, pyqtSlot, QObject, pyqtEnum,
from electrum.i18n import _ from electrum.i18n import _
from electrum.logging import get_logger from electrum.logging import get_logger
from electrum.invoices import (Invoice, PR_UNPAID, PR_EXPIRED, PR_UNKNOWN, PR_PAID, PR_INFLIGHT, from electrum.invoices import (
PR_FAILED, PR_ROUTING, PR_UNCONFIRMED, PR_BROADCASTING, PR_BROADCAST, LN_EXPIRY_NEVER) Invoice, PR_UNPAID, PR_EXPIRED, PR_UNKNOWN, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_ROUTING, PR_UNCONFIRMED,
PR_BROADCASTING, PR_BROADCAST, LN_EXPIRY_NEVER
)
from electrum.transaction import PartialTxOutput, TxOutput from electrum.transaction import PartialTxOutput, TxOutput
from electrum.util import NotEnoughFunds, NoDynamicFeeEstimates
from electrum.lnutil import format_short_channel_id from electrum.lnutil import format_short_channel_id
from electrum.bitcoin import COIN, address_to_script from electrum.bitcoin import COIN, address_to_script
from electrum.paymentrequest import PaymentRequest from electrum.paymentrequest import PaymentRequest
from electrum.payment_identifier import (PaymentIdentifier, PaymentIdentifierState, PaymentIdentifierType) from electrum.payment_identifier import PaymentIdentifier, PaymentIdentifierState, PaymentIdentifierType
from .qetypes import QEAmount from .qetypes import QEAmount
from .qewallet import QEWallet from .qewallet import QEWallet
@@ -56,7 +57,7 @@ class QEInvoice(QObject, QtEventListener):
self._canPay = False self._canPay = False
self._key = None self._key = None
self._invoiceType = QEInvoice.Type.Invalid self._invoiceType = QEInvoice.Type.Invalid
self._effectiveInvoice = None self._effectiveInvoice = None # type: Optional[Invoice]
self._userinfo = '' self._userinfo = ''
self._lnprops = {} self._lnprops = {}
self._amount = QEAmount() self._amount = QEAmount()

View File

@@ -29,7 +29,7 @@ from .util import QtEventListener, qt_event_listener
if TYPE_CHECKING: if TYPE_CHECKING:
from electrum.wallet import Abstract_Wallet from electrum.wallet import Abstract_Wallet
from .qeinvoice import QEInvoice from electrum.invoices import Invoice
class QEWallet(AuthMixin, QObject, QtEventListener): class QEWallet(AuthMixin, QObject, QtEventListener):
@@ -640,12 +640,12 @@ class QEWallet(AuthMixin, QObject, QtEventListener):
self.paymentAuthRejected.emit() self.paymentAuthRejected.emit()
@auth_protect(message=_('Pay lightning invoice?'), reject='ln_auth_rejected') @auth_protect(message=_('Pay lightning invoice?'), reject='ln_auth_rejected')
def pay_lightning_invoice(self, invoice: 'QEInvoice'): def pay_lightning_invoice(self, invoice: 'Invoice'):
amount_msat = invoice.get_amount_msat() amount_msat = invoice.get_amount_msat()
def pay_thread(): def pay_thread():
try: try:
coro = self.wallet.lnworker.pay_invoice(invoice.lightning_invoice, amount_msat=amount_msat) coro = self.wallet.lnworker.pay_invoice(invoice, amount_msat=amount_msat)
fut = asyncio.run_coroutine_threadsafe(coro, get_asyncio_loop()) fut = asyncio.run_coroutine_threadsafe(coro, get_asyncio_loop())
fut.result() fut.result()
except Exception as e: except Exception as e:

View File

@@ -719,7 +719,7 @@ class SendTab(QWidget, MessageBoxMixin, Logger):
if not self.question(msg): if not self.question(msg):
return return
self.save_pending_invoice() self.save_pending_invoice()
coro = lnworker.pay_invoice(invoice.lightning_invoice, amount_msat=amount_msat) coro = lnworker.pay_invoice(invoice, amount_msat=amount_msat)
self.window.run_coroutine_from_thread(coro, _('Sending payment')) self.window.run_coroutine_from_thread(coro, _('Sending payment'))
def broadcast_transaction(self, tx: Transaction, *, payment_identifier: PaymentIdentifier = None): def broadcast_transaction(self, tx: Transaction, *, payment_identifier: PaymentIdentifier = None):

View File

@@ -670,7 +670,7 @@ class ElectrumGui(BaseElectrumGui, EventListener):
if not self.question(msg): if not self.question(msg):
return return
self.save_pending_invoice(invoice) self.save_pending_invoice(invoice)
coro = self.wallet.lnworker.pay_invoice(invoice.lightning_invoice, amount_msat=amount_msat) coro = self.wallet.lnworker.pay_invoice(invoice, amount_msat=amount_msat)
#self.window.run_coroutine_from_thread(coro, _('Sending payment')) #self.window.run_coroutine_from_thread(coro, _('Sending payment'))
self.show_message(_("Please wait..."), getchar=False) self.show_message(_("Please wait..."), getchar=False)

View File

@@ -1491,14 +1491,14 @@ class LNWallet(LNWorker):
@log_exceptions @log_exceptions
async def pay_invoice( async def pay_invoice(
self, invoice: str, *, self, invoice: Invoice, *,
amount_msat: int = None, amount_msat: int = None,
attempts: int = None, # used only in unit tests attempts: int = None, # used only in unit tests
full_path: LNPaymentPath = None, full_path: LNPaymentPath = None,
channels: Optional[Sequence[Channel]] = None, channels: Optional[Sequence[Channel]] = None,
) -> Tuple[bool, List[HtlcLog]]: ) -> Tuple[bool, List[HtlcLog]]:
bolt11 = invoice.lightning_invoice
lnaddr = self._check_invoice(invoice, amount_msat=amount_msat) lnaddr = self._check_bolt11_invoice(bolt11, amount_msat=amount_msat)
min_final_cltv_delta = lnaddr.get_min_final_cltv_delta() min_final_cltv_delta = lnaddr.get_min_final_cltv_delta()
payment_hash = lnaddr.paymenthash payment_hash = lnaddr.paymenthash
key = payment_hash.hex() key = payment_hash.hex()
@@ -1850,11 +1850,11 @@ class LNWallet(LNWorker):
except Exception: except Exception:
return None return None
def _check_invoice(self, invoice: str, *, amount_msat: int = None) -> LnAddr: def _check_bolt11_invoice(self, bolt11_invoice: str, *, amount_msat: int = None) -> LnAddr:
"""Parses and validates a bolt11 invoice str into a LnAddr. """Parses and validates a bolt11 invoice str into a LnAddr.
Includes pre-payment checks external to the parser. Includes pre-payment checks external to the parser.
""" """
addr = lndecode(invoice) addr = lndecode(bolt11_invoice)
if addr.is_expired(): if addr.is_expired():
raise InvoiceError(_("This invoice has expired")) raise InvoiceError(_("This invoice has expired"))
# check amount # check amount
@@ -2872,8 +2872,8 @@ class LNWallet(LNWorker):
fallback_address=None, fallback_address=None,
channels=[chan2], channels=[chan2],
) )
return await self.pay_invoice( invoice_obj = Invoice.from_bech32(invoice)
invoice, channels=[chan1]) return await self.pay_invoice(invoice_obj, channels=[chan1])
def can_receive_invoice(self, invoice: BaseInvoice) -> bool: def can_receive_invoice(self, invoice: BaseInvoice) -> bool:
assert invoice.is_lightning() assert invoice.is_lightning()

View File

@@ -298,7 +298,7 @@ class SwapManager(Logger):
self.invoices_to_pay[key] = 1000000000000 # lock self.invoices_to_pay[key] = 1000000000000 # lock
try: try:
invoice = self.wallet.get_invoice(key) invoice = self.wallet.get_invoice(key)
success, log = await self.lnworker.pay_invoice(invoice.lightning_invoice, attempts=10) success, log = await self.lnworker.pay_invoice(invoice, attempts=10)
except Exception as e: except Exception as e:
self.logger.info(f'exception paying {key}, will not retry') self.logger.info(f'exception paying {key}, will not retry')
self.invoices_to_pay.pop(key, None) self.invoices_to_pay.pop(key, None)
@@ -908,13 +908,13 @@ class SwapManager(Logger):
if locktime - self.network.get_local_height() <= MIN_LOCKTIME_DELTA: if locktime - self.network.get_local_height() <= MIN_LOCKTIME_DELTA:
raise Exception("rswap check failed: locktime too close") raise Exception("rswap check failed: locktime too close")
# verify invoice payment_hash # verify invoice payment_hash
lnaddr = self.lnworker._check_invoice(invoice) lnaddr = self.lnworker._check_bolt11_invoice(invoice)
invoice_amount = int(lnaddr.get_amount_sat()) invoice_amount = int(lnaddr.get_amount_sat())
if lnaddr.paymenthash != payment_hash: if lnaddr.paymenthash != payment_hash:
raise Exception("rswap check failed: inconsistent RHASH and invoice") raise Exception("rswap check failed: inconsistent RHASH and invoice")
# check that the lightning amount is what we requested # check that the lightning amount is what we requested
if fee_invoice: if fee_invoice:
fee_lnaddr = self.lnworker._check_invoice(fee_invoice) fee_lnaddr = self.lnworker._check_bolt11_invoice(fee_invoice)
invoice_amount += fee_lnaddr.get_amount_sat() invoice_amount += fee_lnaddr.get_amount_sat()
prepay_hash = fee_lnaddr.paymenthash prepay_hash = fee_lnaddr.paymenthash
else: else:
@@ -935,13 +935,15 @@ class SwapManager(Logger):
swap._zeroconf = zeroconf swap._zeroconf = zeroconf
# initiate fee payment. # initiate fee payment.
if fee_invoice: if fee_invoice:
asyncio.ensure_future(self.lnworker.pay_invoice(fee_invoice)) fee_invoice_obj = Invoice.from_bech32(fee_invoice)
asyncio.ensure_future(self.lnworker.pay_invoice(fee_invoice_obj))
# we return if we detect funding # we return if we detect funding
async def wait_for_funding(swap): async def wait_for_funding(swap):
while swap.funding_txid is None: while swap.funding_txid is None:
await asyncio.sleep(1) await asyncio.sleep(1)
# initiate main payment # initiate main payment
tasks = [asyncio.create_task(self.lnworker.pay_invoice(invoice, channels=channels)), asyncio.create_task(wait_for_funding(swap))] invoice_obj = Invoice.from_bech32(invoice)
tasks = [asyncio.create_task(self.lnworker.pay_invoice(invoice_obj, channels=channels)), asyncio.create_task(wait_for_funding(swap))]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
return swap.funding_txid return swap.funding_txid

View File

@@ -41,7 +41,7 @@ from electrum.lnworker import PaymentInfo, RECEIVED
from electrum.lnonion import OnionFailureCode, OnionRoutingFailure from electrum.lnonion import OnionFailureCode, OnionRoutingFailure
from electrum.lnutil import UpdateAddHtlc from electrum.lnutil import UpdateAddHtlc
from electrum.lnutil import LOCAL, REMOTE from electrum.lnutil import LOCAL, REMOTE
from electrum.invoices import PR_PAID, PR_UNPAID from electrum.invoices import PR_PAID, PR_UNPAID, Invoice
from electrum.interface import GracefulDisconnect from electrum.interface import GracefulDisconnect
from electrum.simple_config import SimpleConfig from electrum.simple_config import SimpleConfig
from electrum.fee_policy import FeeTimeEstimates from electrum.fee_policy import FeeTimeEstimates
@@ -291,7 +291,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
get_preimage = LNWallet.get_preimage get_preimage = LNWallet.get_preimage
create_route_for_single_htlc = LNWallet.create_route_for_single_htlc create_route_for_single_htlc = LNWallet.create_route_for_single_htlc
create_routes_for_payment = LNWallet.create_routes_for_payment create_routes_for_payment = LNWallet.create_routes_for_payment
_check_invoice = LNWallet._check_invoice _check_bolt11_invoice = LNWallet._check_bolt11_invoice
pay_to_route = LNWallet.pay_to_route pay_to_route = LNWallet.pay_to_route
pay_to_node = LNWallet.pay_to_node pay_to_node = LNWallet.pay_to_node
pay_invoice = LNWallet.pay_invoice pay_invoice = LNWallet.pay_invoice
@@ -536,7 +536,7 @@ class TestPeer(ElectrumTestCase):
payment_hash: bytes = None, payment_hash: bytes = None,
invoice_features: LnFeatures = None, invoice_features: LnFeatures = None,
min_final_cltv_delta: int = None, min_final_cltv_delta: int = None,
) -> Tuple[LnAddr, str]: ) -> Tuple[LnAddr, Invoice]:
amount_btc = amount_msat/Decimal(COIN*1000) amount_btc = amount_msat/Decimal(COIN*1000)
if payment_preimage is None and not payment_hash: if payment_preimage is None and not payment_hash:
payment_preimage = os.urandom(32) payment_preimage = os.urandom(32)
@@ -571,7 +571,7 @@ class TestPeer(ElectrumTestCase):
) )
invoice = lnencode(lnaddr1, w2.node_keypair.privkey) invoice = lnencode(lnaddr1, w2.node_keypair.privkey)
lnaddr2 = lndecode(invoice) # unlike lnaddr1, this now has a pubkey set lnaddr2 = lndecode(invoice) # unlike lnaddr1, this now has a pubkey set
return lnaddr2, invoice return lnaddr2, Invoice.from_bech32(invoice)
async def _activate_trampoline(self, w: MockLNWallet): async def _activate_trampoline(self, w: MockLNWallet):
if w.network.channel_db: if w.network.channel_db:
@@ -1349,7 +1349,7 @@ class TestPeerDirect(TestPeer):
p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel)
lnaddr, pay_req = self.prepare_invoice(w2) lnaddr, pay_req = self.prepare_invoice(w2)
lnaddr = w1._check_invoice(pay_req) lnaddr = w1._check_bolt11_invoice(pay_req.lightning_invoice)
shi = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0] shi = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0]
route, amount_msat = shi.route, shi.amount_msat route, amount_msat = shi.route, shi.amount_msat
assert amount_msat == lnaddr.get_amount_msat() assert amount_msat == lnaddr.get_amount_msat()