1
0

Merge pull request #9265 from SomberNight/202410_ln_address_reuse_2

lnworker: reserve wallet addresses also for chan backups
This commit is contained in:
ThomasV
2024-11-13 10:59:03 +01:00
committed by GitHub
3 changed files with 39 additions and 7 deletions

View File

@@ -488,6 +488,14 @@ class AbstractChannel(Logger, ABC):
def can_be_deleted(self) -> bool:
pass
@abstractmethod
def get_wallet_addresses_channel_might_want_reserved(self) -> Sequence[str]:
"""Returns a list of addrs that the wallet should not use, to avoid address-reuse.
Typically, these addresses are wallet.is_mine, but that is not guaranteed,
in which case the wallet can just ignore those.
"""
pass
class ChannelBackup(AbstractChannel):
"""
@@ -639,6 +647,19 @@ class ChannelBackup(AbstractChannel):
ret.append(ChanCloseOption.REQUEST_REMOTE_FCLOSE)
return ret
def get_wallet_addresses_channel_might_want_reserved(self) -> Sequence[str]:
if self.is_imported:
# For v1 imported cbs, we have the local_payment_pubkey, which is
# directly used as p2wpkh() of static_remotekey channels.
# (for v0 imported cbs, the correct local_payment_pubkey is missing, and so
# we might calculate a different address here, which might not be wallet.is_mine,
# but that should be harmless)
our_payment_pubkey = self.config[LOCAL].payment_basepoint.pubkey
to_remote_address = make_commitment_output_to_remote_address(our_payment_pubkey)
return [to_remote_address]
else: # on-chain backup
return []
class Channel(AbstractChannel):
# note: try to avoid naming ctns/ctxs/etc as "current" and "pending".

View File

@@ -850,14 +850,16 @@ class LNWallet(LNWorker):
self._channels = {} # type: Dict[bytes, Channel]
channels = self.db.get_dict("channels")
for channel_id, c in random_shuffled_copy(channels.items()):
self._channels[bfh(channel_id)] = Channel(c, lnworker=self)
self._channels[bfh(channel_id)] = chan = Channel(c, lnworker=self)
self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
self._channel_backups = {} # type: Dict[bytes, ChannelBackup]
# order is important: imported should overwrite onchain
for name in ["onchain_channel_backups", "imported_channel_backups"]:
channel_backups = self.db.get_dict(name)
for channel_id, storage in channel_backups.items():
self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self)
self._channel_backups[bfh(channel_id)] = cb = ChannelBackup(storage, lnworker=self)
self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
self._paysessions = dict() # type: Dict[bytes, PaySession]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
@@ -1341,8 +1343,7 @@ class LNWallet(LNWorker):
self.add_channel(chan)
channels_db = self.db.get_dict('channels')
channels_db[chan.channel_id.hex()] = chan.storage
for addr in chan.get_wallet_addresses_channel_might_want_reserved():
self.wallet.set_reserved_state_of_address(addr, reserved=True)
self.wallet.set_reserved_addresses_for_chan(chan, reserved=True)
try:
self.save_channel(chan)
except Exception:
@@ -2864,8 +2865,7 @@ class LNWallet(LNWorker):
with self.lock:
self._channels.pop(chan_id)
self.db.get('channels').pop(chan_id.hex())
for addr in chan.get_wallet_addresses_channel_might_want_reserved():
self.wallet.set_reserved_state_of_address(addr, reserved=False)
self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
util.trigger_callback('channels_updated', self.wallet)
util.trigger_callback('wallet_updated', self.wallet)
@@ -2998,6 +2998,7 @@ class LNWallet(LNWorker):
with self.lock:
cb = ChannelBackup(cb_storage, lnworker=self)
self._channel_backups[channel_id] = cb
self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
self.wallet.save_db()
util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
@@ -3025,6 +3026,7 @@ class LNWallet(LNWorker):
raise Exception('Channel not found')
with self.lock:
self._channel_backups.pop(channel_id)
self.wallet.set_reserved_addresses_for_chan(chan, reserved=False)
self.wallet.save_db()
util.trigger_callback('channels_updated', self.wallet)
@@ -3111,6 +3113,7 @@ class LNWallet(LNWorker):
d = self.db.get_dict("onchain_channel_backups")
d[channel_id] = cb_storage
cb = ChannelBackup(cb_storage, lnworker=self)
self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
self.wallet.save_db()
with self.lock:
self._channel_backups[bfh(channel_id)] = cb

View File

@@ -93,6 +93,7 @@ if TYPE_CHECKING:
from .network import Network
from .exchange_rate import FxThread
from .submarine_swaps import SwapData
from .lnchannel import AbstractChannel
_logger = get_logger(__name__)
@@ -2032,13 +2033,20 @@ class Abstract_Wallet(ABC, Logger, EventListener):
def set_reserved_state_of_address(self, addr: str, *, reserved: bool) -> None:
if not self.is_mine(addr):
# silently ignore non-ismine addresses
return
with self.lock:
has_changed = (addr in self._reserved_addresses) != reserved
if reserved:
self._reserved_addresses.add(addr)
else:
self._reserved_addresses.discard(addr)
self.db.put('reserved_addresses', list(self._reserved_addresses))
if has_changed:
self.db.put('reserved_addresses', list(self._reserved_addresses))
def set_reserved_addresses_for_chan(self, chan: 'AbstractChannel', *, reserved: bool) -> None:
for addr in chan.get_wallet_addresses_channel_might_want_reserved():
self.set_reserved_state_of_address(addr, reserved=reserved)
def can_export(self):
return not self.is_watching_only() and hasattr(self.keystore, 'get_private_key')