diff --git a/electrum/interface.py b/electrum/interface.py index c1f34b4d8..42c38b223 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -485,8 +485,8 @@ class Interface(Logger): self.logger.warning(f"disconnecting due to {repr(e)}") self.logger.debug(f"(disconnect) trace for {repr(e)}", exc_info=True) finally: - await self.network.connection_down(self) self.got_disconnected.set() + await self.network.connection_down(self) # if was not 'ready' yet, schedule waiting coroutines: self.ready.cancel() return wrapper_func diff --git a/electrum/network.py b/electrum/network.py index 08d467bc7..247ecc538 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -246,7 +246,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): taskgroup: Optional[TaskGroup] interface: Optional[Interface] interfaces: Dict[ServerAddr, Interface] - _connecting: Set[ServerAddr] + _connecting_ifaces: Set[ServerAddr] + _closing_ifaces: Set[ServerAddr] default_server: ServerAddr _recent_servers: List[ServerAddr] @@ -321,10 +322,16 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): # the main server we are currently communicating with self.interface = None self.default_server_changed_event = asyncio.Event() - # set of servers we have an ongoing connection with - self.interfaces = {} + # Set of servers we have an ongoing connection with. + # For any ServerAddr, at most one corresponding Interface object + # can exist at any given time. Depending on the state of that Interface, + # the ServerAddr can be found in one of the following sets. + # Note: during a transition, the ServerAddr can appear in two sets momentarily. + self._connecting_ifaces = set() + self.interfaces = {} # these are the ifaces in "initialised and usable" state + self._closing_ifaces = set() + self.auto_connect = self.config.get('auto_connect', True) - self._connecting = set() self.proxy = None self._maybe_set_oneserver() @@ -551,7 +558,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): def _get_next_server_to_try(self) -> Optional[ServerAddr]: now = time.time() with self.interfaces_lock: - connected_servers = set(self.interfaces) | self._connecting + connected_servers = set(self.interfaces) | self._connecting_ifaces | self._closing_ifaces # First try from recent servers. (which are persisted) # As these are servers we successfully connected to recently, they are # most likely to work. This also makes servers "sticky". @@ -680,13 +687,9 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): # Stop any current interface in order to terminate subscriptions, # and to cancel tasks in interface.taskgroup. - # However, for headers sub, give preference to this interface - # over unknown ones, i.e. start it again right away. if old_server and old_server != server: # don't wait for old_interface to close as that might be slow: await self.taskgroup.spawn(self._close_interface(old_interface)) - if len(self.interfaces) <= self.num_server: - await self.taskgroup.spawn(self._run_new_interface(old_server)) if server not in self.interfaces: self.interface = None @@ -708,15 +711,23 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): if blockchain_updated: util.trigger_callback('blockchain_updated') - async def _close_interface(self, interface: Interface): - if interface: - with self.interfaces_lock: - if self.interfaces.get(interface.server) == interface: - self.interfaces.pop(interface.server) - if interface == self.interface: - self.interface = None + async def _close_interface(self, interface: Optional[Interface]): + if not interface: + return + if interface.server in self._closing_ifaces: + return + self._closing_ifaces.add(interface.server) + with self.interfaces_lock: + if self.interfaces.get(interface.server) == interface: + self.interfaces.pop(interface.server) + if interface == self.interface: + self.interface = None + try: # this can take some time if server/connection is slow: await interface.close() + await interface.got_disconnected.wait() + finally: + self._closing_ifaces.discard(interface.server) @with_recent_servers_lock def _add_recent_server(self, server: ServerAddr) -> None: @@ -732,8 +743,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): '''A connection to server either went down, or was never made. We distinguish by whether it is in self.interfaces.''' if not interface: return - # note: don't rely on interface.server for comparisons here - if interface == self.interface: + if interface.server == self.default_server: self._set_status('disconnected') await self._close_interface(interface) util.trigger_callback('network_updated') @@ -748,9 +758,11 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): @ignore_exceptions # do not kill outer taskgroup @log_exceptions async def _run_new_interface(self, server: ServerAddr): - if server in self.interfaces or server in self._connecting: + if (server in self.interfaces + or server in self._connecting_ifaces + or server in self._closing_ifaces): return - self._connecting.add(server) + self._connecting_ifaces.add(server) if server == self.default_server: self.logger.info(f"connecting to {server} as new interface") self._set_status('connecting') @@ -770,8 +782,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): assert server not in self.interfaces self.interfaces[server] = interface finally: - try: self._connecting.remove(server) - except KeyError: pass + self._connecting_ifaces.discard(server) if server == self.default_server: await self.switch_to_interface(server) @@ -1129,7 +1140,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): assert not self.taskgroup self.taskgroup = taskgroup = SilentTaskGroup() assert not self.interface and not self.interfaces - assert not self._connecting + assert not self._connecting_ifaces + assert not self._closing_ifaces self.logger.info('starting network') self._clear_addr_retry_times() self._set_proxy(deserialize_proxy(self.config.get('proxy'))) @@ -1173,7 +1185,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): self.taskgroup = None self.interface = None self.interfaces = {} - self._connecting.clear() + self._connecting_ifaces.clear() + self._closing_ifaces.clear() if not full_shutdown: util.trigger_callback('network_updated') @@ -1197,7 +1210,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): async def _maintain_sessions(self): async def maybe_start_new_interfaces(): - for i in range(self.num_server - len(self.interfaces) - len(self._connecting)): + num_existing_ifaces = len(self.interfaces) + len(self._connecting_ifaces) + len(self._closing_ifaces) + for i in range(self.num_server - num_existing_ifaces): # FIXME this should try to honour "healthy spread of connected servers" server = self._get_next_server_to_try() if server: