1
0

Remove LNBackups object: no longer needed since LNWorker is instantiated by default.

This commit is contained in:
ThomasV
2021-03-09 09:55:55 +01:00
parent ef661050c8
commit 652d10aa5f
5 changed files with 46 additions and 85 deletions

View File

@@ -599,6 +599,11 @@ class LNWallet(LNWorker):
for channel_id, c in random_shuffled_copy(channels.items()):
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self._channel_backups = {} # type: Dict[bytes, Channel]
channel_backups = self.db.get_dict("channel_backups")
for channel_id, cb in random_shuffled_copy(channel_backups.items()):
self._channel_backups[bfh(channel_id)] = ChannelBackup(cb, sweep_address=self.sweep_address, lnworker=self)
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.sent_htlcs_routes = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat
self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed)
@@ -618,6 +623,12 @@ class LNWallet(LNWorker):
with self.lock:
return self._channels.copy()
@property
def channel_backups(self) -> Mapping[bytes, Channel]:
"""Returns a read-only copy of channels."""
with self.lock:
return self._channel_backups.copy()
def get_channel_by_id(self, channel_id: bytes) -> Optional[Channel]:
return self._channels.get(channel_id, None)
@@ -680,6 +691,8 @@ class LNWallet(LNWorker):
for chan in self.channels.values():
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
for cb in self.channel_backups.values():
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
for coro in [
self.maybe_listen(),
@@ -843,7 +856,8 @@ class LNWallet(LNWorker):
if chan.node_id == node_id}
def channel_state_changed(self, chan: Channel):
self.save_channel(chan)
if type(chan) is Channel:
self.save_channel(chan)
util.trigger_callback('channel', self.wallet, chan)
def save_channel(self, chan: Channel):
@@ -857,8 +871,14 @@ class LNWallet(LNWorker):
for chan in self.channels.values():
if chan.funding_outpoint.to_str() == txo:
return chan
for chan in self.channel_backups.values():
if chan.funding_outpoint.to_str() == txo:
return chan
async def on_channel_update(self, chan: Channel):
if type(chan) is ChannelBackup:
util.trigger_callback('channel', self.wallet, chan)
return
if chan.get_state() == ChannelState.OPEN and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height()):
self.logger.info(f"force-closing due to expiring htlcs")
@@ -1940,61 +1960,6 @@ class LNWallet(LNWorker):
peer = await self.add_peer(connect_str)
await peer.trigger_force_close(channel_id)
class LNBackups(Logger):
lnwatcher: Optional['LNWalletWatcher']
def __init__(self, wallet: 'Abstract_Wallet'):
Logger.__init__(self)
self.features = LNWALLET_FEATURES
self.lock = threading.RLock()
self.wallet = wallet
self.db = wallet.db
self.lnwatcher = None
self.channel_backups = {}
for channel_id, cb in random_shuffled_copy(self.db.get_dict("channel_backups").items()):
self.channel_backups[bfh(channel_id)] = ChannelBackup(cb, sweep_address=self.sweep_address, lnworker=self)
@property
def sweep_address(self) -> str:
# TODO possible address-reuse
return self.wallet.get_new_sweep_address_for_channel()
def channel_state_changed(self, chan):
util.trigger_callback('channel', self.wallet, chan)
def peer_closed(self, chan):
pass
async def on_channel_update(self, chan):
util.trigger_callback('channel', self.wallet, chan)
def channel_by_txo(self, txo):
with self.lock:
channel_backups = list(self.channel_backups.values())
for chan in channel_backups:
if chan.funding_outpoint.to_str() == txo:
return chan
def on_peer_successfully_established(self, peer: Peer) -> None:
pass
def channels_for_peer(self, node_id):
return {}
def start_network(self, network: 'Network'):
assert network
self.lnwatcher = LNWalletWatcher(self, network)
self.lnwatcher.start_network(network)
self.network = network
for cb in self.channel_backups.values():
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
def stop(self):
self.lnwatcher.stop()
self.lnwatcher = None
def import_channel_backup(self, data):
assert data.startswith('channel_backup:')
encrypted = data[15:]
@@ -2015,19 +1980,20 @@ class LNBackups(Logger):
d = self.db.get_dict("channel_backups")
if channel_id.hex() not in d:
raise Exception('Channel not found')
d.pop(channel_id.hex())
self.channel_backups.pop(channel_id)
with self.lock:
d.pop(channel_id.hex())
self._channel_backups.pop(channel_id)
self.wallet.save_db()
util.trigger_callback('channels_updated', self.wallet)
@log_exceptions
async def request_force_close(self, channel_id: bytes):
async def request_force_close_from_backup(self, channel_id: bytes):
cb = self.channel_backups[channel_id].cb
# TODO also try network addresses from gossip db (as it might have changed)
peer_addr = LNPeerAddr(cb.host, cb.port, cb.node_id)
transport = LNTransport(cb.privkey, peer_addr,
proxy=self.network.proxy)
peer = Peer(self, cb.node_id, transport)
peer = Peer(self, cb.node_id, transport, is_channel_backup=True)
async with TaskGroup() as group:
await group.spawn(peer._message_loop())
await group.spawn(peer.trigger_force_close(channel_id))