1
0

util: add fn run_sync_function_on_asyncio_thread

note: the return value is not propagated out.
It would be trivial to do that for the block=True case - but what about block=False?
This commit is contained in:
SomberNight
2025-04-08 19:53:49 +00:00
parent 216bfe3b50
commit 4b9d874d13
2 changed files with 40 additions and 19 deletions

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Optional, Dict, Union, Sequence, Tuple, Iterab
from decimal import Decimal
import math
import time
import concurrent.futures
import attr
import aiohttp
@@ -28,7 +27,7 @@ from .transaction import PartialTxInput, PartialTxOutput, PartialTransaction, Tr
from .transaction import script_GetOp, match_script_against_template, OPPushDataGeneric, OPPushDataPubkey
from .util import (log_exceptions, ignore_exceptions, BelowDustLimit, OldTaskGroup, age, ca_path,
gen_nostr_ann_pow, get_nostr_ann_pow_amount, make_aiohttp_proxy_connector,
get_running_loop, get_asyncio_loop, wait_for2)
get_running_loop, get_asyncio_loop, wait_for2, run_sync_function_on_asyncio_thread)
from .lnutil import REDEEM_AFTER_DOUBLE_SPENT_DELAY
from .bitcoin import dust_threshold, DummyAddress
from .logging import Logger
@@ -953,19 +952,12 @@ class SwapManager(Logger):
self.trigger_pairs_updated_threadsafe()
def trigger_pairs_updated_threadsafe(self):
future = concurrent.futures.Future()
def trigger():
self.is_initialized.set()
self.pairs_updated.set()
self.pairs_updated.clear()
future.set_result(None)
asyncio_loop = get_asyncio_loop()
if get_running_loop() == asyncio_loop:
trigger() # this is running on the asyncio event loop
else:
asyncio_loop.call_soon_threadsafe(trigger)
# block until the event loop has run the trigger function
_ = future.result()
run_sync_function_on_asyncio_thread(trigger, block=True)
def server_maybe_trigger_liquidity_update(self) -> None:
"""

View File

@@ -53,6 +53,7 @@ from abc import abstractmethod, ABC
import socket
import enum
from contextlib import nullcontext
import traceback
import attr
import aiohttp
@@ -1709,6 +1710,41 @@ def _set_custom_task_factory(loop: asyncio.AbstractEventLoop):
loop.set_task_factory(factory)
def run_sync_function_on_asyncio_thread(func: Callable, *, block: bool) -> None:
"""Run a non-async fn on the asyncio thread. Can be called from any thread.
If the current thread is already the asyncio thread, func is guaranteed
to have been completed when this method returns.
For any other thread, we only wait for completion if `block` is True.
"""
assert not asyncio.iscoroutinefunction(func), "func must be a non-async function"
asyncio_loop = get_asyncio_loop()
if get_running_loop() == asyncio_loop: # we are running on the asyncio thread
func()
else: # non-asyncio thread
async def wrapper():
return func()
fut = asyncio.run_coroutine_threadsafe(wrapper(), loop=asyncio_loop)
if block:
fut.result()
else:
# add explicit logging of exceptions, otherwise they might get lost
tb1 = traceback.format_stack()[:-1]
tb1_str = "".join(tb1)
def on_done(fut_: concurrent.futures.Future):
assert fut_.done()
if fut_.cancelled():
_logger.debug(f"func cancelled. {func=}.")
elif exc := fut_.exception():
# note: We explicitly log the first part of the traceback, tb1_str.
# The second part gets logged by setting "exc_info".
_logger.error(
f"func errored. {func=}. {exc=}"
f"\n{tb1_str}", exc_info=exc)
fut.add_done_callback(on_done)
class OrderedDictWithIndex(OrderedDict):
"""An OrderedDict that keeps track of the positions of keys.
@@ -1894,14 +1930,7 @@ class CallbackManager(Logger):
self.logger.error(f"cb errored. {event=}. {exc=}", exc_info=exc)
fut.add_done_callback(on_done)
else: # non-async cb
# note: the cb needs to run in the asyncio thread
if get_running_loop() == loop:
# run callback immediately, so that it is guaranteed
# to have been executed when this method returns
callback(*args)
else:
# note: if cb raises, asyncio will log the exception
loop.call_soon_threadsafe(callback, *args)
run_sync_function_on_asyncio_thread(partial(callback, *args), block=False)
callback_mgr = CallbackManager()