synchronizer: fix rare race where synchronizer could get stuck
This commit is contained in:
@@ -74,8 +74,9 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
|||||||
self.unverified_channel_info[short_channel_id] = msg
|
self.unverified_channel_info[short_channel_id] = msg
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _start_tasks(self):
|
async def _run_tasks(self, *, taskgroup):
|
||||||
async with self.taskgroup as group:
|
await super()._run_tasks(taskgroup=taskgroup)
|
||||||
|
async with taskgroup as group:
|
||||||
await group.spawn(self.main)
|
await group.spawn(self.main)
|
||||||
|
|
||||||
async def main(self):
|
async def main(self):
|
||||||
|
|||||||
@@ -74,9 +74,10 @@ class SynchronizerBase(NetworkJobOnDefaultServer):
|
|||||||
self.add_queue = asyncio.Queue()
|
self.add_queue = asyncio.Queue()
|
||||||
self.status_queue = asyncio.Queue()
|
self.status_queue = asyncio.Queue()
|
||||||
|
|
||||||
async def _start_tasks(self):
|
async def _run_tasks(self, *, taskgroup):
|
||||||
|
await super()._run_tasks(taskgroup=taskgroup)
|
||||||
try:
|
try:
|
||||||
async with self.taskgroup as group:
|
async with taskgroup as group:
|
||||||
await group.spawn(self.send_subscriptions())
|
await group.spawn(self.send_subscriptions())
|
||||||
await group.spawn(self.handle_status())
|
await group.spawn(self.handle_status())
|
||||||
await group.spawn(self.main())
|
await group.spawn(self.main())
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from ipaddress import IPv4Address, IPv6Address
|
|||||||
import random
|
import random
|
||||||
import secrets
|
import secrets
|
||||||
import functools
|
import functools
|
||||||
|
from abc import abstractmethod, ABC
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -1163,7 +1164,7 @@ class SilentTaskGroup(TaskGroup):
|
|||||||
return super().spawn(*args, **kwargs)
|
return super().spawn(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class NetworkJobOnDefaultServer(Logger):
|
class NetworkJobOnDefaultServer(Logger, ABC):
|
||||||
"""An abstract base class for a job that runs on the main network
|
"""An abstract base class for a job that runs on the main network
|
||||||
interface. Every time the main interface changes, the job is
|
interface. Every time the main interface changes, the job is
|
||||||
restarted, and some of its internals are reset.
|
restarted, and some of its internals are reset.
|
||||||
@@ -1179,8 +1180,10 @@ class NetworkJobOnDefaultServer(Logger):
|
|||||||
self._network_request_semaphore = asyncio.Semaphore(100)
|
self._network_request_semaphore = asyncio.Semaphore(100)
|
||||||
|
|
||||||
self._reset()
|
self._reset()
|
||||||
asyncio.run_coroutine_threadsafe(self._restart(), network.asyncio_loop)
|
# every time the main interface changes, restart:
|
||||||
register_callback(self._restart, ['default_server_changed'])
|
register_callback(self._restart, ['default_server_changed'])
|
||||||
|
# also schedule a one-off restart now, as there might already be a main interface:
|
||||||
|
asyncio.run_coroutine_threadsafe(self._restart(), network.asyncio_loop)
|
||||||
|
|
||||||
def _reset(self):
|
def _reset(self):
|
||||||
"""Initialise fields. Called every time the underlying
|
"""Initialise fields. Called every time the underlying
|
||||||
@@ -1190,13 +1193,17 @@ class NetworkJobOnDefaultServer(Logger):
|
|||||||
|
|
||||||
async def _start(self, interface: 'Interface'):
|
async def _start(self, interface: 'Interface'):
|
||||||
self.interface = interface
|
self.interface = interface
|
||||||
await interface.taskgroup.spawn(self._start_tasks)
|
await interface.taskgroup.spawn(self._run_tasks(taskgroup=self.taskgroup))
|
||||||
|
|
||||||
async def _start_tasks(self):
|
@abstractmethod
|
||||||
"""Start tasks in self.taskgroup. Called every time the underlying
|
async def _run_tasks(self, *, taskgroup: TaskGroup) -> None:
|
||||||
|
"""Start tasks in taskgroup. Called every time the underlying
|
||||||
server connection changes.
|
server connection changes.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError() # implemented by subclasses
|
# If self.taskgroup changed, don't start tasks. This can happen if we have
|
||||||
|
# been restarted *just now*, i.e. after the _run_tasks coroutine object was created.
|
||||||
|
if taskgroup != self.taskgroup:
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
unregister_callback(self._restart)
|
unregister_callback(self._restart)
|
||||||
|
|||||||
@@ -58,8 +58,9 @@ class SPV(NetworkJobOnDefaultServer):
|
|||||||
self.merkle_roots = {} # txid -> merkle root (once it has been verified)
|
self.merkle_roots = {} # txid -> merkle root (once it has been verified)
|
||||||
self.requested_merkle = set() # txid set of pending requests
|
self.requested_merkle = set() # txid set of pending requests
|
||||||
|
|
||||||
async def _start_tasks(self):
|
async def _run_tasks(self, *, taskgroup):
|
||||||
async with self.taskgroup as group:
|
await super()._run_tasks(taskgroup=taskgroup)
|
||||||
|
async with taskgroup as group:
|
||||||
await group.spawn(self.main)
|
await group.spawn(self.main)
|
||||||
|
|
||||||
def diagnostic_name(self):
|
def diagnostic_name(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user