wallet: make sure payment requests are persisted
Fixes: after adding a payment request, if the process was killed, the payreq might get lost. In case of using the GUI, neither the callee nor the caller called wallet.save_db(). Unclear where wallet.save_db() should be called... Now each method tries to persist their changes by default, but as an optimisation, the caller can pass write_to_disk=False e.g. when calling multiple such methods and then call wallet.save_db() itself. If we had partial writes, which would either rm the need for wallet.save_db() or at least make it cheaper, this code might get simpler... related: https://github.com/spesmilo/electrum/pull/6435 related: https://github.com/spesmilo/electrum/issues/4823
This commit is contained in:
@@ -882,14 +882,12 @@ class Commands:
|
||||
expiration = int(expiration) if expiration else None
|
||||
req = wallet.make_payment_request(addr, amount, memo, expiration)
|
||||
wallet.add_payment_request(req)
|
||||
wallet.save_db()
|
||||
return wallet.export_request(req)
|
||||
|
||||
@command('wn')
|
||||
async def add_lightning_request(self, amount, memo='', expiration=3600, wallet: Abstract_Wallet = None):
|
||||
amount_sat = int(satoshis(amount))
|
||||
key = await wallet.lnworker._add_request_coro(amount_sat, memo, expiration)
|
||||
wallet.save_db()
|
||||
return wallet.get_formatted_request(key)
|
||||
|
||||
@command('w')
|
||||
@@ -913,9 +911,7 @@ class Commands:
|
||||
@command('w')
|
||||
async def rmrequest(self, address, wallet: Abstract_Wallet = None):
|
||||
"""Remove a payment request"""
|
||||
result = wallet.remove_payment_request(address)
|
||||
wallet.save_db()
|
||||
return result
|
||||
return wallet.remove_payment_request(address)
|
||||
|
||||
@command('w')
|
||||
async def clear_requests(self, wallet: Abstract_Wallet = None):
|
||||
|
||||
@@ -1697,7 +1697,9 @@ class LNWallet(LNWorker):
|
||||
self, *,
|
||||
amount_msat: Optional[int],
|
||||
message: str,
|
||||
expiry: int) -> Tuple[LnAddr, str]:
|
||||
expiry: int,
|
||||
write_to_disk: bool = True,
|
||||
) -> Tuple[LnAddr, str]:
|
||||
|
||||
timestamp = int(time.time())
|
||||
routing_hints = await self._calc_routing_hints_for_invoice(amount_msat)
|
||||
@@ -1731,8 +1733,10 @@ class LNWallet(LNWorker):
|
||||
date=timestamp,
|
||||
payment_secret=derive_payment_secret_from_payment_preimage(payment_preimage))
|
||||
invoice = lnencode(lnaddr, self.node_keypair.privkey)
|
||||
self.save_preimage(payment_hash, payment_preimage)
|
||||
self.save_payment_info(info)
|
||||
self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
|
||||
self.save_payment_info(info, write_to_disk=False)
|
||||
if write_to_disk:
|
||||
self.wallet.save_db()
|
||||
return lnaddr, invoice
|
||||
|
||||
async def _add_request_coro(self, amount_sat: Optional[int], message, expiry: int) -> str:
|
||||
@@ -1740,17 +1744,21 @@ class LNWallet(LNWorker):
|
||||
lnaddr, invoice = await self.create_invoice(
|
||||
amount_msat=amount_msat,
|
||||
message=message,
|
||||
expiry=expiry)
|
||||
expiry=expiry,
|
||||
write_to_disk=False,
|
||||
)
|
||||
key = bh2u(lnaddr.paymenthash)
|
||||
req = LNInvoice.from_bech32(invoice)
|
||||
self.wallet.add_payment_request(req)
|
||||
self.wallet.add_payment_request(req, write_to_disk=False)
|
||||
self.wallet.set_label(key, message)
|
||||
self.wallet.save_db()
|
||||
return key
|
||||
|
||||
def save_preimage(self, payment_hash: bytes, preimage: bytes):
|
||||
def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True):
|
||||
assert sha256(preimage) == payment_hash
|
||||
self.preimages[bh2u(payment_hash)] = bh2u(preimage)
|
||||
self.wallet.save_db()
|
||||
if write_to_disk:
|
||||
self.wallet.save_db()
|
||||
|
||||
def get_preimage(self, payment_hash: bytes) -> Optional[bytes]:
|
||||
r = self.preimages.get(bh2u(payment_hash))
|
||||
@@ -1764,12 +1772,13 @@ class LNWallet(LNWorker):
|
||||
amount_msat, direction, status = self.payments[key]
|
||||
return PaymentInfo(payment_hash, amount_msat, direction, status)
|
||||
|
||||
def save_payment_info(self, info: PaymentInfo) -> None:
|
||||
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
|
||||
key = info.payment_hash.hex()
|
||||
assert info.status in SAVED_PR_STATUS
|
||||
with self.lock:
|
||||
self.payments[key] = info.amount_msat, info.direction, info.status
|
||||
self.wallet.save_db()
|
||||
if write_to_disk:
|
||||
self.wallet.save_db()
|
||||
|
||||
def check_received_mpp_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
|
||||
""" return MPP status: True (accepted), False (expired) or None """
|
||||
|
||||
@@ -2295,11 +2295,13 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
|
||||
key = req.rhash
|
||||
return key
|
||||
|
||||
def add_payment_request(self, req: Invoice):
|
||||
def add_payment_request(self, req: Invoice, *, write_to_disk: bool = True):
|
||||
key = self.get_key_for_receive_request(req, sanity_checks=True)
|
||||
message = req.message
|
||||
self.receive_requests[key] = req
|
||||
self.set_label(key, message) # should be a default label
|
||||
if write_to_disk:
|
||||
self.save_db()
|
||||
return req
|
||||
|
||||
def delete_request(self, key):
|
||||
@@ -2316,11 +2318,13 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
|
||||
elif self.lnworker:
|
||||
self.lnworker.delete_payment(key)
|
||||
|
||||
def remove_payment_request(self, addr):
|
||||
if addr not in self.receive_requests:
|
||||
return False
|
||||
self.receive_requests.pop(addr)
|
||||
return True
|
||||
def remove_payment_request(self, addr) -> bool:
|
||||
found = False
|
||||
if addr in self.receive_requests:
|
||||
found = True
|
||||
self.receive_requests.pop(addr)
|
||||
self.save_db()
|
||||
return found
|
||||
|
||||
def get_sorted_requests(self) -> List[Invoice]:
|
||||
""" sorted by timestamp """
|
||||
|
||||
Reference in New Issue
Block a user