lnworker.peers: fix threading issues
This commit is contained in:
@@ -7,7 +7,7 @@ import os
|
||||
from decimal import Decimal
|
||||
import random
|
||||
import time
|
||||
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union
|
||||
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping
|
||||
import threading
|
||||
import socket
|
||||
import json
|
||||
@@ -150,8 +150,9 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
max_retry_delay_urgent=300,
|
||||
init_retry_delay_urgent=4,
|
||||
)
|
||||
self.lock = threading.RLock()
|
||||
self.node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
|
||||
self.peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer
|
||||
self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock
|
||||
self.taskgroup = SilentTaskGroup()
|
||||
# set some feature flags as baseline for both LNWallet and LNGossip
|
||||
# note that e.g. DATA_LOSS_PROTECT is needed for LNGossip as many peers require it
|
||||
@@ -161,6 +162,12 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
self.features |= LnFeatures.VAR_ONION_OPT
|
||||
self.features |= LnFeatures.PAYMENT_SECRET_OPT
|
||||
|
||||
@property
|
||||
def peers(self) -> Mapping[bytes, Peer]:
|
||||
"""Returns a read-only copy of peers."""
|
||||
with self.lock:
|
||||
return self._peers.copy()
|
||||
|
||||
def channels_for_peer(self, node_id):
|
||||
return {}
|
||||
|
||||
@@ -180,7 +187,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
self.logger.info('handshake failure from incoming connection')
|
||||
return
|
||||
peer = Peer(self, node_id, transport)
|
||||
self.peers[node_id] = peer
|
||||
self._peers[node_id] = peer
|
||||
await self.taskgroup.spawn(peer.main_loop())
|
||||
try:
|
||||
# FIXME: server.close(), server.wait_closed(), etc... ?
|
||||
@@ -205,7 +212,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
now = time.time()
|
||||
if len(self.peers) >= NUM_PEERS_TARGET:
|
||||
if len(self._peers) >= NUM_PEERS_TARGET:
|
||||
continue
|
||||
peers = await self._get_next_peers_to_try()
|
||||
for peer in peers:
|
||||
@@ -213,8 +220,8 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
await self._add_peer(peer.host, peer.port, peer.pubkey)
|
||||
|
||||
async def _add_peer(self, host: str, port: int, node_id: bytes) -> Peer:
|
||||
if node_id in self.peers:
|
||||
return self.peers[node_id]
|
||||
if node_id in self._peers:
|
||||
return self._peers[node_id]
|
||||
port = int(port)
|
||||
peer_addr = LNPeerAddr(host, port, node_id)
|
||||
transport = LNTransport(self.node_keypair.privkey, peer_addr)
|
||||
@@ -222,12 +229,12 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
self.logger.info(f"adding peer {peer_addr}")
|
||||
peer = Peer(self, node_id, transport)
|
||||
await self.taskgroup.spawn(peer.main_loop())
|
||||
self.peers[node_id] = peer
|
||||
self._peers[node_id] = peer
|
||||
return peer
|
||||
|
||||
def peer_closed(self, peer: Peer) -> None:
|
||||
if peer.pubkey in self.peers:
|
||||
self.peers.pop(peer.pubkey)
|
||||
if peer.pubkey in self._peers:
|
||||
self._peers.pop(peer.pubkey)
|
||||
|
||||
def num_peers(self) -> int:
|
||||
return sum([p.is_initialized() for p in self.peers.values()])
|
||||
@@ -282,7 +289,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
for peer in recent_peers:
|
||||
if not peer:
|
||||
continue
|
||||
if peer.pubkey in self.peers:
|
||||
if peer.pubkey in self._peers:
|
||||
continue
|
||||
if not self._can_retry_addr(peer, now=now):
|
||||
continue
|
||||
@@ -442,7 +449,6 @@ class LNWallet(LNWorker):
|
||||
self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
|
||||
self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
|
||||
self.sweep_address = wallet.get_receiving_address()
|
||||
self.lock = threading.RLock()
|
||||
self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH # (not persisted)
|
||||
self.is_routing = set() # (not persisted) keys of invoices that are in PR_ROUTING state
|
||||
# used in tests
|
||||
@@ -680,12 +686,12 @@ class LNWallet(LNWorker):
|
||||
await self.try_force_closing(chan.channel_id)
|
||||
|
||||
elif chan.get_state() == ChannelState.FUNDED:
|
||||
peer = self.peers.get(chan.node_id)
|
||||
peer = self._peers.get(chan.node_id)
|
||||
if peer and peer.is_initialized():
|
||||
peer.send_funding_locked(chan)
|
||||
|
||||
elif chan.get_state() == ChannelState.OPEN:
|
||||
peer = self.peers.get(chan.node_id)
|
||||
peer = self._peers.get(chan.node_id)
|
||||
if peer:
|
||||
await peer.maybe_update_fee(chan)
|
||||
conf = self.lnwatcher.get_tx_height(chan.funding_outpoint.txid).conf
|
||||
@@ -736,7 +742,7 @@ class LNWallet(LNWorker):
|
||||
@log_exceptions
|
||||
async def add_peer(self, connect_str: str) -> Peer:
|
||||
node_id, rest = extract_nodeid(connect_str)
|
||||
peer = self.peers.get(node_id)
|
||||
peer = self._peers.get(node_id)
|
||||
if not peer:
|
||||
if rest is not None:
|
||||
host, port = split_host_port(rest)
|
||||
@@ -842,7 +848,7 @@ class LNWallet(LNWorker):
|
||||
async def _pay_to_route(self, route: LNPaymentRoute, lnaddr: LnAddr) -> PaymentAttemptLog:
|
||||
short_channel_id = route[0].short_channel_id
|
||||
chan = self.get_channel_by_short_id(short_channel_id)
|
||||
peer = self.peers.get(route[0].node_id)
|
||||
peer = self._peers.get(route[0].node_id)
|
||||
if not peer:
|
||||
raise Exception('Dropped peer')
|
||||
await peer.initialized
|
||||
@@ -1238,7 +1244,7 @@ class LNWallet(LNWorker):
|
||||
|
||||
async def close_channel(self, chan_id):
|
||||
chan = self.channels[chan_id]
|
||||
peer = self.peers[chan.node_id]
|
||||
peer = self._peers[chan.node_id]
|
||||
return await peer.close_channel(chan_id)
|
||||
|
||||
async def force_close_channel(self, chan_id):
|
||||
@@ -1299,7 +1305,7 @@ class LNWallet(LNWorker):
|
||||
# reestablish
|
||||
if not chan.should_try_to_reestablish_peer():
|
||||
continue
|
||||
peer = self.peers.get(chan.node_id, None)
|
||||
peer = self._peers.get(chan.node_id, None)
|
||||
if peer:
|
||||
await peer.taskgroup.spawn(peer.reestablish_channel(chan))
|
||||
else:
|
||||
|
||||
@@ -124,6 +124,10 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
|
||||
@property
|
||||
def peers(self):
|
||||
return self._peers
|
||||
|
||||
@property
|
||||
def _peers(self):
|
||||
return {self.remote_keypair.pubkey: self.peer}
|
||||
|
||||
def channels_for_peer(self, pubkey):
|
||||
|
||||
Reference in New Issue
Block a user