1
0

Improve _mythread checks (#7403)

* Improve _mythread checks

* Create get_running_loop util
This commit is contained in:
MrNaif2018
2021-07-15 17:52:25 +03:00
committed by GitHub
parent c5129ee447
commit aafa74ed08
3 changed files with 12 additions and 14 deletions

View File

@@ -14,13 +14,8 @@ from typing import TYPE_CHECKING, Dict, NamedTuple, Tuple, List, Optional
import sys
import time
if sys.version_info[:2] >= (3, 7):
from asyncio import get_running_loop
else:
from asyncio import _get_running_loop as get_running_loop # noqa: F401
from .logging import Logger
from .util import profiler
from .util import profiler, get_running_loop
from .lnrouter import fee_for_edge_msat
from .lnutil import LnFeatures, ln_compare_features, IncompatibleLightningFeatures

View File

@@ -275,11 +275,6 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
self.asyncio_loop = asyncio.get_event_loop()
assert self.asyncio_loop.is_running(), "event loop not running"
try:
self._loop_thread = self.asyncio_loop._mythread # type: threading.Thread # only used for sanity checks
except AttributeError as e:
self.logger.warning(f"asyncio loop does not have _mythread set: {e!r}")
self._loop_thread = None
assert isinstance(config, SimpleConfig), f"config should be a SimpleConfig instead of {type(config)}"
self.config = config
@@ -387,7 +382,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
self.path_finder = None
def run_from_another_thread(self, coro, *, timeout=None):
assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
assert util.get_running_loop() != self.asyncio_loop, 'must not be called from network thread'
fut = asyncio.run_coroutine_threadsafe(coro, self.asyncio_loop)
return fut.result(timeout)
@@ -1318,7 +1313,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
def send_http_on_proxy(cls, method, url, **kwargs):
network = cls.get_instance()
if network:
assert network._loop_thread is not threading.currentThread()
assert util.get_running_loop() != network.asyncio_loop
loop = network.asyncio_loop
else:
loop = asyncio.get_event_loop()

View File

@@ -1279,7 +1279,6 @@ def create_and_start_event_loop() -> Tuple[asyncio.AbstractEventLoop,
args=(stopping_fut,),
name='EventLoop')
loop_thread.start()
loop._mythread = loop_thread
return loop, stopping_fut, loop_thread
@@ -1626,3 +1625,12 @@ class nullcontext:
async def __aexit__(self, *excinfo):
pass
def get_running_loop():
"""Mimics _get_running_loop convenient functionality for sanity checks on all python versions"""
if sys.version_info < (3, 7):
return asyncio._get_running_loop()
try:
return asyncio.get_running_loop()
except RuntimeError:
return None